ORTTrainingSession

Objective-C

@interface ORTTrainingSession : NSObject

Swift

class ORTTrainingSession : NSObject

訓練器類,提供用於訓練、評估和最佳化 ONNX 模型的方法。

訓練會話需要四種訓練工件

  1. 訓練 ONNX 模型
  2. 評估 ONNX 模型(可選)
  3. 最佳化器 ONNX 模型
  4. 檢查點目錄

onnxruntime-training python 工具可用於生成上述訓練工件。

自 1.16 版本可用。

注意

此類別僅在啟用訓練 API 時可用。
  • 不可用

    宣告

    Objective-C

    - (instancetype)init NS_UNAVAILABLE;
  • 從訓練工件建立訓練會話,可用於開始或恢復訓練。

    此初始化器根據提供的環境和會話選項例項化訓練會話,可用於從給定的檢查點狀態開始或恢復訓練。檢查點狀態表示訓練會話的引數,如果需要,這些引數將被移動到會話選項中指定的裝置。

    注意

    請注意,使用檢查點狀態建立的訓練會話將此狀態用於儲存整個訓練狀態(包括模型引數、其梯度、最佳化器狀態和屬性)。訓練會話會持有檢查點狀態的強(擁有)指標。

    宣告

    Objective-C

    - (nullable instancetype)initWithEnv:(nonnull ORTEnv *)env
                          sessionOptions:
                              (nullable ORTSessionOptions *)sessionOptions
                              checkpoint:(nonnull ORTCheckpoint *)checkpoint
                          trainModelPath:(nonnull NSString *)trainModelPath
                           evalModelPath:(nullable NSString *)evalModelPath
                      optimizerModelPath:(nullable NSString *)optimizerModelPath
                                   error:(NSError *_Nullable *_Nullable)error;

    Swift

    init(env: ORTEnv, sessionOptions: ORTSessionOptions?, checkpoint: ORTCheckpoint, trainModelPath: String, evalModelPath: String?, optimizerModelPath: String?) throws

    引數

    env

    用於訓練會話的 ORTEnv 例項。

    sessionOptions

    用於訓練會話的可選 ORTSessionOptions

    checkpoint

    用作訓練起點的訓練狀態。

    trainModelPath

    訓練 ONNX 模型的路徑。

    evalModelPath

    評估 ONNX 模型的路徑。

    optimizerModelPath

    用於執行梯度下降的最佳化器 ONNX 模型的路徑。

    error

    如果發生錯誤,設定可選的錯誤資訊。

    返回值

    例項,如果發生錯誤則為 nil。

  • 執行一個訓練步驟,相當於一個步驟中的前向和後向傳播。

    訓練步驟計算訓練模型的輸出和給定輸入值下可訓練引數的梯度。訓練步驟是根據提供給訓練會話的訓練模型執行的。它等同於在一個步驟中執行前向和後向傳播。計算出的梯度儲存在訓練會話狀態中,以便後續可由 optimizerStep 消耗。可以透過呼叫 lazyResetGrad 方法延遲重置梯度。

    宣告

    Objective-C

    - (nullable NSArray<ORTValue *> *)
        trainStepWithInputValues:(nonnull NSArray<ORTValue *> *)inputs
                           error:(NSError *_Nullable *_Nullable)error;

    Swift

    func trainStep(withInputValues inputs: [ORTValue]) throws -> [ORTValue]

    引數

    inputs

    訓練模型的輸入值。

    error

    如果發生錯誤,設定可選的錯誤資訊。

    返回值

    訓練模型的輸出值。

  • 執行一個評估步驟,計算給定輸入下評估模型的輸出。評估步驟是根據提供給訓練會話的評估模型執行的。

    宣告

    Objective-C

    - (nullable NSArray<ORTValue *> *)
        evalStepWithInputValues:(nonnull NSArray<ORTValue *> *)inputs
                          error:(NSError *_Nullable *_Nullable)error;

    Swift

    func evalStep(withInputValues inputs: [ORTValue]) throws -> [ORTValue]

    引數

    inputs

    評估模型的輸入值。

    error

    如果發生錯誤,設定可選的錯誤資訊。

    返回值

    評估模型的輸出值。

  • 延遲將所有可訓練引數的梯度重置為零。

    呼叫此方法會設定訓練會話的內部狀態,以便在下次呼叫 trainStep 方法計算新梯度之前,將 ORTCheckpoint 中可訓練引數的梯度安排重置。

    宣告

    Objective-C

    - (BOOL)lazyResetGradWithError:(NSError *_Nullable *_Nullable)error;

    Swift

    func lazyResetGrad() throws

    引數

    error

    如果發生錯誤,設定可選的錯誤資訊。

    返回值

    如果梯度成功重置則為 YES,否則為 NO。

  • 使用最佳化器模型對可訓練引數執行權重更新。最佳化器步驟是根據提供給訓練會話的最佳化器模型執行的。更新後的引數儲存在訓練狀態中,以便下次呼叫 trainStep 方法時使用。

    宣告

    Objective-C

    - (BOOL)optimizerStepWithError:(NSError *_Nullable *_Nullable)error;

    Swift

    func optimizerStep() throws

    引數

    error

    如果發生錯誤,設定可選的錯誤資訊。

    返回值

    如果最佳化器步驟成功執行則為 YES,否則為 NO。

  • 返回訓練模型的使用者輸入名稱,這些名稱可與提供給 trainStepORTValue 相關聯。

    宣告

    Objective-C

    - (nullable NSArray<NSString *> *)getTrainInputNamesWithError:
        (NSError *_Nullable *_Nullable)error;

    Swift

    func getTrainInputNames() throws -> [String]

    引數

    error

    如果發生錯誤,設定可選的錯誤資訊。

    返回值

    訓練模型的使用者輸入名稱。

  • 返回評估模型的使用者輸入名稱,這些名稱可與提供給 evalStepORTValue 相關聯。

    宣告

    Objective-C

    - (nullable NSArray<NSString *> *)getEvalInputNamesWithError:
        (NSError *_Nullable *_Nullable)error;

    Swift

    func getEvalInputNames() throws -> [String]

    引數

    error

    如果發生錯誤,設定可選的錯誤資訊。

    返回值

    評估模型的使用者輸入名稱。

  • 返回訓練模型的使用者輸出名稱,這些名稱可與 trainStep 返回的 ORTValue 相關聯。

    宣告

    Objective-C

    - (nullable NSArray<NSString *> *)getTrainOutputNamesWithError:
        (NSError *_Nullable *_Nullable)error;

    Swift

    func getTrainOutputNames() throws -> [String]

    引數

    error

    如果發生錯誤,設定可選的錯誤資訊。

    返回值

    訓練模型的使用者輸出名稱。

  • 返回評估模型的使用者輸出名稱,這些名稱可與 evalStep 返回的 ORTValue 相關聯。

    宣告

    Objective-C

    - (nullable NSArray<NSString *> *)getEvalOutputNamesWithError:
        (NSError *_Nullable *_Nullable)error;

    Swift

    func getEvalOutputNames() throws -> [String]

    引數

    error

    如果發生錯誤,設定可選的錯誤資訊。

    返回值

    評估模型的使用者輸出名稱。

  • 為訓練會話註冊一個線性學習率排程器。

    排程器在訓練過程中將學習率從初始值逐漸降低到零。降低是透過將當前學習率乘以一個線性更新因子來執行的。在降低之前,學習率在預熱階段從零逐漸增加到初始值。

    宣告

    Objective-C

    - (BOOL)
        registerLinearLRSchedulerWithWarmupStepCount:(int64_t)warmupStepCount
                                      totalStepCount:(int64_t)totalStepCount
                                           initialLr:(float)initialLr
                                               error:(NSError *_Nullable *_Nullable)
                                                         error;

    Swift

    func registerLinearLRScheduler(withWarmupStepCount warmupStepCount: Int64, totalStepCount: Int64, initialLr: Float) throws

    引數

    warmupStepCount

    執行線性預熱的步數。

    totalStepCount

    執行線性衰減的總步數。

    initialLr

    初始學習率。

    error

    如果發生錯誤,設定可選的錯誤資訊。

    返回值

    如果排程器成功註冊則為 YES,否則為 NO。

  • 根據已註冊的學習率排程器更新學習率。

    執行一個排程器步驟,更新訓練會話正在使用的學習率。此函式通常應在每輪呼叫最佳化器步驟之前呼叫,或根據需要更新訓練會話正在使用的學習率。

    注意

    必須首先註冊一個有效的預定義學習率排程器才能呼叫此方法。

    宣告

    Objective-C

    - (BOOL)schedulerStepWithError:(NSError *_Nullable *_Nullable)error;

    Swift

    func schedulerStep() throws

    引數

    error

    如果發生錯誤,設定可選的錯誤資訊。

    返回值

    如果排程器步驟成功執行則為 YES,否則為 NO。

  • 返回訓練會話當前使用的學習率。

    宣告

    Objective-C

    - (float)getLearningRateWithError:(NSError *_Nullable *_Nullable)error;

    Swift

    func getLearningRate() throws -> Float

    引數

    error

    如果發生錯誤,設定可選的錯誤資訊。

    返回值

    當前學習率,如果發生錯誤則為 0.0f。

  • 設定訓練會話正在使用的學習率。

    當前學習率由訓練會話維護,並可透過呼叫此方法並傳入所需學習率來覆蓋。當註冊了有效的學習率排程器時,不應使用此函式。它應僅用於設定自定義學習率排程器派生的學習率,或設定在整個訓練會話中使用的恆定學習率。

    注意

    它不設定預定義學習率排程器可能需要的初始學習率。要為學習率排程器設定初始學習率,請使用 registerLinearLRScheduler 方法。

    宣告

    Objective-C

    - (BOOL)setLearningRate:(float)lr error:(NSError *_Nullable *_Nullable)error;

    Swift

    func setLearningRate(_ lr: Float) throws

    引數

    lr

    訓練會話將使用的學習率。

    error

    如果發生錯誤,設定可選的錯誤資訊。

    返回值

    如果學習率成功設定則為 YES,否則為 NO。

  • 從連續緩衝區載入訓練會話模型引數。

    宣告

    Objective-C

    - (BOOL)fromBufferWithValue:(nonnull ORTValue *)buffer
                          error:(NSError *_Nullable *_Nullable)error;

    Swift

    func fromBuffer(with buffer: ORTValue) throws

    引數

    buffer

    用於載入引數的連續緩衝區。

    error

    如果發生錯誤,設定可選的錯誤資訊。

    返回值

    如果引數成功載入則為 YES,否則為 NO。

  • 返回一個包含所有訓練狀態引數副本的連續緩衝區。

    宣告

    Objective-C

    - (nullable ORTValue *)toBufferWithTrainable:(BOOL)onlyTrainable
                                           error:
                                               (NSError *_Nullable *_Nullable)error;

    Swift

    func toBuffer(withTrainable onlyTrainable: Bool) throws -> ORTValue

    引數

    onlyTrainable

    如果為 YES,則返回一個僅包含可訓練引數的緩衝區;否則,返回一個包含所有引數的緩衝區。

    error

    如果發生錯誤,設定可選的錯誤資訊。

    返回值

    包含所有訓練狀態引數副本的連續緩衝區。

  • 匯出可用於推理的訓練會話模型。

    如果訓練會話提供了評估模型,並且已知推理圖輸出,則訓練會話可以生成推理模型。輸入的推理圖輸出用於修剪評估模型,以便推理模型的輸出與提供的輸出對齊。匯出的模型儲存在提供的路徑中,並可與 ORTSession 一起用於推理。

    注意

    此方法從提供給初始化器的路徑重新載入評估模型,並要求此路徑有效。

    宣告

    Objective-C

    - (BOOL)
        exportModelForInferenceWithOutputPath:(nonnull NSString *)inferenceModelPath
                             graphOutputNames:
                                 (nonnull NSArray<NSString *> *)graphOutputNames
                                        error:(NSError *_Nullable *_Nullable)error;

    Swift

    func exportModelForInference(withOutputPath inferenceModelPath: String, graphOutputNames: [String]) throws

    引數

    inferenceModelPath

    推理模型的序列化路徑。

    graphOutputNames

    推理模型中所需的輸出名稱。

    error

    如果發生錯誤,設定可選的錯誤資訊。

    返回值

    如果推理模型成功匯出則為 YES,否則為 NO。