將 ONNX Runtime generate() API 從 0.5.2 遷移到 0.6.0
瞭解如何將 ONNX Runtime generate() 版本從 0.5.2 遷移到 0.6.0。
版本 0.6.0 增加了對“聊天模式”的支援,也稱為延續、連續解碼和互動式解碼。隨著聊天模式的引入,API 發生了重大更改。
總而言之,新 API 向 Generator 添加了一個 AppendTokens 方法,允許進行多輪對話。以前,輸入是在建立 Generator 之前在 GeneratorParams 中設定的。
在對話迴圈之外呼叫 AppendTokens 可用於實現系統提示快取。
注意:聊天模式和系統提示快取僅支援批處理大小為 1。此外,它們目前在 CPU、使用 CUDA EP 的 NVIDIA GPU 以及所有使用 Web GPU 本機 EP 的 GPU 上受支援。它們不支援 NPU 或使用 DirecML EP 執行的 GPU。對於問答 (Q&A) 模式,下述遷移仍然是必需的。
Python
將 Python 問答(單輪)程式碼遷移到 0.6.0
- 在生成器物件建立後,將對
params.input_ids = input_tokens的呼叫替換為generator.append_tokens(input_tokens)。 - 移除對
generator.compute_logits()的呼叫 - 如果應用程式有 Q&A 迴圈,請在
append_token呼叫之間刪除generator以重置模型狀態。
將系統提示快取新增到 Python 應用程式
-
建立並標記系統提示,然後呼叫
generator.append_tokens(system_tokens)。此呼叫可以在向用戶請求其提示之前完成。system_tokens = tokenizer.encode(system_prompt) generator.append_tokens(system_tokens)
將聊天模式新增到 Python 應用程式
-
在應用程式中建立一個迴圈,並在使用者提供新輸入時每次呼叫
generator.append_tokens(prompt)while True: user_input = input("Input: ") input_tokens = tokenizer.encode(user_input) generator.append_tokens(input_tokens) while not generator.is_done(): generator.generate_next_token() new_token = generator.get_next_tokens()[0] print(tokenizer_stream.decode(new_token), end='', flush=True) except KeyboardInterrupt: print()
C++
將 C++ 問答(單輪)程式碼遷移到 0.6.0
- 將對
params->SetInputSequences(*sequences)的呼叫替換為generator->AppendTokenSequences(*sequences) - 移除對
generator->ComputeLogits()的呼叫
將系統提示快取新增到 C++ 應用程式
-
建立並標記系統提示,然後呼叫
generator->AppendTokenSequences(*sequences)。此呼叫可以在向用戶請求其提示之前完成。auto sequences = OgaSequences::Create(); tokenizer->Encode(system_prompt.c_str(), *sequences); generator->AppendTokenSequences(*sequences); generator.append_tokens(system_tokens)
將聊天模式新增到您的 C++ 應用程式
- 將聊天迴圈新增到您的應用程式
std::cout << "Generating response..." << std::endl; auto params = OgaGeneratorParams::Create(*model); params->SetSearchOption("max_length", 1024); auto generator = OgaGenerator::Create(*model, *params); while (true) { std::string text; std::cout << "Prompt: " << std::endl; std::getline(std::cin, prompt); auto sequences = OgaSequences::Create(); tokenizer->Encode(prompt.c_str(), *sequences); generator->AppendTokenSequences(*sequences); while (!generator->IsDone()) { generator->GenerateNextToken(); const auto num_tokens = generator->GetSequenceCount(0); const auto new_token = generator->GetSequenceData(0)[num_tokens - 1]; std::cout << tokenizer_stream->Decode(new_token) << std::flush; } }
C#
將 C# 問答(單輪)程式碼遷移到 0.6.0
- 將對
generatorParams.SetInputSequences(sequences)的呼叫替換為generator.AppendTokenSequences(sequences)` - 移除對
generator.ComputeLogits()的呼叫
將系統提示快取新增到您的 C# 應用程式
-
建立並標記系統提示,然後呼叫
generator->AppendTokenSequences()。此呼叫可以在向用戶請求其提示之前完成。var systemPrompt = "..." auto sequences = OgaSequences::Create(); tokenizer->Encode(systemPrompt, *sequences); generator->AppendTokenSequences(*sequences);
將聊天模式新增到您的 C# 應用程式
- 將聊天迴圈新增到您的應用程式
using var tokenizerStream = tokenizer.CreateStream(); using var generator = new Generator(model, generatorParams); Console.WriteLine("Prompt:"); prompt = Console.ReadLine(); // Example Phi-3 template var sequences = tokenizer.Encode($"<|user|>{prompt}<|end|><|assistant|>"); do { generator.AppendTokenSequences(sequences); var watch = System.Diagnostics.Stopwatch.StartNew(); while (!generator.IsDone()) { generator.GenerateNextToken(); Console.Write(tokenizerStream.Decode(generator.GetSequence(0)[^1])); } Console.WriteLine(); watch.Stop(); var runTimeInSeconds = watch.Elapsed.TotalSeconds; var outputSequence = generator.GetSequence(0); var totalTokens = outputSequence.Length; Console.WriteLine($"Streaming Tokens: {totalTokens} Time: {runTimeInSeconds:0.00} Tokens per second: {totalTokens / runTimeInSeconds:0.00}"); Console.WriteLine("Next prompt:"); var nextPrompt = Console.ReadLine(); sequences = tokenizer.Encode($"<|user|>{prompt}<|end|><|assistant|>"); } while (prompt != null);
Java
即將推出