在裝置上訓練模型#
生成訓練工件後,可以使用 onnxruntime 訓練 Python API 在裝置上訓練模型。
預期的訓練工件包括
訓練 ONNX 模型
檢查點狀態
最佳化器 ONNX 模型
評估 ONNX 模型(可選)
示例用法
from onnxruntime.training.api import CheckpointState, Module, Optimizer
# Load the checkpoint state
state = CheckpointState.load_checkpoint(path_to_the_checkpoint_artifact)
# Create the module
module = Module(path_to_the_training_model,
state,
path_to_the_eval_model,
device="cpu")
optimizer = Optimizer(path_to_the_optimizer_model, module)
# Training loop
for ...:
module.train()
training_loss = module(...)
optimizer.step()
module.lazy_reset_grad()
# Eval
module.eval()
eval_loss = module(...)
# Save the checkpoint
CheckpointState.save_checkpoint(state, path_to_the_checkpoint_artifact)
- class onnxruntime.training.api.checkpoint_state.Parameter(parameter: Parameter, state: CheckpointState)[source]#
基類:
object表示模型引數的類
此類表示模型引數,並提供對其資料、梯度及其他屬性的訪問。此類不應直接例項化。相反,它由 CheckpointState 物件返回。
- 引數:
parameter – 持有底層引數資料的 C.Parameter 物件。
state – 持有底層會話狀態的 C.CheckpointState 物件。
- class onnxruntime.training.api.checkpoint_state.Parameters(state: CheckpointState)[source]#
基類:
object包含所有模型引數的類
此類包含所有模型引數並提供對其的訪問。此類不應直接例項化。相反,它由 CheckpointState 的 parameters 屬性返回。此類的行為類似於字典,並按名稱提供對引數的訪問。
- 引數:
state – 持有底層會話狀態的 C.CheckpointState 物件。
- __getitem__(name: str) Parameter[source]#
獲取與給定名稱關聯的引數
在檢查點狀態的引數中搜索該名稱。
- 引數:
name – 引數的名稱
- 返回:
引數的值
- 丟擲:
KeyError – 如果未找到引數
- class onnxruntime.training.api.checkpoint_state.Properties(state: CheckpointState)[source]#
基類:
object- __getitem__(name: str) int | float | str[source]#
獲取與給定名稱關聯的屬性
在檢查點狀態的屬性中搜索該名稱。
- 引數:
name – 屬性的名稱
- 返回:
屬性的值
- 丟擲:
KeyError – 如果未找到屬性
- class onnxruntime.training.api.CheckpointState(state: CheckpointState)[source]#
基類:
object包含訓練會話狀態的類
此類包含訓練會話的所有狀態資訊,例如模型引數、其梯度、最佳化器狀態和使用者定義的屬性。
要建立 CheckpointState,請使用 CheckpointState.load_checkpoint 方法。
- 引數:
state – 包含底層會話狀態的 C.Checkpoint state 物件。
- classmethod load_checkpoint(checkpoint_uri: str | os.PathLike) CheckpointState[source]#
從檢查點檔案載入檢查點狀態
檢查點檔案可以是完整的檢查點或名義檢查點。
- 引數:
checkpoint_uri – 檢查點檔案的路徑。
- 返回:
檢查點狀態物件。
- 返回型別:
- classmethod save_checkpoint(state: CheckpointState, checkpoint_uri: str | os.PathLike, include_optimizer_state: bool = False) None[source]#
將檢查點狀態儲存到檢查點檔案
- 引數:
state – 檢查點狀態物件。
checkpoint_uri – 檢查點檔案的路徑。
include_optimizer_state – 如果為 True,最佳化器狀態也將儲存到檢查點檔案。
- property parameters: Parameters#
從檢查點狀態返回模型引數
- property properties: Properties#
從檢查點狀態返回屬性
- class onnxruntime.training.api.Module(train_model_uri: PathLike, state: CheckpointState, eval_model_uri: Optional[PathLike] = None, device: str = 'cpu', session_options: Optional[SessionOptions] = None)[source]#
基類:
object提供 ONNX 模型訓練和評估方法的訓練器類。
在例項化 Module 類之前,應已使用 onnxruntime.training.artifacts.generate_artifacts 工具生成了訓練工件。
- 訓練工件包括
訓練模型
評估模型(可選)
最佳化器模型(可選)
檢查點檔案
- 引數:
train_model_uri – 訓練模型的路徑。
state – 檢查點狀態物件。
eval_model_uri – 評估模型的路徑。
device – 執行模型的裝置。預設為“cpu”。
session_options – 模型使用的會話選項。
- __call__(*user_inputs) tuple[numpy.ndarray, ...] | numpy.ndarray | tuple[onnxruntime.capi.onnxruntime_inference_collection.OrtValue, ...] | onnxruntime.capi.onnxruntime_inference_collection.OrtValue[source]#
呼叫模型的訓練或評估步驟。
- 引數:
*user_inputs – 模型的輸入。使用者輸入可以是 numpy 陣列或 OrtValue。
- 返回:
模型的輸出。
- train(mode: bool = True) Module[source]#
將模組設定為訓練模式。
- 引數:
mode – 是否將模型設定為訓練模式 (True) 或評估模式 (False)。預設值:True。
- 返回:
self
- get_contiguous_parameters(trainable_only: bool = False) OrtValue[source]#
建立訓練會話引數的連續緩衝區
- 引數:
trainable_only – 如果為 True,則只考慮可訓練引數。否則,考慮所有引數。
- 返回:
訓練會話引數的連續緩衝區。
- get_parameters_size(trainable_only: bool = True) int[source]#
返回引數的大小。
- 引數:
trainable_only – 如果為 True,則只考慮可訓練引數。否則,考慮所有引數。
- 返回:
引數中原始(例如浮點)元素的數量。
- copy_buffer_to_parameters(buffer: OrtValue, trainable_only: bool = True) None[source]#
將 OrtValue 緩衝區複製到訓練會話引數。
如果模組是從名義檢查點載入的,則需要呼叫此函式將更新的引數載入到檢查點以完成它。
- 引數:
buffer – 要複製到訓練會話引數的 OrtValue 緩衝區。
- class onnxruntime.training.api.Optimizer(optimizer_uri: str | os.PathLike, module: Module)[source]#
基類:
object提供根據計算出的梯度更新模型引數方法的類。
- 引數:
optimizer_uri – 最佳化器模型的路徑。
model – 要訓練的模組。
- class onnxruntime.training.api.LinearLRScheduler(optimizer: Optimizer, warmup_step_count: int, total_step_count: int, initial_lr: float)[source]#
基類:
object線性更新最佳化器中的學習率
線性學習率排程器透過線性更新的乘法因子將訓練會話中設定的初始學習率衰減到 0。衰減在初始預熱階段之後執行,在該階段中學習率從 0 線性增加到提供的初始學習率。
- 引數:
optimizer – 使用者的 onnxruntime 訓練最佳化器
warmup_step_count – 預熱階段的步數。
total_step_count – 總訓練步數。
initial_lr – 初始學習率。