ONNX Runtime generate() C API

注意:此 API 處於預覽階段,可能會有所更改。

概述

模型 API

建立模型

從給定目錄建立模型。目錄應包含名為 genai_config.json 的檔案,該檔案對應於配置規範

引數

  • 輸入:config_path 模型配置目錄的路徑。路徑應採用 UTF-8 編碼。
  • 輸出:out 建立的模型。

返回值

OgaResult,如果模型建立失敗,則包含錯誤訊息。

OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateModel(const char* config_path, OgaModel** out);

銷燬模型

銷燬給定模型。

引數

  • 輸入:model 要銷燬的模型。

返回值

void

OGA_EXPORT void OGA_API_CALL OgaDestroyModel(OgaModel* model);

生成

根據給定的生成器引數,從模型執行生成令牌陣列的陣列。

引數

  • 輸入:model 用於生成的模型。
  • 輸入:generator_params 用於生成的引數。
  • 輸出:out 生成的令牌序列。呼叫者在使用完序列後,有責任使用 OgaDestroySequences 釋放序列。

返回值

OgaResult,如果生成失敗,則包含錯誤訊息。

OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerate(const OgaModel* model, const OgaGeneratorParams* generator_params, OgaSequences** out);

分詞器 API

建立分詞器

引數

  • 輸入:model. 應為其建立分詞器的模型

返回值

OgaResult,如果分詞器建立失敗,則包含錯誤訊息。

OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateTokenizer(const OgaModel* model, OgaTokenizer** out);

銷燬分詞器

OGA_EXPORT void OGA_API_CALL OgaDestroyTokenizer(OgaTokenizer*);

編碼

編碼單個字串並將編碼後的令牌序列新增到 OgaSequences。OgaSequences 不再需要時必須使用 OgaDestroySequences 釋放。

引數

返回值

OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerEncode(const OgaTokenizer*, const char* str, OgaSequences* sequences);

解碼

解碼單個令牌序列並返回一個以 null 結尾的 utf8 字串。out_string 必須使用 OgaDestroyString 釋放。

引數

返回值

OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerDecode(const OgaTokenizer*, const int32_t* tokens, size_t token_count, const char** out_string);

批次編碼

引數

  • OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerEncodeBatch(const OgaTokenizer*, const char** strings, size_t count, TokenSequences** out);
    

批次解碼

OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerDecodeBatch(const OgaTokenizer*, const OgaSequences* tokens, const char*** out_strings);

銷燬分詞器字串

OGA_EXPORT void OGA_API_CALL OgaTokenizerDestroyStrings(const char** strings, size_t count);

建立分詞器流

OgaTokenizerStream 用於增量解碼令牌字串,一次一個令牌。

OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateTokenizerStream(const OgaTokenizer*, OgaTokenizerStream** out);

銷燬分詞器流

引數

OGA_EXPORT void OGA_API_CALL OgaDestroyTokenizerStream(OgaTokenizerStream*);

解碼流

解碼流中的單個令牌。如果這導致生成一個單詞,它將以“out”形式返回。呼叫者有責任將每個塊連線起來以生成完整的結果。“out”在下次呼叫 OgaTokenizerStreamDecode 或 OgaTokenizerStream 被銷燬之前有效。

OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerStreamDecode(OgaTokenizerStream*, int32_t token, const char** out);

生成器引數 API

建立生成器引數

從給定模型建立 OgaGeneratorParams。

引數

  • 輸入:model 用於生成的模型。
  • 輸出:out 建立的生成器引數。

返回值

OgaResult,如果生成器引數建立失敗,則包含錯誤訊息。

OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateGeneratorParams(const OgaModel* model, OgaGeneratorParams** out);

銷燬生成器引數

銷燬給定的生成器引數。

引數

  • 輸入:generator_params 要銷燬的生成器引數。

返回值

void

OGA_EXPORT void OGA_API_CALL OgaDestroyGeneratorParams(OgaGeneratorParams* generator_params);

設定搜尋選項(數字)

設定一個數值型別的搜尋選項

引數

  • generator_params: 要設定引數的生成器引數物件
  • name: 引數名稱
  • value: 要設定的值

返回值

