建立 Float16 和混合精度模型

將模型轉換為使用 float16 而不是 float32 可以減小模型大小(最多一半)並提高某些 GPU 上的效能。可能會有一些精度損失,但在許多模型中,新的精度是可接受的。float16 轉換不需要調優資料,這使其優於量化。

目錄

Float16 轉換

按照以下步驟將模型轉換為 float16:

  1. 安裝 onnx 和 onnxconverter-common

    pip install onnx onnxconverter-common

  2. 在 python 中使用 convert_float_to_float16 函式。

     import onnx
     from onnxconverter_common import float16
    
     model = onnx.load("path/to/model.onnx")
     model_fp16 = float16.convert_float_to_float16(model)
     onnx.save(model_fp16, "path/to/model_fp16.onnx")
    

Float16 工具引數

如果轉換後的模型不起作用或精度較差,您可能需要設定額外的引數。

convert_float_to_float16(model, min_positive_val=1e-7, max_finite_val=1e4, keep_io_types=False,
                         disable_shape_infer=False, op_block_list=None, node_block_list=None)
  • model:要轉換的 ONNX 模型。
  • min_positive_val, max_finite_val:常量值將被裁剪到這些邊界。 0.0, nan, inf, 和 -inf 將保持不變。
  • keep_io_types:模型輸入/輸出是否應保留為 float32。
  • disable_shape_infer:跳過執行 onnx 形狀/型別推斷。當形狀推斷崩潰、模型中已存在形狀/型別或不需要型別時(型別用於確定不受支援/被阻止的運算元需要插入 cast 運算元的位置),此引數很有用。
  • op_block_list:要保留為 float32 的運算元型別列表。預設使用 float16.DEFAULT_OP_BLOCK_LIST 中的列表。此列表包含 ONNX Runtime 中不支援 float16 的運算元。
  • node_block_list:要保留為 float32 的節點名稱列表。

注意:被阻止的運算元周圍將插入從 float16/float32 到 float32/float16 的 cast 運算元。目前,如果兩個被阻止的運算元相鄰,仍然會插入 cast 運算元,從而建立冗餘對。ORT 會在執行時最佳化掉這對冗餘運算元,因此結果將保持全精度。

混合精度

如果 float16 轉換導致結果不佳,您可以將大多數運算元轉換為 float16,但保留一些運算元為 float32。auto_mixed_precision.auto_convert_mixed_precision 工具會找到一個最小的運算元集來跳過轉換,同時保持一定的精度水平。您需要為模型提供一個示例輸入。

由於 ONNX Runtime 的 CPU 版本不支援 float16 運算元,並且該工具需要測量精度損失,因此混合精度工具必須在帶有 GPU 的裝置上執行

from onnxconverter_common import auto_mixed_precision
import onnx

model = onnx.load("path/to/model.onnx")
# Assuming x is the input to the model
feed_dict = {'input': x.numpy()}
model_fp16 = auto_convert_mixed_precision(model, feed_dict, rtol=0.01, atol=0.001, keep_io_types=True)
onnx.save(model_fp16, "path/to/model_fp16.onnx")

混合精度工具引數

auto_convert_mixed_precision(model, feed_dict, validate_fn=None, rtol=None, atol=None, keep_io_types=False)
  • model:要轉換的 ONNX 模型。
  • feed_dict:用於在轉換期間測量模型精度的測試資料。格式類似於 InferenceSession.run(輸入名稱到值的對映)。
  • validate_fn:一個函式,接受兩個 numpy 陣列列表(分別是 float32 模型和混合精度模型的輸出),如果結果足夠接近則返回 True,否則返回 False。可以代替或補充 rtolatol 使用。
  • rtol, atol:用於驗證的絕對和相對容差。更多資訊請參閱 numpy.allclose
  • keep_io_types:模型輸入/輸出是否應保留為 float32。

混合精度工具透過將運算元簇轉換為 float16 來工作。如果一個簇失敗,它會被分成兩半,並獨立嘗試兩個簇。工具執行時會列印簇大小的視覺化。