讀完這篇文章,你將用監督微調(SFT)把一個 1.5B 規模的數學模型在 GSM8K 上的零樣本推理正確率從 1.56% → 62.9%,同時把輸出格式遵循率從 18.9% → 100%。我們將完整走通數據集下載、Prompt 架構、訓練配置和評估方法,所有代碼均來自本倉庫 alignment 文件夾,保證可復現與透明。

本文將深入剖析 llm-from-scratch 倉庫中 alignment 模塊,展示 SFT 的完整流程。

引言

大語言推理模型常見的兩個痛點:一是“答不對”,二是“答不規範”。前者意味着推理鏈條斷裂或遷移失效,後者則讓後續評估與系統對接步步驚心。

在這篇 How‑to 指南里,我們用 GSM8K 的標準數據與明確的評估規則,讓 Qwen/Qwen2.5-Math-1.5B 通過 SFT 後,能力遷移到不同的數據集上, 學會“思考→作答”的結構化輸出:不僅更常答對,也更懂規範。結果來自實測,正確率由 1.56% 提升到 62.9%,格式遵循由 18.9% 提升到 100%.


為什麼零樣本推理和“格式遵循”都很難?

  • 零樣本推理難在“遷移”:模型雖見過大量文本,但缺少對“分步算術、單位處理、等式化簡”的系統化經驗,容易在多步推理或細節規範上出錯。
  • 格式遵循難在“約束學習”:即使模型知道答案,若不按系統約定的輸出協議(例如必須有 <think>...</think> <answer>...</answer>),評估與下游解析都會失敗。
  • 兩者耦合:不遵循格式會直接“判零分”,即使答案正確;而缺少分步推理(think)又會影響最終答案(answer)的穩定性。

一個好比喻:把模型想象成一個聰明但散漫的學生。零樣本時,它能“蒙”對少數題;SFT 就像班主任的“規範化帶教”,教它先寫草稿(<think>)再交最終答案(<answer>),且必須按卷面格式來——這樣既能提高質量,也能讓閲卷更可靠。


核心概念與評估口徑

  • 監督微調(SFT):用標註樣本(這裏是 GSM8K 的“問題 + 推理過程 + 最終答案”)對模型做下一 token 預測訓練。通過標籤掩碼,只對“推理與答案”部分計算損失,指導模型輸出完整的思維鏈與答案。
  • Prompt 架構:我們採用 R1 風格模板,強制輸出 <think><answer> 標籤,保證格式可解析:
    • <think> ... </think>:推理過程
    • <answer> ... </answer>:最終答案
  • 訓練目標(Loss Target):標準自迴歸語言模型(next-token LM)損失,但只在“迴應(think+answer)”區間計算,避免模型學習到重複的“系統提示與問題”的 token 序列。
  • 評估指標:
    • “推理準確率”:按每條樣本的 答案是否正確(數學同值、LaTeX 等價、數值等價,詳見 grader)計分 0/1,最後取平均。
    • “格式遵循率”:按每條樣本 是否包含合法的 <think>...</think> <answer>...</answer> 標籤計分 0/1,最後取平均。

我們在 alignment/drgrpo_grader.py 中實現了嚴格的格式檢查與寬容但可靠的數學等價判斷(符號化、數值化與 LaTeX 解析的組合),是本文準確率與格式遵循的核心評估邏輯。


方案與架構

本文的流水線架構如下:從 GSM8K JSONL 到 Prompt 構造,再到 SFT 訓練與 vLLM 推理評估,最後彙總指標(準確率/格式遵循)。

flowchart LR
  A["GSM8K JSONL 數據集"] --> B["R1PromptTemplate 構造 Prompt"]
  B --> C["SFT 訓練 (next-token LM)"]
  C --> D["保存 Checkpoint"]
  D --> E["vLLM 推理評估"]
  E --> F["r1_zero_reward_fn 計算格式與正確性"]
  F --> G["輸出指標: 準確率 & 格式遵循"]

