Qwen3 模型用於因果語言建模(Causal Language Modeling, CLM)的主類 Qwen3ForCausalLM,它是整個大模型在推理和訓練階段的核心接口。
🧱 1. 類定義
@auto_docstring
class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin):
繼承關係説明:
|
基類
|
功能
|
|
|
提供權重初始化、配置加載、HuggingFace 集成支持
|
|
|
提供生成能力: |
✅ 這意味着該模型可以直接調用
.generate()來進行文本生成!
🔗 類屬性(關鍵元數據)
(1) _tied_weights_keys = ["lm_head.weight"]
- 表示
lm_head.weight和詞嵌入層model.embed_tokens.weight共享權重(weight tying) - 即:
self.lm_head.weight = self.model.embed_tokens.weight
- 優點:
- 減少參數量;
- 提升語言建模性能(輸入/輸出語義對齊更好);
- 是標準做法(GPT、BERT 等也這麼做)。
⚠️ 注意:只有當
hidden_size == vocab_size的因數時才合理,但現代模型常直接 tie。
(2) _tp_plan = {"lm_head": "colwise_rep"}
- TP = Tensor Parallelism(張量並行)
colwise_rep: 表示lm_head層在列切分時需要“複製”而非分割 —— 可能是為了避免 all-gather?- 實際含義依賴具體分佈式訓練框架(如 DeepSpeed、ColossalAI、vLLM)。
- 通常
lm_head是(d_model, vocab_size),若vocab_size很大,需特殊處理。
📌 目的是指導模型如何在多 GPU 上拆分 lm_head。
(3) _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
- PP = Pipeline Parallelism(流水線並行)
- 定義了
lm_head模塊的輸入輸出邊界:
- 輸入:
hidden_states - 輸出:
logits
- 用於構建 pipeline stages,告訴系統哪部分屬於前一 stage,哪部分屬於後一 stage。
✅ 這些
_tp_plan,_pp_plan是為 大規模分佈式訓練/推理優化 設計的元信息。
⚙️ 構造函數 __init__
def __init__(self, config):
super().__init__(config)
self.model = Qwen3Model(config) # 主幹 Transformer
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init() # 初始化權重 + 後處理
關鍵組件:
|
組件
|
作用
|
|
|
包含所有 |
|
|
解碼頭:將最後一層 hidden state 映射到詞彙表維度的 logits
|
|
|
因為一般 tied weight 後 bias 不必要,且容易出錯
|
post_init() 做什麼?
- 調用父類定義的權重初始化策略(如正態分佈初始化);
- 可能應用特殊初始化規則到
lm_head; - 是 HuggingFace 標準流程的一部分。
📤 前向傳播 forward
@can_return_tuple
@auto_docstring
def forward(
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
) -> CausalLMOutputWithPast:
參數詳解
|
參數
|
用途
|
|
|
token ID 輸入 |
|
|
替代 input_ids 的嵌入表示(兩者互斥)
|
|
|
防止 padding 或未來 token 被關注
|
|
|
顯式位置索引(配合 RoPE 使用)
|
|
|
KV Cache,用於緩存歷史 K/V,加速自迴歸生成
|
|
|
是否啓用 KV 緩存(推理時設為 True)
|
|
|
訓練標籤,用於計算 loss(shifted right)
|
|
|
控制只計算最後幾個 token 的 logits(節省顯存)
|
|
|
當前 token 在緩存中的位置(用於增量解碼)
|
💡 支持
input_ids和inputs_embeds二選一,靈活性高。
🔁 數據流詳解
Step 1: 主幹模型前向傳播
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
- 返回類型:
BaseModelOutputWithPast - 包含:
last_hidden_state: 最終隱藏狀態[B, S, D]past_key_values: 更新後的 KV Cache(如果use_cache=True)hidden_states,attentions: 可選中間輸出
Step 2: 提取最後隱藏層 & 計算 logits
hidden_states = outputs.last_hidden_state
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
關鍵技巧:Partial Logits Computation
logits_to_keep=0→ 默認不提前算任何 logits?可能是 lazy 計算設計。- 若
logits_to_keep=5→ 只計算最後 5 個 token 的輸出 logits。 - 若傳入 tensor → 自定義哪些位置要計算。
✅ 目的:大幅減少顯存佔用,尤其在長序列生成或批處理時。
🌟 這是一種高級優化技術,在 vLLM、FlashAttention 中也有類似思想。
Step 3: 損失計算(僅訓練時)
loss = None
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**kwargs
)
loss_function通常是交叉熵損失(CrossEntropyLoss),但做了封裝以支持:
- label smoothing
- ignore_index=-100
- 分佈式訓練下的 loss reduce
- 注意:
labels是原始input_ids的右移版本(因果語言模型標準做法)。
Step 4: 返回結果
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
這是 HuggingFace 標準輸出結構,包含:
|
字段
|
用途
|
|
|
標量損失值(訓練用)
|
|
|
歸一化前的詞彙分數 |
|
|
KV Cache,用於下一時刻生成
|
|
|
分析中間行為(可選)
|
🔄 整體架構圖示
Input (input_ids)
│
↓
Qwen3Model (Transformer stack)
│
↓
last_hidden_state [B, S, D]
│
↓
lm_head (Linear): [B, S, D] → [B, S, V]
│
├───→ logits ───┐
│ ↓
│ (optional) loss ← labels
↓
CausalLMOutputWithPast(loss, logits, past_key_values, ...)