OgaResult,如果生成器引數建立失敗,則包含錯誤訊息。

OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetSearchNumber(OgaGeneratorParams* generator_params, const char* name, double value);

設定搜尋選項(布林值)

設定一個布林值型別的搜尋選項。

引數

  • generator_params: 要設定引數的生成器引數物件
  • name: 引數名稱
  • value: 要設定的值

返回值

OgaResult,如果生成器引數建立失敗,則包含錯誤訊息。

OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetSearchBool(OgaGeneratorParams* generator_params, const char* name, bool value);

嘗試使用最大批處理大小進行圖捕獲

圖捕獲將計算圖的動態元素固定為常量值。它可以在某些環境中提供更高效的執行。要在圖捕獲模式下執行,需要提前知道最大批處理大小。如果記憶體不足以分配指定的最大批處理大小,此函式可能會失敗。

引數

  • generator_params: 要設定引數的生成器引數物件
  • max_batch_size: 要分配的最大批處理大小

返回值

OgaResult,如果無法使用指定的批處理大小配置圖捕獲模式,則包含錯誤訊息。

OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsTryGraphCaptureWithMaxBatchSize(OgaGeneratorParams* generator_params, int32_t max_batch_size);

設定輸入

為生成器引數設定輸入 ID。輸入 ID 用於啟動生成。

引數

  • 輸入:generator_params 要設定輸入 ID 的生成器引數。
  • 輸入:input_ids 大小為 input_ids_count = batch_size * sequence_length 的輸入 ID 陣列。
  • 輸入:input_ids_count 輸入 ID 的總數。
  • 輸入:sequence_length 輸入 ID 的序列長度。
  • 輸入:batch_size 輸入 ID 的批處理大小。

返回值

OgaResult,如果設定輸入 ID 失敗,則包含錯誤訊息。

OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetInputIDs(OgaGeneratorParams* generator_params, const int32_t* input_ids, size_t input_ids_count, size_t sequence_length, size_t batch_size);

設定輸入序列

為生成器引數設定輸入 ID 序列。輸入 ID 序列用於啟動生成。

引數

  • 輸入:generator_params 要設定輸入 ID 的生成器引數。
  • 輸入:sequences 輸入 ID 序列。

返回值

OgaResult,如果設定輸入 ID 序列失敗,則包含錯誤訊息。

OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetInputSequences(OgaGeneratorParams* generator_params, const OgaSequences* sequences);

設定模型輸入

除了 input_ids 之外,設定一個額外的模型輸入。

引數

  • generator_params: 要設定輸入的生成器引數
  • name: 要設定的引數名稱
  • tensor: 引數的值

返回值

OgaResult,如果設定輸入失敗,則包含錯誤訊息。

OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetWhisperInputFeatures(OgaGeneratorParams*, OgaTensor* tensor);

生成器 API

建立生成器

從給定模型和生成器引數建立生成器。

引數

  • 輸入:model 用於生成的模型。
  • 輸入:params 用於生成的引數。
  • 輸出:out 建立的生成器。

返回值

OgaResult,如果生成器建立失敗,則包含錯誤訊息。

OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateGenerator(const OgaModel* model, const OgaGeneratorParams* params, OgaGenerator** out);

銷燬生成器

銷燬給定的生成器。

引數

  • 輸入:generator 要銷燬的生成器。

返回值

void

OGA_EXPORT void OGA_API_CALL OgaDestroyGenerator(OgaGenerator* generator);

檢查生成是否完成

如果生成器已完成所有序列的生成,則返回 true。

引數

  • 輸入:generator 檢查生成器是否已完成所有序列生成的生成器。

返回值

如果生成器已完成所有序列的生成,則為 True,否則為 false。

OGA_EXPORT bool OGA_API_CALL OgaGenerator_IsDone(const OgaGenerator* generator);

執行模型的一次迭代

根據輸入 ID 和過去狀態計算模型的 logits。計算出的 logits 儲存在生成器中。

引數

  • 輸入:generator 計算 logits 的生成器。

返回值

OgaResult,如果 logits 計算失敗,則包含錯誤訊息。

OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_ComputeLogits(OgaGenerator* generator);

生成下一個令牌

