在邊緣執行 PyTorch 模型
作者:Natalie Kershaw 和 Prasanth Pulavarthi
2023年10月12日
大多數現代機器學習模型都使用 PyTorch 開發。PyTorch 在建立和訓練模型方面提供的敏捷性和靈活性使其成為當今最流行的深度學習框架。典型的工作流程是在雲端訓練這些模型,並也在雲端執行它們。然而,許多場景正在出現,使得在裝置本地執行模型更具吸引力,甚至在某些情況下成為必需。這些場景包括:
- 避免與雲端的網路往返(例如在音訊和影片處理中)
- 將使用者資料保留在裝置上(用於隱私保護或滿足法規要求)
- 雲資源成本高昂(尤其是在裝置能力未得到充分利用時)
- 應用程式需要在沒有網際網路連線的情況下執行
在本文中,我們將揭秘如何在邊緣執行 PyTorch 模型。我們將“邊緣”定義為雲端之外的任何地方,從資源充足的大型個人電腦到手機等小型裝置。過去,完成這項任務一直充滿挑戰,但模型最佳化和 ONNX Runtime 等軟體的新進展使其變得更加可行——即使是對於 Stable Diffusion、Whisper 和 Llama2 等新型生成式 AI 和大型語言模型。
在邊緣執行 PyTorch 模型的考量因素
在考慮在邊緣執行 PyTorch 模型時,有幾個因素需要牢記。
- 大小:現代模型可以達到幾千兆位元組(因此得名大型語言模型!)。在雲端,模型大小通常不是一個問題,除非它變得太大而無法容納在單個 GPU 上。那時,有各種成熟的解決方案可以在多個 GPU 上執行。對於邊緣裝置,我們需要找到能夠適應裝置限制的模型。這有時需要權衡模型質量。大多數現代模型都有多種尺寸(例如10億引數、130億引數、700億引數等),因此您可以選擇適合您裝置的變體。通常會應用量化等技術來減少表示引數的位數,從而進一步減小模型大小。應用程式的大小也受到應用商店的限制,因此引入幾千兆位元組的庫在邊緣裝置上是不可行的。
- 應用程式整合 API:在雲端,模型通常被打包成 Docker 容器,這些容器公開一個供應用程式或服務呼叫的端點。在邊緣裝置上,Docker 容器可能會佔用過多資源,甚至可能不受支援。透過使用像 ONNX Runtime 這樣的最佳化引擎,可以消除對 Python 和 Docker 容器的依賴。ONNX Runtime 還提供多種語言的 API,包括 C、C++、C#、Rust、Java、JavaScript、Objective-C 和 Swift,以便更輕鬆地與宿主應用程式進行原生整合。
- 效能:在雲端,憑藉大量記憶體、無功耗限制和強大的計算能力,執行未最佳化的模型是可能的。在邊緣裝置上,這些“奢侈品”不存在,因此最佳化至關重要。例如,ONNX Runtime 最佳化記憶體分配、融合模型運算子、減少核心啟動時間、最小化處理單元之間的張量傳輸,並應用經過調優的矩陣數學演算法。它還能夠利用裝置特定的編譯器和引擎,為您的應用程式提供通用介面,同時在每個裝置上發揮最佳效能。
- 可維護性:在雲端,更新模型就像部署新的容器映象和增加流量一樣簡單。在邊緣端,您需要考慮如何分發模型更新。有時這涉及嚮應用商店釋出更新,有時可能需要在您的應用程式中實現資料更新機制並下載新的模型檔案,甚至只下載增量更新。有許多可能的途徑,因此本文不會深入探討此主題,但這是您在規劃生產時需要牢記的一個方面。
- 混合模式:您可以選擇同時利用雲端和裝置端,而非僅選擇其一。如今,Office 等應用程式在生產中使用了多種混合模式。一種模式是根據網路條件或輸入特性動態決定在裝置上還是在雲端執行。另一種模式是在裝置上執行模型管道的一部分,在雲端執行另一部分。這對於具有獨立編碼器和解碼器階段的現代模型管道尤其有用。使用像 ONNX Runtime 這樣同時支援雲端和裝置端的引擎可以簡化開發。我們將在後續文章中更詳細地討論混合場景。
- 個性化:在許多情況下,PyTorch 模型只是簡單地在裝置上執行。然而,您也可能遇到需要在裝置上個性化模型而無需將資料傳送到雲端的場景。推薦和內容定向就是可以透過根據裝置上的活動更新模型來提高質量的示例場景。在裝置上使用 PyTorch 進行微調和訓練可能不可行(由於效能和大小問題),但使用像 ONNX Runtime 這樣的引擎可以允許 PyTorch 模型在本地進行更新和個性化。相同的機制還支援聯邦學習,這有助於減輕使用者資料暴露的風險。
在邊緣執行 PyTorch 模型的工具
我們前面多次提到 ONNX Runtime。ONNX Runtime 是一個緊湊、基於標準的引擎,與 PyTorch 深度整合。透過使用 PyTorch 的 ONNX API,您的 PyTorch 模型可以在各種邊緣裝置上使用 ONNX Runtime 執行。
在邊緣執行 PyTorch 模型的第一步是將其轉換為輕量級格式,使其不再需要 PyTorch 框架及其數千兆位元組的依賴。PyTorch 已經考慮到了這一點,並提供了一個專門實現此功能的 API——torch.onnx。ONNX 是一個開放標準,定義了構成模型的運算子。PyTorch ONNX API 將 Python 風格的 PyTorch 程式碼轉換為功能圖,該圖捕獲了無需 Python 即可執行模型所需的運算子。像機器學習中的所有事物一樣,也存在一些需要注意的限制。某些 PyTorch 模型無法表示為單個圖——在這種情況下,您可能需要輸出多個圖並在自己的管道中將它們拼接起來。
流行的 Hugging Face 庫也提供了基於 torch.onnx 功能的 API,用於將模型匯出為 ONNX 格式。超過 130,000 個模型受支援,這意味著您關心的模型很可能就在其中。
在本文中,我們將透過多種語言(從 C# 到 JavaScript 再到 Swift),向您展示在流行裝置(如 Windows 筆記型電腦、手機和網頁瀏覽器)上執行最先進的 PyTorch 模型(如 Whisper 和 Stable Diffusion)的幾個示例。
在邊緣執行 PyTorch 模型的示例
在 Windows 上執行 Stable Diffusion
Stable Diffusion 管道由五個 PyTorch 模型組成,這些模型根據文字描述生成影像。擴散過程會迭代隨機畫素,直到輸出影像與描述匹配。
為了在邊緣執行,其中四個模型可以從 HuggingFace 匯出為 ONNX 格式。
from optimum.onnxruntime import ORTStableDiffusionPipeline
pipeline = ORTStableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", export=True)
pipeline.save_pretrained("./onnx-stable-diffusion") 您無需匯出第五個模型 ClipTokenizer,因為它在 ONNX Runtime 擴充套件中可用,該庫用於 PyTorch 模型的預處理和後處理。
為了將這一系列模型作為 .NET 應用程式執行,我們使用 C# 構建了管道程式碼。如果您的機器上有可用的 CPU、GPU 或 NPU,此程式碼可以使用 ONNX Runtime 的裝置專用硬體加速器在其上執行。這透過下面的 ExecutionProviderTarget 進行配置。
static void Main(string[] args)
{
var prompt = "Two golden retriever puppies playing in the grass.";
var config = new StableDiffusionConfig
{
NumInferenceSteps = 50,
GuidanceScale = 7.5,
ExecutionProviderTarget = StableDiffusionConfig.ExecutionProvider.Cpu,
DeviceId = 0,
TokenizerOnnxPath = ".\models\tokenizer\model.onnx",
TextEncoderOnnxPath = ".\models\text_encoder\model.onnx",
UnetOnnxPath = ".\models\unet\model.onnx",
VaeDecoderOnnxPath = ".\models\vae_decoder\model.onnx",
SafetyModelPath = ".\models\safety_checker\model.onnx",
};
var image = UNet.Inference(prompt, config);
if (image == null)
{
Console.WriteLine("Unable to create image, please try again.");
}
} 這是模型管道的輸出,以 50 次推理迭代執行
您可以按照此 教程中顯示的詳細步驟,在 Windows 上構建並執行該應用程式。
瀏覽器中的文字生成
使用 transformers.js 庫,在瀏覽器中本地執行 PyTorch 模型不僅可能,而且非常簡單。Transformers.js 使用 ONNX Runtime Web 作為其後端。許多模型已經轉換為 ONNX 格式並透過 transformers.js CDN 提供服務,使得在瀏覽器中進行推理只需編寫幾行 HTML 程式碼即可。
<html>
<body>
<h1>Enter starting text …</h1>
<form id="form">
<input type="text" id="inputText">
<button type="submit" id="submitButton">Submit</button>
</form>
<div id="output"></div>
<script type="module">
import { pipeline } from 'https://cdn.jsdelivr.net/npm/@xenova/transformers@2.6.2';
let inputText = document.getElementById('inputText');
let outputDiv = document.getElementById('output');
let submitButton = document.getElementById('submitButton');
submitButton.addEventListener('click', async (e) => {
e.preventDefault();
let generator = await pipeline('text-generation', 'Xenova/LaMini-Neo-125M');
let result = await generator(inputText.value,
{ max_new_tokens: 200,
temperature: 2,
repetition_penalty: 1.5,
no_repeat_ngram_size: 2,
num_beams: 2,
num_return_sequences: 1,
});
outputDiv.innerHTML = result[0].generated_text;
});
</script>
</body>
</html> 您還可以使用純 JavaScript 或在 React 或 Next.js 等 Web 應用程式中嵌入對 transformers 管道的呼叫,或者編寫瀏覽器擴充套件。
ONNX Runtime Web 目前使用 WebAssembly 在 CPU 上執行模型。這對於許多模型來說已經足夠,但如果裝置上存在 GPU,利用 GPU 可以改善使用者體驗。ONNX Runtime Web 對 WebGPU 的支援即將推出,這將使您能夠利用 GPU,同時使用相同的推理 API。
在移動裝置上使用 Whisper 進行語音識別
OpenAI 的 Whisper 是一個 PyTorch 語音識別模型。Whisper 有多種尺寸變體——最小的 Whisper Tiny 適合在移動裝置上執行。使用 Olive 框架可以將 Whisper Tiny 模型的所有元件(音訊解碼器、編碼器、解碼器和文字序列生成)組合並匯出為單個 ONNX 模型。要將此模型作為移動應用程式的一部分執行,您可以使用 ONNX Runtime Mobile,它支援 Android、iOS、React Native 和 MAUI/Xamarin。
ONNX Runtime Mobile 透過 NNAPI(在 Android 上)、CoreML(在 iOS 上)和 XNNPACK(在 iOS 和 Android 上)支援硬體加速。
下面顯示了一個在短音訊樣本上執行語音轉錄的 Android 移動應用示例的相關程式碼片段。
init {
val env = OrtEnvironment.getEnvironment()
val sessionOptions = OrtSession.SessionOptions()
sessionOptions.registerCustomOpLibrary(OrtxPackage.getLibraryPath())
session = env.createSession(modelBytes, sessionOptions)
val nMels: Long = 80
val nFrames: Long = 3000
baseInputs = mapOf(
"min_length" to createIntTensor(env, intArrayOf(1), tensorShape(1)),
"max_length" to createIntTensor(env, intArrayOf(200), tensorShape(1)),
"num_beams" to createIntTensor(env, intArrayOf(1), tensorShape(1)),
"num_return_sequences" to createIntTensor(env, intArrayOf(1), tensorShape(1)),
"length_penalty" to createFloatTensor(env, floatArrayOf(1.0f), tensorShape(1)),
"repetition_penalty" to createFloatTensor(env, floatArrayOf(1.0f), tensorShape(1)),
)
}
data class Result(val text: String, val inferenceTimeInMs: Long)
fun run(audioTensor: OnnxTensor): Result {
val inputs = mutableMapOf()
baseInputs.toMap(inputs)
inputs["audio_pcm"] = audioTensor
val startTimeInMs = SystemClock.elapsedRealtime()
val outputs = session.run(inputs)
val elapsedTimeInMs = SystemClock.elapsedRealtime() - startTimeInMs
val recognizedText = outputs.use {
@Suppress("UNCHECKED_CAST")
(outputs[0].value as Array>)[0][0]
}
return Result(recognizedText, elapsedTimeInMs)
} 您可以錄製一小段音訊片段進行轉錄。
在移動裝置上訓練模型以識別您的聲音
ONNX Runtime 還可以接受預訓練模型並使其適應新資料。它可以在邊緣端做到這一點——特別是在移動裝置上,在那裡很容易錄製您的聲音、訪問您的照片和其他個性化資料。重要的是,您的資料在訓練期間不會離開裝置。
例如,您可以訓練一個 PyTorch 模型,在您的手機上僅識別您自己的聲音,用於身份驗證場景。
PyTorch 模型是在您的開發環境中從 HuggingFace 獲取的,並添加了額外的層來執行說話人分類。
from transformers import Wav2Vec2ForSequenceClassification, AutoConfig
import torch
config = AutoConfig.from_pretrained("superb/wav2vec2-base-superb-sid")
model = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-sid")
model.classifier = torch.nn.Linear(256, 2) 訓練所需的模型和其他元件(用於衡量模型質量的損失函式和用於指導訓練期間權重調整的最佳化器)均使用 ONNX Runtime Training 匯出。
artifacts.generate_artifacts(
onnx_model,
requires_grad=requires_grad,
frozen_params=frozen_params,
loss=CustomCELoss(),
optimizer=artifacts.OptimType.AdamW,
artifact_directory="MyVoice/artifacts",
) 這組工件現在已準備好由移動應用程式載入,這裡以 iOS Swift 程式碼的形式顯示。該應用程式會要求使用者提供語音樣本,模型將使用這些樣本進行訓練。
func trainStep(inputData: [Data], labels: [Int64]) throws {
let inputs = [try getORTValue(dataList: inputData), try getORTValue(labels: labels)]
try trainingSession.trainStep(withInputValues: inputs)
try trainingSession.optimizerStep()
try trainingSession.lazyResetGrad()
} 模型訓練完成後,您可以執行它來驗證語音樣本是否是您本人!
您可以閱讀完整的 說話人驗證教程,並從原始碼構建並執行該應用程式。
下一步是什麼?
在本文中,我們展示了為何要在邊緣執行 PyTorch 模型以及需要考慮的方面。我們還分享了幾個包含程式碼的示例,您可以用於使用 ONNX Runtime 在邊緣執行最先進的 PyTorch 模型。我們還展示了 ONNX Runtime 如何為效能和跨平臺執行而構建,使其成為在邊緣執行 PyTorch 模型的理想方式。使用 ONNX Runtime 在邊緣執行 PyTorch 模型,盡情享受吧!
您可能已經注意到,儘管 ONNX Runtime 經過最佳化可以執行 Llama2,但我們並未包含 Llama2 的示例。那是因為出色的 Llama2 模型值得單獨撰寫一篇文章,敬請期待!