概述#
onnxruntime-training 的 ORTModule 為使用 PyTorch 前端定義的模型提供了高效能訓練引擎。ORTModule 旨在加速大型模型的訓練,而無需更改模型定義或訓練程式碼。
ORTModule 的目標是為使用者 PyTorch 程式中的一個或多個 torch.nn.Module 物件提供即插即用的替代方案,並使用 ORT 執行這些模組的前向和後向傳遞。
因此,使用者將能夠使用 ORT 加速他們的訓練指令碼,而無需修改他們的訓練迴圈。
使用者將能夠使用標準的 PyTorch 除錯技術來解決收斂問題,例如,透過探測模型引數上計算出的梯度。
以下程式碼示例說明了如何在使用者的訓練指令碼中使用 ORTModule,在整個模型可以解除安裝到 ONNX Runtime 的簡單情況下。
from onnxruntime.training import ORTModule
# Original PyTorch model
class NeuralNet(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
...
def forward(self, x):
...
model = NeuralNet(input_size=784, hidden_size=500, num_classes=10)
model = ORTModule(model) # The only change to the original PyTorch script
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
# Training Loop is unchanged
for data, target in data_loader:
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()