ONNX Runtime 中的圖最佳化

ONNX Runtime 提供了各種圖最佳化以提高效能。圖最佳化本質上是圖級別的轉換,範圍從小的圖簡化和節點消除到更復雜的節點融合和佈局最佳化。

圖最佳化根據其複雜性和功能分為幾個類別(或級別)。它們可以線上離線執行。線上模式下,最佳化在執行推理之前完成;而在離線模式下,執行時將最佳化後的圖儲存到磁碟。ONNX Runtime 提供了 Python、C#、C++ 和 C API,以啟用不同的最佳化級別並選擇離線或線上模式。

下面我們提供有關最佳化級別、線上/離線模式以及控制它們的各種 API 的詳細資訊。

目錄

圖最佳化級別

圖最佳化分為三個級別

  1. 基本
  2. 擴充套件
  3. 佈局最佳化

屬於一個級別的最佳化在應用前一個級別的最佳化之後執行(例如,在應用基本最佳化之後應用擴充套件最佳化)。

所有最佳化預設啟用。

基本圖最佳化

這些是保留語義的圖重寫,用於移除冗餘節點和冗餘計算。它們在圖分割槽之前執行,因此適用於所有執行提供者。可用的基本圖最佳化如下:

  • 常量摺疊:靜態計算圖中僅依賴於常量初始化器的部分。這消除了在執行時計算它們的需要。

  • 冗餘節點消除:在不改變圖結構的情況下移除所有冗餘節點。目前支援的此類最佳化如下:
    • Identity 消除
    • Slice 消除
    • Unsqueeze 消除
    • Dropout 消除
  • 語義保留節點融合:將多個節點融合/摺疊成一個節點。例如,Conv Add 融合將 Add 運算子摺疊為 Conv 運算子的偏置。目前支援的此類最佳化如下:
    • Conv Add 融合
    • Conv Mul 融合
    • Conv BatchNorm 融合
    • Relu Clip 融合
    • Reshape 融合

擴充套件圖最佳化

這些最佳化包括複雜的節點融合。它們在圖分割槽之後執行,並且僅適用於分配給 CPU、CUDA 或 ROCm 執行提供者的節點。可用的擴充套件圖最佳化如下:

最佳化項 執行提供者 備註
GEMM 啟用融合 CPU  
Matmul Add 融合 CPU  
Conv 啟用融合 CPU  
GELU 融合 CPU, CUDA, ROCm  
層歸一化融合 CPU, CUDA, ROCm  
BERT 嵌入層融合 CPU, CUDA, ROCm 融合 BERT 嵌入層、層歸一化和注意力掩碼長度
注意力融合* CPU, CUDA, ROCm  
跳躍層歸一化融合 CPU, CUDA, ROCm 融合全連線層的偏置、跳躍連線和層歸一化
偏置 GELU 融合 CPU, CUDA, ROCm 融合全連線層的偏置和 GELU 啟用
GELU 近似* CUDA, ROCm 預設停用。透過 kOrtSessionOptionsEnableGeluApproximation 啟用。
Approximations (click to expand)

為最佳化 BERT 的效能,在 GELU 近似和注意力融合中,對 CUDA 和 ROCm 執行提供者使用了近似方法。根據我們的評估,對精度的影響可以忽略不計:SQuAD v1.1 上 BERT 模型的 F1 分數幾乎相同 (87.05 vs 87.03)。

佈局最佳化

這些最佳化改變了適用節點的資料佈局,以實現更高的效能改進。它們在圖分割槽之後執行,並且僅適用於分配給 CPU 執行提供者的節點。可用的佈局最佳化如下:

  • NCHWc 最佳化器:透過使用 NCHWc 佈局而不是 NCHW 佈局來最佳化圖。

線上/離線模式

所有最佳化既可以線上執行,也可以離線執行。線上模式下,在初始化推理會話時,我們會在執行模型推理之前應用所有已啟用的圖最佳化。每次啟動會話時都應用所有最佳化可能會增加模型啟動時間開銷(特別是對於複雜模型),這在生產場景中可能至關重要。這就是離線模式可以帶來巨大好處的地方。在離線模式下,執行圖最佳化後,ONNX Runtime 會將結果模型序列化到磁碟。隨後,我們可以透過使用已最佳化模型並停用所有最佳化來減少啟動時間。

備註:

  • 在離線模式下執行時,請確保使用與模型推理將執行的目標機器完全相同的選項(例如,執行提供者、最佳化級別)和硬體(例如,您不能在僅配備 CPU 的機器上執行為 GPU 執行提供者預最佳化的模型)。
  • 當啟用佈局最佳化時,離線模式只能在與儲存離線模型時的環境相容的硬體上使用。例如,如果模型針對 AVX2 進行了佈局最佳化,則離線模型將需要支援 AVX2 的 CPU。

用法

級別

ONNX Runtime 定義了 GraphOptimizationLevel 列舉,用於確定啟用上述哪些最佳化級別。選擇一個級別將啟用該級別的最佳化以及所有前置級別的最佳化。例如,啟用擴充套件最佳化也會啟用基本最佳化。這些級別到列舉的對映如下:

  • GraphOptimizationLevel::ORT_DISABLE_ALL -> 停用所有最佳化
  • GraphOptimizationLevel::ORT_ENABLE_BASIC -> 啟用基本最佳化
  • GraphOptimizationLevel::ORT_ENABLE_EXTENDED -> 啟用基本和擴充套件最佳化
  • GraphOptimizationLevel::ORT_ENABLE_ALL -> 啟用所有可用最佳化,包括佈局最佳化

離線模式

要啟用將最佳化後的模型序列化到磁碟,請設定 SessionOptions 選項 optimized_model_filepath

Python API 示例

import onnxruntime as rt

sess_options = rt.SessionOptions()

# Set graph optimization level
sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_EXTENDED

# To enable model serialization after graph optimization set this
sess_options.optimized_model_filepath = "<model_output_path\optimized_model.onnx>"

session = rt.InferenceSession("<model_path>", sess_options)

C API 示例

  const OrtApi* Ort::g_api = OrtGetApi(ORT_API_VERSION);
  OrtEnv* env;
  g_ort->CreateEnv(ORT_LOGGING_LEVEL_WARNING, "test", &env);
  OrtSessionOptions* session_options;
  g_ort->CreateSessionOptions(&session_options)

  // Set graph optimization level
  g_ort->SetSessionGraphOptimizationLevel(session_options, ORT_ENABLE_EXTENDED);

  // To enable model serialization after graph optimization set this
  const ORTCHAR_T* optimized_model_path = ORT_TSTR("optimized_model_path");
  g_ort->SetOptimizedModelFilePath(session_options, optimized_model_path);

  OrtSession* session;
  const ORTCHAR_T* model_path = ORT_TSTR("model_path");
  g_ort->CreateSession(env, model_path, session_options, &session);

C# API 示例

SessionOptions so = new SessionOptions();

// Set graph optimization level
so.GraphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_EXTENDED;

// To enable model serialization after graph optimization set this
so.OptimizedModelFilePath = "model_output_path\optimized_model.onnx"

var session = new InferenceSession(modelPath, so);

C++ API 示例

Ort::SessionOptions session_options;

// Set graph optimization level
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);

// To enable model serialization after graph optimization set this
session_options.SetOptimizedModelFilePath("optimized_file_path");

auto session_ = Ort::Session(env, "model_file_path", session_options);