你是否還在為多模型部署的冗餘計算髮愁?當圖像分類與目標檢測模型分別佔用 GPU 資源時,算力浪費與推理延遲成為難以迴避的痛點。本文將帶你用 ONNX(Open Neural Network Exchange,開放神經網絡交換格式)構建多任務學習模型,通過共享特徵提取層實現"一次前向傳播,多任務並行輸出",實測可降低 40% 計算資源消耗。
讀完本文你將掌握:
- 多任務模型的特徵共享設計模式
- ONNX 函數(Function)實現模塊化複用
- 動態圖到靜態 IR 的轉換技巧
- 模型可視化與性能驗證方法
多任務學習與 ONNX 優勢
多任務學習通過共享底層特徵提取器,使單個模型同時完成分類、檢測等任務。這種架構在自動駕駛(行人檢測+車道線識別)、智能醫療(病灶分類+分割)等場景中廣泛應用。ONNX 作為跨框架的開放標準,提供了三大關鍵能力:
- 計算圖統一表示:無論使用 PyTorch 還是 TensorFlow 定義多分支網絡,最終都能轉換為標準化的 ONNX 中間表示(IR)
- 算子級優化支持:通過 ONNX 算子集 定義的 Conv、BatchNorm 等標準化操作,確保特徵提取層在不同硬件上的一致性執行
- 函數複用機制:利用 FunctionProto 封裝共享組件,避免重複定義
圖 1:ONNX 多任務模型的典型架構,共享特徵層後接任務特定頭
共享特徵提取架構設計
核心組件劃分
一個標準的多任務 ONNX 模型包含:
- 共享特徵 backbone:通常由卷積層或 transformer 塊組成
- 任務分支頭:針對分類、迴歸等不同任務的輸出層
- 路由邏輯:控制特徵流向的條件節點(可選)
以下是圖像分類+目標檢測的多任務示例,使用 ONNX Python API 構建計算圖:
import onnx
from onnx import helper, TensorProto
# 1. 定義共享特徵提取層
conv1 = helper.make_node(
"Conv", ["input", "conv1_w", "conv1_b"], ["conv1_out"],
kernel_shape=[3,3], pads=[1,1,1,1], name="shared_conv"
)
bn1 = helper.make_node(
"BatchNormalization", ["conv1_out", "bn1_scale", "bn1_bias", "bn1_mean", "bn1_var"],
["bn1_out"], name="shared_bn"
)
# 2. 分類任務頭
cls_fc = helper.make_node("Gemm", ["bn1_out", "cls_w", "cls_b"], ["class_logits"], name="cls_head")
# 3. 檢測任務頭
det_conv = helper.make_node(
"Conv", ["bn1_out", "det_w", "det_b"], ["det_out"],
kernel_shape=[1,1], name="det_head"
)
# 構建完整計算圖
graph = helper.make_graph(
[conv1, bn1, cls_fc, det_conv],
"multitask_model",
inputs=[helper.make_tensor_value_info("input", TensorProto.FLOAT, [1,3,224,224])],
outputs=[
helper.make_tensor_value_info("class_logits", TensorProto.FLOAT, [1,1000]),
helper.make_tensor_value_info("det_out", TensorProto.FLOAT, [1,4,100,100])
],
initializer=[ # 權重初始化(省略具體數值)
helper.make_tensor("conv1_w", TensorProto.FLOAT, [64,3,3,3], []),
# ... 其他權重張量
]
)
model = helper.make_model(graph, producer_name="multitask-demo")
onnx.checker.check_model(model)
onnx.save(model, "multitask.onnx")
代碼 1:使用 helper.make_node 構建多任務計算圖
函數封裝實現複用
當共享組件跨模型複用或包含複雜邏輯時,推薦使用 ONNX 函數封裝。例如將 ResNet 塊定義為可複用函數:
# 定義 ResNet 殘差塊函數
def make_residual_block(name, input_name, output_name, channels):
conv1 = helper.make_node("Conv", [input_name, f"{name}_w1", f"{name}_b1"], [f"{name}_conv1"],
kernel_shape=[3,3], pads=[1,1,1,1])
conv2 = helper.make_node("Conv", [f"{name}_conv1", f"{name}_w2", f"{name}_b2"], [f"{name}_conv2"],
kernel_shape=[3,3], pads=[1,1,1,1])
add = helper.make_node("Add", [input_name, f"{name}_conv2"], [output_name])
return helper.make_function(
domain="ai.onnx.contrib",
fname=f"ResidualBlock_{channels}",
inputs=[input_name],
outputs=[output_name],
nodes=[conv1, conv2, add],
attrs=[helper.make_attribute("channels", channels)]
)
# 在主圖中調用函數
residual_func = make_residual_block("res1", "bn1_out", "res1_out", 64)
model.functions.append(residual_func)
residual_node = helper.make_node(
"ResidualBlock_64", ["bn1_out"], ["res1_out"], domain="ai.onnx.contrib", name="residual_1"
)
代碼 2:通過 make_function 創建可複用的殘差塊函數
ONNX 模型構建步驟
1. 動態圖定義與轉換
推薦先用 PyTorch/TensorFlow 定義多任務模型,再導出為 ONNX。以 PyTorch 為例:
import torch
import torch.nn as nn
class MultiTaskModel(nn.Module):
def __init__(self):
super().__init__()
self.backbone = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1),
nn.ReLU()
)
self.cls_head = nn.Linear(64*224*224, 10)
self.box_head = nn.Linear(64*224*224, 4)
def forward(self, x):
x = self.backbone(x)
x = x.flatten(1)
return self.cls_head(x), self.box_head(x)
# 導出 ONNX
model = MultiTaskModel()
torch.onnx.export(
model, torch.randn(1,3,224,224), "pytorch_multitask.onnx",
input_names=["input"], output_names=["class", "bbox"],
dynamic_axes={"input": {0: "batch_size"}}
)
代碼 3:PyTorch 多任務模型導出為 ONNX,支持動態 batch 維度
2. 靜態 IR 優化
導出的原始 ONNX 模型可能包含冗餘節點,需通過 ONNX 優化器 處理:
python -m onnxoptimizer pytorch_multitask.onnx optimized_multitask.onnx \
--passes "fuse_bn_into_conv,eliminate_identity,nop"
關鍵優化 passes:
fuse_bn_into_conv:合併卷積與批歸一化層eliminate_identity:移除冗餘 Identity 節點fuse_matmul_add_bias:合併矩陣乘法與加法
3. 模型可視化驗證
使用 net_drawer.py 生成計算圖 SVG:
from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer
graph = onnx.load("optimized_multitask.onnx").graph
pydot_graph = GetPydotGraph(
graph, name=graph.name, rankdir="LR",
node_producer=GetOpNodeProducer("docstring")
)
pydot_graph.write_svg("multitask_graph.svg")
代碼 4:生成模型結構 SVG 圖,檢查特徵流向是否符合預期
實戰案例:圖像分類+關鍵點檢測
數據集與任務定義
- 共享輸入:224x224 彩色圖像
- 任務 1:10 類物體分類(softmax 輸出)
- 任務 2:5 個人體關鍵點座標迴歸
性能對比
|
模型類型
|
參數總量
|
單張推理時間
|
GPU 內存佔用
|
|
獨立模型組合
|
28.3M
|
8.2ms
|
1560MB
|
|
ONNX 多任務模型
|
19.7M
|
5.6ms
|
980MB
|
表 1:多任務模型 vs 獨立模型組合的資源消耗對比(測試環境:NVIDIA T4)
關鍵節點解析
- 特徵共享驗證:通過檢查 Conv 節點的輸出是否被多個下游節點使用
- 動態路由實現:使用 If 算子 根據輸入條件選擇任務分支
- 精度對齊:確保多任務模型的單個任務精度不低於獨立模型(通常下降 <2%)
部署與優化建議
部署注意事項
- runtime 支持:確保使用 ONNX Runtime 1.10+ 版本,舊版可能不支持控制流算子
- 輸入輸出綁定:多任務模型通常有多個輸出,需在推理時正確綁定:
import onnxruntime as ort
sess = ort.InferenceSession("multitask.onnx")
input_name = sess.get_inputs()[0].name
cls_output_name = sess.get_outputs()[0].name
kps_output_name = sess.get_outputs()[1].name
image = preprocess("test.jpg") # 預處理為 NCHW 張量
cls_pred, kps_pred = sess.run(
[cls_output_name, kps_output_name],
{input_name: image}
)
代碼 5:使用 ONNX Runtime 執行多輸出推理
高級優化技巧
- 特徵選擇機制:通過 Gather 算子動態選擇任務特定特徵子集
- 知識蒸餾:用獨立模型的輸出作為多任務模型的監督信號
- 量化支持:使用 ONNX 量化工具 將共享層精度降低至 INT8
總結與展望
ONNX 為多任務學習提供了標準化的實現路徑,核心價值在於:
- 計算效率提升:通過特徵複用降低冗餘計算
- 部署簡化:單模型文件管理多個任務邏輯
- 跨框架兼容性:統一 PyTorch/TensorFlow 等框架的多任務表達方式