關鍵路徑對應的源碼均在 alignment 目錄下:dataset.py(數據加載)、r1_prompt.py(模板)、sft.py(訓練)、evaluate.py(評估)、drgrpo_grader.py(格式與答案打分)。


可復現配置與注意事項

  • 模型與精度:默認使用 Qwen/Qwen2.5-Math-1.5B,dtype 默認 bfloat16(見 alignment/args.py)。
  • 設備與顯存:alignment/sft.py 通過 accelerate 的 infer_auto_device_map 自動切分模型到多卡(--sft_device),並設置 --max_sft_gpu_memory_use(默認 31GiB/卡)。評估使用 vLLM 獨立進程(--eval_device)。
  • 隨機性:--seed 默認 42。不同環境(驅動、庫版本、顯存壓力)與隨機種子可能導致輕微波動。
  • 批次與累積:--batch_size--gradient_accumulation_steps 控制有效批次大小,訓練日誌會打印累計後的損失。
  • 提示模板:alignment/prompts/r1_zero.prompt 強制 <think>/<answer> 標籤,保證格式可檢。

實踐步驟:從數據到評估與訓練

  1. 下載 GSM8K(train/test)到本地 data 目錄(可根據你的項目目錄調整):
cd dataset
wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/train.jsonl
wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl

建議將文件移動/複製為倉庫約定路徑:data/gsm8k/train.jsonldata/gsm8k/test.jsonl(alignment/args.py 默認如此)。

  1. 微調前的零樣本評估(使用 vLLM + R1 模板 + 嚴格格式與答案打分):
uv run -m alignment.evaluate
  1. 執行 SFT 訓練並在測試集上評估(訓練中每個 epoch 結束都會評估並保存 checkpoint):
uv run -m alignment.sft

代碼講解:數據加載與樣本構造(包含 <think>/<answer> 標籤)

首先用 R1 模板將問題轉為 Prompt,將“推理過程 + 最終答案”打包為監督信號。alignment/dataset.py 與 r1_prompt.py 如下:

# alignment/dataset.py
from torch.utils.data import Dataset
import json
from .r1_prompt import R1PromptTemplate

class Gsm8kDataset(Dataset):
    def __init__(self, data_path: str, promt_template_path: str):
        template = R1PromptTemplate(promt_template_path)
        self.data = []
        self.label = []
        self.ground_truth = []
        with open(data_path, "r") as f:
            lines = f.readlines()
        for line in lines:
            qa = json.loads(line)
            question = qa["question"]
            answer_think = qa["answer"]
            think, answer = answer_think.split("####")
            think, answer = think.strip(), answer.strip()
            self.data.append(template.gen_prompt(question))
            self.label.append(template.gen_response(think, answer))
            self.ground_truth.append(answer)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.label[idx], self.ground_truth[idx]
# alignment/r1_prompt.py(關鍵:強制 <think>/<answer> 標籤)
class R1PromptTemplate:
    def __init__(self, template_path: os.PathLike):
        with open(template_path, "r") as f:
            self.template = f.read().strip()

    def gen_prompt(self, question: str) -> str:
        return self.template.replace(r"{question}", question)

    def gen_response(self, think: str, answer: str) -> str:
        return think + "</think>" + " <answer>" + answer + " </answer>"

模板文件 alignment/prompts/r1_zero.prompt:

A conversation between User and Assistant... <think> reasoning process here </think> <answer> answer here </answer>.
User: {question}
Assistant: <think>

這保證了訓練時模型學習“先寫 <think> 再寫 <answer>”,評估時也能穩定抽取並驗證答案。


代碼講解:SFT 訓練循環與標籤掩碼(只對迴應部分計算損失)

在 alignment/sft.py,我們將 prompt + completion 拼接,對 prompt 區間的 label 置為 -100,從而讓交叉熵只在“迴應(think + answer)”上回傳梯度:

