在裝置上訓練模型#

生成訓練工件後,可以使用 onnxruntime 訓練 Python API 在裝置上訓練模型。

預期的訓練工件包括

  1. 訓練 ONNX 模型

  2. 檢查點狀態

  3. 最佳化器 ONNX 模型

  4. 評估 ONNX 模型(可選)

示例用法

from onnxruntime.training.api import CheckpointState, Module, Optimizer

# Load the checkpoint state
state = CheckpointState.load_checkpoint(path_to_the_checkpoint_artifact)

# Create the module
module = Module(path_to_the_training_model,
                state,
                path_to_the_eval_model,
                device="cpu")

optimizer = Optimizer(path_to_the_optimizer_model, module)

# Training loop
for ...:
    module.train()
    training_loss = module(...)
    optimizer.step()
    module.lazy_reset_grad()

# Eval
module.eval()
eval_loss = module(...)

# Save the checkpoint
CheckpointState.save_checkpoint(state, path_to_the_checkpoint_artifact)
class onnxruntime.training.api.checkpoint_state.Parameter(parameter: Parameter, state: CheckpointState)[source]#

基類:object

表示模型引數的類

此類表示模型引數,並提供對其資料、梯度及其他屬性的訪問。此類不應直接例項化。相反,它由 CheckpointState 物件返回。

引數:
  • parameter – 持有底層引數資料的 C.Parameter 物件。

  • state – 持有底層會話狀態的 C.CheckpointState 物件。

property name: str#

引數的名稱

property data: ndarray#

引數的資料

property grad: ndarray#

引數的梯度

property requires_grad: bool#

引數是否需要計算其梯度

__repr__() str[source]#

返回引數的字串表示

class onnxruntime.training.api.checkpoint_state.Parameters(state: CheckpointState)[source]#

基類:object

包含所有模型引數的類

此類包含所有模型引數並提供對其的訪問。此類不應直接例項化。相反,它由 CheckpointState 的 parameters 屬性返回。此類的行為類似於字典,並按名稱提供對引數的訪問。

引數:

state – 持有底層會話狀態的 C.CheckpointState 物件。

__getitem__(name: str) Parameter[source]#

獲取與給定名稱關聯的引數

在檢查點狀態的引數中搜索該名稱。

引數:

name – 引數的名稱

返回:

引數的值

丟擲:

KeyError – 如果未找到引數

__setitem__(name: str, value: ndarray) None[source]#

設定給定名稱的引數值

在檢查點狀態的引數中搜索該名稱。如果找到該名稱,則更新其值。

引數:
  • name – 引數的名稱

  • value – 作為 numpy 陣列的引數值

丟擲:

KeyError – 如果未找到引數

__contains__(name: str) bool[source]#

檢查引數是否存在於狀態中

引數:

name – 引數的名稱

返回:

如果名稱是引數,則為 True,否則為 False

__iter__()[source]#

返回屬性的迭代器

__repr__() str[source]#

返回引數的字串表示

__len__() int[source]#

返回引數的數量

class onnxruntime.training.api.checkpoint_state.Properties(state: CheckpointState)[source]#

基類:object

__getitem__(name: str) int | float | str[source]#

獲取與給定名稱關聯的屬性

在檢查點狀態的屬性中搜索該名稱。

引數:

name – 屬性的名稱

返回:

屬性的值

丟擲:

KeyError – 如果未找到屬性

__setitem__(name: str, value: int | float | str) None[source]#

設定給定名稱的屬性值

在檢查點狀態的屬性中搜索該名稱。該值將新增到屬性中或在屬性中更新。

引數:
  • name – 屬性的名稱

  • value – 屬性值。屬性僅支援 int、float 和 str 值。

__contains__(name: str) bool[source]#

檢查屬性是否存在於狀態中

引數:

name – 屬性的名稱

返回:

如果名稱是屬性,則為 True,否則為 False

__iter__()[source]#

返回屬性的迭代器

__repr__() str[source]#

返回屬性的字串表示

__len__() int[source]#

返回屬性的數量

class onnxruntime.training.api.CheckpointState(state: CheckpointState)[source]#

基類:object

包含訓練會話狀態的類

此類包含訓練會話的所有狀態資訊,例如模型引數、其梯度、最佳化器狀態和使用者定義的屬性。

要建立 CheckpointState,請使用 CheckpointState.load_checkpoint 方法。

引數:

state – 包含底層會話狀態的 C.Checkpoint state 物件。

classmethod load_checkpoint(checkpoint_uri: str | os.PathLike) CheckpointState[source]#

從檢查點檔案載入檢查點狀態

檢查點檔案可以是完整的檢查點或名義檢查點。

引數:

checkpoint_uri – 檢查點檔案的路徑。

返回:

檢查點狀態物件。

返回型別:

CheckpointState

classmethod save_checkpoint(state: CheckpointState, checkpoint_uri: str | os.PathLike, include_optimizer_state: bool = False) None[source]#

將檢查點狀態儲存到檢查點檔案

引數:
  • state – 檢查點狀態物件。

  • checkpoint_uri – 檢查點檔案的路徑。

  • include_optimizer_state – 如果為 True,最佳化器狀態也將儲存到檢查點檔案。

