使用自定義 ONNX 運算元匯出 PyTorch 模型
本文件解釋了使用自定義 ONNX Runtime 運算元匯出 PyTorch 模型的過程。目的是匯出包含 ONNX 不支援的運算元的 PyTorch 模型,並擴充套件 ONNX Runtime 以支援這些自定義運算元。
目錄
匯出內建 Contrib 運算元
“Contrib 運算元”指的是內置於大多數 ORT 軟體包中的一組自定義運算元。所有 Contrib 運算元的符號函式應在 pytorch_export_contrib_ops.py 中定義。
要使用這些 Contrib 運算元進行匯出,請在呼叫 torch.onnx.export() 之前呼叫 pytorch_export_contrib_ops.register()。例如:
from onnxruntime.tools import pytorch_export_contrib_ops
import torch
pytorch_export_contrib_ops.register()
torch.onnx.export(...)
匯出自定義運算元
要匯出非 Contrib 運算元或未包含在 pytorch_export_contrib_ops 中的自定義運算元,需要編寫並註冊自定義運算元的符號函式。
我們以 Inverse 運算元為例:
from torch.onnx import register_custom_op_symbolic
def my_inverse(g, self):
return g.op("com.microsoft::Inverse", self)
# register_custom_op_symbolic('<namespace>::inverse', my_inverse, <opset_version>)
register_custom_op_symbolic('::inverse', my_inverse, 1)
<namespace> 是 torch 運算元名稱的一部分。對於標準 torch 運算元,名稱空間可以省略。
com.microsoft 應作為 ONNX Runtime 運算元的自定義 opset 域使用。您可以在運算元註冊期間選擇自定義 opset 版本。
有關編寫符號函式的更多資訊,請參閱 torch.onnx 文件。
使用自定義運算元擴充套件 ONNX Runtime
下一步是在 ONNX Runtime 中新增運算元 Schema 和核心實現。詳見自定義運算元。
端到端測試:匯出和執行
一旦自定義運算元在匯出器中註冊並在 ONNX Runtime 中實現,您應該能夠匯出它並使用 ONNX Runtime 執行它。
下面是一個示例指令碼,用於匯出 Inverse 運算元並作為模型的一部分執行。
匯出的模型包括 ONNX 標準運算元和自定義運算元的組合。
此測試還比較了 PyTorch 模型與 ONNX Runtime 輸出的結果,以測試運算元匯出和實現。
import io
import numpy
import onnxruntime
import torch
class CustomInverse(torch.nn.Module):
def forward(self, x):
return torch.inverse(x) + x
x = torch.randn(3, 3)
# Export model to ONNX
f = io.BytesIO()
torch.onnx.export(CustomInverse(), (x,), f)
model = CustomInverse()
pt_outputs = model(x)
# Run the exported model with ONNX Runtime
ort_sess = onnxruntime.InferenceSession(f.getvalue())
ort_inputs = dict((ort_sess.get_inputs()[i].name, input.cpu().numpy()) for i, input in enumerate((x,)))
ort_outputs = ort_sess.run(None, ort_inputs)
# Validate PyTorch and ONNX Runtime results
numpy.testing.assert_allclose(pt_outputs.cpu().numpy(), ort_outputs[0], rtol=1e-03, atol=1e-05)
預設情況下,自定義 opset 的 opset 版本將設定為 1。如果您想將自定義運算元匯出到更高的 opset 版本,可以在呼叫匯出 API 時使用 custom_opsets 引數指定自定義 opset 域和版本。請注意,這與預設 ONNX 域關聯的 opset 版本不同。
torch.onnx.export(CustomInverse(), (x,), f, custom_opsets={"com.microsoft": 5})
請注意,您可以將自定義運算元匯出到任何大於等於註冊時使用的 opset 版本。