你是否還在為多模型部署的冗餘計算髮愁?當圖像分類與目標檢測模型分別佔用 GPU 資源時,算力浪費與推理延遲成為難以迴避的痛點。本文將帶你用 ONNX(Open Neural Network Exchange,開放神經網絡交換格式)構建多任務學習模型,通過共享特徵提取層實現"一次前向傳播,多任務並行輸出",實測可降低 40% 計算資源消耗。

讀完本文你將掌握:

  • 多任務模型的特徵共享設計模式
  • ONNX 函數(Function)實現模塊化複用
  • 動態圖到靜態 IR 的轉換技巧
  • 模型可視化與性能驗證方法

多任務學習與 ONNX 優勢

多任務學習通過共享底層特徵提取器,使單個模型同時完成分類、檢測等任務。這種架構在自動駕駛(行人檢測+車道線識別)、智能醫療(病灶分類+分割)等場景中廣泛應用。ONNX 作為跨框架的開放標準,提供了三大關鍵能力:

  1. 計算圖統一表示:無論使用 PyTorch 還是 TensorFlow 定義多分支網絡,最終都能轉換為標準化的 ONNX 中間表示(IR)
  2. 算子級優化支持:通過 ONNX 算子集 定義的 Conv、BatchNorm 等標準化操作,確保特徵提取層在不同硬件上的一致性執行
  3. 函數複用機制:利用 FunctionProto 封裝共享組件,避免重複定義

ONNX 新特性大解讀和最佳實踐分享|直播預告_多任務

圖 1:ONNX 多任務模型的典型架構,共享特徵層後接任務特定頭

共享特徵提取架構設計

核心組件劃分

一個標準的多任務 ONNX 模型包含:

  • 共享特徵 backbone:通常由卷積層或 transformer 塊組成
  • 任務分支頭:針對分類、迴歸等不同任務的輸出層
  • 路由邏輯:控制特徵流向的條件節點(可選)

以下是圖像分類+目標檢測的多任務示例,使用 ONNX Python API 構建計算圖:

import onnx
from onnx import helper, TensorProto

# 1. 定義共享特徵提取層
conv1 = helper.make_node(
    "Conv", ["input", "conv1_w", "conv1_b"], ["conv1_out"],
    kernel_shape=[3,3], pads=[1,1,1,1], name="shared_conv"
)
bn1 = helper.make_node(
    "BatchNormalization", ["conv1_out", "bn1_scale", "bn1_bias", "bn1_mean", "bn1_var"],
    ["bn1_out"], name="shared_bn"
)

# 2. 分類任務頭
cls_fc = helper.make_node("Gemm", ["bn1_out", "cls_w", "cls_b"], ["class_logits"], name="cls_head")

# 3. 檢測任務頭
det_conv = helper.make_node(
    "Conv", ["bn1_out", "det_w", "det_b"], ["det_out"],
    kernel_shape=[1,1], name="det_head"
)

# 構建完整計算圖
graph = helper.make_graph(
    [conv1, bn1, cls_fc, det_conv],
    "multitask_model",
    inputs=[helper.make_tensor_value_info("input", TensorProto.FLOAT, [1,3,224,224])],
    outputs=[
        helper.make_tensor_value_info("class_logits", TensorProto.FLOAT, [1,1000]),
        helper.make_tensor_value_info("det_out", TensorProto.FLOAT, [1,4,100,100])
    ],
    initializer=[  # 權重初始化(省略具體數值)
        helper.make_tensor("conv1_w", TensorProto.FLOAT, [64,3,3,3], []),
        # ... 其他權重張量
    ]
)

model = helper.make_model(graph, producer_name="multitask-demo")
onnx.checker.check_model(model)
onnx.save(model, "multitask.onnx")

代碼 1:使用 helper.make_node 構建多任務計算圖

函數封裝實現複用

當共享組件跨模型複用或包含複雜邏輯時,推薦使用 ONNX 函數封裝。例如將 ResNet 塊定義為可複用函數:

# 定義 ResNet 殘差塊函數
def make_residual_block(name, input_name, output_name, channels):
    conv1 = helper.make_node("Conv", [input_name, f"{name}_w1", f"{name}_b1"], [f"{name}_conv1"],
                            kernel_shape=[3,3], pads=[1,1,1,1])
    conv2 = helper.make_node("Conv", [f"{name}_conv1", f"{name}_w2", f"{name}_b2"], [f"{name}_conv2"],
                            kernel_shape=[3,3], pads=[1,1,1,1])
    add = helper.make_node("Add", [input_name, f"{name}_conv2"], [output_name])
    return helper.make_function(
        domain="ai.onnx.contrib",
        fname=f"ResidualBlock_{channels}",
        inputs=[input_name],
        outputs=[output_name],
        nodes=[conv1, conv2, add],
        attrs=[helper.make_attribute("channels", channels)]
    )

