詳解圖神經網絡(GNN):方法與應用(一)

圖神經網絡(Graph Neural Networks, GNN)是深度學習領域針對非歐幾里得數據(如社交網絡、分子結構、知識圖譜)的核心模型家族。與CNN(處理網格結構圖像)、RNN(處理序列結構文本)不同,GNN能夠天然捕捉圖數據中的節點關聯關係拓撲結構信息,已成為解決複雜關聯問題的關鍵技術。本系列將分三篇詳解GNN:第一篇聚焦基礎概念、核心原理與經典模型;第二篇深入進階方法與優化技巧;第三篇拆解工業級應用與實戰案例。

一、圖數據基礎:為什麼需要GNN?

1. 什麼是圖(Graph)?

圖是一種描述對象間關聯關係的數據結構,形式化定義為 \( G = (V, E) \),其中:

  • \( V = \{v_1, v_2, ..., v_N\} \):節點(Vertex/Node)集合,代表實體(如社交網絡中的人、分子中的原子、論文中的作者);
  • \( E = \{e_{12}, e_{13}, ..., e_{ij}\} \):邊(Edge)集合,代表節點間的關係(如朋友關係、化學鍵、論文引用);
  • 可選屬性:
  • 節點屬性 \( X \in \mathbb{R}^{N \times F} \):每個節點的特徵向量(如用户的年齡/性別、原子的化學性質),\( F \) 為特徵維度;
  • 邊屬性 \( E \in \mathbb{R}^{M \times D} \):每條邊的特徵(如朋友關係的親密度、化學鍵的強度),\( M \) 為邊的數量;
  • 圖屬性 \( y \in \mathbb{R}^C \):整個圖的標籤(如分子的毒性分類、社交網絡的主題標籤)。

2. 圖數據的核心特點(與傳統數據的區別)

數據類型

結構特點

代表模型

核心侷限

網格數據(圖像)

規則網格、固定鄰居數

CNN

無法處理不規則關聯

序列數據(文本)

線性順序、固定長度

RNN/Transformer

忽略非順序的複雜關聯

圖數據(網絡)

不規則拓撲、動態鄰居數

GNN

天然適配關聯依賴

關鍵問題:傳統深度學習模型依賴數據的“平移不變性”(如CNN的卷積核滑動),但圖數據的節點鄰居數量不固定、拓撲結構不規則,無法直接套用傳統模型。GNN的核心創新是:通過“鄰居聚合”機制,將節點的局部拓撲信息與自身特徵融合,生成具有結構感知能力的節點嵌入

3. 圖數據的常見類型

  • 無向圖(Undirected Graph):邊無方向(如朋友關係),\( e_{ij} = e_{ji} \);
  • 有向圖(Directed Graph):邊有方向(如論文引用、網頁跳轉),\( e_{ij} \neq e_{ji} \);
  • 加權圖(Weighted Graph):邊帶有權重(如道路長度、社交親密度);
  • 異構圖(Heterogeneous Graph):節點/邊類型不同(如知識圖譜:人、地點、事件為不同節點,“居住”“參與”為不同邊);
  • 動態圖(Dynamic Graph):節點/邊的存在或屬性隨時間變化(如實時社交網絡、動態交通網絡)。

二、GNN的核心原理:鄰居聚合與消息傳遞

GNN的本質是迭代式地聚合節點自身特徵與鄰居節點特徵,讓每個節點“感知”到局部乃至全局的拓撲結構信息。其核心思想可概括為:

一個節點的嵌入向量 = 自身特徵 + 鄰居節點特徵的聚合結果

1. 通用框架:消息傳遞神經網絡(Message Passing Neural Networks, MPNN)

MPNN是絕大多數GNN的基礎框架,由兩個核心階段組成(2017年提出,統一了早期GNN模型):

(1)消息傳遞階段(Message Passing)

每個節點向其鄰居發送“消息”,消息由自身特徵和邊特徵(可選)轉換得到。對於節點 \( v_i \) 和其鄰居 \( v_j \),消息定義為: \[ m_{ij}^{(t)} = \phi \left( h_i^{(t-1)}, h_j^{(t-1)}, e_{ij} \right) \]

  • \( h_i^{(t-1)} \):節點 \( v_i \) 在第 \( t-1 \) 層的嵌入(初始嵌入 \( h_i^{(0)} = x_i \),即節點原始特徵);
  • \( e_{ij} \):邊 \( (i,j) \) 的特徵(無則忽略);
  • \( \phi \):消息函數(如線性變換、MLP),用於轉換特徵。
