使用 ORTModule 開始大型模型訓練

ONNX Runtime TrainingORTModule 為使用 PyTorch 前端定義的模型提供高效能訓練引擎。ORTModule 旨在加速大型模型的訓練,無需更改模型定義,只需對整個訓練指令碼進行一行程式碼更改(即 ORTModule 封裝)。

使用 ORTModule 類封裝器,ONNX Runtime 透過最佳化過的自動匯出的 ONNX 計算圖執行訓練指令碼的前向和後向傳播。

ORT 訓練示例

在此示例中,我們將介紹如何使用 ORT 訓練 PyTorch 模型。

# Installs the torch_ort and onnxruntime-training Python packages
pip install torch-ort
# Configures onnxruntime-training to work with user's PyTorch installation
python -m torch_ort.configure

注意: 這會安裝預設版本的 torch-ortonnxruntime-training 包,這些包對映到特定版本的 CUDA 庫。請參閱 onnxruntime.ai 中的安裝選項。

  • train.py 中新增 ORTModule
+  from torch_ort import ORTModule
   .
   .
   .
-  model = build_model() # Users PyTorch model
+  model = ORTModule(build_model())

示例

ONNX Runtime 訓練示例