開始將 TensorFlow 轉換為 ONNX
TensorFlow 模型(包括 Keras 和 TFLite 模型)可以使用 tf2onnx 工具轉換為 ONNX。
本教程的完整程式碼可在此處檢視。
安裝
首先,在已安裝 TensorFlow 的 Python 環境中安裝 tf2onnx。
pip install tf2onnx (穩定版)
或者
pip install git+https://github.com/onnx/tensorflow-onnx (GitHub 最新版)
轉換模型
Keras 模型和 tf 函式
Keras 模型和 tf 函式可以直接在 Python 中進行轉換
import tensorflow as tf
import tf2onnx
import onnx
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(4, activation="relu"))
input_signature = [tf.TensorSpec([3, 3], tf.float32, name='x')]
# Use from_function for tf functions
onnx_model, _ = tf2onnx.convert.from_keras(model, input_signature, opset=13)
onnx.save(onnx_model, "dst/path/model.onnx")
有關完整文件,請參閱 Python API 參考。
SavedModel
使用以下命令轉換 TensorFlow 儲存的模型:
python -m tf2onnx.convert --saved-model path/to/savedmodel --output dst/path/model.onnx --opset 13
path/to/savedmodel 應該是 **包含** saved_model.pb 的 **目錄路徑**
有關完整文件,請參閱 CLI 參考。
TFLite
tf2onnx 支援轉換 tflite 模型。
python -m tf2onnx.convert --tflite path/to/model.tflite --output dst/path/model.onnx --opset 13
注意:運算元集版本號
如果使用的 ONNX 運算元集版本過低,某些 TensorFlow 運算元將無法轉換。**請使用與您的應用程式相容的最新運算元集版本。** 有關完整的轉換說明,請參閱 tf2onnx README。
驗證轉換後的模型
使用以下命令安裝 onnxruntime:
pip install onnxruntime
使用以下模板在 Python 中測試您的模型
import onnxruntime as ort
import numpy as np
# Change shapes and types to match model
input1 = np.zeros((1, 100, 100, 3), np.float32)
# Start from ORT 1.10, ORT requires explicitly setting the providers parameter if you want to use execution providers
# other than the default CPU provider (as opposed to the previous behavior of providers getting set/registered by default
# based on the build flags) when instantiating InferenceSession.
# Following code assumes NVIDIA GPU is available, you can specify other execution providers or don't include providers parameter
# to use default CPU provider.
sess = ort.InferenceSession("dst/path/model.onnx", providers=["CUDAExecutionProvider"])
# Set first argument of sess.run to None to use all model outputs in default order
# Input/output names are printed by the CLI and can be set with --rename-inputs and --rename-outputs
# If using the python API, names are determined from function arg names or TensorSpec names.
results_ort = sess.run(["output1", "output2"], {"input1": input1})
import tensorflow as tf
model = tf.saved_model.load("path/to/savedmodel")
results_tf = model(input1)
for ort_res, tf_res in zip(results_ort, results_tf):
np.testing.assert_allclose(ort_res, tf_res, rtol=1e-5, atol=1e-5)
print("Results match")
轉換失敗
如果您的模型轉換失敗,請閱讀我們的 README 和 故障排除指南。如果仍然失敗,請隨時在 GitHub 上提出問題。歡迎為 tf2onnx 做出貢獻!