推理 PyTorch 模型
瞭解 PyTorch 以及如何使用 PyTorch 模型進行推理。
PyTorch 以其易於理解和靈活的 API;大量可用的現成模型,特別是在自然語言處理 (NLP) 領域;以及其領域特定的庫,引領著深度學習領域。
越來越多的開發者和應用程式希望使用 PyTorch 構建的模型,本文快速介紹了 PyTorch 模型的推理。PyTorch 模型有多種不同的推理方式;下面將列舉這些方式。
本文假定您正在尋找有關如何使用 PyTorch 模型進行推理的資訊,而不是如何訓練 PyTorch 模型。
目錄
PyTorch 概述
PyTorch 的核心是 nn.Module,它是一個代表整個深度學習模型或單個層的類。模組可以組合或擴充套件以構建模型。要編寫自己的模組,您需要實現一個根據模型輸入和模型訓練權重計算輸出的前向函式。如果您正在編寫自己的 PyTorch 模型,那麼您很可能也在訓練它。或者,您可以使用 PyTorch 本身或其他庫(例如 HuggingFace)的預訓練模型。
使用 PyTorch 本身編寫影像處理模型
import torch
import torch.nn as nn
import torchvision.transforms as T
from torchvision.models import resnet18, ResNet18_Weights
class Predictor(nn.Module):
def __init__(self):
super().__init__()
weights = ResNet18_Weights.DEFAULT
self.resnet18 = resnet18(weights=weights, progress=False).eval()
self.transforms = weights.transforms()
def forward(self, x: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
x = self.transforms(x)
y_pred = self.resnet18(x)
return y_pred.argmax(dim=1)
要使用 HuggingFace 庫建立語言模型,您可以
model_name = "bert-large-uncased-whole-word-masking-finetuned-squad"
tokenizer = transformers.BertTokenizer.from_pretrained(model_name)
model = transformers.BertForQuestionAnswering.from_pretrained(model_name)
建立或匯入訓練好的模型後,如何執行它進行推理?下面我們介紹幾種您可以在 PyTorch 中進行推理的方法。
推理選項
使用原生 PyTorch 進行推理
如果您對效能或大小不敏感,並且在包含 Python 可執行檔案和庫的環境中執行,您可以在原生 PyTorch 中執行您的應用程式。
獲得訓練好的模型後,您(或您的資料科學團隊)可以使用兩種方法來儲存和載入模型以進行推理:
-
儲存和載入整個模型
# Save the entire model to PATH torch.save(model, PATH) # Load the model from PATH and set eval mode for inference model = torch.load(PATH) model.eval() -
儲存模型引數,重新宣告模型,然後載入引數
# Save the model parameters torch.save(model.state_dict(), PATH) # Redeclare the model and load the saved parameters model = TheModel(...) model.load_state_dict(torch.load(PATH)) model.eval()
您使用哪種方法取決於您的配置。儲存和載入整個模型意味著您無需重新宣告模型,甚至無需訪問模型程式碼本身。但缺點是,儲存環境和載入環境必須在可用的類、方法和引數方面匹配(因為這些是直接序列化和反序列化的)。
只要您可以訪問原始模型程式碼,儲存模型訓練好的引數(狀態字典或 state_dict)比第一種方法更靈活。
您可能不想使用原生 PyTorch 對模型進行推理的主要原因有兩個。首先,您必須在包含 Python 執行時、PyTorch 庫和相關依賴項的環境中執行——這些檔案加起來有幾千兆位元組。如果您想在行動電話、Web 瀏覽器或專用硬體等環境中執行,使用原生 PyTorch 進行 PyTorch 推理將無法工作。第二個原因是效能。開箱即用的 PyTorch 模型可能無法提供您的應用程式所需的效能。
使用 TorchScript 進行推理
如果您在更受限制的環境中執行,無法安裝 PyTorch 或其他 Python 庫,則可以選擇使用已轉換為 TorchScript 的 PyTorch 模型進行推理。TorchScript 是 Python 的一個子集,它允許您建立可序列化的模型,這些模型可以在非 Python 環境中載入和執行。
# Export to TorchScript
script = torch.jit.script(model, example)
# Save scripted model
script.save(PATH)
# Load scripted model
model = torch.jit.load(PATH)
model.eval()
#include <torch/script.h>
...
torch::jit::script::Module module;
try {
// Deserialize the ScriptModule
module = torch::jit::load(PATH);
}
catch (const c10::Error& e) {
...
}
...
雖然您不需要在環境中擁有 Python 執行時即可使用 TorchScript 方法對 PyTorch 模型進行推理,但您確實需要安裝 libtorch 二進位制檔案,這些檔案可能對於您的環境來說太大了。您也可能無法獲得應用程式所需的效能。
使用 ONNXRuntime 進行推理
當效能和可移植性至關重要時,您可以使用 ONNXRuntime 對 PyTorch 模型進行推理。透過 ONNXRuntime,您可以減少延遲和記憶體佔用,並提高吞吐量。您還可以使用 ONNXRuntime 提供的語言繫結和庫,在雲、邊緣、Web 或移動裝置上執行模型。
第一步是使用 PyTorch ONNX 匯出器將您的 PyTorch 模型匯出為 ONNX 格式。
# Specify example data
example = ...
# Export model to ONNX format
torch.onnx.export(model, PATH, example)
匯出為 ONNX 格式後,您可以選擇在 Netron 檢視器中檢視模型,以瞭解模型圖、輸入和輸出節點名稱和形狀,以及哪些節點具有可變大小的輸入和輸出(動態軸)。
然後您可以在您選擇的環境中執行 ONNX 模型。ONNXRuntime 引擎用 C++ 實現,並提供 C++、Python、C#、Java、Javascript、Julia 和 Ruby 的 API。ONNXRuntime 可以在 Linux、Mac、Windows、iOS 和 Android 上執行您的模型。例如,以下程式碼片段展示了一個 C++ 推理應用程式的骨架。
// Allocate ONNXRuntime session
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
Ort::Env env;
Ort::Session session{env, ORT_TSTR("model.onnx"), Ort::SessionOptions{nullptr}};
// Allocate model inputs: fill in shape and size
std::array<float, ...> input{};
std::array<int64_t, ...> input_shape{...};
Ort::Value input_tensor = Ort::Value::CreateTensor<float>(memory_info, input.data(), input.size(), input_shape.data(), input_shape.size());
const char* input_names[] = {...};
// Allocate model outputs: fill in shape and size
std::array<float, ...> output{};
std::array<int64_t, ...> output_shape{...};
Ort::Value output_tensor = Ort::Value::CreateTensor<float>(memory_info, output.data(), output.size(), output_shape.data(), output_shape.size());
const char* output_names[] = {...};
// Run the model
session_.Run(Ort::RunOptions{nullptr}, input_names, &input_tensor, 1, output_names, &output_tensor, 1);
開箱即用,ONNXRuntime 會對 ONNX 圖應用一系列最佳化,儘可能地組合節點並提取常量值(常量摺疊)。ONNXRuntime 還透過其執行提供程式介面與多種硬體加速器整合,包括 CUDA、TensorRT、OpenVINO、CoreML 和 NNAPI,具體取決於您所針對的硬體平臺。
您可以透過量化 ONNX 模型來進一步提高其效能。
如果應用程式在移動和邊緣等受限環境中執行,您可以根據應用程式執行的模型或模型集構建一個精簡版執行時。
要在您選擇的語言和環境中開始使用,請參閱ONNX Runtime 快速入門