裝置端訓練:構建 Android 應用程式

在本教程中,我們將探討如何構建一個整合 ONNX Runtime 裝置端訓練解決方案的 Android 應用程式。裝置端訓練是指直接在邊緣裝置上訓練機器學習模型的過程,不依賴雲服務或外部伺服器。

本教程結束時,應用程式將如下圖所示

an image classification app with Tom Cruise in the middle.

簡介

我們將指導您完成建立 Android 應用程式的步驟,該應用程式可以使用裝置端訓練技術訓練一個簡單的影像分類模型。本教程展示了遷移學習技術,即利用從一項任務中訓練模型獲得的知識來提高模型在不同但相關任務上的效能。遷移學習允許我們將預訓練模型學到的知識或特徵遷移到新任務,而不是從頭開始學習過程。

在本教程中,我們將利用 MobileNetV2 模型,該模型已在 ImageNet 等大型影像資料集(包含 1,000 個類別)上進行訓練。我們將使用此模型將自定義資料分類到四個類別中的一個。MobileNetV2 的初始層作為特徵提取器,捕獲適用於各種任務的通用視覺特徵,而只有最終的分類器層將針對手頭的任務進行訓練。

在本教程中,我們將使用資料來學習

  • 使用預打包的動物資料集將動物分類到四個類別中的一個。
  • 使用自定義名人資料集將名人分類到四個類別中的一個。

目錄

先決條件

要學習本教程,您應該對使用 Java 或 Kotlin 進行 Android 應用程式開發有基本的瞭解。熟悉 C++ 以及機器學習概念(如神經網路和影像分類)也將有所幫助。

  • 用於準備訓練工件的 Python 開發環境
  • Android Studio 4.1+
  • Android SDK 29+
  • Android NDK r21+
  • 一臺已啟用 開發者模式 並開啟 USB 除錯的 Android 裝置,帶攝像頭

注意 整個 Android 應用程式也已在 onnxruntime-training-examples GitHub 倉庫中提供。

離線階段 - 構建訓練工件

  1. 將模型匯出為 ONNX。

    我們從預訓練的 PyTorch 模型開始,並將其匯出為 ONNX。MobileNetV2 模型已在 Imagenet 資料集上進行了預訓練,該資料集包含 1000 個類別的資料。對於我們的影像分類任務,我們只想對 4 個類別中的影像進行分類。因此,我們將模型的最後一層更改為輸出 4 個 logits,而不是 1,000 個。

    有關如何將 PyTorch 模型匯出為 ONNX 的更多詳細資訊,請參見此處

    import torch
    import torchvision
    
    model = torchvision.models.mobilenet_v2(
       weights=torchvision.models.MobileNet_V2_Weights.IMAGENET1K_V2)
    
    # The original model is trained on imagenet which has 1000 classes.
    # For our image classification scenario, we need to classify among 4 categories.
    # So we need to change the last layer of the model to have 4 outputs.
    model.classifier[1] = torch.nn.Linear(1280, 4)
    
    # Export the model to ONNX.
    model_name = "mobilenetv2"
    torch.onnx.export(model, torch.randn(1, 3, 224, 224),
                      f"training_artifacts/{model_name}.onnx",
                      input_names=["input"], output_names=["output"],
                      dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}})
    
  2. 定義可訓練和不可訓練的引數

    import onnx
    
    # Load the onnx model.
    onnx_model = onnx.load(f"training_artifacts/{model_name}.onnx")
    
    # Define the parameters that require their gradients to be computed
    # (trainable parameters) and those that do not (frozen/non trainable parameters).
    requires_grad = ["classifier.1.weight", "classifier.1.bias"]
    frozen_params = [
       param.name
       for param in onnx_model.graph.initializer
       if param.name not in requires_grad
    ]
    
  3. 生成訓練工件。

    在本教程中,我們將使用 CrossEntropyLoss 損失函式和 AdamW 最佳化器。有關工件生成的更多詳細資訊,請參見此處

    from onnxruntime.training import artifacts
    
    # Generate the training artifacts.
    artifacts.generate_artifacts(
       onnx_model,
       requires_grad=requires_grad,
       frozen_params=frozen_params,
       loss=artifacts.LossType.CrossEntropyLoss,
       optimizer=artifacts.OptimType.AdamW,
       artifact_directory="training_artifacts"
    )
    

    就這樣!訓練工件已生成到 training_artifacts 資料夾中。這標誌著離線階段的結束。這些工件已準備好部署到 Android 裝置進行訓練。

