文章目錄
- 模塊結構與功能解析
- 分佈式處理機制
- 關鍵計算流程
- 設計意義分析
class MoE(nn.Module):
"""
Mixture-of-Experts (MoE) module.
Attributes:
dim (int): Dimensionality of input features.
n_routed_experts (int): Total number of experts in the model.
n_local_experts (int): Number of experts handled locally in distributed systems.
n_activated_experts (int): Number of experts activated for each input.
gate (nn.Module): Gating mechanism to route inputs to experts.
experts (nn.ModuleList): List of expert modules.
shared_experts (nn.Module): Shared experts applied to all inputs.
"""
def __init__(self, args: ModelArgs):
"""
Initializes the MoE module.
Args:
args (ModelArgs): Model arguments containing MoE parameters.
"""
super().__init__()
self.dim = args.dim
assert args.n_routed_experts % world_size == 0, f"Number of experts must be divisible by world size (world_size={world_size})"
self.n_routed_experts = args.n_routed_experts
self.n_local_experts = args.n_routed_experts // world_size
self.n_activated_experts = args.n_activated_experts
self.experts_start_idx = rank * self.n_local_experts
self.experts_end_idx = self.experts_start_idx + self.n_local_experts
self.gate = Gate(args)
self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else None
for i in range(self.n_routed_experts)])
self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the MoE module.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor after expert routing and computation.
"""
shape = x.size()
x = x.view(-1, self.dim)
weights, indices = self.gate(x)
y = torch.zeros_like(x)
counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()
for i in range(self.experts_start_idx, self.experts_end_idx):
if counts[i] == 0:
continue
expert = self.experts[i]
idx, top = torch.where(indices == i)
y[idx] += expert(x[idx]) * weights[idx, top, None]
z = self.shared_experts(x)
if world_size > 1:
dist.all_reduce(y)
return (y + z).view(shape)
模塊結構與功能解析
Gate模塊
負責輸入的路由決策,計算每個輸入應分配給哪些專家的權重和索引。輸出包含兩個部分:權重矩陣(表示每個專家對輸入的重要性)和索引矩陣(指示哪些專家被激活)。
Experts模塊
由多個獨立的專家網絡組成,每個專家是一個MLP結構。這些專家專門處理特定類型的輸入模式,通過門控機制動態選擇部分專家參與計算。n_local_experts表示當前設備負責的專家數量,實現分佈式計算負載均衡。
Shared Experts模塊
全局共享的MLP網絡,對所有輸入進行處理。其作用是提供基礎特徵變換,與路由專家形成互補。設計目的是保證即使某些輸入未被任何路由專家處理,仍能通過共享專家獲得基本特徵更新。
跨設備通信
使用dist.all_reduce同步所有設備的專家計算結果。該操作對分佈式環境中各GPU計算的專家輸出進行求和聚合,保證最終結果的完整性。通信僅發生在專家計算階段,共享專家部分無需同步。
關鍵計算流程
門控計算階段
輸入張量展平後通過門控網絡,得到形狀為(batch*n_tokens, n_activated_experts)的權重和索引。n_activated_experts控制每個輸入激活的專家數量,實現計算稀疏性。
專家處理階段
通過torch.bincount統計各專家的負載情況,僅處理當前設備負責的非空專家。使用torch.where篩選屬於特定專家的輸入,進行加權計算。未激活的專家跳過計算以提升效率。
結果融合階段
路由專家輸出與共享專家輸出相加,保留原始輸入形狀。這種殘差連接設計確保:1) 路由專家捕獲特異性特徵 2) 共享專家提供基礎特徵 3) 兩者互補增強模型容量。
設計意義分析
計算效率優化
通過n_activated_experts實現條件計算,典型設置如激活2-4個專家(總專家數64+),使模型參數量增長而計算量基本不變。門控網絡採用輕量級設計,計算開銷遠小於專家前向傳播。
分佈式擴展性
專家均勻分佈和all_reduce通信模式支持線性擴展。當增加GPU數量時,每個設備管理的專家數量同比減少,通信量保持不變(始終需要全局聚合)。
模型容量提升
MoE結構突破傳統DNN的固定計算路徑限制,使模型總參數量可突破萬億規模(如Switch Transformer),同時保持實際計算量在合理範圍內。共享專家提供保底性能,避免路由失敗導致的性能下降。