ONNX Runtime generate() Java API
注意:此 API 處於預覽階段,可能會有更改。
- 安裝和匯入
- SimpleGenAI 類
- 異常類
- 模型類
- 分詞器類
- TokenizerStream 類
- 張量類
- GeneratorParams 類
- Generator 類
- Sequences 類
- 介面卡類
安裝和匯入
Java API 由 ai.onnxruntime.genai Java 包提供。包的釋出正在進行中。要從原始碼構建包,請參閱從原始碼構建指南。
import ai.onnxruntime.genai.*;
SimpleGenAI 類
SimpleGenAI 類提供了 GenAI API 的一個簡單使用示例。它使用一個基於提示生成文字的模型,一次處理一個提示。用法
使用模型路徑建立該類的例項。該路徑也應包含 GenAI 配置檔案。
SimpleGenAI genAI = new SimpleGenAI(folderPath);
使用提示文字呼叫 createGeneratorParams。根據需要使用 setSearchOption 透過 GeneratorParams 物件設定任何其他搜尋選項。
GeneratorParams generatorParams = genAI.createGeneratorParams(promptText);
// .. set additional generator params before calling generate()
使用 GeneratorParams 物件和可選的監聽器呼叫 generate。
String fullResponse = genAI.generate(generatorParams, listener);
監聽器用作回撥機制,以便可以在生成令牌時使用它們。建立一個實現 Consumer<String> 介面的類,並提供該類的一個例項作為 listener 引數。
建構函式
public SimpleGenAI(String modelPath) throws GenAIException
丟擲
GenAIException - 失敗時。
生成方法
根據 GeneratorParams 中的提示和設定生成文字。
注意:這隻處理單個輸入序列(即單個提示,相當於批處理大小為 1)。
public String generate(GeneratorParams generatorParams, Consumer<String> listener) throws GenAIException
引數
generatorParams:用於執行模型的提示和設定。listener:可選的回撥,用於在令牌生成時提供它們。
注意:令牌生成將被阻塞,直到監聽器的 accept 方法返回。
丟擲
GenAIException - 失敗時。
返回
生成的文字。
示例
SimpleGenAI generator = new SimpleGenAI(modelPath);
GeneratorParams params = generator.createGeneratorParams("What's 6 times 7?");
Consumer<String> listener = token -> logger.info("onTokenGenerate: " + token);
String result = generator.generate(params, listener);
logger.info("Result: " + result);
createGenerateParams 方法
建立生成器引數並新增提示文字。使用者可以在執行 generate 之前透過 GeneratorParams 物件設定其他搜尋選項。
public GeneratorParams createGeneratorParams(String prompt) throws GenAIException
引數
prompt:要編碼的提示文字。
丟擲
GenAIException - 失敗時。
返回
生成器引數。
異常類
一個異常,包含由原生層生成的錯誤訊息和程式碼。
建構函式
public GenAIException(String message)
示例
catch (GenAIException e) {
throw new GenAIException("Token generation loop failed.", e);
}
模型類
建構函式
Model(String modelPath)
建立分詞器方法
為此模型建立 Tokenizer 例項。模型包含確定要使用的分詞器的配置資訊。
public Tokenizer createTokenizer() throws GenAIException
丟擲
GenAIException - 如果呼叫 GenAI 原生 API 失敗
返回
分詞器例項。
生成方法
public Sequences generate(GeneratorParams generatorParams) throws GenAIException
引數
generatorParams:生成器引數。
丟擲
GenAIException - 如果呼叫 GenAI 原生 API 失敗。
返回
生成的序列。
示例
Sequences output = model.generate(generatorParams);
createGeneratorParams 方法
建立用於執行模型的 GeneratorParams 例項。
注意:GeneratorParams 內部使用 Model,因此 Model 例項必須保持有效。
public GeneratorParams createGeneratorParams() throws GenAIException
丟擲
GenAIException - 如果呼叫 GenAI 原生 API 失敗。
返回
GeneratorParams 例項。
示例
GeneratorParams params = generator.createGeneratorParams("What's 6 times 7?");
分詞器類
編碼方法
將字串編碼為令牌 ID 序列。
public Sequences encode(String string) throws GenAIException
引數
string:要編碼為令牌 ID 的文字。
丟擲
GenAIException - 如果呼叫 GenAI 原生 API 失敗。
返回
一個包含單個序列的 Sequences 物件。
示例
Sequences encodedPrompt = tokenizer.encode(prompt);
解碼方法
將令牌 ID 序列解碼為文字。
public String decode(int[] sequence) throws GenAIException
引數
sequence:要解碼為文字的令牌 ID 集合
丟擲
GenAIException - 如果呼叫 GenAI 原生 API 失敗。
返回
序列的文字表示。
示例
String result = tokenizer.decode(output_ids);
encodeBatch 方法
將字串陣列編碼為每個輸入的令牌 ID 序列。
public Sequences encodeBatch(String[] strings) throws GenAIException
引數
strings:要編碼為令牌 ID 的字串集合。
丟擲
GenAIException - 如果呼叫 GenAI 原生 API 失敗。
返回
一個 Sequences 物件,每個輸入字串包含一個序列。
示例
Sequences encoded = tokenizer.encodeBatch(inputs);
decodeBatch 方法
將一批令牌 ID 序列解碼為文字。
public String[] decodeBatch(Sequences sequences) throws GenAIException
引數
sequences:一個 Sequences 物件,包含一個或多個令牌 ID 序列。
丟擲
GenAIException - 如果呼叫 GenAI 原生 API 失敗。
返回
一個字串陣列,其中包含每個序列的文字表示。
示例
String[] decoded = tokenizer.decodeBatch(encoded);
createStream 方法
建立 TokenizerStream 物件用於流式分詞。這與 Generator 類一起使用,以便在生成每個令牌時提供它們。
public TokenizerStream createStream() throws GenAIException
丟擲
GenAIException - 如果呼叫 GenAI 原生 API 失敗。
返回
新的 TokenizerStream 例項。
TokenizerStream 類
此類用於在使用 Generator.generateNextToken 時轉換單個令牌。
解碼方法
public String decode(int token) throws GenAIException
引數
token:令牌的 int 值
丟擲
GenAIException
張量類
使用給定資料、形狀和元素型別構造一個張量。
public Tensor(ByteBuffer data, long[] shape, ElementType elementType) throws GenAIException
引數
data:張量的資料。必須是直接 ByteBuffer。shape:張量的形狀。elementType:張量中元素的型別。
丟擲
GenAIException
示例
建立一個具有 32 位浮點資料的 2x2 張量。
long[] shape = {2, 2};
ByteBuffer data = ByteBuffer.allocateDirect(4 * Float.BYTES);
FloatBuffer floatBuffer = data.asFloatBuffer();
floatBuffer.put(new float[] {1.0f, 2.0f, 3.0f, 4.0f});
Tensor tensor = new Tensor(data, shape, Tensor.ElementType.float32);
GeneratorParams 類
GeneratorParams 類表示用於使用模型生成序列的引數。使用 setInput 設定提示,並使用 setSearchOption 設定任何其他搜尋選項。
建立一個 Generator Params 物件
GeneratorParams params = new GeneratorParams(model);
setSearchOption 方法
public void setSearchOption(String optionName, double value) throws GenAIException
丟擲
GenAIException
示例
設定搜尋選項以限制模型生成長度。
generatorParams.setSearchOption("max_length", 10);
setSearchOption 方法
public void setSearchOption(String optionName, boolean value) throws GenAIException
丟擲
GenAIException
示例
generatorParams.setSearchOption("early_stopping", true);
setInput 方法
設定模型執行的提示。 sequences 是透過使用 Tokenizer.Encode 或 EncodeBatch 建立的。
public void setInput(Sequences sequences) throws GenAIException
引數
sequences:包含編碼提示的序列。
丟擲
GenAIException - 如果呼叫 GenAI 原生 API 失敗。
示例
generatorParams.setInput(encodedPrompt);
setInput 方法
設定模型執行的提示/令牌 ID。 tokenIds 是編碼後的引數。
public void setInput(int[] tokenIds, int sequenceLength, int batchSize)
throws GenAIException
引數
tokenIds:編碼提示的令牌 IDsequenceLength:每個序列的長度。batchSize:批次大小。
丟擲
GenAIException - 如果呼叫 GenAI 原生 API 失敗。
注意:批次中的所有序列必須具有相同的長度。
示例
generatorParams.setInput(tokenIds, sequenceLength, batchSize);
Generator 類
Generator 類使用模型和生成器引數生成輸出。預期的用法是迴圈直到 isDone 返回 false。在迴圈中,先呼叫 computeLogits,然後呼叫 generateNextToken。
新生成的令牌可以使用 getLastTokenInSequence 獲取,並使用 TokenizerStream.Decode 解碼。
生成過程完成後,如果需要,可以使用 GetSequence 檢索完整的生成序列。
建立一個 Generator
使用給定模型和生成器引數構造一個 Generator 物件。
Generator(Model model, GeneratorParams generatorParams)
引數
model:模型。params:生成器引數。
丟擲
GenAIException - 如果呼叫 GenAI 原生 API 失敗。
isDone 方法
檢查生成過程是否完成。
public boolean isDone()
返回
如果生成過程完成,則返回 true,否則返回 false。
computeLogits 方法
計算序列中下一個令牌的 logits。
public void computeLogits() throws GenAIException
丟擲
GenAIException - 如果呼叫 GenAI 原生 API 失敗。
getSequence 方法
檢索指定序列索引的令牌 ID 序列。
public int[] getSequence(long sequenceIndex) throws GenAIException
引數
sequenceIndex:序列的索引。
丟擲
GenAIException - 如果呼叫 GenAI 原生 API 失敗。
返回
一個包含令牌 ID 序列的整數陣列。
示例
int[] outputIds = output.getSequence(i);
generateNextToken 方法
生成序列中的下一個令牌。
public void generateNextToken() throws GenAIException
丟擲
GenAIException - 如果呼叫 GenAI 原生 API 失敗。
getLastTokenInSequence 方法
檢索指定序列索引的序列中的最後一個令牌。
public int getLastTokenInSequence(long sequenceIndex) throws GenAIException
引數
sequenceIndex:序列的索引。
丟擲
GenAIException - 如果呼叫 GenAI 原生 API 失敗。
返回
序列中的最後一個令牌。
Sequences 類
表示編碼提示/響應的集合。
numSequences 方法
獲取集合中的序列數量。這等同於批次大小。
public long numSequences()
返回
序列的數量。
示例
int numSequences = (int) sequences.numSequences();
getSequence 方法
獲取指定索引處的序列。
public int[] getSequence(long sequenceIndex)
引數
sequenceIndex:序列的索引。
返回
序列(整數陣列形式)。
介面卡類
即將推出!