使用機器學習超解析度在移動裝置上提高影像解析度

瞭解如何使用 ONNX Runtime Mobile 構建一個應用程式,透過包含預處理和後處理的模型來提高影像解析度。

您可以使用本教程為 Android 或 iOS 構建應用程式。

該應用程式接收影像輸入,在點選按鈕時執行超解析度操作,並在下方顯示解析度提高後的影像,如下圖所示。

Super resolution on a cat

目錄

準備模型

本教程中使用的機器學習模型基於本頁底部引用的 PyTorch 教程中使用的模型。

我們提供了一個方便的 Python 指令碼,用於將 PyTorch 模型匯出為 ONNX 格式並新增預處理和後處理。

  1. 在執行此指令碼之前,請安裝以下 Python 包

     pip install torch
     pip install pillow
     pip install onnx
     pip install onnxruntime
     pip install --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/ onnxruntime-extensions
    

    版本說明:最佳的超解析度結果是在 ONNX opset 18 下實現的(它支援帶抗鋸齒的 Resize 運算子),該版本受 onnx 1.13.0 及更高版本和 onnxruntime 1.14.0 及更高版本支援。onnxruntime-extensions 包是預釋出版本。釋出版本將很快可用。

  2. 然後從 onnxruntime-extensions GitHub 儲存庫下載指令碼和測試影像(如果您尚未克隆此儲存庫)

     curl https://raw.githubusercontent.com/microsoft/onnxruntime-extensions/main/tutorials/superresolution_e2e.py > superresolution_e2e.py
     curl https://raw.githubusercontent.com/microsoft/onnxruntime-extensions/main/tutorials/data/super_res_input.png > data/super_res_input.png
    
  3. 執行指令碼以匯出核心模型併為其新增預處理和後處理

     python superresolution_e2e.py 
    

指令碼執行後,您應該在執行指令碼的資料夾中看到兩個 ONNX 檔案

pytorch_superresolution.onnx
pytorch_superresolution_with_pre_and_post_processing.onnx

如果您將這兩個模型載入到 netron 中,您可以看到兩者之間輸入和輸出的區別。下面前兩張圖顯示了原始模型,其輸入是批次通道資料,後兩張圖顯示了輸入和輸出是影像位元組。

ONNX model without pre and post processing

ONNX model inputs and outputs without pre and post processing

ONNX model with pre and post processing

ONNX model inputs and outputs with pre and post processing

現在是編寫應用程式程式碼的時候了。

Android 應用

先決條件

  • Android Studio Dolphin 2021.3.1 Patch + (安裝在 Mac/Windows/Linux 上)
  • Android SDK 29+
  • Android NDK r22+
  • Android 裝置或 Android 模擬器

示例程式碼

您可以在 GitHub 上找到 Android 超解析度應用的完整原始碼

要從原始碼執行應用程式,請克隆上述倉庫並將 build.gradle 檔案載入到 Android Studio 中,然後構建並執行!

要一步步構建應用程式,請按照以下章節操作。

從零開始編寫程式碼

設定專案

在 Android Studio 中為手機和平板電腦建立一個新專案,並選擇空白模板。將應用程式命名為 super_resolution 或類似名稱。

依賴項

將以下依賴項新增到應用的 build.gradle

implementation 'com.microsoft.onnxruntime:onnxruntime-android:latest.release'
implementation 'com.microsoft.onnxruntime:onnxruntime-extensions-android:latest.release'

專案資源

  1. 將模型檔案新增為原始資源

    src/main/res 資料夾中建立一個名為 raw 的資料夾,並將 ONNX 模型移動或複製到該 raw 資料夾中。

  2. 將測試影像新增為資產

    在主專案資料夾中建立一個名為 assets 的資料夾,並將您想要執行超解析度的影像複製到該資料夾中,檔名為 test_superresolution.png

主應用程式類程式碼