(2)聚合階段(Aggregation)

節點 \( v_i \) 收集所有鄰居的消息,通過聚合函數彙總為“鄰居特徵表示”: \[ m_i^{(t)} = \psi \left( \{ m_{ij}^{(t)} | j \in \mathcal{N}(i) \} \right) \]

  • \( \mathcal{N}(i) \):節點 \( v_i \) 的鄰居集合;
  • \( \psi \):聚合函數,需滿足置換不變性(即鄰居順序不影響結果,如求和、平均、最大值)。
(3)更新階段(Update)

將自身特徵與聚合後的鄰居特徵融合,得到節點 \( v_i \) 在第 \( t \) 層的新嵌入: \[ h_i^{(t)} = \gamma \left( h_i^{(t-1)}, m_i^{(t)} \right) \]

  • \( \gamma \):更新函數(如殘差連接、MLP),用於融合特徵。

2. 核心設計原則

  • 局部性:節點嵌入僅依賴其局部鄰居(符合圖數據的稀疏性,降低計算複雜度);
  • 迭代性:通過多層GNN,節點可間接聚合更遠距離鄰居的特徵(如2層GNN可感知2-hop鄰居);
  • 置換不變性:聚合函數對鄰居順序不敏感(保證模型對圖的節點編號無關性)。

三、經典GNN模型:從基礎到進階

1. 圖卷積網絡(Graph Convolutional Network, GCN)

(1)核心思想

GCN是最經典的GNN模型(2017年提出),其核心是基於圖的鄰接矩陣,對節點特徵進行加權卷積。它簡化了MPNN的消息傳遞過程,直接通過鄰接矩陣實現鄰居特徵的加權聚合。

(2)數學公式

對於無向無權圖,GCN第 \( t \) 層的傳播公式為: \[ H^{(t)} = \hat{A} \cdot H^{(t-1)} \cdot W^{(t)} \]

  • \( H^{(t)} \in \mathbb{R}^{N \times F^{(t)}} \):第 \( t \) 層所有節點的嵌入矩陣,\( F^{(t)} \) 為當前層特徵維度;
  • \( W^{(t)} \in \mathbb{R}{F{(t-1)} \times F^{(t)}} \):可學習的權重矩陣(類似CNN的卷積核);
  • \( \hat{A} = \tilde{D}^{-1/2} \cdot (A + I) \cdot \tilde{D}^{-1/2} \):歸一化後的鄰接矩陣(關鍵!):
  • \( A \):原始鄰接矩陣(\( A_{ij}=1 \) 表示節點 \( i \) 和 \( j \) 相連,否則為0);
  • \( I \):單位矩陣(添加自環,讓節點聚合自身特徵);
  • \( \tilde{D} \):度矩陣(對角矩陣,\( \tilde{D}{ii} = \sum_j (A + I){ij} \),即節點 \( i \) 的度+1);
  • 歸一化目的:避免因節點度數差異導致的嵌入值過大或梯度消失。
(3)直觀理解

GCN的卷積過程可拆解為3步:

  1. 給每個節點添加自環(保證自身特徵被考慮);
  2. 對每個節點的鄰居特徵按“度的倒數平方根”加權(度數越大的鄰居,權重越小,避免過度影響);
  3. 對加權後的鄰居特徵求和,再通過線性變換(\( W^{(t)} \))更新節點嵌入。
(4)代碼實現(PyTorch Geometric)
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        # 定義兩層GCN
        self.conv1 = GCNConv(in_channels, hidden_channels)  # 第一層:輸入維度→隱藏層維度
        self.conv2 = GCNConv(hidden_channels, out_channels)  # 第二層:隱藏層維度→輸出維度

    def forward(self, x, edge_index):
        # x: 節點特徵矩陣 (N, in_channels)
        # edge_index: 邊索引矩陣 (2, M),PyG中默認存儲方式(第一行是源節點,第二行是目標節點)
        
        # 第一層GCN + 激活函數
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)  #  dropout防止過擬合
        
        # 第二層GCN(輸出層,無激活函數,用於分類/迴歸)
        x = self.conv2(x, edge_index)
        return x

