Java ORT 入門
ONNX Runtime 提供 Java 繫結,用於在 JVM 上執行 ONNX 模型推理。
目錄
支援的版本
Java 8 或更高版本
構建版本
釋出工件釋出到 Maven Central,可作為大多數 Java 構建工具的依賴項使用。這些工件支援一些流行平臺。
| 工件 | 描述 | 支援的平臺 |
|---|---|---|
| com.microsoft.onnxruntime:onnxruntime | CPU | Windows x64, Linux x64, macOS x64 |
| com.microsoft.onnxruntime:onnxruntime_gpu | GPU (CUDA) | Windows x64, Linux x64 |
有關本地構建的更多詳細資訊,請參閱Java API 開發文件。
有關共享庫載入機制的自定義,請參閱高階載入說明。
API 參考
Javadoc 可在此處獲取:here。
示例
示例實現位於 src/test/java/sample/ScoreMNIST.java。
編譯後,示例程式碼需要以下引數 ScoreMNIST [path-to-mnist-model] [path-to-mnist] [scikit-learn-flag]。MNIST 預計為 libsvm 格式。如果提供了可選的 scikit-learn 標誌,則模型應由 skl2onnx 生成(因此需要一個扁平的特徵向量,併產生結構化輸出);否則,模型應為 PyTorch 中的 CNN(需要 [1][1][28][28] 輸入,併產生機率向量)。在 testdata 中提供了兩個示例模型:cnn_mnist_pytorch.onnx 和 lr_mnist_scikit.onnx。第一個是使用 PyTorch 訓練的 LeNet5 風格的 CNN,第二個是使用 scikit-learn 訓練的邏輯迴歸。
單元測試包含載入模型、檢查輸入/輸出節點形狀和型別以及構造用於評分的張量的幾個示例。
入門
這是一個簡單的入門教程,用於在給定輸入資料的情況下,對現有 ONNX 模型執行推理。模型通常使用任何知名的訓練框架進行訓練並匯出為 ONNX 格式。
請注意,下面提供的程式碼使用 Java 10 及更高版本可用的語法。Java 8 語法類似但更冗長。
要開始評分會話,首先建立 OrtEnvironment,然後使用 OrtSession 類開啟一個會話,將模型的檔案路徑作為引數傳入。
var env = OrtEnvironment.getEnvironment();
var session = env.createSession("model.onnx",new OrtSession.SessionOptions());
一旦建立了會話,就可以使用 OrtSession 物件的 run 方法執行查詢。目前我們支援 OnnxTensor 輸入,模型可以生成 OnnxTensor、OnnxSequence 或 OnnxMap 輸出。後兩種情況在對 scikit-learn 等框架生成的模型進行評分時更常見。
run 呼叫需要一個 Map<String,OnnxTensor>,其中鍵與模型中儲存的輸入節點名稱匹配。這些可以透過在例項化會話上呼叫 session.getInputNames() 或 session.getInputInfo() 來檢視。run 呼叫會產生一個 Result 物件,其中包含一個表示輸出的 Map<String,OnnxValue>。 Result 物件是 AutoCloseable 的,可以在 try-with-resources 語句中使用,以防止引用洩漏。一旦 Result 物件關閉,其所有子 OnnxValue 也會關閉。
OnnxTensor t1,t2;
var inputs = Map.of("name1",t1,"name2",t2);
try (var results = session.run(inputs)) {
// manipulate the results
}
您可以透過多種方式將輸入資料載入到 OnnxTensor 物件中。最有效的方法是使用 java.nio.Buffer,但也可以使用多維陣列。如果使用陣列構造,陣列不得是不規則的。
FloatBuffer sourceData; // assume your data is loaded into a FloatBuffer
long[] dimensions; // and the dimensions of the input are stored here
var tensorFromBuffer = OnnxTensor.createTensor(env,sourceData,dimensions);
float[][] sourceArray = new float[28][28]; // assume your data is loaded into a float array
var tensorFromArray = OnnxTensor.createTensor(env,sourceArray);
這是一個完整的示例程式,它在一個預訓練的 MNIST 模型上執行推理。
在 GPU 或使用其他提供程式上執行 (可選)
要啟用 GPU 等其他執行提供程式,只需在建立 OrtSession 時在 SessionOptions 上開啟相應的標誌即可。
int gpuDeviceId = 0; // The GPU device ID to execute on
var sessionOptions = new OrtSession.SessionOptions();
sessionOptions.addCUDA(gpuDeviceId);
var session = environment.createSession("model.onnx", sessionOptions);
執行提供程式按其啟用順序確定優先順序。