建立一個名為 MainActivity.kt 的檔案,並向其中新增以下程式碼片段。

  1. 新增匯入語句

    import ai.onnxruntime.*
    import ai.onnxruntime.extensions.OrtxPackage
    import android.annotation.SuppressLint
    import android.os.Bundle
    import android.widget.Button
    import android.widget.ImageView
    import android.widget.Toast
    import androidx.activity.*
    import androidx.appcompat.app.AppCompatActivity
    import kotlinx.android.synthetic.main.activity_main.*
    import kotlinx.coroutines.*
    import java.io.InputStream
    import java.util.*
    import java.util.concurrent.ExecutorService
    import java.util.concurrent.Executors
    
  2. 建立主活動類並新增類變數

    class MainActivity : AppCompatActivity() {
        private var ortEnv: OrtEnvironment = OrtEnvironment.getEnvironment()
        private lateinit var ortSession: OrtSession
        private var inputImage: ImageView? = null
        private var outputImage: ImageView? = null
        private var superResolutionButton: Button? = null
    
        ...
    }
    
  3. 新增 onCreate() 方法

    在這裡我們初始化 ONNX Runtime 會話。一個會話儲存對應用程式中用於執行推理的模型引用。它還接受一個會話選項引數,您可以在其中指定不同的執行提供程式(如 NNAPI 等硬體加速器)。在這種情況下,我們預設在 CPU 上執行。但是,我們確實註冊了自定義操作庫,其中包含模型輸入和輸出處的影像編碼和解碼運算子。

     override fun onCreate(savedInstanceState: Bundle?) {
         super.onCreate(savedInstanceState)
         setContentView(R.layout.activity_main)
    
         inputImage = findViewById(R.id.imageView1)
         outputImage = findViewById(R.id.imageView2);
         superResolutionButton = findViewById(R.id.super_resolution_button)
         inputImage?.setImageBitmap(
             BitmapFactory.decodeStream(readInputImage())
         );
    
         // Initialize Ort Session and register the onnxruntime extensions package that contains the custom operators.
         // Note: These are used to decode the input image into the format the original model requires,
         // and to encode the model output into png format
         val sessionOptions: OrtSession.SessionOptions = OrtSession.SessionOptions()
         sessionOptions.registerCustomOpLibrary(OrtxPackage.getLibraryPath())
         ortSession = ortEnv.createSession(readModel(), sessionOptions)
    
         superResolutionButton?.setOnClickListener {
             try {
                 performSuperResolution(ortSession)
                 Toast.makeText(baseContext, "Super resolution performed!", Toast.LENGTH_SHORT)
                     .show()
             } catch (e: Exception) {
                 Log.e(TAG, "Exception caught when perform super resolution", e)
                 Toast.makeText(baseContext, "Failed to perform super resolution", Toast.LENGTH_SHORT)
                     .show()
             }
         }
     }
    
  4. 新增 onDestroy 方法

     override fun onDestroy() {
         super.onDestroy()
         ortEnv.close()
         ortSession.close()
     }
    
    
  5. 新增 updateUI 方法

    private fun updateUI(result: Result) {
        outputImage?.setImageBitmap(result.outputBitmap)
    }
    
  6. 新增 readModel 方法

    此方法從資原始檔夾中讀取 ONNX 模型。

    private fun readModel(): ByteArray {
        val modelID = R.pytorch_superresolution_with_pre_post_processing_op18
        return resources.openRawResource(modelID).readBytes()
    }   
    
  7. 新增一個讀取輸入影像的方法

    此方法從 assets 資料夾讀取測試影像。目前它讀取應用程式中內建的固定影像。示例將很快擴充套件,以直接從相機或相機膠捲讀取影像。

    private fun readInputImage(): InputStream {
        return assets.open("test_superresolution.png")
    }   
    
  8. 新增執行推理的方法

    此方法呼叫應用程式核心方法:SuperResPerformer.upscale(),這是在模型上執行推理的方法。其程式碼在下一節中顯示。

     private fun performSuperResolution(ortSession: OrtSession) {
         var superResPerformer = SuperResPerformer()
         var result = superResPerformer.upscale(readInputImage(), ortEnv, ortSession)
         updateUI(result);
     }   
    
  9. 新增 TAG 物件

    companion object {
        const val TAG = "ORTSuperResolution"
    }
    

模型推理類程式碼