# alignment/sft.py(節選)
model.train()
for epoch in range(args.epochs):
    for i, batch in enumerate(train_data_loader):
        prompts, completions, _ = batch

        full_texts = [p + c + tokenizer.eos_token for p, c in zip(prompts, completions)]
        inputs = tokenizer(
            full_texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=args.max_seq_len,
        ).to(model.device)

        # 計算每個樣本的 prompt token 長度
        prompt_tokens = tokenizer(list(prompts), add_special_tokens=False)
        prompt_lengths = [len(ids) for ids in prompt_tokens.input_ids]

        labels = inputs.input_ids.clone()
        for idx in range(len(prompts)):
            prompt_len = prompt_lengths[idx]
            labels[idx, :prompt_len] = -100  # 對 prompt 部分不計 loss

        # Mask padding tokens
        labels[labels == tokenizer.pad_token_id] = -100

        outputs = model(
            input_ids=inputs.input_ids,
            attention_mask=inputs.attention_mask,
            labels=labels,
        )
        loss = outputs.loss
        loss = loss / args.gradient_accumulation_steps
        loss.backward()

        if (i + 1) % args.gradient_accumulation_steps == 0:
            print(f"Epoch {epoch}, Iteration {i}, Loss: {loss.item() * args.gradient_accumulation_steps}")
            optimizer.step()
            optimizer.zero_grad()

    # 評估:將策略權重加載到 vLLM 並在測試集上跑評估
    load_policy_into_vllm_instance(model, eval_model)
    evaluate_math(
        eval_model,
        args.prompt_template_path,
        args.sft_test_data,
        args.batch_size,
        log_sample=True,
    )
    model.save_pretrained(args.checkpoint_path)

這段邏輯實現了標準的 SFT:用“迴應”作為監督目標,配合累積梯度與分卡策略,保證 1.5B 規模模型在可控顯存與吞吐下完成訓練。


代碼講解:評估與指標計算(準確率與格式遵循)

評估由 alignment/evaluate.py 驅動,調用 r1_zero_reward_fn 計算每條樣本的格式與答案得分:

# alignment/evaluate.py(核心)
from collections.abc import Callable
from vllm import LLM, SamplingParams
from .drgrpo_grader import r1_zero_reward_fn

def evaluate_vllm(vllm_model: LLM,
                  reward_fn: Callable[[str, str], dict[str, float]],
                  prompts: list[str],
                  ground_truths: list[str],
                  eval_sampling_params: SamplingParams,
                  log_sample: bool) -> dict:
    outputs = vllm_model.generate(prompts, eval_sampling_params)
    generated_texts = [output.outputs[0].text for output in outputs]
    rewards = [reward_fn(generated_text, ground_truth)
               for generated_text, ground_truth in zip(generated_texts, ground_truths)]

    avg_format_rewards = sum([r["format_reward"] for r in rewards]) / len(prompts)
    avg_answer_rewards = sum([r["answer_reward"] for r in rewards]) / len(prompts)
    avg_all_rewards = sum([r["reward"] for r in rewards]) / len(prompts)
    print(f"avg_format_rewards: {avg_format_rewards}")
    print(f"avg_answer_rewards: {avg_answer_rewards}")
    print(f"avg_all_rewards: {avg_all_rewards}")
    return {
        "avg_format_rewards": avg_format_rewards,
        "avg_answer_rewards": avg_answer_rewards,
        "avg_all_rewards": avg_all_rewards,
    }

獎勵函數嚴格要求格式標籤,並對答案進行多重等價校驗:

