類 OrtTrainingSession
- java.lang.Object
-
- ai.onnxruntime.OrtTrainingSession
-
- 所有已實現的介面
java.lang.AutoCloseable
public final class OrtTrainingSession extends java.lang.Object implements java.lang.AutoCloseable封裝 ONNX 訓練模型並允許訓練和推理呼叫。允許檢查模型的輸入和輸出節點。由
OrtEnvironment生成。如果會話已關閉並呼叫了方法,大多數例項方法會丟擲
IllegalStateException。
-
-
方法摘要
所有方法 靜態方法 例項方法 具體方法 修飾符和型別 方法 描述 voidaddProperty(java.lang.String name, float value)向此訓練會話檢查點新增一個浮點屬性。voidaddProperty(java.lang.String name, int value)向此訓練會話檢查點新增一個整型屬性。voidaddProperty(java.lang.String name, java.lang.String value)向此訓練會話檢查點新增一個字串屬性。voidclose()OrtSession.ResultevalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs)使用提供的輸入執行單個評估步驟。OrtSession.ResultevalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, OrtSession.RunOptions runOptions)使用提供的輸入執行單個評估步驟。OrtSession.ResultevalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs)使用提供的輸入執行單個評估步驟。OrtSession.ResultevalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs)使用提供的輸入執行單個評估步驟。OrtSession.ResultevalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs, OrtSession.RunOptions runOptions)使用提供的輸入執行單個評估步驟。voidexportModelForInference(java.nio.file.Path outputPath, java.lang.String[] outputNames)將評估模型匯出為適用於推理的模型,並將所需節點設定為輸出節點。java.util.Set<java.lang.String>getEvalInputNames()返回評估模型的輸入名稱的有序集合。java.util.Set<java.lang.String>getEvalOutputNames()返回評估模型的輸出名稱的有序集合。floatgetFloatProperty(java.lang.String name)從此訓練會話檢查點獲取一個浮點屬性。intgetIntProperty(java.lang.String name)從此訓練會話檢查點獲取一個整型屬性。floatgetLearningRate()獲取此訓練會話的當前學習率。java.lang.StringgetStringProperty(java.lang.String name)從此訓練會話檢查點獲取一個字串屬性。java.util.Set<java.lang.String>getTrainInputNames()返回訓練模型的輸入名稱的有序集合。java.util.Set<java.lang.String>getTrainOutputNames()返回訓練模型的輸出名稱的有序集合。voidlazyResetGrad()voidoptimizerStep()使用最佳化器模型將梯度更新應用於可訓練引數。voidoptimizerStep(OrtSession.RunOptions runOptions)使用最佳化器模型將梯度更新應用於可訓練引數。voidregisterLinearLRScheduler(long warmupSteps, long totalSteps, float initialLearningRate)註冊一個帶有線性預熱的線性學習率排程器。voidsaveCheckpoint(java.nio.file.Path outputPath, boolean saveOptimizer)將訓練會話狀態儲存到提供的檢查點目錄中。voidschedulerStep()根據註冊的學習率排程器更新學習率。voidsetLearningRate(float learningRate)設定訓練會話的學習率。static voidsetSeed(long seed)設定 ONNX Runtime 使用的 RNG 種子。OrtSession.ResulttrainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs)執行單個訓練步驟,累積梯度。OrtSession.ResulttrainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, OrtSession.RunOptions runOptions)執行單個訓練步驟,累積梯度。OrtSession.ResulttrainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs)執行單個訓練步驟,累積梯度。OrtSession.ResulttrainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs)執行單個訓練步驟,累積梯度。OrtSession.ResulttrainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs, OrtSession.RunOptions runOptions)執行單個訓練步驟,累積梯度。
-
-
-
方法詳情
-
getTrainInputNames
public java.util.Set<java.lang.String> getTrainInputNames()
返回訓練模型的輸入名稱的有序集合。- 返回值
- 訓練輸入。
-
getTrainOutputNames
public java.util.Set<java.lang.String> getTrainOutputNames()
返回訓練模型的輸出名稱的有序集合。- 返回值
- 訓練輸出。
-
getEvalInputNames
public java.util.Set<java.lang.String> getEvalInputNames()
返回評估模型的輸入名稱的有序集合。- 返回值
- 評估輸入。
-
getEvalOutputNames
public java.util.Set<java.lang.String> getEvalOutputNames()
返回評估模型的輸出名稱的有序集合。- 返回值
- 評估輸出。
-
addProperty
public void addProperty(java.lang.String name, float value) throws OrtException向此訓練會話檢查點新增一個浮點屬性。- 引數
name- 屬性名稱。value- 屬性值。- 丟擲
OrtException- 如果呼叫失敗。
-
addProperty
public void addProperty(java.lang.String name, int value) throws OrtException向此訓練會話檢查點新增一個整型屬性。- 引數
name- 屬性名稱。value- 屬性值。- 丟擲
OrtException- 如果呼叫失敗。
-
addProperty
public void addProperty(java.lang.String name, java.lang.String value) throws OrtException向此訓練會話檢查點新增一個字串屬性。- 引數
name- 屬性名稱。value- 屬性值。- 丟擲
OrtException- 如果呼叫失敗。
-
getFloatProperty
public float getFloatProperty(java.lang.String name) throws OrtException從此訓練會話檢查點獲取一個浮點屬性。- 引數
name- 屬性名稱。- 返回值
- 屬性值。
- 丟擲
OrtException- 如果屬性不存在或型別錯誤。
-
getIntProperty
public int getIntProperty(java.lang.String name) throws OrtException從此訓練會話檢查點獲取一個整型屬性。- 引數
name- 屬性名稱。- 返回值
- 屬性值。
- 丟擲
OrtException- 如果屬性不存在或型別錯誤。
-
getStringProperty
public java.lang.String getStringProperty(java.lang.String name) throws OrtException從此訓練會話檢查點獲取一個字串屬性。- 引數
name- 屬性名稱。- 返回值
- 屬性值。
- 丟擲
OrtException- 如果屬性不存在或型別錯誤。
-
close
public void close()
- 指定者
close在介面java.lang.AutoCloseable中
-
saveCheckpoint
public void saveCheckpoint(java.nio.file.Path outputPath, boolean saveOptimizer) throws OrtException將訓練會話狀態儲存到提供的檢查點目錄中。- 引數
outputPath- 檢查點目錄的路徑。saveOptimizer- 是否應儲存最佳化器狀態。- 丟擲
OrtException- 如果原生呼叫失敗。
-
lazyResetGrad
public void lazyResetGrad() throws OrtException確保在下次呼叫trainStep(java.util.Map<java.lang.String, ? extends ai.onnxruntime.OnnxTensorLike>)之前,梯度被重置為零。注意,這是一個延遲呼叫,梯度是在執行下一個
trainStep(java.util.Map<java.lang.String, ? extends ai.onnxruntime.OnnxTensorLike>)時清除的,而不是在此之前。- 丟擲
OrtException- 如果原生呼叫失敗。
-
setSeed
public static void setSeed(long seed) throws OrtException設定 ONNX Runtime 使用的 RNG 種子。注意,此設定在所有 OrtTrainingSession 例項中是全域性的。
- 引數
seed- RNG 種子。- 丟擲
OrtException- 如果原生呼叫失敗。
-
trainStep
public OrtSession.Result trainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs) throws OrtException
執行單個訓練步驟,累積梯度。- 引數
inputs- 輸入(必須同時包含特徵和目標)。- 返回值
- 訓練步驟產生的所有輸出。
- 丟擲
OrtException- 如果原生呼叫失敗。
-
trainStep
public OrtSession.Result trainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, OrtSession.RunOptions runOptions) throws OrtException
執行單個訓練步驟,累積梯度。- 引數
inputs- 輸入(必須同時包含特徵和目標)。runOptions- 控制此特定呼叫的執行選項。- 返回值
- 訓練步驟產生的所有輸出。
- 丟擲
OrtException- 如果原生呼叫失敗。
-
trainStep
public OrtSession.Result trainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs) throws OrtException
執行單個訓練步驟,累積梯度。- 引數
inputs- 輸入(必須同時包含特徵和目標)。requestedOutputs- 請求的輸出。- 返回值
- 訓練步驟產生的請求輸出。
- 丟擲
OrtException- 如果原生呼叫失敗。
-
trainStep
public OrtSession.Result trainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs) throws OrtException
執行單個訓練步驟,累積梯度。輸出根據提供的對映遍歷順序排序。
注意:固定的輸出不屬於
OrtSession.Result物件,並且在結果物件關閉時不會關閉。- 引數
inputs- 輸入(必須同時包含特徵和目標)。pinnedOutputs- 使用者已分配的請求輸出。- 返回值
- 訓練步驟產生的請求輸出。
- 丟擲
OrtException- 如果原生呼叫失敗。
-
trainStep
public OrtSession.Result trainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs, OrtSession.RunOptions runOptions) throws OrtException
執行單個訓練步驟,累積梯度。輸出根據提供的集合遍歷順序排序,固定的輸出在前,然後是請求的輸出。如果請求的輸出和固定的輸出中出現相同的輸出名稱,則丟擲
IllegalArgumentException。注意:固定的輸出不屬於
OrtSession.Result物件,並且在結果物件關閉時不會關閉。- 引數
inputs- 輸入(必須同時包含特徵和目標)。requestedOutputs- ORT 將分配的請求輸出。pinnedOutputs- 使用者已分配的請求輸出。runOptions- 控制此特定呼叫的執行選項。- 返回值
- 訓練步驟產生的請求輸出。
- 丟擲
OrtException- 如果原生呼叫失敗。
-
evalStep
public OrtSession.Result evalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs) throws OrtException
使用提供的輸入執行單個評估步驟。- 引數
inputs- 模型輸入。- 返回值
- 所有模型輸出。
- 丟擲
OrtException- 如果原生呼叫失敗。
-
evalStep
public OrtSession.Result evalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, OrtSession.RunOptions runOptions) throws OrtException
使用提供的輸入執行單個評估步驟。- 引數
inputs- 模型輸入。runOptions- 控制此特定呼叫的執行選項。- 返回值
- 所有模型輸出。
- 丟擲
OrtException- 如果原生呼叫失敗。
-
evalStep
public OrtSession.Result evalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs) throws OrtException
使用提供的輸入執行單個評估步驟。- 引數
inputs- 模型輸入。requestedOutputs- 請求的輸出名稱。- 返回值
- 請求的輸出。
- 丟擲
OrtException- 如果原生呼叫失敗。
-
evalStep
public OrtSession.Result evalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs) throws OrtException
使用提供的輸入執行單個評估步驟。輸出根據提供的對映遍歷順序排序。
注意:固定的輸出不屬於
OrtSession.Result物件,並且在結果物件關閉時不會關閉。- 引數
inputs- 用於評分的輸入。pinnedOutputs- 使用者已分配的請求輸出。- 返回值
- 請求的輸出。
- 丟擲
OrtException- 如果原生呼叫失敗。
-
evalStep
public OrtSession.Result evalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs, OrtSession.RunOptions runOptions) throws OrtException
使用提供的輸入執行單個評估步驟。輸出根據提供的集合遍歷順序排序,固定的輸出在前,然後是請求的輸出。如果請求的輸出和固定的輸出中出現相同的輸出名稱,則丟擲
IllegalArgumentException。注意:固定的輸出不屬於
OrtSession.Result物件,並且在結果物件關閉時不會關閉。- 引數
inputs- 用於評分的輸入。requestedOutputs- ORT 將分配的請求輸出。pinnedOutputs- 使用者已分配的請求輸出。runOptions- 控制此特定呼叫的執行選項。- 返回值
- 請求的輸出。
- 丟擲
OrtException- 如果原生呼叫失敗。
-
setLearningRate
public void setLearningRate(float learningRate) throws OrtException設定訓練會話的學習率。僅當會話中沒有學習率排程器時才應使用。不用於設定學習率排程器的初始學習率。
- 引數
learningRate- 學習率。- 丟擲
OrtException- 如果呼叫失敗。
-
getLearningRate
public float getLearningRate() throws OrtException獲取此訓練會話的當前學習率。- 返回值
- 當前學習率。
- 丟擲
OrtException- 如果呼叫失敗。
-
optimizerStep
public void optimizerStep() throws OrtException使用最佳化器模型將梯度更新應用於可訓練引數。- 丟擲
OrtException- 如果原生呼叫失敗。
-
optimizerStep
public void optimizerStep(OrtSession.RunOptions runOptions) throws OrtException
使用最佳化器模型將梯度更新應用於可訓練引數。執行選項可用於控制日誌記錄和提前終止呼叫。
- 引數
runOptions- 控制模型執行的選項。- 丟擲
OrtException- 如果原生呼叫失敗。
-
registerLinearLRScheduler
public void registerLinearLRScheduler(long warmupSteps, long totalSteps, float initialLearningRate) throws OrtException註冊一個帶有線性預熱的線性學習率排程器。- 引數
warmupSteps- 將學習率從零增加到initialLearningRate所需的步數。totalSteps- 此排程器操作的總步數。initialLearningRate- 最大學習率。- 丟擲
OrtException- 如果原生呼叫失敗。
-
schedulerStep
public void schedulerStep() throws OrtException根據註冊的學習率排程器更新學習率。- 丟擲
OrtException- 如果原生呼叫失敗。
-
exportModelForInference
public void exportModelForInference(java.nio.file.Path outputPath, java.lang.String[] outputNames) throws OrtException將評估模型匯出為適用於推理的模型,並將所需節點設定為輸出節點。注意,此方法從提供給訓練會話的路徑重新載入評估模型,並且此路徑必須仍然有效。
- 引數
outputPath- 寫入推理模型的路徑。outputNames- 輸出節點的名稱。- 丟擲
OrtException- 如果原生呼叫失敗。
-
-