ONNX Runtime generate() Java API

注意:此 API 處於預覽階段,可能會有更改。

安裝和匯入

Java API 由 ai.onnxruntime.genai Java 包提供。包的釋出正在進行中。要從原始碼構建包,請參閱從原始碼構建指南

import ai.onnxruntime.genai.*;

SimpleGenAI 類

SimpleGenAI 類提供了 GenAI API 的一個簡單使用示例。它使用一個基於提示生成文字的模型,一次處理一個提示。用法

使用模型路徑建立該類的例項。該路徑也應包含 GenAI 配置檔案。

SimpleGenAI genAI = new SimpleGenAI(folderPath);

使用提示文字呼叫 createGeneratorParams。根據需要使用 setSearchOption 透過 GeneratorParams 物件設定任何其他搜尋選項。

GeneratorParams generatorParams = genAI.createGeneratorParams(promptText);
// .. set additional generator params before calling generate()

使用 GeneratorParams 物件和可選的監聽器呼叫 generate。

String fullResponse = genAI.generate(generatorParams, listener);

監聽器用作回撥機制,以便可以在生成令牌時使用它們。建立一個實現 Consumer<String> 介面的類,並提供該類的一個例項作為 listener 引數。

建構函式

public SimpleGenAI(String modelPath) throws GenAIException

丟擲

GenAIException - 失敗時。

生成方法

根據 GeneratorParams 中的提示和設定生成文字。

注意:這隻處理單個輸入序列(即單個提示,相當於批處理大小為 1)。

public String generate(GeneratorParams generatorParams, Consumer<String> listener) throws GenAIException

引數

  • generatorParams:用於執行模型的提示和設定。
  • listener:可選的回撥,用於在令牌生成時提供它們。

注意:令牌生成將被阻塞,直到監聽器的 accept 方法返回。

丟擲

GenAIException - 失敗時。

返回

生成的文字。

示例

SimpleGenAI generator = new SimpleGenAI(modelPath);
GeneratorParams params = generator.createGeneratorParams("What's 6 times 7?");
Consumer<String> listener = token -> logger.info("onTokenGenerate: " + token);
String result = generator.generate(params, listener);

logger.info("Result: " + result);

createGenerateParams 方法

建立生成器引數並新增提示文字。使用者可以在執行 generate 之前透過 GeneratorParams 物件設定其他搜尋選項。

public GeneratorParams createGeneratorParams(String prompt) throws GenAIException

引數

  • prompt:要編碼的提示文字。

丟擲

GenAIException - 失敗時。

返回

生成器引數。

異常類

一個異常,包含由原生層生成的錯誤訊息和程式碼。

建構函式

public GenAIException(String message)

示例

catch (GenAIException e) {
  throw new GenAIException("Token generation loop failed.", e);
}

模型類

建構函式

Model(String modelPath)

建立分詞器方法

為此模型建立 Tokenizer 例項。模型包含確定要使用的分詞器的配置資訊。

public Tokenizer createTokenizer() throws GenAIException

丟擲

GenAIException - 如果呼叫 GenAI 原生 API 失敗

返回

分詞器例項。

生成方法

public Sequences generate(GeneratorParams generatorParams) throws GenAIException

引數

  • generatorParams:生成器引數。

丟擲

GenAIException - 如果呼叫 GenAI 原生 API 失敗。

返回

生成的序列。

示例

Sequences output = model.generate(generatorParams);

createGeneratorParams 方法

建立用於執行模型的 GeneratorParams 例項。

注意:GeneratorParams 內部使用 Model,因此 Model 例項必須保持有效。

public GeneratorParams createGeneratorParams() throws GenAIException

丟擲

GenAIException - 如果呼叫 GenAI 原生 API 失敗。

返回

GeneratorParams 例項。

示例

GeneratorParams params = generator.createGeneratorParams("What's 6 times 7?");

分詞器類

編碼方法

將字串編碼為令牌 ID 序列。

public Sequences encode(String string) throws GenAIException

引數

  • string:要編碼為令牌 ID 的文字。

丟擲

GenAIException - 如果呼叫 GenAI 原生 API 失敗。

返回

一個包含單個序列的 Sequences 物件。

示例

Sequences encodedPrompt = tokenizer.encode(prompt);

解碼方法

將令牌 ID 序列解碼為文字。

public String decode(int[] sequence) throws GenAIException

引數

  • sequence:要解碼為文字的令牌 ID 集合

丟擲

GenAIException - 如果呼叫 GenAI 原生 API 失敗。

返回

序列的文字表示。

示例

String result = tokenizer.decode(output_ids);

encodeBatch 方法

將字串陣列編碼為每個輸入的令牌 ID 序列。

public Sequences encodeBatch(String[] strings) throws GenAIException

引數

  • strings:要編碼為令牌 ID 的字串集合。

丟擲

GenAIException - 如果呼叫 GenAI 原生 API 失敗。

返回

一個 Sequences 物件,每個輸入字串包含一個序列。

示例

Sequences encoded = tokenizer.encodeBatch(inputs);

decodeBatch 方法

將一批令牌 ID 序列解碼為文字。

public String[] decodeBatch(Sequences sequences) throws GenAIException

引數

  • sequences:一個 Sequences 物件,包含一個或多個令牌 ID 序列。

丟擲

GenAIException - 如果呼叫 GenAI 原生 API 失敗。

返回

一個字串陣列,其中包含每個序列的文字表示。

