在 Java 中使用 MNIST 進行字元識別

這是一個簡單的教程,介紹如何開始使用 ONNX 模型對給定輸入資料執行推理。模型通常使用任何知名的訓練框架進行訓練並匯出為 ONNX 格式。

請注意,下面提供的程式碼使用 Java 10 及更高版本可用的語法。Java 8 語法類似但更冗長。要開始評分會話,首先建立 OrtEnvironment,然後使用 OrtSession 類開啟一個會話,將模型的檔案路徑作為引數傳入。

    var env = OrtEnvironment.getEnvironment();
    var session = env.createSession("model.onnx",new OrtSession.SessionOptions());

建立會話後,您可以使用 OrtSession 物件的 run 方法執行查詢。目前我們支援 OnnxTensor 輸入,模型可以生成 OnnxTensorOnnxSequenceOnnxMap 輸出。後兩種情況在對 Scikit-learn 等框架生成的模型進行評分時更常見。run 呼叫需要一個 Map<String,OnnxTensor>,其中鍵與模型中儲存的輸入節點名稱匹配。這些可以透過在例項化會話上呼叫 session.getInputNames()session.getInputInfo() 來檢視。run 呼叫會生成一個 Result 物件,其中包含一個表示輸出的 Map<String,OnnxValue>Result 物件是 AutoCloseable 的,可以在 try-with-resources 語句中使用,以防止引用洩漏。一旦 Result 物件關閉,它的所有子 OnnxValue 也會關閉。

    OnnxTensor t1,t2;
    var inputs = Map.of("name1",t1,"name2",t2);
    try (var results = session.run(inputs)) {
    // manipulate the results
   }

您可以透過多種方式將輸入資料載入到 OnnxTensor 物件中。最有效的方法是使用 java.nio.Buffer,但也可以使用多維陣列。如果使用陣列構造,陣列不得是不規則的。

    FloatBuffer sourceData;  // assume your data is loaded into a FloatBuffer
    long[] dimensions;       // and the dimensions of the input are stored here
    var tensorFromBuffer = OnnxTensor.createTensor(env,sourceData,dimensions);

    float[][] sourceArray = new float[28][28];  // assume your data is loaded into a float array 
    var tensorFromArray = OnnxTensor.createTensor(env,sourceArray);

這是一個完整的示例程式,它在預訓練的 MNIST 模型上執行推理。

在 GPU 或使用其他提供者上執行(可選)

要啟用 GPU 等其他執行提供者,只需在建立 OrtSession 時開啟 SessionOptions 上的相應標誌即可。

    int gpuDeviceId = 0; // The GPU device ID to execute on
    var sessionOptions = new OrtSession.SessionOptions();
    sessionOptions.addCUDA(gpuDeviceId);
    var session = environment.createSession("model.onnx", sessionOptions);

執行提供者將按其啟用的順序優先選擇。