在 ONNX Runtime 中使用裝置張量
在構建高效的 AI 流水線時,使用裝置張量是至關重要的一環,尤其是在異構記憶體系統中。這類系統的典型例子是任何配有專用 GPU 的 PC。雖然 最新的 GPU 本身具有約 1TB/s 的記憶體頻寬,但連線到 CPU 的互連 PCI 4.0 x16 往往是瓶頸,其頻寬僅為約 32GB/s。因此,最好儘可能地將資料保留在 GPU 本地,或者透過計算來隱藏緩慢的記憶體流量,因為 GPU 能夠同時執行計算和 PCI 記憶體流量。
在記憶體已本地化到推理裝置的這些場景中,一個典型的用例是 GPU 加速的編碼影片流處理,該影片流可以透過 GPU 解碼器進行解碼。另一個常見情況是迭代網路,例如擴散網路或大型語言模型,其間中間張量不必複製回 CPU。針對高解析度影像的基於瓦片的推理是另一個用例,其中自定義記憶體管理對於減少 PCI 複製期間的 GPU 空閒時間至關重要。與順序處理每個瓦片不同,可以重疊 PCI 複製和 GPU 上的處理,並以此方式進行工作流水線化。

CUDA
ONNX Runtime 中的 CUDA 有兩種自定義記憶體型別:"CudaPinned" 和 "Cuda" 記憶體,其中 CUDA pinned(CUDA 頁面鎖定)記憶體 實際上是 CPU 記憶體,可由 GPU 直接訪問,允許使用 cudaMemcpyAsync 進行完全非同步的記憶體上傳和下載。普通的 CPU 張量只允許從 GPU 到 CPU 的同步下載,而從 CPU 到 GPU 的複製總是可以非同步執行。
使用 Ort::Sessions 的分配器分配張量非常簡單,透過 C++ API 進行,它直接對映到 C API。
Ort::Session session(ort_env, model_path_cstr, session_options);
Ort::MemoryInfo memory_info_cuda("Cuda", OrtArenaAllocator, /*device_id*/0,
OrtMemTypeDefault);
Ort::Allocator gpu_allocator(session, memory_info_cuda);
auto ort_value = Ort::Value::CreateTensor(
gpu_allocator, shape.data(), shape.size(),
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16);
外部分配的資料也可以包裝到 Ort::Value 中而無需複製。
Ort::MemoryInfo memory_info_cuda("Cuda", OrtArenaAllocator, device_id,
OrtMemTypeDefault);
std::array<int64_t, 4> shape{1, 4, 64, 64};
size_t cuda_buffer_size = 4 * 64 * 64 * sizeof(float);
void *cuda_resource;
CUDA_CHECK(cudaMalloc(&cuda_resource, cuda_buffer_size));
auto ort_value = Ort::Value::CreateTensor(
memory_info_cuda, cuda_resource, cuda_buffer_size,
shape.data(), shape.size(),
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
這些已分配的張量可以作為 I/O 繫結 使用,以消除網路上的複製操作並將責任轉移給使用者。透過此類 I/O 繫結,可以進行更多的效能調優:
- 由於張量地址固定,可以捕獲 CUDA 圖以減少 CPU 上的 CUDA 啟動延遲
- 由於可以完全非同步下載到頁面鎖定記憶體,或透過使用裝置本地張量消除記憶體複製,CUDA 可以在其給定流上透過執行選項實現完全非同步執行
要為 CUDA 設定自定義計算流,請參考 V2 選項 API,它公開了 Ort[CUDA|TensorRT]ProviderOptionsV2* 不透明結構指標和函式 Update[CUDA|TensorRT]ProviderOptionsWithValue(options, "user_compute_stream", cuda_stream); 來設定其流成員。更多詳細資訊可在每個執行提供者文件中找到。
如果您想驗證您的最佳化,Nsight System 有助於關聯 CPU API 和 CUDA 操作的 GPU 執行。這也可以驗證是否進行了所需的同步以及是否有非同步操作回退到同步執行。它還用於 本次 GTC 演講,解釋了裝置張量的最佳使用。
Python API
Python API 支援與上述 C++ API 相同的效能最佳化機會。裝置張量 可以按此所示進行分配。此外,user_compute_stream 可以透過此 API 進行設定。
sess = onnxruntime.InferenceSession("model.onnx", providers=["TensorrtExecutionProvider"])
option = {}
s = torch.cuda.Stream()
option["user_compute_stream"] = str(s.cuda_stream)
sess.set_providers(["TensorrtExecutionProvider"], [option])
在 Python 中啟用非同步執行可以透過與 C++ API 相同的執行選項實現。
DirectML
透過 DirectX 資源可以實現相同的行為。要執行非同步處理,與 CUDA 一樣,對執行流進行相同的管理至關重要。對於 DirectX,這意味著管理裝置及其命令佇列,這可以透過 C API 實現。關於如何設定計算命令佇列的詳細資訊已在使用 SessionOptionsAppendExecutionProvider_DML1 的文件中說明。
如果為複製和計算使用獨立的命令佇列,則可以重疊 PCI 複製和執行,並使執行非同步化。
#include <onnxruntime/dml_provider_factory.h>
Ort::MemoryInfo memory_info_dml("DML", OrtDeviceAllocator, device_id,
OrtMemTypeDefault);
std::array<int64_t, 4> shape{1, 4, 64, 64};
void *dml_resource;
size_t d3d_buffer_size = 4 * 64 * 64 * sizeof(float);
const OrtDmlApi *ort_dml_api;
Ort::ThrowOnError(Ort::GetApi().GetExecutionProviderApi(
"DML", ORT_API_VERSION, reinterpret_cast<const void **>(&ort_dml_api)));
// Create d3d_buffer using D3D12 APIs
Microsoft::WRL::ComPtr<ID3D12Resource> d3d_buffer = ...;
// Create the dml resource from the D3D resource.
ort_dml_api->CreateGPUAllocationFromD3DResource(d3d_buffer.Get(), &dml_resource);
Ort::Value ort_value(Ort::Value::CreateTensor(memory_info_dml, dml_resource,
d3d_buffer_size, shape.data(), shape.size(),
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT));
一個單檔案示例可以在 GitHub 上找到,它展示瞭如何管理和建立複製和執行命令佇列。
Python API
儘管從 Python 分配 DirectX 輸入可能不是一個主要用例,但該 API 是可用的。這可能非常有利,特別是對於中間網路快取,例如大型語言模型 (LLM) 中的鍵值快取。
import onnxruntime as ort
import numpy as np
session = ort.InferenceSession("model.onnx",
providers=["DmlExecutionProvider"])
cpu_array = np.zeros((1, 4, 512, 512), dtype=np.float32)
dml_array = ort.OrtValue.ortvalue_from_numpy(cpu_array, "dml")
binding = session.io_binding()
binding.bind_ortvalue_input("data", dml_array)
binding.bind_output("out", "dml")
# if the output dims are known we can also bind a preallocated value
# binding.bind_ortvalue_output("out", dml_array_out)
session.run_with_iobinding(binding)