使用 ORTModule 開始大型模型訓練
ONNX Runtime Training 的 ORTModule 為使用 PyTorch 前端定義的模型提供高效能訓練引擎。ORTModule 旨在加速大型模型的訓練,無需更改模型定義,只需對整個訓練指令碼進行一行程式碼更改(即 ORTModule 封裝)。
使用 ORTModule 類封裝器,ONNX Runtime 透過最佳化過的自動匯出的 ONNX 計算圖執行訓練指令碼的前向和後向傳播。
ORT 訓練示例
在此示例中,我們將介紹如何使用 ORT 訓練 PyTorch 模型。
# Installs the torch_ort and onnxruntime-training Python packages
pip install torch-ort
# Configures onnxruntime-training to work with user's PyTorch installation
python -m torch_ort.configure
注意: 這會安裝預設版本的 torch-ort 和 onnxruntime-training 包,這些包對映到特定版本的 CUDA 庫。請參閱 onnxruntime.ai 中的安裝選項。
- 在
train.py中新增 ORTModule
+ from torch_ort import ORTModule
.
.
.
- model = build_model() # Users PyTorch model
+ model = ORTModule(build_model())