訓練階段 - Android 應用程式開發

  1. 在 Android Studio 中設定專案

    a. 開啟 Android Studio,然後單擊 New Project Android Studio 設定 - 新建專案

    b. 單擊 Native C++ -> Next。填寫 New Project 詳細資訊如下

    • 名稱 - ORT Personalize
    • 包名 - com.example.ortpersonalize
    • 語言 - Kotlin

    單擊 Next

    Android Studio Setup - Project Name

    c. 選擇 C++17 工具鏈 -> Finish

    Android Studio Setup - Project C++ ToolChain

    d. 完成!Android Studio 專案已設定完畢。您現在應該能夠看到帶有樣板程式碼的 Android Studio 編輯器。

  2. 新增 ONNX Runtime 依賴項

    a. 在 Android Studio 專案的 cpp 目錄下建立兩個新資料夾:libinclude\onnxruntime

    lib and include folder

    b. 訪問 Maven Central。前往 Versions->Browse-> 並下載 onnxruntime-training-android 歸檔包 (aar 檔案)。

    c. 將 aar 副檔名重新命名為 zip。因此,onnxruntime-training-android-1.15.0.aar 變為 onnxruntime-training-android-1.15.0.zip

    d. 解壓 zip 檔案的內容。

    e. 將 jni\arm64-v8a 資料夾中的 libonnxruntime.so 共享庫複製到新建立的 lib 資料夾下的 Android 專案中。

    f. 將 headers 資料夾的內容複製到新建立的 include\onnxruntime 資料夾中。

    g. 在 native-lib.cpp 檔案中,包含訓練 cxx 標頭檔案。

    #include "onnxruntime_training_cxx_api.h"
    

    h. 將 abiFilters 新增到 build.gradle (Module) 檔案中,以選擇 arm64-v8a。此設定必須新增到 build.gradle 中的 defaultConfig 下。

    ndk {
       abiFilters 'arm64-v8a'
    }
    

    請注意,build.gradle 檔案的 defaultConfig 部分應如下所示

     defaultConfig {
        applicationId "com.example.ortpersonalize"
        minSdk 29
        targetSdk 33
        versionCode 1
        versionName "1.0"
    
        testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
        externalNativeBuild {
           cmake {
                 cppFlags '-std=c++17'
           }
        }
    +   ndk {
    +       abiFilters 'arm64-v8a'
    +   }
       
     }
    

    i. 將 onnxruntime 共享庫新增到 CMakeLists.txt,以便 cmake 可以找到並針對該共享庫進行構建。為此,請在 CMakeLists.txt 中新增 ortpersonalize 庫之後新增以下行

    add_library(onnxruntime SHARED IMPORTED)
    set_target_properties(onnxruntime PROPERTIES IMPORTED_LOCATION ${CMAKE_SOURCE_DIR}/lib/libonnxruntime.so)
    

    透過在上述兩行之後新增以下行,讓 CMake 知道 ONNX Runtime 標頭檔案在哪裡可以找到

    target_include_directories(ortpersonalize PRIVATE ${CMAKE_SOURCE_DIR}/include/onnxruntime)
    

    透過將 onnxruntime 庫新增到 target_link_libraries 來將 Android C++ 專案與 onnxruntime 庫連結

    target_link_libraries( # Specifies the target library.
         ortpersonalize
    
         # Links the target library to the log library
         # included in the NDK.
         ${log-lib}
    
         onnxruntime)
    

    請注意,CMakeLists.txt 檔案應如下所示

    project("ortpersonalize")
    
    add_library( # Sets the name of the library.
          ortpersonalize
    
          # Sets the library as a shared library.
          SHARED
    
          # Provides a relative path to your source file(s).
          native-lib.cpp
    +     utils.cpp
    +     inference.cpp
    +     train.cpp)
    + add_library(onnxruntime SHARED IMPORTED)
    + set_target_properties(onnxruntime PROPERTIES IMPORTED_LOCATION ${CMAKE_SOURCE_DIR}/lib/libonnxruntime.so)
    + target_include_directories(ortpersonalize PRIVATE ${CMAKE_SOURCE_DIR}/include/onnxruntime)
    
    find_library( # Sets the name of the path variable.
          log-lib
    
          # Specifies the name of the NDK library that
          # you want CMake to locate.
          log)
    
    target_link_libraries( # Specifies the target library.
          ortpersonalize
    
          # Links the target library to the log library
          # included in the NDK.
          ${log-lib}
    +     onnxruntime)
    
    

    j. 構建應用程式並等待成功,以確認應用程式已包含 ONNX Runtime 標頭檔案並能成功連結到共享的 onnxruntime 庫。

  3. 打包預構建的訓練工件和資料集

    a. 在 Android Studio 專案的左側面板中,右鍵單擊 app -> New -> Folder -> Assets Folder,並在 main 下建立一個名為 assets 的新資料夾。

    b. 將步驟 2 中生成的訓練工件複製到此資料夾。

    c. 現在,前往 onnxruntime-training-examples 倉庫,將資料集 (images.zip) 下載到您的機器並解壓。此資料集是在 Kaggle 上由 Corrado Alessio 建立的原始 animals-10 資料集的基礎上修改的。

    d. 將下載的 images 資料夾複製到 Android Studio 的 assets/images 目錄中。

    專案的左側面板應如下所示

    Project Assets

  4. 與 ONNX Runtime 互動 - C++ 程式碼

    a. 我們將在 C++ 中實現以下四個函式,這些函式將從應用程式中呼叫

    • createSession:將在應用程式啟動時呼叫。它將建立新的 CheckpointStateTrainingSession 物件。
    • releaseSession:將在應用程式即將關閉時呼叫。此函式將釋放應用程式啟動時分配的資源。
    • performTraining:將在使用者單擊 UI 上的 Train 按鈕時呼叫。
    • performInference:將在使用者單擊 UI 上的 Infer 按鈕時呼叫。

    b. 建立會話

    此函式在應用程式啟動時呼叫。它將使用訓練工件資源建立 C++ CheckpointState 和 TrainingSession 物件。這些物件將用於在裝置上訓練模型。

    createSession 的引數是

    • checkpoint_path:檢查點工件的快取路徑。
    • train_model_path:訓練模型工件的快取路徑。
    • eval_model_path:評估模型工件的快取路徑。
    • optimizer_model_path:最佳化器模型工件的快取路徑。
    • cache_dir_path:Android 裝置上快取目錄的路徑。快取目錄用作從 C++ 程式碼訪問訓練工件的一種方式。

    該函式返回一個 long,它表示指向 session_cache 物件的指標。每當我們想訪問訓練會話時,這個 long 都可以強制轉換為 SessionCache

    extern "C" JNIEXPORT jlong JNICALL
    Java_com_example_ortpersonalize_MainActivity_createSession(
          JNIEnv *env, jobject /* this */,
          jstring checkpoint_path, jstring train_model_path, jstring eval_model_path,
          jstring optimizer_model_path, jstring cache_dir_path)
    {
       std::unique_ptr<SessionCache> session_cache = std::make_unique<SessionCache>(
                utils::JString2String(env, checkpoint_path),
                utils::JString2String(env, train_model_path),
                utils::JString2String(env, eval_model_path),
                utils::JString2String(env, optimizer_model_path),
                utils::JString2String(env, cache_dir_path));
       return reinterpret_cast<long>(session_cache.release());
    }
    

    從上面的函式體可以看出,這個函式建立了一個 SessionCache 類的物件的唯一指標。SessionCache 的定義如下。

    struct SessionCache {
       ArtifactPaths artifact_paths;
       Ort::Env ort_env;
       Ort::SessionOptions session_options;
       Ort::CheckpointState checkpoint_state;
       Ort::TrainingSession training_session;
       Ort::Session* inference_session;
    
       SessionCache(const std::string &checkpoint_path, const std::string &training_model_path,
                   const std::string &eval_model_path, const std::string &optimizer_model_path,
                   const std::string& cache_dir_path) :
       artifact_paths(checkpoint_path, training_model_path, eval_model_path, optimizer_model_path, cache_dir_path),
       ort_env(ORT_LOGGING_LEVEL_WARNING, "ort personalize"), session_options(),
       checkpoint_state(Ort::CheckpointState::LoadCheckpoint(artifact_paths.checkpoint_path.c_str())),
       training_session(session_options, checkpoint_state, artifact_paths.training_model_path.c_str(),
                         artifact_paths.eval_model_path.c_str(), artifact_paths.optimizer_model_path.c_str()),
       inference_session(nullptr) {}
    };
    

    ArtifactPaths 的定義是

    struct ArtifactPaths {
       std::string checkpoint_path;
       std::string training_model_path;
       std::string eval_model_path;
       std::string optimizer_model_path;
       std::string cache_dir_path;
       std::string inference_model_path;
    
       ArtifactPaths(const std::string &checkpoint_path, const std::string &training_model_path,
                      const std::string &eval_model_path, const std::string &optimizer_model_path,
                      const std::string& cache_dir_path) :
       checkpoint_path(checkpoint_path), training_model_path(training_model_path),
       eval_model_path(eval_model_path), optimizer_model_path(optimizer_model_path),
       cache_dir_path(cache_dir_path), inference_model_path(cache_dir_path + "/inference.onnx") {}
    };
    

    c. 釋放會話

    此函式在應用程式即將關閉時呼叫。它釋放應用程式啟動時建立的資源,主要是 CheckpointState 和 TrainingSession。

    releaseSession 的引數是

    • session:代表 SessionCache 物件的 long 值。
    extern "C" JNIEXPORT void JNICALL
    Java_com_example_ortpersonalize_MainActivity_releaseSession(
          JNIEnv *env, jobject /* this */,
          jlong session) {
       auto *session_cache = reinterpret_cast<SessionCache *>(session);
       delete session_cache->inference_session;
       delete session_cache;
    }
    

    d. 執行訓練

    此函式針對每個需要訓練的批次呼叫。訓練迴圈用 Kotlin 編寫在應用程式端,並且在訓練迴圈中,每次批次都會呼叫 performTraining 函式。

    performTraining 的引數是

    • session:代表 SessionCache 物件的 long 值。
    • batch:作為浮點陣列的輸入影像,用於訓練。
    • labels:與提供用於訓練的輸入影像關聯的整數陣列形式的標籤。
    • batch_size:每個 TrainStep 要處理的影像數量。
    • channels:影像中的通道數。在我們的示例中,此值始終為 3
    • frame_rows:影像中的行數。在我們的示例中,此值始終為 224
    • frame_cols:影像中的列數。在我們的示例中,此值始終為 224

    該函式返回一個 float,表示該批次的訓練損失。

    extern "C"
    JNIEXPORT float JNICALL
    Java_com_example_ortpersonalize_MainActivity_performTraining(
          JNIEnv *env, jobject /* this */,
          jlong session, jfloatArray batch, jintArray labels, jint batch_size,
          jint channels, jint frame_rows, jint frame_cols) {
       auto* session_cache = reinterpret_cast<SessionCache *>(session);
    
       if (session_cache->inference_session) {
          // Invalidate the inference session since we will be updating the model parameters
          // in train_step.
          // The next call to inference session will need to recreate the inference session.
          delete session_cache->inference_session;
          session_cache->inference_session = nullptr;
       }
    
       // Update the model parameters using this batch of inputs.
       return training::train_step(session_cache, env->GetFloatArrayElements(batch, nullptr),
                                  env->GetIntArrayElements(labels, nullptr), batch_size,
                                  channels, frame_rows, frame_cols);
    }
    

    上述函式利用了 train_step 函式。train_step 函式的定義如下

    namespace training {
    
       float train_step(SessionCache* session_cache, float *batches, int32_t *labels,
                         int64_t batch_size, int64_t image_channels, int64_t image_rows,
                         int64_t image_cols) {
          const std::vector<int64_t> input_shape({batch_size, image_channels, image_rows, image_cols});
          const std::vector<int64_t> labels_shape({batch_size});
    
          Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
          std::vector<Ort::Value> user_inputs; // {inputs, labels}
          // Inputs batched
          user_inputs.emplace_back(Ort::Value::CreateTensor(memory_info, batches,
                                                             batch_size * image_channels * image_rows * image_cols * sizeof(float),
                                                             input_shape.data(), input_shape.size(),
                                                             ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT));
    
          // Labels batched
          user_inputs.emplace_back(Ort::Value::CreateTensor(memory_info, labels,
                                                             batch_size * sizeof(int32_t),
                                                             labels_shape.data(), labels_shape.size(),
                                                             ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32));
    
          // Run the train step and execute the forward + loss + backward.
          float loss = *(session_cache->training_session.TrainStep(user_inputs).front().GetTensorMutableData<float>());
    
          // Update the model parameters by taking a step in the direction of the gradients computed above.
          session_cache->training_session.OptimizerStep();
    
          // Reset the gradients now that the parameters have been updated.
          // New set of gradients can then be computed for the next round of inputs.
          session_cache->training_session.LazyResetGrad();
    
          return loss;
       }
    
    } // namespace training
    

    e. 執行推理

    當用戶想要執行推理時,將呼叫此函式。

    performInference 的引數是

    • session:代表 SessionCache 物件的 long 值。
    • image_buffer:作為浮點陣列的輸入影像,用於訓練。
    • batch_size:每次推理要處理的影像數量。在我們的示例中,此值始終為 1
    • image_channels:影像中的通道數。在我們的示例中,此值始終為 3
    • image_rows:影像中的行數。在我們的示例中,此值始終為 224
    • image_cols:影像中的列數。在我們的示例中,此值始終為 224
    • classes:表示所有四個自定義類別的字串列表。

    該函式返回一個 string,表示所提供的四個自定義類別之一。這是模型的預測結果。

    extern "C"
    JNIEXPORT jstring JNICALL
    Java_com_example_ortpersonalize_MainActivity_performInference(
          JNIEnv *env, jobject  /* this */,
          jlong session, jfloatArray image_buffer, jint batch_size, jint image_channels, jint image_rows,
          jint image_cols, jobjectArray classes) {
    
       std::vector<std::string> classes_str;
       for (int i = 0; i < env->GetArrayLength(classes); ++i) {
          // Access the current string element
          jstring elem = static_cast<jstring>(env->GetObjectArrayElement(classes, i));
          classes_str.push_back(utils::JString2String(env, elem));
       }
    
       auto* session_cache = reinterpret_cast<SessionCache *>(session);
       if (!session_cache->inference_session) {
          // The inference session does not exist, so create a new one.
          session_cache->training_session.ExportModelForInferencing(
                   session_cache->artifact_paths.inference_model_path.c_str(), {"output"});
          session_cache->inference_session = std::make_unique<Ort::Session>(
                   session_cache->ort_env, session_cache->artifact_paths.inference_model_path.c_str(),
                   session_cache->session_options).release();
       }
    
       auto prediction = inference::classify(
                session_cache, env->GetFloatArrayElements(image_buffer, nullptr),
                batch_size, image_channels, image_rows, image_cols, classes_str);
    
       return env->NewStringUTF(prediction.first.c_str());
    }
    

    上述函式呼叫了 classify。classify 的定義是

    namespace inference {
    
       std::pair<std::string, float> classify(SessionCache* session_cache, float *image_data,
                                              int64_t batch_size, int64_t image_channels,
                                              int64_t image_rows, int64_t image_cols,
                                              const std::vector<std::string>& classes) {
          std::vector<const char *> input_names = {"input"};
          size_t input_count = 1;
    
          std::vector<const char *> output_names = {"output"};
          size_t output_count = 1;
    
          std::vector<int64_t> input_shape({batch_size, image_channels, image_rows, image_cols});
    
          Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
          std::vector<Ort::Value> input_values; // {input images}
          input_values.emplace_back(Ort::Value::CreateTensor(memory_info, image_data,
                                                             batch_size * image_channels * image_rows * image_cols * sizeof(float),
                                                             input_shape.data(), input_shape.size(),
                                                             ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT));
    
    
          std::vector<Ort::Value> output_values;
          output_values.emplace_back(nullptr);
    
          // get the logits
          session_cache->inference_session->Run(Ort::RunOptions(), input_names.data(), input_values.data(),
                                                 input_count, output_names.data(), output_values.data(), output_count);
    
          float *output = output_values.front().GetTensorMutableData<float>();
    
          // run softmax and get the probabilities of each class
          std::vector<float> probabilities = Softmax(output, classes.size());
          size_t best_index = std::distance(probabilities.begin(), std::max_element(probabilities.begin(), probabilities.end()));
    
          return {classes[best_index], probabilities[best_index]};
       }
    
    } // namespace inference
    

    classify 函式呼叫另一個名為 Softmax 的函式。Softmax 的定義是

    std::vector<float> Softmax(float *logits, size_t num_logits) {
       std::vector<float> probabilities(num_logits, 0);
       float sum = 0;
       for (size_t i = 0; i < num_logits; ++i) {
             probabilities[i] = exp(logits[i]);
             sum += probabilities[i];
       }
    
       if (sum != 0.0f) {
             for (size_t i = 0; i < num_logits; ++i) {
                probabilities[i] /= sum;
             }
       }
    
       return probabilities;
    }
    
  5. 影像預處理

    a. MobileNetV2 模型要求輸入的影像是

    • 大小為 3 x 224 x 224
    • 一個標準化影像,減去平均值 (0.485, 0.456, 0.406) 併除以標準差 (0.229, 0.224, 0.225)

    此預處理在 Java/Kotlin 中使用 Android 提供的庫完成。

    讓我們在 app/src/main/java/com/example/ortpersonalize 目錄下建立一個名為 ImageProcessingUtil.kt 的新檔案。我們將在該檔案中新增用於裁剪、調整大小和規範化影像的實用方法。

    b. 裁剪和調整影像大小。

    fun processBitmap(bitmap: Bitmap) : Bitmap {
       // This function processes the given bitmap by
       //   - cropping along the longer dimension to get a square bitmap
       //     If the width is larger than the height
       //     ___+_________________+___
       //     |  +                 +  |
       //     |  +                 +  |
       //     |  +        +        +  |
       //     |  +                 +  |
       //     |__+_________________+__|
       //     <-------- width -------->
       //        <----- height ---->
       //     <-->      cropped    <-->
       //
       //     If the height is larger than the width
       //     _________________________   ʌ            ʌ
       //     |                       |   |         cropped
       //     |+++++++++++++++++++++++|   |      ʌ     v
       //     |                       |   |      |
       //     |                       |   |      |
       //     |           +           | height width
       //     |                       |   |      |
       //     |                       |   |      |
       //     |+++++++++++++++++++++++|   |      v     ʌ
       //     |                       |   |         cropped
       //     |_______________________|   v            v
       //
       //
       //
       //   - resizing the cropped square image to be of size (3 x 224 x 224) as needed by the
       //     mobilenetv2 model.
       lateinit var bitmapCropped: Bitmap
       if (bitmap.getWidth() >= bitmap.getHeight()) {
          // Since height is smaller than the width, we crop a square whose length is the height
          // So cropping happens along the width dimesion
          val width: Int = bitmap.getHeight()
          val height: Int = bitmap.getHeight()
    
          // left side of the cropped image must begin at (bitmap.getWidth() / 2 - bitmap.getHeight() / 2)
          // so that the cropped width contains equal portion of the width on either side of center
          // top side of the cropped image must begin at 0 since we are not cropping along the height
          // dimension
          val x: Int = bitmap.getWidth() / 2 - bitmap.getHeight() / 2
          val y: Int = 0
          bitmapCropped = Bitmap.createBitmap(bitmap, x, y, width, height)
       } else {
          // Since width is smaller than the height, we crop a square whose length is the width
          // So cropping happens along the height dimesion
          val width: Int = bitmap.getWidth()
          val height: Int = bitmap.getWidth()
    
          // left side of the cropped image must begin at 0 since we are not cropping along the width
          // dimension
          // top side of the cropped image must begin at (bitmap.getHeight() / 2 - bitmap.getWidth() / 2)
          // so that the cropped height contains equal portion of the height on either side of center
          val x: Int = 0
          val y: Int = bitmap.getHeight() / 2 - bitmap.getWidth() / 2
          bitmapCropped = Bitmap.createBitmap(bitmap, x, y, width, height)
       }
    
       // Resize the image to be channels x width x height as needed by the mobilenetv2 model
       val width: Int = 224
       val height: Int = 224
       val bitmapResized: Bitmap = Bitmap.createScaledBitmap(bitmapCropped, width, height, false)
    
       return bitmapResized
    }
    

    c. 規範化影像。

    fun processImage(bitmap: Bitmap, buffer: FloatBuffer, offset: Int) {
       // This function iterates over the image and performs the following
       // on the image pixels
       //   - normalizes the pixel values to be between 0 and 1
       //   - substracts the mean (0.485, 0.456, 0.406) (derived from the mobilenetv2 model configuration)
       //     from the pixel values
       //   - divides by pixel values by the standard deviation (0.229, 0.224, 0.225) (derived from the
       //     mobilenetv2 model configuration)
       // Values are written to the given buffer starting at the provided offset.
       // Values are written as follows
       // |____|____________________|__________________| <--- buffer
       //      ʌ                                         <--- offset
       //                           ʌ                    <--- offset + width * height * channels
       // |____|rrrrrr|_____________|__________________| <--- red channel read in column major order
       // |____|______|gggggg|______|__________________| <--- green channel read in column major order
       // |____|______|______|bbbbbb|__________________| <--- blue channel read in column major order
    
       val width: Int = bitmap.getWidth()
       val height: Int = bitmap.getHeight()
       val stride: Int = width * height
    
       for (x in 0 until width) {
          for (y in 0 until height) {
                val color: Int = bitmap.getPixel(y, x)
                val index = offset + (x * height + y)
    
                // Subtract the mean and divide by the standard deviation
                // Values for mean and standard deviation used for
                // the movilenetv2 model.
                buffer.put(index + stride * 0, ((Color.red(color).toFloat() / 255f) - 0.485f) / 0.229f)
                buffer.put(index + stride * 1, ((Color.green(color).toFloat() / 255f) - 0.456f) / 0.224f)
                buffer.put(index + stride * 2, ((Color.blue(color).toFloat() / 255f) - 0.406f) / 0.225f)
          }
       }
    }
    

    d. 從 Uri 獲取 Bitmap

    fun bitmapFromUri(uri: Uri, contentResolver: ContentResolver): Bitmap {
       // This function reads the image file at the given uri and decodes it to a bitmap
       val source: ImageDecoder.Source = ImageDecoder.createSource(contentResolver, uri)
       return ImageDecoder.decodeBitmap(source).copy(Bitmap.Config.ARGB_8888, true)
    }
    
  6. 應用程式前端

    a. 對於本教程,我們將使用以下使用者介面元素

    • 訓練和推理按鈕
    • 類別按鈕
    • 狀態訊息文字
    • 影像顯示
    • 進度對話方塊

    b. 本教程不打算展示如何建立圖形使用者介面。因此,我們只需重複使用 GitHub 上提供的檔案。

    c. 將 strings.xml 中的所有字串定義複製到您的 Android Studio 本地的 strings.xml 檔案中。

    d. 將 activity_main.xml 的內容複製到您的 Android Studio 本地的 activity_main.xml 檔案中。

    e. 在 layout 資料夾下建立一個名為 dialog.xml 的新檔案。將 dialog.xml 的內容複製到您 Android Studio 本地新建的 dialog.xml 中。

    f. 本節中剩餘的更改需要在 MainActivity.kt 檔案中進行。

    g. 啟動應用程式

    當應用程式啟動時,將呼叫 onCreate 函式。此函式負責設定會話快取和使用者介面處理程式。

    有關程式碼,請參閱 MainActivity.kt 檔案中的 onCreate 函式。

    h. 自定義類別按鈕處理程式 - 我們希望使用類別按鈕讓使用者選擇他們的自定義影像進行訓練。我們需要為這些按鈕新增監聽器以實現此目的。這些監聽器將精確地完成這項工作。

    請參考 MainActivity.kt 中的這些按鈕處理程式

    • onClassAClickedListener
    • onClassBClickedListener
    • onClassXClickedListener
    • onClassYClickedListener

    i. 個性化自定義類別標籤

    預設情況下,自定義類別標籤是 [A, B, X, Y]。但是,讓我們允許使用者為清晰起見重新命名這些標籤。這透過長按監聽器實現(在 MainActivity.kt 中定義)

    • onClassALongClickedListener
    • onClassBLongClickedListener
    • onClassXLongClickedListener
    • onClassYLongClickedListener

    j. 切換自定義類別。

    當自定義類別開關關閉時,執行預打包的動物資料集。當開關開啟時,使用者需要提供自己的資料集進行訓練。為了處理這種轉換,onCustomClassSettingChangedListener 開關處理程式在 MainActivity.kt 中實現。

    k. 訓練處理程式

    當每個類別至少有 1 張影像時,可以啟用 Train 按鈕。單擊 Train 按鈕時,將針對選定的影像啟動訓練。訓練處理程式負責

    • 將訓練影像收集到一個容器中。
    • 打亂影像的順序。
    • 裁剪和調整影像大小。
    • 歸一化影像。
    • 對影像進行批處理。
    • 執行訓練迴圈(迴圈呼叫 C++ performTraining 函式)。

    MainActivity.kt 中定義的 onTrainButtonClickedListener 函式完成上述操作。

    l. 推理處理程式

    訓練完成後,使用者可以點選 Infer 按鈕來推理任何影像。推理處理程式負責

    • 收集推理影像。
    • 裁剪和調整影像大小。
    • 歸一化影像。
    • 呼叫 C++ performInference 函式。
    • 向用戶介面報告推理輸出。

    這透過 MainActivity.kt 中的 onInferenceButtonClickedListener 函式實現。

    m. 上述所有活動的處理程式

    一旦為推理或自定義類別選擇了影像,就需要對其進行處理。MainActivity.kt 中定義的 onActivityResult 函式執行此操作。

    n. 最後一件事。在 AndroidManifest.xml 檔案中新增以下內容以使用攝像頭

    <uses-permission android:name="android.permission.CAMERA" />
    <uses-feature android:name="android.hardware.camera" />
    