使用配置的生成引數,根據計算出的 logits 生成下一個令牌。

引數

  • 輸入:generator 生成下一個令牌的生成器。

返回值

OgaResult,如果下一個令牌生成失敗,則包含錯誤訊息。

OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken(OgaGenerator* generator);

獲取令牌數量

返回給定索引處序列中的令牌數量。

引數

  • 輸入:generator 獲取給定索引處序列令牌計數的生成器。
  • 輸入:index. 要返回令牌的索引

返回值

給定索引處序列中的令牌數量。

OGA_EXPORT size_t OGA_API_CALL OgaGenerator_GetSequenceCount(const OgaGenerator* generator, size_t index);

獲取序列

返回指向給定索引處序列資料的指標。序列中的令牌數量由 OgaGenerator_GetSequenceCount 提供。

引數

  • 輸入:generator 獲取給定索引處序列資料的生成器。指向給定索引處序列資料的指標。序列資料由 OgaGenerator 擁有,並將在 OgaGenerator 銷燬時釋放。如果需要在 OgaGenerator 銷燬後使用資料,呼叫者必須複製資料。
  • 輸入:index. 獲取序列的索引。

返回值

指向令牌序列的指標

OGA_EXPORT const int32_t* OGA_API_CALL OgaGenerator_GetSequenceData(const OgaGenerator* generator, size_t index);

設定執行時選項

一個用於設定執行時選項的 API,更多引數將新增到此通用 API 以支援執行時選項。使用此 API 終止當前會話的示例是呼叫 SetRuntimeOption,其中鍵為“terminate_session”,值為“1”:OgaGenerator_SetRuntimeOption(generator, “terminate_session”, “1”)

有關當前執行時選項的更多詳細資訊,請參見此處

引數

  • 輸入:generator 需要設定執行時選項的生成器。
  • 輸入:key 設定執行時選項的鍵。
  • 輸入:value 為提供的鍵設定的值。

返回值

void

OGA_EXPORT void OGA_API_CALL OgaGenerator_SetRuntimeOption(OgaGenerator* generator, const char* key, const char* value);

介面卡 API

此 API 用於載入和切換微調適配器,例如 LoRA 介面卡。

建立介面卡

建立管理介面卡的物件。此物件用於載入所有模型介面卡。它負責載入介面卡的引用計數。

OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateAdapters(const OgaModel* model, OgaAdapters** out);

引數

  • model: 之前建立的 OgaModel

結果

  • out: 對建立的 OgaAdapters 列表的引用

載入介面卡

從給定的介面卡檔案路徑和介面卡名稱載入模型介面卡。

OGA_EXPORT OgaResult* OGA_API_CALL OgaLoadAdapter(OgaAdapters* adapters, const char* adapter_file_path, const char* adapter_name);

引數

  • adapters: 要載入介面卡的 OgaAdapters 物件。
  • adapter_file_path: 要載入的介面卡檔案路徑。
  • adapter_name: 介面卡的唯一識別符號,用於介面卡查詢

返回值

OgaResult,如果介面卡載入失敗,則包含錯誤訊息。

解除安裝介面卡

從先前載入的介面卡集中解除安裝具有給定識別符號的介面卡。如果未找到介面卡,或者無法解除安裝(在使用中),則返回錯誤。

OGA_EXPORT OgaResult* OGA_API_CALL OgaUnloadAdapter(OgaAdapters* adapters, const char* adapter_name);

引數

  • adapters: 要解除安裝介面卡的 OgaAdapters 物件。
  • adapter_name: 要解除安裝的介面卡名稱。

返回值

OgaResult,如果介面卡解除安裝失敗,則包含錯誤訊息。這可能發生在呼叫方法時,介面卡未載入或已被仍在使用的 OgaGenerator 標記為活動狀態。

設定活動介面卡

將具有給定介面卡名稱的介面卡設定為給定 OgaGenerator 物件的活動介面卡。

OGA_EXPORT OgaResult* OGA_API_CALL OgaSetActiveAdapter(OgaGenerator* generator, OgaAdapters* adapters, const char* adapter_name);