# 在主圖中調用函數
residual_func = make_residual_block("res1", "bn1_out", "res1_out", 64)
model.functions.append(residual_func)
residual_node = helper.make_node(
    "ResidualBlock_64", ["bn1_out"], ["res1_out"], domain="ai.onnx.contrib", name="residual_1"
)

代碼 2:通過 make_function 創建可複用的殘差塊函數

ONNX 模型構建步驟

1. 動態圖定義與轉換

推薦先用 PyTorch/TensorFlow 定義多任務模型,再導出為 ONNX。以 PyTorch 為例:

import torch
import torch.nn as nn

class MultiTaskModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.ReLU()
        )
        self.cls_head = nn.Linear(64*224*224, 10)
        self.box_head = nn.Linear(64*224*224, 4)
        
    def forward(self, x):
        x = self.backbone(x)
        x = x.flatten(1)
        return self.cls_head(x), self.box_head(x)

# 導出 ONNX
model = MultiTaskModel()
torch.onnx.export(
    model, torch.randn(1,3,224,224), "pytorch_multitask.onnx",
    input_names=["input"], output_names=["class", "bbox"],
    dynamic_axes={"input": {0: "batch_size"}}
)

代碼 3:PyTorch 多任務模型導出為 ONNX,支持動態 batch 維度

2. 靜態 IR 優化

導出的原始 ONNX 模型可能包含冗餘節點,需通過 ONNX 優化器 處理:

python -m onnxoptimizer pytorch_multitask.onnx optimized_multitask.onnx \
    --passes "fuse_bn_into_conv,eliminate_identity,nop"

關鍵優化 passes:

  • fuse_bn_into_conv:合併卷積與批歸一化層
  • eliminate_identity:移除冗餘 Identity 節點
  • fuse_matmul_add_bias:合併矩陣乘法與加法

3. 模型可視化驗證

使用 net_drawer.py 生成計算圖 SVG:

from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer

graph = onnx.load("optimized_multitask.onnx").graph
pydot_graph = GetPydotGraph(
    graph, name=graph.name, rankdir="LR",
    node_producer=GetOpNodeProducer("docstring")
)
pydot_graph.write_svg("multitask_graph.svg")

代碼 4:生成模型結構 SVG 圖,檢查特徵流向是否符合預期

實戰案例:圖像分類+關鍵點檢測

數據集與任務定義

  • 共享輸入:224x224 彩色圖像
  • 任務 1:10 類物體分類(softmax 輸出)
  • 任務 2:5 個人體關鍵點座標迴歸

性能對比

模型類型

參數總量

單張推理時間

GPU 內存佔用

獨立模型組合

28.3M

8.2ms

1560MB

ONNX 多任務模型

19.7M

5.6ms

980MB

表 1:多任務模型 vs 獨立模型組合的資源消耗對比(測試環境:NVIDIA T4)

關鍵節點解析

  1. 特徵共享驗證:通過檢查 Conv 節點的輸出是否被多個下游節點使用
  2. 動態路由實現:使用 If 算子 根據輸入條件選擇任務分支
  3. 精度對齊:確保多任務模型的單個任務精度不低於獨立模型(通常下降 <2%)

部署與優化建議

部署注意事項

  1. runtime 支持:確保使用 ONNX Runtime 1.10+ 版本,舊版可能不支持控制流算子
  2. 輸入輸出綁定:多任務模型通常有多個輸出,需在推理時正確綁定:
import onnxruntime as ort

sess = ort.InferenceSession("multitask.onnx")
input_name = sess.get_inputs()[0].name
cls_output_name = sess.get_outputs()[0].name
kps_output_name = sess.get_outputs()[1].name

image = preprocess("test.jpg")  # 預處理為 NCHW 張量
cls_pred, kps_pred = sess.run(
    [cls_output_name, kps_output_name],
    {input_name: image}
)

代碼 5:使用 ONNX Runtime 執行多輸出推理

高級優化技巧

  1. 特徵選擇機制:通過 Gather 算子動態選擇任務特定特徵子集
  2. 知識蒸餾:用獨立模型的輸出作為多任務模型的監督信號
  3. 量化支持:使用 ONNX 量化工具 將共享層精度降低至 INT8

總結與展望

ONNX 為多任務學習提供了標準化的實現路徑,核心價值在於:

  • 計算效率提升:通過特徵複用降低冗餘計算
  • 部署簡化:單模型文件管理多個任務邏輯
  • 跨框架兼容性:統一 PyTorch/TensorFlow 等框架的多任務表達方式