ORTTrainingSession
Objective-C
@interface ORTTrainingSession : NSObject
Swift
class ORTTrainingSession : NSObject
訓練器類,提供用於訓練、評估和最佳化 ONNX 模型的方法。
訓練會話需要四種訓練工件
- 訓練 ONNX 模型
- 評估 ONNX 模型(可選)
- 最佳化器 ONNX 模型
- 檢查點目錄
onnxruntime-training python 工具可用於生成上述訓練工件。
自 1.16 版本可用。
注意
此類別僅在啟用訓練 API 時可用。-
不可用
宣告
Objective-C
- (instancetype)init NS_UNAVAILABLE; -
從訓練工件建立訓練會話,可用於開始或恢復訓練。
此初始化器根據提供的環境和會話選項例項化訓練會話,可用於從給定的檢查點狀態開始或恢復訓練。檢查點狀態表示訓練會話的引數,如果需要,這些引數將被移動到會話選項中指定的裝置。
注意
請注意,使用檢查點狀態建立的訓練會話將此狀態用於儲存整個訓練狀態(包括模型引數、其梯度、最佳化器狀態和屬性)。訓練會話會持有檢查點狀態的強(擁有)指標。
宣告
Objective-C
- (nullable instancetype)initWithEnv:(nonnull ORTEnv *)env sessionOptions: (nullable ORTSessionOptions *)sessionOptions checkpoint:(nonnull ORTCheckpoint *)checkpoint trainModelPath:(nonnull NSString *)trainModelPath evalModelPath:(nullable NSString *)evalModelPath optimizerModelPath:(nullable NSString *)optimizerModelPath error:(NSError *_Nullable *_Nullable)error;Swift
init(env: ORTEnv, sessionOptions: ORTSessionOptions?, checkpoint: ORTCheckpoint, trainModelPath: String, evalModelPath: String?, optimizerModelPath: String?) throws引數
env用於訓練會話的
ORTEnv例項。sessionOptions用於訓練會話的可選
ORTSessionOptions。checkpoint用作訓練起點的訓練狀態。
trainModelPath訓練 ONNX 模型的路徑。
evalModelPath評估 ONNX 模型的路徑。
optimizerModelPath用於執行梯度下降的最佳化器 ONNX 模型的路徑。
error如果發生錯誤,設定可選的錯誤資訊。
返回值
例項,如果發生錯誤則為 nil。
-
執行一個訓練步驟,相當於一個步驟中的前向和後向傳播。
訓練步驟計算訓練模型的輸出和給定輸入值下可訓練引數的梯度。訓練步驟是根據提供給訓練會話的訓練模型執行的。它等同於在一個步驟中執行前向和後向傳播。計算出的梯度儲存在訓練會話狀態中,以便後續可由
optimizerStep消耗。可以透過呼叫lazyResetGrad方法延遲重置梯度。宣告
引數
inputs訓練模型的輸入值。
error如果發生錯誤,設定可選的錯誤資訊。
返回值
訓練模型的輸出值。
-
延遲將所有可訓練引數的梯度重置為零。
呼叫此方法會設定訓練會話的內部狀態,以便在下次呼叫
trainStep方法計算新梯度之前,將 ORTCheckpoint 中可訓練引數的梯度安排重置。宣告
Objective-C
- (BOOL)lazyResetGradWithError:(NSError *_Nullable *_Nullable)error;Swift
func lazyResetGrad() throws引數
error如果發生錯誤,設定可選的錯誤資訊。
返回值
如果梯度成功重置則為 YES,否則為 NO。
-
使用最佳化器模型對可訓練引數執行權重更新。最佳化器步驟是根據提供給訓練會話的最佳化器模型執行的。更新後的引數儲存在訓練狀態中,以便下次呼叫
trainStep方法時使用。宣告
Objective-C
- (BOOL)optimizerStepWithError:(NSError *_Nullable *_Nullable)error;Swift
func optimizerStep() throws引數
error如果發生錯誤,設定可選的錯誤資訊。
返回值
如果最佳化器步驟成功執行則為 YES,否則為 NO。
-
返回訓練模型的使用者輸入名稱,這些名稱可與提供給
trainStep的ORTValue相關聯。宣告
Objective-C
- (nullable NSArray<NSString *> *)getTrainInputNamesWithError: (NSError *_Nullable *_Nullable)error;Swift
func getTrainInputNames() throws -> [String]引數
error如果發生錯誤,設定可選的錯誤資訊。
返回值
訓練模型的使用者輸入名稱。
-
返回評估模型的使用者輸入名稱,這些名稱可與提供給
evalStep的ORTValue相關聯。宣告
Objective-C
- (nullable NSArray<NSString *> *)getEvalInputNamesWithError: (NSError *_Nullable *_Nullable)error;Swift
func getEvalInputNames() throws -> [String]引數
error如果發生錯誤,設定可選的錯誤資訊。
返回值
評估模型的使用者輸入名稱。
-
返回訓練模型的使用者輸出名稱,這些名稱可與
trainStep返回的ORTValue相關聯。宣告
Objective-C
- (nullable NSArray<NSString *> *)getTrainOutputNamesWithError: (NSError *_Nullable *_Nullable)error;Swift
func getTrainOutputNames() throws -> [String]引數
error如果發生錯誤,設定可選的錯誤資訊。
返回值
訓練模型的使用者輸出名稱。
-
返回評估模型的使用者輸出名稱,這些名稱可與
evalStep返回的ORTValue相關聯。宣告
Objective-C
- (nullable NSArray<NSString *> *)getEvalOutputNamesWithError: (NSError *_Nullable *_Nullable)error;Swift
func getEvalOutputNames() throws -> [String]引數
error如果發生錯誤,設定可選的錯誤資訊。
返回值
評估模型的使用者輸出名稱。
-
為訓練會話註冊一個線性學習率排程器。
排程器在訓練過程中將學習率從初始值逐漸降低到零。降低是透過將當前學習率乘以一個線性更新因子來執行的。在降低之前,學習率在預熱階段從零逐漸增加到初始值。
宣告
Objective-C
- (BOOL) registerLinearLRSchedulerWithWarmupStepCount:(int64_t)warmupStepCount totalStepCount:(int64_t)totalStepCount initialLr:(float)initialLr error:(NSError *_Nullable *_Nullable) error;Swift
func registerLinearLRScheduler(withWarmupStepCount warmupStepCount: Int64, totalStepCount: Int64, initialLr: Float) throws引數
warmupStepCount執行線性預熱的步數。
totalStepCount執行線性衰減的總步數。
initialLr初始學習率。
error如果發生錯誤,設定可選的錯誤資訊。
返回值
如果排程器成功註冊則為 YES,否則為 NO。
-
根據已註冊的學習率排程器更新學習率。
執行一個排程器步驟,更新訓練會話正在使用的學習率。此函式通常應在每輪呼叫最佳化器步驟之前呼叫,或根據需要更新訓練會話正在使用的學習率。
注意
必須首先註冊一個有效的預定義學習率排程器才能呼叫此方法。
宣告
Objective-C
- (BOOL)schedulerStepWithError:(NSError *_Nullable *_Nullable)error;Swift
func schedulerStep() throws引數
error如果發生錯誤,設定可選的錯誤資訊。
返回值
如果排程器步驟成功執行則為 YES,否則為 NO。
-
返回訓練會話當前使用的學習率。
宣告
Objective-C
- (float)getLearningRateWithError:(NSError *_Nullable *_Nullable)error;Swift
func getLearningRate() throws -> Float引數
error如果發生錯誤,設定可選的錯誤資訊。
返回值
當前學習率,如果發生錯誤則為 0.0f。
-
設定訓練會話正在使用的學習率。
當前學習率由訓練會話維護,並可透過呼叫此方法並傳入所需學習率來覆蓋。當註冊了有效的學習率排程器時,不應使用此函式。它應僅用於設定自定義學習率排程器派生的學習率,或設定在整個訓練會話中使用的恆定學習率。
注意
它不設定預定義學習率排程器可能需要的初始學習率。要為學習率排程器設定初始學習率,請使用
registerLinearLRScheduler方法。宣告
Objective-C
- (BOOL)setLearningRate:(float)lr error:(NSError *_Nullable *_Nullable)error;Swift
func setLearningRate(_ lr: Float) throws引數
lr訓練會話將使用的學習率。
error如果發生錯誤,設定可選的錯誤資訊。
返回值
如果學習率成功設定則為 YES,否則為 NO。
-
返回一個包含所有訓練狀態引數副本的連續緩衝區。
宣告
Objective-C
- (nullable ORTValue *)toBufferWithTrainable:(BOOL)onlyTrainable error: (NSError *_Nullable *_Nullable)error;Swift
func toBuffer(withTrainable onlyTrainable: Bool) throws -> ORTValue引數
onlyTrainable如果為 YES,則返回一個僅包含可訓練引數的緩衝區;否則,返回一個包含所有引數的緩衝區。
error如果發生錯誤,設定可選的錯誤資訊。
返回值
包含所有訓練狀態引數副本的連續緩衝區。
-
匯出可用於推理的訓練會話模型。
如果訓練會話提供了評估模型,並且已知推理圖輸出,則訓練會話可以生成推理模型。輸入的推理圖輸出用於修剪評估模型,以便推理模型的輸出與提供的輸出對齊。匯出的模型儲存在提供的路徑中,並可與
ORTSession一起用於推理。注意
此方法從提供給初始化器的路徑重新載入評估模型,並要求此路徑有效。
宣告
Objective-C
- (BOOL) exportModelForInferenceWithOutputPath:(nonnull NSString *)inferenceModelPath graphOutputNames: (nonnull NSArray<NSString *> *)graphOutputNames error:(NSError *_Nullable *_Nullable)error;Swift
func exportModelForInference(withOutputPath inferenceModelPath: String, graphOutputNames: [String]) throws引數
inferenceModelPath推理模型的序列化路徑。
graphOutputNames推理模型中所需的輸出名稱。
error如果發生錯誤,設定可選的錯誤資訊。
返回值
如果推理模型成功匯出則為 YES,否則為 NO。
在 GitHub 上檢視
ORTTrainingSession 類參考