property parameters: Parameters#

從檢查點狀態返回模型引數

property properties: Properties#

從檢查點狀態返回屬性

class onnxruntime.training.api.Module(train_model_uri: PathLike, state: CheckpointState, eval_model_uri: Optional[PathLike] = None, device: str = 'cpu', session_options: Optional[SessionOptions] = None)[source]#

基類:object

提供 ONNX 模型訓練和評估方法的訓練器類。

在例項化 Module 類之前,應已使用 onnxruntime.training.artifacts.generate_artifacts 工具生成了訓練工件。

訓練工件包括
  • 訓練模型

  • 評估模型(可選)

  • 最佳化器模型(可選)

  • 檢查點檔案

training#

如果模型處於訓練模式則為 True,如果處於評估模式則為 False。

型別:

bool

引數:
  • train_model_uri – 訓練模型的路徑。

  • state – 檢查點狀態物件。

  • eval_model_uri – 評估模型的路徑。

  • device – 執行模型的裝置。預設為“cpu”。

  • session_options – 模型使用的會話選項。

__call__(*user_inputs) tuple[numpy.ndarray, ...] | numpy.ndarray | tuple[onnxruntime.capi.onnxruntime_inference_collection.OrtValue, ...] | onnxruntime.capi.onnxruntime_inference_collection.OrtValue[source]#

呼叫模型的訓練或評估步驟。

引數:

*user_inputs – 模型的輸入。使用者輸入可以是 numpy 陣列或 OrtValue。

返回:

模型的輸出。

train(mode: bool = True) Module[source]#

將模組設定為訓練模式。

引數:

mode – 是否將模型設定為訓練模式 (True) 或評估模式 (False)。預設值:True。

返回:

self

eval() Module[source]#

將模組設定為評估模式。

返回:

self

lazy_reset_grad()[source]#

惰性重置訓練梯度。

此函式設定模組的內部狀態,以便模組梯度將在下次呼叫 train() 計算新梯度之前被排程重置。

get_contiguous_parameters(trainable_only: bool = False) OrtValue[source]#

建立訓練會話引數的連續緩衝區

引數:

trainable_only – 如果為 True,則只考慮可訓練引數。否則,考慮所有引數。

返回:

訓練會話引數的連續緩衝區。

get_parameters_size(trainable_only: bool = True) int[source]#

返回引數的大小。

引數:

trainable_only – 如果為 True,則只考慮可訓練引數。否則,考慮所有引數。

返回:

引數中原始(例如浮點)元素的數量。

copy_buffer_to_parameters(buffer: OrtValue, trainable_only: bool = True) None[source]#

將 OrtValue 緩衝區複製到訓練會話引數。

如果模組是從名義檢查點載入的,則需要呼叫此函式將更新的引數載入到檢查點以完成它。

引數:

buffer – 要複製到訓練會話引數的 OrtValue 緩衝區。

export_model_for_inferencing(inference_model_uri: str | os.PathLike, graph_output_names: list[str]) None[source]#

匯出模型用於推理。

訓練完成後,此函式可用於刪除 ONNX 模型中訓練特定的節點。具體來說,此函式執行以下操作

  • 解析訓練圖並識別生成給定輸出名稱的節點。

  • 刪除圖中所有後續節點,因為它們與推理圖無關。

引數:
  • inference_model_uri – 推理模型的路徑。

  • graph_output_names – 推理所需的輸出名稱列表。

input_names() list[str][source]#

返回訓練模型或評估模型的輸入名稱。

output_names() list[str][source]#

返回訓練模型或評估模型的輸出名稱。

class onnxruntime.training.api.Optimizer(optimizer_uri: str | os.PathLike, module: Module)[source]#

基類:object

提供根據計算出的梯度更新模型引數方法的類。

引數:
  • optimizer_uri – 最佳化器模型的路徑。

  • model – 要訓練的模組。

step() None[source]#

根據計算出的梯度更新模型引數。

此方法透過在計算梯度的方向上邁出一步來更新模型引數。所使用的最佳化器取決於所提供的最佳化器模型。

set_learning_rate(learning_rate: float) None[source]#

設定最佳化器的學習率。

引數:

learning_rate – 要設定的學習率。

get_learning_rate() float[source]#

獲取最佳化器當前的學習率。

返回:

當前的學習率。

class onnxruntime.training.api.LinearLRScheduler(optimizer: Optimizer, warmup_step_count: int, total_step_count: int, initial_lr: float)[source]#

基類:object

線性更新最佳化器中的學習率

線性學習率排程器透過線性更新的乘法因子將訓練會話中設定的初始學習率衰減到 0。衰減在初始預熱階段之後執行,在該階段中學習率從 0 線性增加到提供的初始學習率。

引數:
  • optimizer – 使用者的 onnxruntime 訓練最佳化器

  • warmup_step_count – 預熱階段的步數。

  • total_step_count – 總訓練步數。

  • initial_lr – 初始學習率。

step() None[source]#

線性更新最佳化器的學習率。

在訓練的每一步都應呼叫此方法,以確保正確調整學習率。