訓練階段 - 在裝置上執行應用程式

  1. 在裝置上執行應用程式

    a. 讓我們將 Android 裝置連線到機器並在裝置上執行應用程式。

    b. 在裝置上啟動應用程式應該看起來像這樣

    Barebones ORT Personalize app

  2. 使用預載入資料集訓練 - 動物

    a. 讓我們透過在裝置上啟動應用程式,開始使用預載入的動物資料集進行訓練。

    b. 切換底部的 Custom classes 開關。

    c. 類別標籤將更改為 DogCatElephantCow

    d. 執行 Training 並等待進度對話方塊消失(訓練完成後)。

    e. 現在使用您相簿中的任何動物影像進行推理。

    ORT Personalize app with an image of a cow

    從上圖可以看出,模型正確預測為 Cow

  3. 使用自定義資料集訓練 - 名人

    a. 從網上下載湯姆·克魯斯 (Tom Cruise)、萊昂納多·迪卡普里奧 (Leonardo DiCaprio)、瑞安·雷諾茲 (Ryan Reynolds) 和布拉德·皮特 (Brad Pitt) 的影像。

    b. 確保透過關閉並重新啟動應用程式來啟動一個新的應用程式會話。

    c. 應用程式啟動後,透過長按將四個類別分別重新命名為 TomLeoRyanBrad

    d. 點選每個類別的按鈕,並選擇與該名人相關的影像。每個類別我們可以使用大約 10~15 張影像。

    e. 點選 Train 按鈕,讓應用程式從提供的資料中學習。

    f. 訓練完成後,我們可以點選 Infer 按鈕並提供應用程式尚未見過的影像。

    g. 就這樣!希望應用程式正確分類了影像。

    an image classification app with Tom Cruise in the middle.

結論

恭喜!您已成功構建了一個 Android 應用程式,該應用程式使用 ONNX Runtime 在裝置上學習影像分類。該應用程式也已在 GitHub 上提供,地址為 onnxruntime-training-examples