# 示例:初始化模型並前向傳播
model = GCN(in_channels=10, hidden_channels=32, out_channels=2)  # 輸入特徵10維,輸出2分類
x = torch.randn(50, 10)  # 50個節點,每個節點10維特徵
edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]], dtype=torch.long)  # 6條邊
out = model(x, edge_index)
print("輸出節點嵌入形狀:", out.shape)  # (50, 2):每個節點2維輸出(對應2分類)
(5)優缺點
  • 優點:結構簡單、計算高效(利用稀疏矩陣乘法)、效果穩定,是GNN的“入門標配”;
  • 缺點:
  • 僅適用於無向圖(有向圖需修改鄰接矩陣處理);
  • 對節點度數敏感(歸一化雖緩解但未完全解決);
  • 全局拓撲信息捕捉能力有限(依賴層數堆疊,易過擬合)。

2. 圖注意力網絡(Graph Attention Network, GAT)

(1)核心改進:注意力機制替代固定權重

GCN的鄰居聚合權重是基於節點度數的固定值,而GAT引入自注意力機制,讓模型自動學習每個鄰居對當前節點的“重要性權重”。

(2)數學公式
  1. 線性變換:對每個節點特徵進行線性變換,得到中間特徵: \[ h_i' = W \cdot h_i \]
  2. 注意力分數計算:通過拼接(或點積)計算節點 \( i \) 與鄰居 \( j \) 的注意力分數: \[ e_{ij} = \text{LeakyReLU} \left( a^T \cdot [h_i' \parallel h_j'] \right) \]
  • \( a \in \mathbb{R}^{2F'} \):注意力權重向量(\( F' \) 為線性變換後的特徵維度);
  • \( [h_i' \parallel h_j'] \):拼接 \( h_i' \) 和 \( h_j' \),維度為 \( 2F' \);
  • LeakyReLU:激活函數,引入非線性。
  1. 歸一化注意力分數:使用softmax對鄰居的注意力分數歸一化,得到權重: \[ \alpha_{ij} = \text{softmax}j \left( e{ij} \right) = \frac{\exp(e_{ij})}{\sum_{k \in \mathcal{N}(i)} \exp(e_{ik})} \]
  2. 聚合與更新:加權聚合鄰居特徵,得到節點 \( i \) 的新嵌入: \[ h_i^{(t)} = \sigma \left( \sum_{j \in \mathcal{N}(i)} \alpha_{ij} \cdot h_j' \right) \]
  • \( \sigma \):激活函數(如ReLU)。
(3)多頭注意力(Multi-Head Attention)

為了提高模型的表達能力,GAT引入“多頭注意力”:並行計算 \( K \) 組注意力權重,將結果拼接(或平均)作為最終嵌入: \[ h_i^{(t)} = \parallel_{k=1}^K \sigma \left( \sum_{j \in \mathcal{N}(i)} \alpha_{ij}^k \cdot h_j'^k \right) \]

  • \( \alpha_{ij}^k \):第 \( k \) 個頭的注意力權重;
  • \( \parallel \):拼接操作。
(4)代碼實現(PyTorch Geometric)
from torch_geometric.nn import GATConv

class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads):
        super().__init__()
        # 第一層GAT:多頭注意力(heads個注意力頭,輸出維度=hidden_channels*heads)
        self.conv1 = GATConv(in_channels, hidden_channels, heads=heads, dropout=0.5)
        # 第二層GAT:單頭注意力(輸出最終分類維度)
        self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1, dropout=0.5)

    def forward(self, x, edge_index):
        # 第一層:多頭注意力 + ReLU激活
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        # 第二層:單頭注意力(無激活,用於分類)
        x = self.conv2(x, edge_index)
        return x

# 示例:初始化模型(3個注意力頭)
model = GAT(in_channels=10, hidden_channels=16, out_channels=2, heads=3)
x = torch.randn(50, 10)
edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]], dtype=torch.long)
out = model(x, edge_index)
print("輸出節點嵌入形狀:", out.shape)  # (50, 2)
(5)優缺點
  • 優點:
  • 自動學習鄰居重要性,適配不同拓撲結構;
  • 支持有向圖(注意力分數不對稱);
  • 表達能力強,在許多數據集上效果優於GCN;
  • 缺點:計算複雜度高於GCN(注意力分數計算),對超參數(注意力頭數)敏感。

3. 圖SAGE(Graph Sample and Aggregate)

(1)核心改進:解決大規模圖的計算瓶頸

GCN和GAT在處理大規模圖(如百萬級節點)時,需要加載整個鄰接矩陣,導致內存溢出。圖SAGE的核心創新是鄰居採樣:對每個節點,僅採樣部分鄰居進行聚合,而非全部鄰居,從而降低計算和內存開銷。

(2)核心流程
  1. 採樣(Sample):對每個節點 \( v_i \),從其鄰居集合 \( \mathcal{N}(i) \) 中隨機採樣 \( K \) 個鄰居(不足則重複採樣);
  2. 聚合(Aggregate):使用聚合函數(如平均、最大值、LSTM)聚合採樣後的鄰居特徵;
  3. 組合(Combine):融合自身特徵與聚合後的鄰居特徵,得到節點嵌入。
(3)經典聚合函數
  • 平均聚合(Mean Aggregator):\( \text{mean}(h_i^{(t-1)} \cup \{ h_j^{(t-1)} | j \in \text{sample}(\mathcal{N}(i)) \}) \)
  • 池化聚合(Pool Aggregator):\( \max( \text{MLP}(h_j^{(t-1)}) | j \in \text{sample}(\mathcal{N}(i)) ) \)
  • LSTM聚合(LSTM Aggregator):用LSTM對鄰居特徵序列(隨機排序)進行編碼(捕捉鄰居的順序信息,但破壞置換不變性,需謹慎使用)。
(4)代碼實現(PyTorch Geometric)
from torch_geometric.nn import SAGEConv

class GraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        # 第一層GraphSAGE(平均聚合)
        self.conv1 = SAGEConv(in_channels, hidden_channels, aggr="mean")
        # 第二層GraphSAGE(平均聚合)
        self.conv2 = SAGEConv(hidden_channels, out_channels, aggr="mean")

    def forward(self, x, edge_index):
        # 第一層:聚合 + ReLU + Dropout
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)
        # 第二層:輸出
        x = self.conv2(x, edge_index)
        return x