示例

String[] decoded = tokenizer.decodeBatch(encoded);

createStream 方法

建立 TokenizerStream 物件用於流式分詞。這與 Generator 類一起使用,以便在生成每個令牌時提供它們。

public TokenizerStream createStream() throws GenAIException

丟擲

GenAIException - 如果呼叫 GenAI 原生 API 失敗。

返回

新的 TokenizerStream 例項。

TokenizerStream 類

此類用於在使用 Generator.generateNextToken 時轉換單個令牌。

解碼方法

public String decode(int token) throws GenAIException

引數

  • token:令牌的 int 值

丟擲

GenAIException

張量類

使用給定資料、形狀和元素型別構造一個張量。

public Tensor(ByteBuffer data, long[] shape, ElementType elementType) throws GenAIException

引數

  • data:張量的資料。必須是直接 ByteBuffer。
  • shape:張量的形狀。
  • elementType:張量中元素的型別。

丟擲

GenAIException

示例

建立一個具有 32 位浮點資料的 2x2 張量。

long[] shape = {2, 2};
ByteBuffer data = ByteBuffer.allocateDirect(4 * Float.BYTES);
FloatBuffer floatBuffer = data.asFloatBuffer();
floatBuffer.put(new float[] {1.0f, 2.0f, 3.0f, 4.0f});

Tensor tensor = new Tensor(data, shape, Tensor.ElementType.float32);

GeneratorParams 類

GeneratorParams 類表示用於使用模型生成序列的引數。使用 setInput 設定提示,並使用 setSearchOption 設定任何其他搜尋選項。

建立一個 Generator Params 物件

GeneratorParams params = new GeneratorParams(model);

setSearchOption 方法

public void setSearchOption(String optionName, double value) throws GenAIException

丟擲

GenAIException

示例

設定搜尋選項以限制模型生成長度。

generatorParams.setSearchOption("max_length", 10);

setSearchOption 方法

public void setSearchOption(String optionName, boolean value) throws GenAIException

丟擲

GenAIException

示例

generatorParams.setSearchOption("early_stopping", true);

setInput 方法

設定模型執行的提示。 sequences 是透過使用 Tokenizer.Encode 或 EncodeBatch 建立的。

public void setInput(Sequences sequences) throws GenAIException

引數

  • sequences:包含編碼提示的序列。

丟擲

GenAIException - 如果呼叫 GenAI 原生 API 失敗。

示例

generatorParams.setInput(encodedPrompt);

setInput 方法

設定模型執行的提示/令牌 ID。 tokenIds 是編碼後的引數。

public void setInput(int[] tokenIds, int sequenceLength, int batchSize)
 throws GenAIException

引數

  • tokenIds:編碼提示的令牌 ID
  • sequenceLength:每個序列的長度。
  • batchSize:批次大小。

丟擲

GenAIException - 如果呼叫 GenAI 原生 API 失敗。

注意:批次中的所有序列必須具有相同的長度。

示例

generatorParams.setInput(tokenIds, sequenceLength, batchSize);

Generator 類

Generator 類使用模型和生成器引數生成輸出。預期的用法是迴圈直到 isDone 返回 false。在迴圈中,先呼叫 computeLogits,然後呼叫 generateNextToken。

新生成的令牌可以使用 getLastTokenInSequence 獲取,並使用 TokenizerStream.Decode 解碼。

生成過程完成後,如果需要,可以使用 GetSequence 檢索完整的生成序列。

建立一個 Generator

使用給定模型和生成器引數構造一個 Generator 物件。

Generator(Model model, GeneratorParams generatorParams)

引數

  • model:模型。
  • params:生成器引數。

丟擲

GenAIException - 如果呼叫 GenAI 原生 API 失敗。

isDone 方法

檢查生成過程是否完成。

public boolean isDone()

返回

如果生成過程完成,則返回 true,否則返回 false。

computeLogits 方法

計算序列中下一個令牌的 logits。

public void computeLogits() throws GenAIException

丟擲

GenAIException - 如果呼叫 GenAI 原生 API 失敗。

getSequence 方法

檢索指定序列索引的令牌 ID 序列。

public int[] getSequence(long sequenceIndex) throws GenAIException

引數

  • sequenceIndex:序列的索引。

丟擲

GenAIException - 如果呼叫 GenAI 原生 API 失敗。

返回

一個包含令牌 ID 序列的整數陣列。

示例

int[] outputIds = output.getSequence(i);

generateNextToken 方法

生成序列中的下一個令牌。

public void generateNextToken() throws GenAIException

丟擲

GenAIException - 如果呼叫 GenAI 原生 API 失敗。

getLastTokenInSequence 方法

檢索指定序列索引的序列中的最後一個令牌。

public int getLastTokenInSequence(long sequenceIndex) throws GenAIException

引數

  • sequenceIndex:序列的索引。

丟擲

GenAIException - 如果呼叫 GenAI 原生 API 失敗。

返回

序列中的最後一個令牌。

Sequences 類

表示編碼提示/響應的集合。

numSequences 方法

獲取集合中的序列數量。這等同於批次大小。

public long numSequences()

返回

序列的數量。

示例

int numSequences = (int) sequences.numSequences();

getSequence 方法

獲取指定索引處的序列。

public int[] getSequence(long sequenceIndex)

引數

  • sequenceIndex:序列的索引。

返回

序列(整數陣列形式)。

介面卡類

即將推出!