將 ONNX Runtime generate() API 從 0.5.2 遷移到 0.6.0

瞭解如何將 ONNX Runtime generate() 版本從 0.5.2 遷移到 0.6.0。

版本 0.6.0 增加了對“聊天模式”的支援,也稱為延續連續解碼互動式解碼。隨著聊天模式的引入,API 發生了重大更改。

總而言之,新 API 向 Generator 添加了一個 AppendTokens 方法,允許進行多輪對話。以前,輸入是在建立 Generator 之前在 GeneratorParams 中設定的。

在對話迴圈之外呼叫 AppendTokens 可用於實現系統提示快取。

注意:聊天模式和系統提示快取僅支援批處理大小為 1。此外,它們目前在 CPU、使用 CUDA EP 的 NVIDIA GPU 以及所有使用 Web GPU 本機 EP 的 GPU 上受支援。它們不支援 NPU 或使用 DirecML EP 執行的 GPU。對於問答 (Q&A) 模式,下述遷移仍然是必需的。

Python

將 Python 問答(單輪)程式碼遷移到 0.6.0

  1. 在生成器物件建立後,將對 params.input_ids = input_tokens 的呼叫替換為 generator.append_tokens(input_tokens)
  2. 移除對 generator.compute_logits() 的呼叫
  3. 如果應用程式有 Q&A 迴圈,請在 append_token 呼叫之間刪除 generator 以重置模型狀態。

將系統提示快取新增到 Python 應用程式

  1. 建立並標記系統提示,然後呼叫 generator.append_tokens(system_tokens)。此呼叫可以在向用戶請求其提示之前完成。

    system_tokens = tokenizer.encode(system_prompt)
    generator.append_tokens(system_tokens)
    

將聊天模式新增到 Python 應用程式

  1. 在應用程式中建立一個迴圈,並在使用者提供新輸入時每次呼叫 generator.append_tokens(prompt)

    while True:
        user_input = input("Input: ")
        input_tokens = tokenizer.encode(user_input)
        generator.append_tokens(input_tokens)
    
        while not generator.is_done():
            generator.generate_next_token()
    
            new_token = generator.get_next_tokens()[0]
            print(tokenizer_stream.decode(new_token), end='', flush=True)
         except KeyboardInterrupt:
         print()
    

C++

將 C++ 問答(單輪)程式碼遷移到 0.6.0

  1. 將對 params->SetInputSequences(*sequences) 的呼叫替換為 generator->AppendTokenSequences(*sequences)
  2. 移除對 generator->ComputeLogits() 的呼叫

將系統提示快取新增到 C++ 應用程式

  1. 建立並標記系統提示,然後呼叫 generator->AppendTokenSequences(*sequences)。此呼叫可以在向用戶請求其提示之前完成。

    auto sequences = OgaSequences::Create();
    tokenizer->Encode(system_prompt.c_str(), *sequences);
    generator->AppendTokenSequences(*sequences);
    generator.append_tokens(system_tokens)
    

將聊天模式新增到您的 C++ 應用程式

  1. 將聊天迴圈新增到您的應用程式
    std::cout << "Generating response..." << std::endl;
    auto params = OgaGeneratorParams::Create(*model);
    params->SetSearchOption("max_length", 1024);
    
    auto generator = OgaGenerator::Create(*model, *params);
    
    while (true) {
      std::string text;
      std::cout << "Prompt: "  << std::endl;
      std::getline(std::cin, prompt);
    
      auto sequences = OgaSequences::Create();
      tokenizer->Encode(prompt.c_str(), *sequences);
    
      generator->AppendTokenSequences(*sequences);
    
      while (!generator->IsDone()) {
        generator->GenerateNextToken();
    
        const auto num_tokens = generator->GetSequenceCount(0);
        const auto new_token = generator->GetSequenceData(0)[num_tokens - 1];
       std::cout << tokenizer_stream->Decode(new_token) << std::flush;
       }
    }
    

C#

將 C# 問答(單輪)程式碼遷移到 0.6.0

  1. 將對 generatorParams.SetInputSequences(sequences) 的呼叫替換為 generator.AppendTokenSequences(sequences)`
  2. 移除對 generator.ComputeLogits() 的呼叫

將系統提示快取新增到您的 C# 應用程式

  1. 建立並標記系統提示,然後呼叫 generator->AppendTokenSequences()。此呼叫可以在向用戶請求其提示之前完成。

    var systemPrompt = "..."
    auto sequences = OgaSequences::Create();
    tokenizer->Encode(systemPrompt, *sequences);
    generator->AppendTokenSequences(*sequences);
    

將聊天模式新增到您的 C# 應用程式

  1. 將聊天迴圈新增到您的應用程式
    using var tokenizerStream = tokenizer.CreateStream();
    using var generator = new Generator(model, generatorParams);
    Console.WriteLine("Prompt:");
    prompt = Console.ReadLine();
    
    // Example Phi-3 template
    var sequences = tokenizer.Encode($"<|user|>{prompt}<|end|><|assistant|>");
    
    do
    {
       generator.AppendTokenSequences(sequences);
       var watch = System.Diagnostics.Stopwatch.StartNew();
       while (!generator.IsDone())
       {
          generator.GenerateNextToken();
          Console.Write(tokenizerStream.Decode(generator.GetSequence(0)[^1]));
       }
       Console.WriteLine();
       watch.Stop();
       var runTimeInSeconds = watch.Elapsed.TotalSeconds;
       var outputSequence = generator.GetSequence(0);
       var totalTokens = outputSequence.Length;
       Console.WriteLine($"Streaming Tokens: {totalTokens} Time: {runTimeInSeconds:0.00} Tokens per second: {totalTokens / runTimeInSeconds:0.00}");
       Console.WriteLine("Next prompt:");
       var nextPrompt = Console.ReadLine();
       sequences = tokenizer.Encode($"<|user|>{prompt}<|end|><|assistant|>");
    } while (prompt != null);
    
    

Java

即將推出