建立一個名為 SuperResPerformer.kt 的檔案,並向其中新增以下程式碼片段。

  1. 新增匯入

    import ai.onnxruntime.OnnxJavaType
    import ai.onnxruntime.OrtSession
    import ai.onnxruntime.OnnxTensor
    import ai.onnxruntime.OrtEnvironment
    import android.graphics.Bitmap
    import android.graphics.BitmapFactory
    import java.io.InputStream
    import java.nio.ByteBuffer
    import java.util.*
    
  2. 建立結果類

    internal data class Result(
        var outputBitmap: Bitmap? = null
    ) {}
    
  3. 建立超解析度執行器類

    該類及其主要函式 upscale 是大部分 ONNX Runtime 呼叫所在的地方。

    internal class SuperResPerformer(
    ) {
    
        fun upscale(inputStream: InputStream, ortEnv: OrtEnvironment, ortSession: OrtSession): Result {
            var result = Result()
    
            // Step 1: convert image into byte array (raw image bytes)
            val rawImageBytes = inputStream.readBytes()
    
            // Step 2: get the shape of the byte array and make ort tensor
            val shape = longArrayOf(rawImageBytes.size.toLong())
    
            val inputTensor = OnnxTensor.createTensor(
                ortEnv,
                ByteBuffer.wrap(rawImageBytes),
                shape,
                OnnxJavaType.UINT8
            )
            inputTensor.use {
                // Step 3: call ort inferenceSession run
                val output = ortSession.run(Collections.singletonMap("image", inputTensor))
    
                // Step 4: output analysis
                output.use {
                    val rawOutput = (output?.get(0)?.value) as ByteArray
                    val outputImageBitmap =
                        byteArrayToBitmap(rawOutput)
    
                    // Step 5: set output result
                    result.outputBitmap = outputImageBitmap
                }
            }
            return result
        }
    

構建並執行應用

在 Android Studio 中

  • 選擇 Build -> Make Project
  • Run -> app

應用程式在裝置模擬器中執行。連線到您的 Android 裝置以在裝置上執行應用程式。

iOS 應用

先決條件

  • 安裝 Xcode 13.0 及更高版本(最好是最新版本)
  • iOS 裝置或 iOS 模擬器
  • Xcode 命令列工具 xcode-select --install
  • CocoaPods sudo gem install cocoapods
  • 有效的 Apple 開發者 ID(如果您計劃在裝置上執行)

示例程式碼

您可以在 GitHub 上找到 iOS 超解析度應用的完整原始碼

從原始碼執行應用程式

  1. 克隆 onnxruntime-inference-examples 倉庫

    git clone https://github.com/microsoft/onnxruntime-inference-examples
    cd onnxruntime-inference-examples/mobile/examples/super_resolution/ios
    
  2. 安裝所需的 Pod 檔案

    pod install
    
  3. 在 Xcode 中開啟生成的 ORTSuperResolution.xcworkspace 檔案

    (可選:僅當您在裝置上執行時才需要)選擇您的開發團隊

  4. 執行應用程式

    連線您的 iOS 裝置或模擬器,構建並執行應用程式

    點選 Perform Super Resolution 按鈕檢視應用執行情況

要逐步開發應用程式,請按照以下章節操作。

從零開始編寫程式碼

建立專案

使用 APP 模板在 Xcode 中建立一個新專案

依賴項

安裝以下 Pod

  # Pods for OrtSuperResolution
  pod 'onnxruntime-c'
  
  # Pre-release version pods
  pod 'onnxruntime-extensions-c', '0.5.0-dev+261962.e3663fb'

專案資源

  1. 將模型檔案新增到專案中

    將本教程開頭生成的模型檔案複製到專案根資料夾中。

  2. 將測試影像新增為資產

    將您想要執行超解析度的影像複製到專案根資料夾中。

主應用

開啟名為 ORTSuperResolutionApp.swift 的檔案並新增以下程式碼

import SwiftUI

@main
struct ORTSuperResolutionApp: App {
    var body: some Scene {
        WindowGroup {
            ContentView()
        }
    }
}

內容檢視

開啟名為 ContentView.swift 的檔案並新增以下程式碼

import SwiftUI

struct ContentView: View {
    @State private var performSuperRes = false
    
    func runOrtSuperResolution() -> UIImage? {
        do {
            let outputImage = try ORTSuperResolutionPerformer.performSuperResolution()
            return outputImage
        } catch let error as NSError {
            print("Error: \(error.localizedDescription)")
            return nil
        }
    }
    
    var body: some View {
        ScrollView {
            VStack {
                VStack {
                    Text("ORTSuperResolution").font(.title).bold()
                        .frame(width: 400, height: 80)
                        .border(Color.purple, width: 4)
                        .background(Color.purple)
                    
                    Text("Input low resolution image: ").frame(width: 350, height: 40, alignment:.leading)
                    
                    Image("cat_224x224").frame(width: 250, height: 250)
                    
                    Button("Perform Super Resolution") {
                        performSuperRes.toggle()
                    }
                    
                    if performSuperRes {
                        Text("Output high resolution image: ").frame(width: 350, height: 40, alignment:.leading)
                        
                        if let outputImage = runOrtSuperResolution() {
                            Image(uiImage: outputImage)
                        } else {
                            Text("Unable to perform super resolution. ").frame(width: 350, height: 40, alignment:.leading)
                        }
                    }
                    Spacer()
                }
            }
            .padding()
        }
    }
}

struct ContentView_Previews: PreviewProvider {
    static var previews: some View {
        ContentView()
    }
}

Swift / Objective C 橋接標頭檔案

建立一個名為 ORTSuperResolution-Bridging-Header.h 的檔案,並新增以下匯入語句

#import "ORTSuperResolutionPerformer.h"

超解析度程式碼

  1. 建立一個名為 ORTSuperResolutionPerformer.h 的檔案,並新增以下程式碼

    #ifndef ORTSuperResolutionPerformer_h
    #define ORTSuperResolutionPerformer_h
    
    #import <Foundation/Foundation.h>
    #import <UIKit/UIKit.h>
    
    NS_ASSUME_NONNULL_BEGIN
    
    @interface ORTSuperResolutionPerformer : NSObject
    
    + (nullable UIImage*)performSuperResolutionWithError:(NSError**)error;
    
    @end
    
    NS_ASSUME_NONNULL_END
    
    #endif
    
  2. 建立一個名為 ORTSuperResolutionPerformer.mm 的檔案,並新增以下程式碼

     #import "ORTSuperResolutionPerformer.h"
     #import <Foundation/Foundation.h>
     #import <UIKit/UIKit.h>
    
     #include <array>
     #include <cstdint>
     #include <stdexcept>
     #include <string>
     #include <vector>
    
     #include <onnxruntime_cxx_api.h>
     #include <onnxruntime_extensions.h>
    
    
     @implementation ORTSuperResolutionPerformer
    
     + (nullable UIImage*)performSuperResolutionWithError:(NSError **)error {
            
         UIImage* output_image = nil;
            
         try {
                
             // Register custom ops
                
             const auto ort_log_level = ORT_LOGGING_LEVEL_INFO;
             auto ort_env = Ort::Env(ort_log_level, "ORTSuperResolution");
             auto session_options = Ort::SessionOptions();
                
             if (RegisterCustomOps(session_options, OrtGetApiBase()) != nullptr) {
                 throw std::runtime_error("RegisterCustomOps failed");
             }
                
             // Step 1: Load model
                
             NSString *model_path = [NSBundle.mainBundle pathForResource:@"pt_super_resolution_with_pre_post_processing_opset16"
                                                                 ofType:@"onnx"];
             if (model_path == nullptr) {
                 throw std::runtime_error("Failed to get model path");
             }
                
             // Step 2: Create Ort Inference Session
                
             auto sess = Ort::Session(ort_env, [model_path UTF8String], session_options);
                
             // Read input image
                
             // note: need to set Xcode settings to prevent it from messing with PNG files:
             // in "Build Settings":
             // - set "Compress PNG Files" to "No"
             // - set "Remove Text Metadata From PNG Files" to "No"
             NSString *input_image_path =
             [NSBundle.mainBundle pathForResource:@"cat_224x224" ofType:@"png"];
             if (input_image_path == nullptr) {
                 throw std::runtime_error("Failed to get image path");
             }
                
             // Step 3: Prepare input tensors and input/output names
                
             NSMutableData *input_data =
             [NSMutableData dataWithContentsOfFile:input_image_path];
             const int64_t input_data_length = input_data.length;
             const auto memoryInfo =
             Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
                
             const auto input_tensor = Ort::Value::CreateTensor(memoryInfo, [input_data mutableBytes], input_data_length,
                                                             &input_data_length, 1, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8);
                
             constexpr auto input_names = std::array{"image"};
             constexpr auto output_names = std::array{"image_out"};
                
             // Step 4: Call inference session run
                
             const auto outputs = sess.Run(Ort::RunOptions(), input_names.data(),
                                         &input_tensor, 1, output_names.data(), 1);
             if (outputs.size() != 1) {
                 throw std::runtime_error("Unexpected number of outputs");
             }
                
             // Step 5: Analyze model outputs
                
             const auto &output_tensor = outputs.front();
             const auto output_type_and_shape_info = output_tensor.GetTensorTypeAndShapeInfo();
             const auto output_shape = output_type_and_shape_info.GetShape();
                
             if (const auto output_element_type =
                 output_type_and_shape_info.GetElementType();
                 output_element_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8) {
                 throw std::runtime_error("Unexpected output element type");
             }
                
             const uint8_t *output_data_raw = output_tensor.GetTensorData<uint8_t>();
                
             // Step 6: Convert raw bytes into NSData and return as displayable UIImage
                
             NSData *output_data = [NSData dataWithBytes:output_data_raw length:(output_shape[0])];
             output_image = [UIImage imageWithData:output_data];
                
         } catch (std::exception &e) {
             NSLog(@"%s error: %s", __FUNCTION__, e.what());
                
             static NSString *const kErrorDomain = @"ORTSuperResolution";
             constexpr NSInteger kErrorCode = 0;
             if (error) {
                 NSString *description =
                 [NSString stringWithCString:e.what() encoding:NSASCIIStringEncoding];
                 *error =
                 [NSError errorWithDomain:kErrorDomain
                                     code:kErrorCode
                                 userInfo:@{NSLocalizedDescriptionKey : description}];
             }
             return nullptr;
         }
            
         if (error) {
             *error = nullptr;
         }
         return output_image;
     }
    
     @end
    

構建並執行應用

在 Xcode 中,選擇三角形構建圖示來構建並執行應用程式!

資源

原始 PyTorch 教程