# alignment/drgrpo_grader.py(節選)
def r1_zero_reward_fn(response, ground_truth, fast=True):
    # 格式嚴格:必須包含 </think> <answer> ... </answer>
    if "</think> <answer>" in response and "</answer>" in response:
        model_answer = response.split("<answer>")[-1].replace("</answer>", "")
        # 允許 \boxed{...} 的答案形式並抽取
        if "\\boxed" in model_answer:
            model_answer = extract_answer(model_answer)
            if model_answer is None:
                return {"format_reward": 1.0, "answer_reward": 0.0, "reward": 0.0}
        # 字符串、數值、LaTeX、SymPy 等價綜合判斷
        is_correct = grade(model_answer, str(ground_truth), fast)
        if is_correct:
            return {"format_reward": 1.0, "answer_reward": 1.0, "reward": 1.0}
        else:
            return {"format_reward": 1.0, "answer_reward": 0.0, "reward": 0.0}
    else:
        # 不符合格式:直接判 0 分(同時 answer_reward 也為 0)
        return {"format_reward": 0.0, "answer_reward": 0.0, "reward": 0.0}

因此:

  • 格式遵循率 = 所有樣本的 format_reward 平均值。
  • 推理準確率 = 所有樣本的 answer_reward 平均值(在格式正確的前提下計算答案是否等價)。

數據讀取與 Ground Truth 解析(alignment/evaluate.py):

# alignment/evaluate.py(節選)
def get_gsm8k_test_data(test_data_path: os.PathLike) -> list[dict]:
    data = []
    with open(test_data_path, "r") as f:
        for line in f.readlines():
            obj = json.loads(line.strip())
            ts = obj["answer"].split("####")
            if len(ts) != 2:
                print(f"invalid answer: {obj['answer']}")
                continue
            data.append({"question": obj["question"],
                         "think": ts[0].strip(),
                         "answer": ts[1].strip()})
    return data

結果與討論

在同一評估口徑下,我們在 GSM8K 上得到如下對比(報告值):

指標 微調前(零樣本) SFT 微調後
推理準確率 1.56% 62.9%
格式遵循率 18.9% 100%

説明:

  • 上述結果是在 alignment.evaluate 與 alignment.sft 的評估框架下得到的報告值。不同硬件/軟件環境與隨機種子可能引起輕微差異。
  • “格式遵循率”達到 100% 的關鍵是訓練時的模板約束與監督信號覆蓋 <think>/<answer>,並在獎勵函數中嚴格判定標籤存在性。

常見問題與優化建議

  • 訓練不收斂或損失震盪:適當降低學習率(如 1e‑5 → 5e‑6)、增大 gradient_accumulation_steps,或提高 max_seq_len 以覆蓋完整推理鏈。
  • 顯存不足:增加 --sft_device 的卡數或調大 --max_sft_gpu_memory_use,必要時裁剪序列長度或使用更小 batch。
  • 評估速度慢:調小 SamplingParamsmax_tokens;但過小的生成長度會影響完整 <think>/<answer> 輸出與正確率。
  • 格式仍偶發不合規:檢查模板與數據是否一致(是否始終以 <think> 開頭),並確保 SFT 標籤覆蓋足量樣本。

結論

通過對 Qwen/Qwen2.5-Math-1.5BGSM8K 上的監督微調,我們實現了兩條主線的同步提升:

  • 推理鏈條更穩,**零樣本正確率從 1.56% 提升到 62.9%**;
  • 輸出規範更強,**格式遵循率從 18.9% 提升到 100%**。

這背後的關鍵是“結構化輸出模板 + 只對迴應部分計損失 + 嚴格但寬容的評估規則”。如果你也在做數學推理或其他需要“思考‑答案二段式”輸出的任務,強烈建議複用本文的架構與代碼。

行動號召:現在就下載 GSM8K,跑一遍 uv run -m alignment.evaluateuv run -m alignment.sft,觀察你本地的改進幅度吧! 可以參考 llm-from-scratch 倉庫中 alignment 模塊,對照進行學習。

開放問題:在你的場景裏,是否遇到過“答案正確但格式導致系統解析失敗”的案例?你是如何設計模板與評估邏輯來避免它的?歡迎在評論區分享你的經驗與挑戰。