使用機器學習超解析度在移動裝置上提高影像解析度
瞭解如何使用 ONNX Runtime Mobile 構建一個應用程式,透過包含預處理和後處理的模型來提高影像解析度。
您可以使用本教程為 Android 或 iOS 構建應用程式。
該應用程式接收影像輸入,在點選按鈕時執行超解析度操作,並在下方顯示解析度提高後的影像,如下圖所示。

目錄
準備模型
本教程中使用的機器學習模型基於本頁底部引用的 PyTorch 教程中使用的模型。
我們提供了一個方便的 Python 指令碼,用於將 PyTorch 模型匯出為 ONNX 格式並新增預處理和後處理。
-
在執行此指令碼之前,請安裝以下 Python 包
pip install torch pip install pillow pip install onnx pip install onnxruntime pip install --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/ onnxruntime-extensions版本說明:最佳的超解析度結果是在 ONNX opset 18 下實現的(它支援帶抗鋸齒的 Resize 運算子),該版本受 onnx 1.13.0 及更高版本和 onnxruntime 1.14.0 及更高版本支援。onnxruntime-extensions 包是預釋出版本。釋出版本將很快可用。
-
然後從 onnxruntime-extensions GitHub 儲存庫下載指令碼和測試影像(如果您尚未克隆此儲存庫)
curl https://raw.githubusercontent.com/microsoft/onnxruntime-extensions/main/tutorials/superresolution_e2e.py > superresolution_e2e.py curl https://raw.githubusercontent.com/microsoft/onnxruntime-extensions/main/tutorials/data/super_res_input.png > data/super_res_input.png -
執行指令碼以匯出核心模型併為其新增預處理和後處理
python superresolution_e2e.py
指令碼執行後,您應該在執行指令碼的資料夾中看到兩個 ONNX 檔案
pytorch_superresolution.onnx
pytorch_superresolution_with_pre_and_post_processing.onnx
如果您將這兩個模型載入到 netron 中,您可以看到兩者之間輸入和輸出的區別。下面前兩張圖顯示了原始模型,其輸入是批次通道資料,後兩張圖顯示了輸入和輸出是影像位元組。




