類 OrtTrainingSession

  • 所有已實現的介面
    java.lang.AutoCloseable

    public final class OrtTrainingSession
    extends java.lang.Object
    implements java.lang.AutoCloseable
    封裝 ONNX 訓練模型並允許訓練和推理呼叫。

    允許檢查模型的輸入和輸出節點。由 OrtEnvironment 生成。

    如果會話已關閉並呼叫了方法,大多數例項方法會丟擲 IllegalStateException

    • 方法詳情

      • getTrainInputNames

        public java.util.Set<java.lang.String> getTrainInputNames()
        返回訓練模型的輸入名稱的有序集合。
        返回值
        訓練輸入。
      • getTrainOutputNames

        public java.util.Set<java.lang.String> getTrainOutputNames()
        返回訓練模型的輸出名稱的有序集合。
        返回值
        訓練輸出。
      • getEvalInputNames

        public java.util.Set<java.lang.String> getEvalInputNames()
        返回評估模型的輸入名稱的有序集合。
        返回值
        評估輸入。
      • getEvalOutputNames

        public java.util.Set<java.lang.String> getEvalOutputNames()
        返回評估模型的輸出名稱的有序集合。
        返回值
        評估輸出。
      • addProperty

        public void addProperty​(java.lang.String name,
                                float value)
                         throws OrtException
        向此訓練會話檢查點新增一個浮點屬性。
        引數
        name - 屬性名稱。
        value - 屬性值。
        丟擲
        OrtException - 如果呼叫失敗。
      • addProperty

        public void addProperty​(java.lang.String name,
                                int value)
                         throws OrtException
        向此訓練會話檢查點新增一個整型屬性。
        引數
        name - 屬性名稱。
        value - 屬性值。
        丟擲
        OrtException - 如果呼叫失敗。
      • addProperty

        public void addProperty​(java.lang.String name,
                                java.lang.String value)
                         throws OrtException
        向此訓練會話檢查點新增一個字串屬性。
        引數
        name - 屬性名稱。
        value - 屬性值。
        丟擲
        OrtException - 如果呼叫失敗。
      • getFloatProperty

        public float getFloatProperty​(java.lang.String name)
                               throws OrtException
        從此訓練會話檢查點獲取一個浮點屬性。
        引數
        name - 屬性名稱。
        返回值
        屬性值。
        丟擲
        OrtException - 如果屬性不存在或型別錯誤。
      • getIntProperty

        public int getIntProperty​(java.lang.String name)
                           throws OrtException
        從此訓練會話檢查點獲取一個整型屬性。
        引數
        name - 屬性名稱。
        返回值
        屬性值。
        丟擲
        OrtException - 如果屬性不存在或型別錯誤。
      • getStringProperty

        public java.lang.String getStringProperty​(java.lang.String name)
                                           throws OrtException
        從此訓練會話檢查點獲取一個字串屬性。
        引數
        name - 屬性名稱。
        返回值
        屬性值。
        丟擲
        OrtException - 如果屬性不存在或型別錯誤。
      • close

        public void close()
        指定者
        close 在介面 java.lang.AutoCloseable 中
      • saveCheckpoint

        public void saveCheckpoint​(java.nio.file.Path outputPath,
                                   boolean saveOptimizer)
                            throws OrtException
        將訓練會話狀態儲存到提供的檢查點目錄中。
        引數
        outputPath - 檢查點目錄的路徑。
        saveOptimizer - 是否應儲存最佳化器狀態。
        丟擲
        OrtException - 如果原生呼叫失敗。
      • setSeed

        public static void setSeed​(long seed)
                            throws OrtException
        設定 ONNX Runtime 使用的 RNG 種子。

        注意,此設定在所有 OrtTrainingSession 例項中是全域性的。

        引數
        seed - RNG 種子。
        丟擲
        OrtException - 如果原生呼叫失敗。
      • trainStep

        public OrtSession.Result trainStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs)
                                    throws OrtException
        執行單個訓練步驟,累積梯度。
        引數
        inputs - 輸入(必須同時包含特徵和目標)。
        返回值
        訓練步驟產生的所有輸出。
        丟擲
        OrtException - 如果原生呼叫失敗。
      • trainStep

        public OrtSession.Result trainStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs,
                                           OrtSession.RunOptions runOptions)
                                    throws OrtException
        執行單個訓練步驟,累積梯度。
        引數
        inputs - 輸入(必須同時包含特徵和目標)。
        runOptions - 控制此特定呼叫的執行選項。
        返回值
        訓練步驟產生的所有輸出。
        丟擲
        OrtException - 如果原生呼叫失敗。
      • trainStep

        public OrtSession.Result trainStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs,
                                           java.util.Set<java.lang.String> requestedOutputs)
                                    throws OrtException
        執行單個訓練步驟,累積梯度。
        引數
        inputs - 輸入(必須同時包含特徵和目標)。
        requestedOutputs - 請求的輸出。
        返回值
        訓練步驟產生的請求輸出。
        丟擲
        OrtException - 如果原生呼叫失敗。
      • trainStep

        public OrtSession.Result trainStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs,
                                           java.util.Map<java.lang.String,​? extends OnnxValue> pinnedOutputs)
                                    throws OrtException
        執行單個訓練步驟,累積梯度。

        輸出根據提供的對映遍歷順序排序。

        注意:固定的輸出不屬於 OrtSession.Result 物件,並且在結果物件關閉時不會關閉。

        引數
        inputs - 輸入(必須同時包含特徵和目標)。
        pinnedOutputs - 使用者已分配的請求輸出。
        返回值
        訓練步驟產生的請求輸出。
        丟擲
        OrtException - 如果原生呼叫失敗。
      • trainStep

        public OrtSession.Result trainStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs,
                                           java.util.Set<java.lang.String> requestedOutputs,
                                           java.util.Map<java.lang.String,​? extends OnnxValue> pinnedOutputs,
                                           OrtSession.RunOptions runOptions)
                                    throws OrtException
        執行單個訓練步驟,累積梯度。

        輸出根據提供的集合遍歷順序排序,固定的輸出在前,然後是請求的輸出。如果請求的輸出和固定的輸出中出現相同的輸出名稱,則丟擲 IllegalArgumentException

        注意:固定的輸出不屬於 OrtSession.Result 物件,並且在結果物件關閉時不會關閉。

        引數
        inputs - 輸入(必須同時包含特徵和目標)。
        requestedOutputs - ORT 將分配的請求輸出。
        pinnedOutputs - 使用者已分配的請求輸出。
        runOptions - 控制此特定呼叫的執行選項。
        返回值
        訓練步驟產生的請求輸出。
        丟擲
        OrtException - 如果原生呼叫失敗。
      • evalStep

        public OrtSession.Result evalStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs)
                                   throws OrtException
        使用提供的輸入執行單個評估步驟。
        引數
        inputs - 模型輸入。
        返回值
        所有模型輸出。
        丟擲
        OrtException - 如果原生呼叫失敗。
      • evalStep

        public OrtSession.Result evalStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs,
                                          OrtSession.RunOptions runOptions)
                                   throws OrtException
        使用提供的輸入執行單個評估步驟。
        引數
        inputs - 模型輸入。
        runOptions - 控制此特定呼叫的執行選項。
        返回值
        所有模型輸出。
        丟擲
        OrtException - 如果原生呼叫失敗。
      • evalStep

        public OrtSession.Result evalStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs,
                                          java.util.Set<java.lang.String> requestedOutputs)
                                   throws OrtException
        使用提供的輸入執行單個評估步驟。
        引數
        inputs - 模型輸入。
        requestedOutputs - 請求的輸出名稱。
        返回值
        請求的輸出。
        丟擲
        OrtException - 如果原生呼叫失敗。
      • evalStep

        public OrtSession.Result evalStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs,
                                          java.util.Map<java.lang.String,​? extends OnnxValue> pinnedOutputs)
                                   throws OrtException
        使用提供的輸入執行單個評估步驟。

        輸出根據提供的對映遍歷順序排序。

        注意:固定的輸出不屬於 OrtSession.Result 物件,並且在結果物件關閉時不會關閉。

        引數
        inputs - 用於評分的輸入。
        pinnedOutputs - 使用者已分配的請求輸出。
        返回值
        請求的輸出。
        丟擲
        OrtException - 如果原生呼叫失敗。
      • evalStep

        public OrtSession.Result evalStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs,
                                          java.util.Set<java.lang.String> requestedOutputs,
                                          java.util.Map<java.lang.String,​? extends OnnxValue> pinnedOutputs,
                                          OrtSession.RunOptions runOptions)
                                   throws OrtException
        使用提供的輸入執行單個評估步驟。

        輸出根據提供的集合遍歷順序排序,固定的輸出在前,然後是請求的輸出。如果請求的輸出和固定的輸出中出現相同的輸出名稱,則丟擲 IllegalArgumentException

        注意:固定的輸出不屬於 OrtSession.Result 物件,並且在結果物件關閉時不會關閉。

        引數
        inputs - 用於評分的輸入。
        requestedOutputs - ORT 將分配的請求輸出。
        pinnedOutputs - 使用者已分配的請求輸出。
        runOptions - 控制此特定呼叫的執行選項。
        返回值
        請求的輸出。
        丟擲
        OrtException - 如果原生呼叫失敗。
      • setLearningRate

        public void setLearningRate​(float learningRate)
                             throws OrtException
        設定訓練會話的學習率。

        僅當會話中沒有學習率排程器時才應使用。不用於設定學習率排程器的初始學習率。

        引數
        learningRate - 學習率。
        丟擲
        OrtException - 如果呼叫失敗。
      • getLearningRate

        public float getLearningRate()
                              throws OrtException
        獲取此訓練會話的當前學習率。
        返回值
        當前學習率。
        丟擲
        OrtException - 如果呼叫失敗。
      • optimizerStep

        public void optimizerStep()
                           throws OrtException
        使用最佳化器模型將梯度更新應用於可訓練引數。
        丟擲
        OrtException - 如果原生呼叫失敗。
      • optimizerStep

        public void optimizerStep​(OrtSession.RunOptions runOptions)
                           throws OrtException
        使用最佳化器模型將梯度更新應用於可訓練引數。

        執行選項可用於控制日誌記錄和提前終止呼叫。

        引數
        runOptions - 控制模型執行的選項。
        丟擲
        OrtException - 如果原生呼叫失敗。
      • registerLinearLRScheduler

        public void registerLinearLRScheduler​(long warmupSteps,
                                              long totalSteps,
                                              float initialLearningRate)
                                       throws OrtException
        註冊一個帶有線性預熱的線性學習率排程器。
        引數
        warmupSteps - 將學習率從零增加到 initialLearningRate 所需的步數。
        totalSteps - 此排程器操作的總步數。
        initialLearningRate - 最大學習率。
        丟擲
        OrtException - 如果原生呼叫失敗。
      • schedulerStep

        public void schedulerStep()
                           throws OrtException
        根據註冊的學習率排程器更新學習率。
        丟擲
        OrtException - 如果原生呼叫失敗。
      • exportModelForInference

        public void exportModelForInference​(java.nio.file.Path outputPath,
                                            java.lang.String[] outputNames)
                                     throws OrtException
        將評估模型匯出為適用於推理的模型,並將所需節點設定為輸出節點。

        注意,此方法從提供給訓練會話的路徑重新載入評估模型,並且此路徑必須仍然有效。

        引數
        outputPath - 寫入推理模型的路徑。
        outputNames - 輸出節點的名稱。
        丟擲
        OrtException - 如果原生呼叫失敗。