Qwen3 模型用於因果語言建模(Causal Language Modeling, CLM)的主類 Qwen3ForCausalLM,它是整個大模型在推理和訓練階段的核心接口。

🧱 1. 類定義

@auto_docstring
class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin):

繼承關係説明:

基類

功能

Qwen3PreTrainedModel

提供權重初始化、配置加載、HuggingFace 集成支持

GenerationMixin

提供生成能力:.generate() 方法,支持 greedy search、beam search、sampling 等

✅ 這意味着該模型可以直接調用 .generate() 來進行文本生成!

🔗 類屬性(關鍵元數據)

(1) _tied_weights_keys = ["lm_head.weight"]

  • 表示 lm_head.weight 和詞嵌入層 model.embed_tokens.weight 共享權重(weight tying)
  • 即:
self.lm_head.weight = self.model.embed_tokens.weight
  • 優點:
  • 減少參數量;
  • 提升語言建模性能(輸入/輸出語義對齊更好);
  • 是標準做法(GPT、BERT 等也這麼做)。

⚠️ 注意:只有當 hidden_size == vocab_size 的因數時才合理,但現代模型常直接 tie。


(2) _tp_plan = {"lm_head": "colwise_rep"}

  • TP = Tensor Parallelism(張量並行)
  • colwise_rep: 表示 lm_head 層在列切分時需要“複製”而非分割 —— 可能是為了避免 all-gather?
  • 實際含義依賴具體分佈式訓練框架(如 DeepSpeed、ColossalAI、vLLM)。
  • 通常 lm_head(d_model, vocab_size),若 vocab_size 很大,需特殊處理。

📌 目的是指導模型如何在多 GPU 上拆分 lm_head


(3) _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

  • PP = Pipeline Parallelism(流水線並行)
  • 定義了 lm_head 模塊的輸入輸出邊界:
  • 輸入:hidden_states
  • 輸出:logits
  • 用於構建 pipeline stages,告訴系統哪部分屬於前一 stage,哪部分屬於後一 stage。

✅ 這些 _tp_plan, _pp_plan 是為 大規模分佈式訓練/推理優化 設計的元信息。


⚙️ 構造函數 __init__

def __init__(self, config):
    super().__init__(config)
    self.model = Qwen3Model(config)                  # 主幹 Transformer
    self.vocab_size = config.vocab_size
    self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

    self.post_init()  # 初始化權重 + 後處理

關鍵組件:

組件

作用

self.model

包含所有 Qwen3DecoderLayer 的主幹網絡(即 Transformer 主體)

self.lm_head

解碼頭:將最後一層 hidden state 映射到詞彙表維度的 logits

bias=False

因為一般 tied weight 後 bias 不必要,且容易出錯

post_init() 做什麼?

  • 調用父類定義的權重初始化策略(如正態分佈初始化);
  • 可能應用特殊初始化規則到 lm_head
  • 是 HuggingFace 標準流程的一部分。

📤 前向傳播 forward

@can_return_tuple
@auto_docstring
def forward(
    input_ids: Optional[torch.LongTensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_values: Optional[Cache] = None,
    inputs_embeds: Optional[torch.FloatTensor] = None,
    labels: Optional[torch.LongTensor] = None,
    use_cache: Optional[bool] = None,
    cache_position: Optional[torch.LongTensor] = None,
    logits_to_keep: Union[int, torch.Tensor] = 0,
    **kwargs,
) -> CausalLMOutputWithPast:

參數詳解

參數

用途

input_ids

token ID 輸入 [B, S]

inputs_embeds

替代 input_ids 的嵌入表示(兩者互斥)

attention_mask

防止 padding 或未來 token 被關注

position_ids

顯式位置索引(配合 RoPE 使用)

past_key_values

KV Cache,用於緩存歷史 K/V,加速自迴歸生成

use_cache

是否啓用 KV 緩存(推理時設為 True)

labels

訓練標籤,用於計算 loss(shifted right)

logits_to_keep

控制只計算最後幾個 token 的 logits(節省顯存)

cache_position

當前 token 在緩存中的位置(用於增量解碼)

💡 支持 input_idsinputs_embeds 二選一,靈活性高。


🔁 數據流詳解

Step 1: 主幹模型前向傳播

outputs = self.model(
    input_ids=input_ids,
    attention_mask=attention_mask,
    position_ids=position_ids,
    past_key_values=past_key_values,
    inputs_embeds=inputs_embeds,
    use_cache=use_cache,
    cache_position=cache_position,
    **kwargs,
)
  • 返回類型:BaseModelOutputWithPast
  • 包含:
  • last_hidden_state: 最終隱藏狀態 [B, S, D]
  • past_key_values: 更新後的 KV Cache(如果 use_cache=True
  • hidden_states, attentions: 可選中間輸出

Step 2: 提取最後隱藏層 & 計算 logits

hidden_states = outputs.last_hidden_state
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
關鍵技巧:Partial Logits Computation
  • logits_to_keep=0 → 默認不提前算任何 logits?可能是 lazy 計算設計。
  • logits_to_keep=5 → 只計算最後 5 個 token 的輸出 logits。
  • 若傳入 tensor → 自定義哪些位置要計算。

目的:大幅減少顯存佔用,尤其在長序列生成或批處理時。

🌟 這是一種高級優化技術,在 vLLM、FlashAttention 中也有類似思想。


Step 3: 損失計算(僅訓練時)

loss = None
if labels is not None:
    loss = self.loss_function(
        logits=logits,
        labels=labels,
        vocab_size=self.config.vocab_size,
        **kwargs
    )
  • loss_function 通常是交叉熵損失(CrossEntropyLoss),但做了封裝以支持:
  • label smoothing
  • ignore_index=-100
  • 分佈式訓練下的 loss reduce
  • 注意:labels 是原始 input_ids 的右移版本(因果語言模型標準做法)。

Step 4: 返回結果

return CausalLMOutputWithPast(
    loss=loss,
    logits=logits,
    past_key_values=outputs.past_key_values,
    hidden_states=outputs.hidden_states,
    attentions=outputs.attentions,
)

這是 HuggingFace 標準輸出結構,包含:

字段

用途

loss

標量損失值(訓練用)

logits

歸一化前的詞彙分數 [B, S', V](S’ 取決於 logits_to_keep

past_key_values

KV Cache,用於下一時刻生成

hidden_states / attentions

分析中間行為(可選)


🔄 整體架構圖示

Input (input_ids)
   │
   ↓
Qwen3Model (Transformer stack)
   │
   ↓
last_hidden_state [B, S, D]
   │
   ↓
lm_head (Linear): [B, S, D] → [B, S, V]
   │
   ├───→ logits ───┐
   │               ↓
   │           (optional) loss ← labels
   ↓
CausalLMOutputWithPast(loss, logits, past_key_values, ...)