現在是編寫應用程式程式碼的時候了。
Android 應用
先決條件
- Android Studio Dolphin 2021.3.1 Patch + (安裝在 Mac/Windows/Linux 上)
- Android SDK 29+
- Android NDK r22+
- Android 裝置或 Android 模擬器
示例程式碼
您可以在 GitHub 上找到 Android 超解析度應用的完整原始碼。
要從原始碼執行應用程式,請克隆上述倉庫並將 build.gradle 檔案載入到 Android Studio 中,然後構建並執行!
要一步步構建應用程式,請按照以下章節操作。
從零開始編寫程式碼
設定專案
在 Android Studio 中為手機和平板電腦建立一個新專案,並選擇空白模板。將應用程式命名為 super_resolution 或類似名稱。
依賴項
將以下依賴項新增到應用的 build.gradle
implementation 'com.microsoft.onnxruntime:onnxruntime-android:latest.release'
implementation 'com.microsoft.onnxruntime:onnxruntime-extensions-android:latest.release'
專案資源
-
將模型檔案新增為原始資源
在
src/main/res資料夾中建立一個名為raw的資料夾,並將 ONNX 模型移動或複製到該 raw 資料夾中。 -
將測試影像新增為資產
在主專案資料夾中建立一個名為
assets的資料夾,並將您想要執行超解析度的影像複製到該資料夾中,檔名為test_superresolution.png
主應用程式類程式碼
建立一個名為 MainActivity.kt 的檔案,並向其中新增以下程式碼片段。
-
新增匯入語句
import ai.onnxruntime.* import ai.onnxruntime.extensions.OrtxPackage import android.annotation.SuppressLint import android.os.Bundle import android.widget.Button import android.widget.ImageView import android.widget.Toast import androidx.activity.* import androidx.appcompat.app.AppCompatActivity import kotlinx.android.synthetic.main.activity_main.* import kotlinx.coroutines.* import java.io.InputStream import java.util.* import java.util.concurrent.ExecutorService import java.util.concurrent.Executors -
建立主活動類並新增類變數
class MainActivity : AppCompatActivity() { private var ortEnv: OrtEnvironment = OrtEnvironment.getEnvironment() private lateinit var ortSession: OrtSession private var inputImage: ImageView? = null private var outputImage: ImageView? = null private var superResolutionButton: Button? = null ... } -
新增
onCreate()方法在這裡我們初始化 ONNX Runtime 會話。一個會話儲存對應用程式中用於執行推理的模型引用。它還接受一個會話選項引數,您可以在其中指定不同的執行提供程式(如 NNAPI 等硬體加速器)。在這種情況下,我們預設在 CPU 上執行。但是,我們確實註冊了自定義操作庫,其中包含模型輸入和輸出處的影像編碼和解碼運算子。
override fun onCreate(savedInstanceState: Bundle?) { super.onCreate(savedInstanceState) setContentView(R.layout.activity_main) inputImage = findViewById(R.id.imageView1) outputImage = findViewById(R.id.imageView2); superResolutionButton = findViewById(R.id.super_resolution_button) inputImage?.setImageBitmap( BitmapFactory.decodeStream(readInputImage()) ); // Initialize Ort Session and register the onnxruntime extensions package that contains the custom operators. // Note: These are used to decode the input image into the format the original model requires, // and to encode the model output into png format val sessionOptions: OrtSession.SessionOptions = OrtSession.SessionOptions() sessionOptions.registerCustomOpLibrary(OrtxPackage.getLibraryPath()) ortSession = ortEnv.createSession(readModel(), sessionOptions) superResolutionButton?.setOnClickListener { try { performSuperResolution(ortSession) Toast.makeText(baseContext, "Super resolution performed!", Toast.LENGTH_SHORT) .show() } catch (e: Exception) { Log.e(TAG, "Exception caught when perform super resolution", e) Toast.makeText(baseContext, "Failed to perform super resolution", Toast.LENGTH_SHORT) .show() } } } -
新增 onDestroy 方法
override fun onDestroy() { super.onDestroy() ortEnv.close() ortSession.close() } -
新增 updateUI 方法
private fun updateUI(result: Result) { outputImage?.setImageBitmap(result.outputBitmap) } -
新增 readModel 方法
此方法從資原始檔夾中讀取 ONNX 模型。
private fun readModel(): ByteArray { val modelID = R.pytorch_superresolution_with_pre_post_processing_op18 return resources.openRawResource(modelID).readBytes() } -
新增一個讀取輸入影像的方法
此方法從 assets 資料夾讀取測試影像。目前它讀取應用程式中內建的固定影像。示例將很快擴充套件,以直接從相機或相機膠捲讀取影像。
private fun readInputImage(): InputStream { return assets.open("test_superresolution.png") } -
新增執行推理的方法
此方法呼叫應用程式核心方法:
SuperResPerformer.upscale(),這是在模型上執行推理的方法。其程式碼在下一節中顯示。private fun performSuperResolution(ortSession: OrtSession) { var superResPerformer = SuperResPerformer() var result = superResPerformer.upscale(readInputImage(), ortEnv, ortSession) updateUI(result); } -
新增 TAG 物件
companion object { const val TAG = "ORTSuperResolution" }
模型推理類程式碼
建立一個名為 SuperResPerformer.kt 的檔案,並向其中新增以下程式碼片段。
-
新增匯入
import ai.onnxruntime.OnnxJavaType import ai.onnxruntime.OrtSession import ai.onnxruntime.OnnxTensor import ai.onnxruntime.OrtEnvironment import android.graphics.Bitmap import android.graphics.BitmapFactory import java.io.InputStream import java.nio.ByteBuffer import java.util.* -
建立結果類
internal data class Result( var outputBitmap: Bitmap? = null ) {} -
建立超解析度執行器類
該類及其主要函式
upscale是大部分 ONNX Runtime 呼叫所在的地方。- OrtEnvironment 單例維護環境屬性和配置的日誌級別
- OnnxTensor.createTensor() 用於建立由輸入影像位元組組成的張量,適合作為模型輸入
- OnnxJavaType.UINT8 是輸入張量 ByteBuffer 的資料型別
- OrtSession.run() 在模型上執行推理(預測),以獲取輸出的放大影像
internal class SuperResPerformer( ) { fun upscale(inputStream: InputStream, ortEnv: OrtEnvironment, ortSession: OrtSession): Result { var result = Result() // Step 1: convert image into byte array (raw image bytes) val rawImageBytes = inputStream.readBytes() // Step 2: get the shape of the byte array and make ort tensor val shape = longArrayOf(rawImageBytes.size.toLong()) val inputTensor = OnnxTensor.createTensor( ortEnv, ByteBuffer.wrap(rawImageBytes), shape, OnnxJavaType.UINT8 ) inputTensor.use { // Step 3: call ort inferenceSession run val output = ortSession.run(Collections.singletonMap("image", inputTensor)) // Step 4: output analysis output.use { val rawOutput = (output?.get(0)?.value) as ByteArray val outputImageBitmap = byteArrayToBitmap(rawOutput) // Step 5: set output result result.outputBitmap = outputImageBitmap } } return result }
構建並執行應用
在 Android Studio 中
- 選擇 Build -> Make Project
- Run -> app
應用程式在裝置模擬器中執行。連線到您的 Android 裝置以在裝置上執行應用程式。
iOS 應用
先決條件
- 安裝 Xcode 13.0 及更高版本(最好是最新版本)
- iOS 裝置或 iOS 模擬器
- Xcode 命令列工具
xcode-select --install - CocoaPods
sudo gem install cocoapods - 有效的 Apple 開發者 ID(如果您計劃在裝置上執行)
示例程式碼
您可以在 GitHub 上找到 iOS 超解析度應用的完整原始碼。
從原始碼執行應用程式
-
克隆 onnxruntime-inference-examples 倉庫
git clone https://github.com/microsoft/onnxruntime-inference-examples cd onnxruntime-inference-examples/mobile/examples/super_resolution/ios -
安裝所需的 Pod 檔案
pod install -
在 Xcode 中開啟生成的
ORTSuperResolution.xcworkspace檔案(可選:僅當您在裝置上執行時才需要)選擇您的開發團隊
-
執行應用程式
連線您的 iOS 裝置或模擬器,構建並執行應用程式
點選
Perform Super Resolution按鈕檢視應用執行情況
要逐步開發應用程式,請按照以下章節操作。
從零開始編寫程式碼
建立專案
使用 APP 模板在 Xcode 中建立一個新專案
依賴項
安裝以下 Pod
# Pods for OrtSuperResolution
pod 'onnxruntime-c'
# Pre-release version pods
pod 'onnxruntime-extensions-c', '0.5.0-dev+261962.e3663fb'
專案資源
-
將模型檔案新增到專案中
將本教程開頭生成的模型檔案複製到專案根資料夾中。
-
將測試影像新增為資產
將您想要執行超解析度的影像複製到專案根資料夾中。
主應用
開啟名為 ORTSuperResolutionApp.swift 的檔案並新增以下程式碼
import SwiftUI
@main
struct ORTSuperResolutionApp: App {
var body: some Scene {
WindowGroup {
ContentView()
}
}
}
內容檢視
開啟名為 ContentView.swift 的檔案並新增以下程式碼
import SwiftUI
struct ContentView: View {
@State private var performSuperRes = false
func runOrtSuperResolution() -> UIImage? {
do {
let outputImage = try ORTSuperResolutionPerformer.performSuperResolution()
return outputImage
} catch let error as NSError {
print("Error: \(error.localizedDescription)")
return nil
}
}
var body: some View {
ScrollView {
VStack {
VStack {
Text("ORTSuperResolution").font(.title).bold()
.frame(width: 400, height: 80)
.border(Color.purple, width: 4)
.background(Color.purple)
Text("Input low resolution image: ").frame(width: 350, height: 40, alignment:.leading)
Image("cat_224x224").frame(width: 250, height: 250)
Button("Perform Super Resolution") {
performSuperRes.toggle()
}
if performSuperRes {
Text("Output high resolution image: ").frame(width: 350, height: 40, alignment:.leading)
if let outputImage = runOrtSuperResolution() {
Image(uiImage: outputImage)
} else {
Text("Unable to perform super resolution. ").frame(width: 350, height: 40, alignment:.leading)
}
}
Spacer()
}
}
.padding()
}
}
}
struct ContentView_Previews: PreviewProvider {
static var previews: some View {
ContentView()
}
}
Swift / Objective C 橋接標頭檔案
建立一個名為 ORTSuperResolution-Bridging-Header.h 的檔案,並新增以下匯入語句
#import "ORTSuperResolutionPerformer.h"
超解析度程式碼
-
建立一個名為
ORTSuperResolutionPerformer.h的檔案,並新增以下程式碼#ifndef ORTSuperResolutionPerformer_h #define ORTSuperResolutionPerformer_h #import <Foundation/Foundation.h> #import <UIKit/UIKit.h> NS_ASSUME_NONNULL_BEGIN @interface ORTSuperResolutionPerformer : NSObject + (nullable UIImage*)performSuperResolutionWithError:(NSError**)error; @end NS_ASSUME_NONNULL_END #endif -
建立一個名為
ORTSuperResolutionPerformer.mm的檔案,並新增以下程式碼#import "ORTSuperResolutionPerformer.h" #import <Foundation/Foundation.h> #import <UIKit/UIKit.h> #include <array> #include <cstdint> #include <stdexcept> #include <string> #include <vector> #include <onnxruntime_cxx_api.h> #include <onnxruntime_extensions.h> @implementation ORTSuperResolutionPerformer + (nullable UIImage*)performSuperResolutionWithError:(NSError **)error { UIImage* output_image = nil; try { // Register custom ops const auto ort_log_level = ORT_LOGGING_LEVEL_INFO; auto ort_env = Ort::Env(ort_log_level, "ORTSuperResolution"); auto session_options = Ort::SessionOptions(); if (RegisterCustomOps(session_options, OrtGetApiBase()) != nullptr) { throw std::runtime_error("RegisterCustomOps failed"); } // Step 1: Load model NSString *model_path = [NSBundle.mainBundle pathForResource:@"pt_super_resolution_with_pre_post_processing_opset16" ofType:@"onnx"]; if (model_path == nullptr) { throw std::runtime_error("Failed to get model path"); } // Step 2: Create Ort Inference Session auto sess = Ort::Session(ort_env, [model_path UTF8String], session_options); // Read input image // note: need to set Xcode settings to prevent it from messing with PNG files: // in "Build Settings": // - set "Compress PNG Files" to "No" // - set "Remove Text Metadata From PNG Files" to "No" NSString *input_image_path = [NSBundle.mainBundle pathForResource:@"cat_224x224" ofType:@"png"]; if (input_image_path == nullptr) { throw std::runtime_error("Failed to get image path"); } // Step 3: Prepare input tensors and input/output names NSMutableData *input_data = [NSMutableData dataWithContentsOfFile:input_image_path]; const int64_t input_data_length = input_data.length; const auto memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); const auto input_tensor = Ort::Value::CreateTensor(memoryInfo, [input_data mutableBytes], input_data_length, &input_data_length, 1, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8); constexpr auto input_names = std::array{"image"}; constexpr auto output_names = std::array{"image_out"}; // Step 4: Call inference session run const auto outputs = sess.Run(Ort::RunOptions(), input_names.data(), &input_tensor, 1, output_names.data(), 1); if (outputs.size() != 1) { throw std::runtime_error("Unexpected number of outputs"); } // Step 5: Analyze model outputs const auto &output_tensor = outputs.front(); const auto output_type_and_shape_info = output_tensor.GetTensorTypeAndShapeInfo(); const auto output_shape = output_type_and_shape_info.GetShape(); if (const auto output_element_type = output_type_and_shape_info.GetElementType(); output_element_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8) { throw std::runtime_error("Unexpected output element type"); } const uint8_t *output_data_raw = output_tensor.GetTensorData<uint8_t>(); // Step 6: Convert raw bytes into NSData and return as displayable UIImage NSData *output_data = [NSData dataWithBytes:output_data_raw length:(output_shape[0])]; output_image = [UIImage imageWithData:output_data]; } catch (std::exception &e) { NSLog(@"%s error: %s", __FUNCTION__, e.what()); static NSString *const kErrorDomain = @"ORTSuperResolution"; constexpr NSInteger kErrorCode = 0; if (error) { NSString *description = [NSString stringWithCString:e.what() encoding:NSASCIIStringEncoding]; *error = [NSError errorWithDomain:kErrorDomain code:kErrorCode userInfo:@{NSLocalizedDescriptionKey : description}]; } return nullptr; } if (error) { *error = nullptr; } return output_image; } @end
構建並執行應用
在 Xcode 中,選擇三角形構建圖示來構建並執行應用程式!