引數

  • generator: 要設定活動介面卡的 OgaGenerator 物件。
  • adapters: 管理模型介面卡的 OgaAdapters 物件。
  • adapter_name: 要設定為活動的介面卡名稱。

返回值

OgaResult,如果介面卡無法設定為活動狀態,則包含錯誤訊息。這可能發生在呼叫方法時,介面卡未事先載入。

列舉和結構體

typedef enum OgaDataType {
  OgaDataType_int32,
  OgaDataType_float32,
  OgaDataType_string,  // UTF8 string
} OgaDataType;
typedef struct OgaResult OgaResult;
typedef struct OgaGeneratorParams OgaGeneratorParams;
typedef struct OgaGenerator OgaGenerator;
typedef struct OgaModel OgaModel;
typedef struct OgaBuffer OgaBuffer;

實用函式

設定 GPU 裝置 ID

OGA_EXPORT OgaResult* OGA_API_CALL OgaSetCurrentGpuDeviceId(int device_id);

獲取 GPU 裝置 ID

OGA_EXPORT OgaResult* OGA_API_CALL OgaGetCurrentGpuDeviceId(int* device_id);

獲取錯誤訊息

引數

  • 輸入:result 包含錯誤訊息的 OgaResult。

返回值

OgaResult 中包含的錯誤訊息。const char* 由 OgaResult 擁有,並將在 OgaResult 銷燬時釋放。

OGA_EXPORT const char* OGA_API_CALL OgaResultGetError(OgaResult* result);

銷燬結果

引數

  • 輸入:result 要銷燬的 OgaResult。

返回值

void

OGA_EXPORT void OGA_API_CALL OgaDestroyResult(OgaResult*);

銷燬字串

引數

  • 輸入:要銷燬的字串

返回值

OGA_EXPORT void OGA_API_CALL OgaDestroyString(const char*);

銷燬緩衝區

引數

  • 輸入:要銷燬的緩衝區

返回值

void

OGA_EXPORT void OGA_API_CALL OgaDestroyBuffer(OgaBuffer*);

獲取緩衝區型別

引數

  • 輸入:緩衝區

返回值

緩衝區的型別

OGA_EXPORT OgaDataType OGA_API_CALL OgaBufferGetType(const OgaBuffer*);

獲取緩衝區的維度數量

引數

  • 輸入:緩衝區

返回值

緩衝區中的維度數量

OGA_EXPORT size_t OGA_API_CALL OgaBufferGetDimCount(const OgaBuffer*);

獲取緩衝區維度

獲取緩衝區的維度

引數

  • 輸入:緩衝區
  • 輸出:維度陣列

返回值

OgaResult

OGA_EXPORT OgaResult* OGA_API_CALL OgaBufferGetDims(const OgaBuffer*, size_t* dims, size_t dim_count);

獲取緩衝區資料

從緩衝區獲取資料

引數

返回值

void

OGA_EXPORT const void* OGA_API_CALL OgaBufferGetData(const OgaBuffer*);

建立序列

OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateSequences(OgaSequences** out);

銷燬序列

引數

  • 輸入:sequences 要銷燬的 OgaSequences。

返回值

void

返回值

OGA_EXPORT void OGA_API_CALL OgaDestroySequences(OgaSequences* sequences);

獲取序列數量

返回 OgaSequences 中的序列數量

引數

  • 輸入:sequences

返回值

OgaSequences 中的序列數量

OGA_EXPORT size_t OGA_API_CALL OgaSequencesCount(const OgaSequences* sequences);

獲取序列中的令牌數量

返回給定索引處序列中的令牌數量

引數

  • 輸入:sequences

返回值

給定索引處序列中的令牌數量

OGA_EXPORT size_t OGA_API_CALL OgaSequencesGetSequenceCount(const OgaSequences* sequences, size_t sequence_index);

獲取序列資料

返回指向給定索引處序列資料的指標。序列中的令牌數量由 OgaSequencesGetSequenceCount 提供。

引數

  • 輸入:sequences

返回值

指向給定索引處序列資料的指標。該指標在 OgaSequences 銷燬之前有效。

OGA_EXPORT const int32_t* OGA_API_CALL OgaSequencesGetSequenceData(const OgaSequences* sequences, size_t sequence_index);