讀完這篇文章,你將用監督微調(SFT)把一個 1.5B 規模的數學模型在 GSM8K 上的零樣本推理正確率從 1.56% → 62.9%,同時把輸出格式遵循率從 18.9% → 100%。我們將完整走通數據集下載、Prompt 架構、訓練配置和評估方法,所有代碼均來自本倉庫 alignment 文件夾,保證可復現與透明。
本文將深入剖析 llm-from-scratch 倉庫中 alignment 模塊,展示 SFT 的完整流程。
引言
大語言推理模型常見的兩個痛點:一是“答不對”,二是“答不規範”。前者意味着推理鏈條斷裂或遷移失效,後者則讓後續評估與系統對接步步驚心。
在這篇 How‑to 指南里,我們用 GSM8K 的標準數據與明確的評估規則,讓 Qwen/Qwen2.5-Math-1.5B 通過 SFT 後,能力遷移到不同的數據集上, 學會“思考→作答”的結構化輸出:不僅更常答對,也更懂規範。結果來自實測,正確率由 1.56% 提升到 62.9%,格式遵循由 18.9% 提升到 100%.
為什麼零樣本推理和“格式遵循”都很難?
- 零樣本推理難在“遷移”:模型雖見過大量文本,但缺少對“分步算術、單位處理、等式化簡”的系統化經驗,容易在多步推理或細節規範上出錯。
- 格式遵循難在“約束學習”:即使模型知道答案,若不按系統約定的輸出協議(例如必須有
<think>...</think> <answer>...</answer>),評估與下游解析都會失敗。 - 兩者耦合:不遵循格式會直接“判零分”,即使答案正確;而缺少分步推理(think)又會影響最終答案(answer)的穩定性。
一個好比喻:把模型想象成一個聰明但散漫的學生。零樣本時,它能“蒙”對少數題;SFT 就像班主任的“規範化帶教”,教它先寫草稿(<think>)再交最終答案(<answer>),且必須按卷面格式來——這樣既能提高質量,也能讓閲卷更可靠。
核心概念與評估口徑
- 監督微調(SFT):用標註樣本(這裏是 GSM8K 的“問題 + 推理過程 + 最終答案”)對模型做下一 token 預測訓練。通過標籤掩碼,只對“推理與答案”部分計算損失,指導模型輸出完整的思維鏈與答案。
- Prompt 架構:我們採用 R1 風格模板,強制輸出
<think>和<answer>標籤,保證格式可解析:<think> ... </think>:推理過程<answer> ... </answer>:最終答案
- 訓練目標(Loss Target):標準自迴歸語言模型(next-token LM)損失,但只在“迴應(think+answer)”區間計算,避免模型學習到重複的“系統提示與問題”的 token 序列。
- 評估指標:
- “推理準確率”:按每條樣本的 答案是否正確(數學同值、LaTeX 等價、數值等價,詳見 grader)計分 0/1,最後取平均。
- “格式遵循率”:按每條樣本 是否包含合法的
<think>...</think> <answer>...</answer>標籤計分 0/1,最後取平均。
我們在 alignment/drgrpo_grader.py 中實現了嚴格的格式檢查與寬容但可靠的數學等價判斷(符號化、數值化與 LaTeX 解析的組合),是本文準確率與格式遵循的核心評估邏輯。
方案與架構
本文的流水線架構如下:從 GSM8K JSONL 到 Prompt 構造,再到 SFT 訓練與 vLLM 推理評估,最後彙總指標(準確率/格式遵循)。
flowchart LR
A["GSM8K JSONL 數據集"] --> B["R1PromptTemplate 構造 Prompt"]
B --> C["SFT 訓練 (next-token LM)"]
C --> D["保存 Checkpoint"]
D --> E["vLLM 推理評估"]
E --> F["r1_zero_reward_fn 計算格式與正確性"]
F --> G["輸出指標: 準確率 & 格式遵循"]
關鍵路徑對應的源碼均在 alignment 目錄下:dataset.py(數據加載)、r1_prompt.py(模板)、sft.py(訓練)、evaluate.py(評估)、drgrpo_grader.py(格式與答案打分)。
可復現配置與注意事項
- 模型與精度:默認使用 Qwen/Qwen2.5-Math-1.5B,dtype 默認 bfloat16(見 alignment/args.py)。
- 設備與顯存:alignment/sft.py 通過 accelerate 的
infer_auto_device_map自動切分模型到多卡(--sft_device),並設置--max_sft_gpu_memory_use(默認 31GiB/卡)。評估使用 vLLM 獨立進程(--eval_device)。 - 隨機性:
--seed默認 42。不同環境(驅動、庫版本、顯存壓力)與隨機種子可能導致輕微波動。 - 批次與累積:
--batch_size與--gradient_accumulation_steps控制有效批次大小,訓練日誌會打印累計後的損失。 - 提示模板:
alignment/prompts/r1_zero.prompt強制<think>/<answer>標籤,保證格式可檢。
實踐步驟:從數據到評估與訓練
- 下載 GSM8K(train/test)到本地 data 目錄(可根據你的項目目錄調整):
cd dataset
wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/train.jsonl
wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
建議將文件移動/複製為倉庫約定路徑:data/gsm8k/train.jsonl 與 data/gsm8k/test.jsonl(alignment/args.py 默認如此)。
- 微調前的零樣本評估(使用 vLLM + R1 模板 + 嚴格格式與答案打分):
uv run -m alignment.evaluate
- 執行 SFT 訓練並在測試集上評估(訓練中每個 epoch 結束都會評估並保存 checkpoint):
uv run -m alignment.sft
代碼講解:數據加載與樣本構造(包含 <think>/<answer> 標籤)
首先用 R1 模板將問題轉為 Prompt,將“推理過程 + 最終答案”打包為監督信號。alignment/dataset.py 與 r1_prompt.py 如下:
# alignment/dataset.py
from torch.utils.data import Dataset
import json
from .r1_prompt import R1PromptTemplate
class Gsm8kDataset(Dataset):
def __init__(self, data_path: str, promt_template_path: str):
template = R1PromptTemplate(promt_template_path)
self.data = []
self.label = []
self.ground_truth = []
with open(data_path, "r") as f:
lines = f.readlines()
for line in lines:
qa = json.loads(line)
question = qa["question"]
answer_think = qa["answer"]
think, answer = answer_think.split("####")
think, answer = think.strip(), answer.strip()
self.data.append(template.gen_prompt(question))
self.label.append(template.gen_response(think, answer))
self.ground_truth.append(answer)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx], self.label[idx], self.ground_truth[idx]
# alignment/r1_prompt.py(關鍵:強制 <think>/<answer> 標籤)
class R1PromptTemplate:
def __init__(self, template_path: os.PathLike):
with open(template_path, "r") as f:
self.template = f.read().strip()
def gen_prompt(self, question: str) -> str:
return self.template.replace(r"{question}", question)
def gen_response(self, think: str, answer: str) -> str:
return think + "</think>" + " <answer>" + answer + " </answer>"
模板文件 alignment/prompts/r1_zero.prompt:
A conversation between User and Assistant... <think> reasoning process here </think> <answer> answer here </answer>.
User: {question}
Assistant: <think>
這保證了訓練時模型學習“先寫 <think> 再寫 <answer>”,評估時也能穩定抽取並驗證答案。
代碼講解:SFT 訓練循環與標籤掩碼(只對迴應部分計算損失)
在 alignment/sft.py,我們將 prompt + completion 拼接,對 prompt 區間的 label 置為 -100,從而讓交叉熵只在“迴應(think + answer)”上回傳梯度:
# alignment/sft.py(節選)
model.train()
for epoch in range(args.epochs):
for i, batch in enumerate(train_data_loader):
prompts, completions, _ = batch
full_texts = [p + c + tokenizer.eos_token for p, c in zip(prompts, completions)]
inputs = tokenizer(
full_texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=args.max_seq_len,
).to(model.device)
# 計算每個樣本的 prompt token 長度
prompt_tokens = tokenizer(list(prompts), add_special_tokens=False)
prompt_lengths = [len(ids) for ids in prompt_tokens.input_ids]
labels = inputs.input_ids.clone()
for idx in range(len(prompts)):
prompt_len = prompt_lengths[idx]
labels[idx, :prompt_len] = -100 # 對 prompt 部分不計 loss
# Mask padding tokens
labels[labels == tokenizer.pad_token_id] = -100
outputs = model(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
labels=labels,
)
loss = outputs.loss
loss = loss / args.gradient_accumulation_steps
loss.backward()
if (i + 1) % args.gradient_accumulation_steps == 0:
print(f"Epoch {epoch}, Iteration {i}, Loss: {loss.item() * args.gradient_accumulation_steps}")
optimizer.step()
optimizer.zero_grad()
# 評估:將策略權重加載到 vLLM 並在測試集上跑評估
load_policy_into_vllm_instance(model, eval_model)
evaluate_math(
eval_model,
args.prompt_template_path,
args.sft_test_data,
args.batch_size,
log_sample=True,
)
model.save_pretrained(args.checkpoint_path)
這段邏輯實現了標準的 SFT:用“迴應”作為監督目標,配合累積梯度與分卡策略,保證 1.5B 規模模型在可控顯存與吞吐下完成訓練。
代碼講解:評估與指標計算(準確率與格式遵循)
評估由 alignment/evaluate.py 驅動,調用 r1_zero_reward_fn 計算每條樣本的格式與答案得分:
# alignment/evaluate.py(核心)
from collections.abc import Callable
from vllm import LLM, SamplingParams
from .drgrpo_grader import r1_zero_reward_fn
def evaluate_vllm(vllm_model: LLM,
reward_fn: Callable[[str, str], dict[str, float]],
prompts: list[str],
ground_truths: list[str],
eval_sampling_params: SamplingParams,
log_sample: bool) -> dict:
outputs = vllm_model.generate(prompts, eval_sampling_params)
generated_texts = [output.outputs[0].text for output in outputs]
rewards = [reward_fn(generated_text, ground_truth)
for generated_text, ground_truth in zip(generated_texts, ground_truths)]
avg_format_rewards = sum([r["format_reward"] for r in rewards]) / len(prompts)
avg_answer_rewards = sum([r["answer_reward"] for r in rewards]) / len(prompts)
avg_all_rewards = sum([r["reward"] for r in rewards]) / len(prompts)
print(f"avg_format_rewards: {avg_format_rewards}")
print(f"avg_answer_rewards: {avg_answer_rewards}")
print(f"avg_all_rewards: {avg_all_rewards}")
return {
"avg_format_rewards": avg_format_rewards,
"avg_answer_rewards": avg_answer_rewards,
"avg_all_rewards": avg_all_rewards,
}
獎勵函數嚴格要求格式標籤,並對答案進行多重等價校驗:
# alignment/drgrpo_grader.py(節選)
def r1_zero_reward_fn(response, ground_truth, fast=True):
# 格式嚴格:必須包含 </think> <answer> ... </answer>
if "</think> <answer>" in response and "</answer>" in response:
model_answer = response.split("<answer>")[-1].replace("</answer>", "")
# 允許 \boxed{...} 的答案形式並抽取
if "\\boxed" in model_answer:
model_answer = extract_answer(model_answer)
if model_answer is None:
return {"format_reward": 1.0, "answer_reward": 0.0, "reward": 0.0}
# 字符串、數值、LaTeX、SymPy 等價綜合判斷
is_correct = grade(model_answer, str(ground_truth), fast)
if is_correct:
return {"format_reward": 1.0, "answer_reward": 1.0, "reward": 1.0}
else:
return {"format_reward": 1.0, "answer_reward": 0.0, "reward": 0.0}
else:
# 不符合格式:直接判 0 分(同時 answer_reward 也為 0)
return {"format_reward": 0.0, "answer_reward": 0.0, "reward": 0.0}
因此:
- 格式遵循率 = 所有樣本的
format_reward平均值。 - 推理準確率 = 所有樣本的
answer_reward平均值(在格式正確的前提下計算答案是否等價)。
數據讀取與 Ground Truth 解析(alignment/evaluate.py):
# alignment/evaluate.py(節選)
def get_gsm8k_test_data(test_data_path: os.PathLike) -> list[dict]:
data = []
with open(test_data_path, "r") as f:
for line in f.readlines():
obj = json.loads(line.strip())
ts = obj["answer"].split("####")
if len(ts) != 2:
print(f"invalid answer: {obj['answer']}")
continue
data.append({"question": obj["question"],
"think": ts[0].strip(),
"answer": ts[1].strip()})
return data
結果與討論
在同一評估口徑下,我們在 GSM8K 上得到如下對比(報告值):
| 指標 | 微調前(零樣本) | SFT 微調後 |
|---|---|---|
| 推理準確率 | 1.56% | 62.9% |
| 格式遵循率 | 18.9% | 100% |
説明:
- 上述結果是在 alignment.evaluate 與 alignment.sft 的評估框架下得到的報告值。不同硬件/軟件環境與隨機種子可能引起輕微差異。
- “格式遵循率”達到 100% 的關鍵是訓練時的模板約束與監督信號覆蓋
<think>/<answer>,並在獎勵函數中嚴格判定標籤存在性。
常見問題與優化建議
- 訓練不收斂或損失震盪:適當降低學習率(如 1e‑5 → 5e‑6)、增大
gradient_accumulation_steps,或提高max_seq_len以覆蓋完整推理鏈。 - 顯存不足:增加
--sft_device的卡數或調大--max_sft_gpu_memory_use,必要時裁剪序列長度或使用更小 batch。 - 評估速度慢:調小
SamplingParams的max_tokens;但過小的生成長度會影響完整<think>/<answer>輸出與正確率。 - 格式仍偶發不合規:檢查模板與數據是否一致(是否始終以
<think>開頭),並確保 SFT 標籤覆蓋足量樣本。
結論
通過對 Qwen/Qwen2.5-Math-1.5B 在 GSM8K 上的監督微調,我們實現了兩條主線的同步提升:
- 推理鏈條更穩,**零樣本正確率從 1.56% 提升到 62.9%**;
- 輸出規範更強,**格式遵循率從 18.9% 提升到 100%**。
這背後的關鍵是“結構化輸出模板 + 只對迴應部分計損失 + 嚴格但寬容的評估規則”。如果你也在做數學推理或其他需要“思考‑答案二段式”輸出的任務,強烈建議複用本文的架構與代碼。
行動號召:現在就下載 GSM8K,跑一遍 uv run -m alignment.evaluate 與 uv run -m alignment.sft,觀察你本地的改進幅度吧! 可以參考 llm-from-scratch 倉庫中 alignment 模塊,對照進行學習。
開放問題:在你的場景裏,是否遇到過“答案正確但格式導致系統解析失敗”的案例?你是如何設計模板與評估邏輯來避免它的?歡迎在評論區分享你的經驗與挑戰。