# 示例:初始化模型
model = GraphSAGE(in_channels=10, hidden_channels=32, out_channels=2)
x = torch.randn(1000, 10)  # 1000個節點(大規模圖示例)
edge_index = torch.randint(0, 1000, (2, 5000), dtype=torch.long)  # 5000條邊
out = model(x, edge_index)
print("輸出節點嵌入形狀:", out.shape)  # (1000, 2)
(5)優缺點
  • 優點:
  • 支持大規模圖(採樣降低複雜度);
  • 聚合函數靈活,可適配不同任務;
  • 泛化能力強(不依賴完整圖結構,適合動態圖);
  • 缺點:採樣過程引入隨機性,訓練時需固定隨機種子保證穩定性。

四、GNN的核心應用場景(入門級)

1. 節點分類(Node Classification)

  • 任務目標:預測圖中節點的類別(如社交網絡用户的興趣標籤、論文的研究領域、蛋白質的功能);
  • 典型數據集:Cora(論文引用網絡,7類論文,2708個節點)、Citeseer、PubMed;
  • 實現邏輯:用GCN/GAT/GraphSAGE學習節點嵌入,再接入全連接層+Softmax進行分類。

2. 鏈路預測(Link Prediction)

  • 任務目標:預測圖中缺失的邊或未來可能出現的邊(如社交網絡好友推薦、知識圖譜補全、蛋白質相互作用預測);
  • 核心思路:對節點對 \( (v_i, v_j) \),將其嵌入向量通過拼接/點積/差運算得到邊特徵,再用二分類模型(如MLP)預測邊是否存在;
  • 示例:好友推薦 = 預測用户節點間是否存在“朋友”邊。

3. 圖分類(Graph Classification)

  • 任務目標:預測整個圖的類別(如分子的毒性、化合物的屬性、社交網絡的主題);
  • 核心思路:先通過GNN學習每個節點的嵌入,再通過“圖池化”(Graph Pooling)將節點嵌入聚合為圖級嵌入(如求和、平均、Top-K池化),最後用全連接層分類;
  • 典型數據集:TUDataset(包含分子圖、社交網絡圖等)。

五、總結與後續預告

本文作為GNN系列第一篇,重點講解了:

  1. 圖數據的基礎概念與特點(非歐幾里得結構、關聯依賴);
  2. GNN的核心原理(消息傳遞框架:消息傳遞→聚合→更新);
  3. 三大經典模型(GCN、GAT、GraphSAGE)的原理、代碼實現與優缺點;
  4. 入門級應用場景(節點分類、鏈路預測、圖分類)。

下一篇預告:《詳解圖神經網絡:方法與應用(二)》

  • 進階模型:GIN(圖同構網絡)、GNN+Transformer、異構圖神經網絡(HGNN);
  • 關鍵技術:圖池化、注意力機制優化、動態圖處理;
  • 訓練技巧:正則化、超參數調優、大規模圖訓練框架(如PyTorch Geometric分佈式訓練)。

如果需要深入某部分內容(如特定模型的數學推導、某應用的完整實戰代碼),可以隨時告訴我!