简述
如何让模型具备更强的数学推理能力,成为当前研究的关键方向之一。GSM8K(Grade School Math 8K)是一个被广泛采用的数学问答数据集,专注于 小学数学题目的逐步推理能力评估。该数据集包含高质量、逐步求解格式的问答对,尤其适合检验模型的 算术能力 与 多步逻辑推理能力。
与此同时,传统的 SFT(Supervised Fine-Tuning)与 RLHF(Reinforcement Learning from Human Feedback)方法虽然取得了一定成效,但在训练效率、显存开销以及 reward 机制设计等方面仍存在挑战。为了更高效地对大型语言模型进行强化学习训练,GRPO(Guided Reward Pretraining with Offloading)作为一种新型方法应运而生。GRPO 引入了一种基于 KL 控制的奖励引导机制,并结合 vLLM 高效采样与参数 Offloading、FSDP 等分布式优化技术,使得在大模型训练中实现更好的 reward guidance 和显存管理成为可能。
将 GRPO 应用于 GSM8K 数据集的训练任务具有以下重要意义:
-
验证 GRPO 对数学推理任务的效果提升能力
GSM8K 是标准的数学推理 benchmark,通过该数据集可以清晰评估 GRPO 相较于传统 RLHF/SFT 在数学任务上的提升程度,尤其是解题准确率与 reasoning trace 合理性。 -
检验 reward guidance 对多步推理行为的引导能力
GSM8K 题目具有明显的多步推理结构,GRPO 提供的 KL reward 和 baseline policy 对模型生成路径的引导,能够验证其在“过程引导”而非“答案对错”上的奖励设计优势。 -
测试大模型在资源受限条件下的训练效率
GRPO 设计了低资源显存消耗的训练流程(通过 offload、vLLM、FSDP),GSM8K 作为较小规模数据集,能够快速迭代实验,验证在资源敏感环境下的训练吞吐与稳定性。 -
推动可解释推理能力的发展
GSM8K 的 step-by-step 格式天然适合生成过程评估,而 GRPO 的引导训练有望提升模型的“可解释中间推理”,为 LLM 的可控生成奠定基础。在 GSM8K 上使用 GRPO 所训练出的模型,未来可拓展到更复杂的数学任务与真实教育场景中。
数据集选择
GSM8K ?
- GSM8K 是一个面向小学数学水平的问题数据集,旨在测试模型的数学推理能力。数据集中的每个条目包含一个问题(例如:“一件衬衫原价 20 美元,打 25% 折扣后多少钱?”)和答案(例如:“15”)。
- 代码中通过 Parquet 格式从本地路径加载数据集(/opt/chenrui/grpo/dataset/gsm8k/main/test-00000-of-00001.parquet),并将其解析为问题(Q)和答案(A)的键值对列表。
为什么选择 GSM8K?
- 任务明确:GSM8K 的问题设计简单但需要逻辑推理,适合测试和提升模型在数学推理任务上的能力。问题通常涉及基本的算术运算(如加减乘除、百分比),这对语言模型的理解和推理能力提出了挑战。
- 答案可验证:每个问题都有明确的正确答案(通常是数字或简单表达式),便于设计奖励函数来评估模型输出的正确性。
- 适中难度:GSM8K 的问题难度适中,既不过于简单(避免模型直接记忆答案),也不过于复杂(避免训练难度过高),适合作为微调任务的数据集。
- 广泛应用:GSM8K 是评估语言模型数学推理能力的标准数据集,广泛用于学术研究和模型性能对比,具有代表性。
代码中的处理
QAs = [{'Q':x, 'A':y.split('####')[-1].strip()} for x,y in zip(dataset['question'], dataset['answer'])]
- 数据集被加载并解析为 QAs 列表,每个元素包含问题(Q)和标准答案(A)。
- 标准答案通过正则表达式提取最后一个数字(如 15),用于后续奖励计算。
GRPO 训练过程
GRPO(Gradient-based Reward Policy Optimization)是一种基于奖励的强化学习方法,用于优化语言模型的生成策略。它通过比较模型生成答案的奖励和参考答案的分布,调整模型参数,使其更倾向于生成高质量的回答。
代码中的 GRPO 实现
def generate_mode(num=10, rank=0, gen_timeout=999999):
if rank == 0: print('enter generate mode')
tic = time.time()
for ii in range(num):
if time.time() - tic > gen_timeout: break
inputs = random.sample(QAs, 1)
prompts_text, rewards, answers = gen_samples(inputs)
if prompts_text is None: continue
xdata = make_bytes_list([
json.dumps({"Q": prompts_text[0], "As": answers}).encode(),
tensor_to_bytes(torch.tensor(rewards, dtype=torch.float32))])
requests.post(f"{ref_server}/upload", data=xdata)
if rank != 0: continue
print('gen:', rewards)
if rank == 0: print('exit generate mode')
print(f'{rank}: {time.time()-tic:.3f}s')
- 数据生成(generate_mode):
- 从 GSM8K 数据集中随机抽取问题,生成多个答案(num_pre_Q=8 个答案/问题)。
- 使用自定义奖励函数(reward_correct 和 reward_format)评估生成答案:
- reward_correct:检查答案的最后一个数字是否与标准答案一致(通过 math_verify 模块验证)。
- reward_format:检查答案是否符合指定格式(即包含 <think> 和 <answer> 标签)。
- 生成的数据(问题、答案、奖励)通过 HTTP 请求发送到参考服务器(ref_server)存储。
def GRPO_step(batch):
prompt_length = batch['plen']
inputs = batch['inputs'].to(engine.device)
advantages = batch['rewards'].to(engine.device)
logits = engine(inputs).logits
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
input_ids = inputs[:, 1:] # (B, L-1), exclude the first input ID since we don't have logits for it
per_token_logps = get_per_token_logps(logits, input_ids)
per_token_logps = per_token_logps[:,prompt_length-1:]
ref_per_token_logps = batch['refs'].to(engine.device)
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
completion_mask = (inputs[:, prompt_length:] != tokenizer.pad_token_id).int()
per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
per_token_loss = -(per_token_loss - beta * per_token_kl)
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
return loss
- 训练循环:
- 从参考服务器获取批量数据(get_batch),包括输入 token、奖励和参考对数概率。
- 对每个微批次(micro_batch_size=4)计算 GRPO 损失:
- 计算模型生成的对数概率(per_token_logps)。
- 与参考对数概率(ref_per_token_logps)比较,计算 KL 散度(per_token_kl)作为正则化项。
- 使用奖励(advantages)加权损失,考虑 KL 散度的惩罚(beta=0.04)。
- 损失公式为:loss = -[exp(logps - logps.detach()) * advantages - beta * KL],其中 KL 散度防止模型偏离过远。
- 通过 DeepSpeed 进行反向传播和参数更新。
- 每隔 gen_steps=30 步触发数据生成,每隔 save_steps=200 步保存模型。
设计
流程设计
-
处理微批次:将数据划分为可并行处理的小 batch。
-
计算 GRPO 损失:按照 GRPO 策略进行训练。
-
获取 logits:模型输出每个 token 的预测分布。
-
计算 KL 散度与 token-level loss:通过对比当前策略与参考策略(baseline)的输出,获取每个 token 的损失。
-
计算最终损失:将 token 级损失加权合并,应用生成 mask 控制梯度传播范围。
-
-
反向传播和参数更新:进行一次优化步骤,更新模型参数。
GRPO 训练过程的高明之处
a. 奖励函数设计的巧妙性
- 双重奖励机制:
- 正确性奖励(reward_correct):通过 math_verify 模块验证答案的数学正确性,确保模型生成的结果与标准答案一致。这直接针对 GSM8K 的数学推理任务,鼓励模型学习正确的推理过程。
- 格式奖励(reward_format):要求答案遵循 <think> 和 <answer> 标签的结构,鼓励模型生成结构化的输出,便于用户理解推理过程。
- 综合奖励(正确性 + 格式)平衡了任务目标和输出可读性,避免模型生成无意义的正确答案或格式混乱的回答。
- 奖励数值设计:
- 正确性奖励:正确为 +1,错误为 -1。
- 格式奖励:正确格式为 +1.25,错误为 -1。这种设计赋予格式稍高的权重,强调结构化输出的重要性,同时保持正确性的核心地位。
b. GRPO 的优化策略
- KL 散度正则化:
- GRPO 通过 KL 散度(per_token_kl)惩罚模型生成分布与参考分布的偏离,防止模型在优化奖励时过度偏离原始预训练模型的语义分布。这确保模型在提升数学推理能力的同时,保留语言生成的一般能力。
- 参数 beta=0.04 控制正则化强度,值较小表明更注重奖励优化,同时不过分限制模型的探索空间。
- 高效梯度计算:
- 使用 fast_log_softmax_gather 函数优化对数概率计算,减少内存占用,提升训练效率。
- 损失函数中通过 completion_mask 忽略填充 token 的影响,确保只对生成部分计算损失,提高计算精度。
模型训练
训练代码
import torch.distributed
from transformers import AutoTokenizer, AutoModelForCausalLM
import json, os, shutil, re, random, pickle, sys, time, traceback
import torch
import torch.nn as nn
import numpy as np
import torch.distributed as dist
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
model_path = "/opt/chenrui/qwq32b/base_model/qwen2-7b"
beta = 0.04
num_pre_Q = 8
Q_batch_size = 1
micro_batch_size = 4
all_steps = 1000
max_prompt_length = 400
gen_steps = 30
save_steps = 200
ds_config = {
"train_micro_batch_size_per_gpu": micro_batch_size,
"gradient_accumulation_steps": 4,
"optimizer": {
"type": "AdamW",
"params": {"lr": 1e-6}
},
"bf16": {"enabled": True},
"zero_optimization": {
"stage": 1,
"allgather_partitions": True,
"allgather_bucket_size": 2e8,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 2e8,
"contiguous_gradients": True,
"stage3_gather_16bit_weights_on_model_save": True,
"offload_optimizer": {"device": "cpu"}
}
}
# 移除 ref_server 相关模块,定义本地缓存文件
cache_file = "data_cache.pkl"
def get_batch():
dist.barrier() # 同步所有进程
try:
if not os.path.exists(cache_file):
print(f"Rank {torch.distributed.get_rank()}: Cache file {cache_file} not found")
return None
with open(cache_file, "rb") as f:
data_cache = pickle.load(f)
if not data_cache:
print(f"Rank {torch.distributed.get_rank()}: Cache is empty")
return None
data = data_cache.pop(0) # 取出第一条数据
with open(cache_file, "wb") as f:
pickle.dump(data_cache, f) # 保存剩余数据
return data
except Exception as e:
print(f"Rank {torch.distributed.get_rank()}: Failed to load cache: {str(e)}")
return None
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path,
torch_dtype=torch.bfloat16, _attn_implementation="sdpa")
gen_model = model
from datasets import load_dataset
dataset = load_dataset(
"parquet",
data_files="/opt/chenrui/grpo/dataset/gsm8k/main/test-00000-of-00001.parquet",
split="train"
)
QAs = [{'Q': x, 'A': y.split('####')[-1].strip()} for x, y in zip(dataset['question'], dataset['answer'])]
from transformers import GenerationConfig
generation_config = GenerationConfig(
max_new_tokens=512,
do_sample=True, temperature=0.9,
num_return_sequences=num_pre_Q,
pad_token_id=tokenizer.pad_token_id,
)
system_prompt = """You are a helpful assistant. A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The Assistant first thinks about the reasoning process in the mind and then provides the user with the answer.
The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>."""
def gen_answers(prompts):
tip_text = []
for x in prompts:
tip_text.append(tokenizer.apply_chat_template([
{"role": "system", "content": system_prompt},
{"role": "user", "content": x}], tokenize=False, add_generation_prompt=True))
tip_inputs = tokenizer(tip_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False)
prompt_length = tip_inputs["input_ids"].shape[-1]
if prompt_length > max_prompt_length: return []
tip_inputs = {k: v.to(model.device) for k, v in tip_inputs.items()}
with torch.inference_mode():
tip_completion_ids = gen_model.generate(**tip_inputs, generation_config=generation_config)
completion_ids = tip_completion_ids[:, prompt_length:]
answers = [tokenizer.decode(x).replace('<|endoftext|>', '') for x in completion_ids]
return answers
from math_verify import parse, verify, ExprExtractionConfig
def reward_correct(item, answer):
pattern = r'\d+\.\d+|\d+/\d+|\d+'
nums = re.findall(pattern, answer)
if len(nums) == 0: return -1.0
lastnum = nums[-1]
ans = parse(lastnum, extraction_config=[ExprExtractionConfig()])
ground_truth = parse(item["A"], extraction_config=[ExprExtractionConfig()])
return 1 if verify(ans, ground_truth) else -1
def reward_format(item, answer):
pattern = r"^<think>.*?</think><answer>.*?</answer>$"
return 1.25 if re.match(pattern, answer, re.DOTALL | re.VERBOSE) else -1
def gen_samples(inputs):
prompts = [x["Q"] for x in inputs]
answers = gen_answers(prompts)
if len(answers) == 0: return None, None, None
rewards = []
for i, inp in enumerate(inputs):
for a in answers[i * num_pre_Q:(i + 1) * num_pre_Q]:
rewards.append(reward_correct(inp, a) + reward_format(inp, a))
prompts_text = [tokenizer.apply_chat_template([
{"role": "system", "content": system_prompt},
{"role": "user", "content": x}], tokenize=False, add_generation_prompt=True) for x in prompts]
return prompts_text, rewards, answers
def generate_mode(num=10, rank=0, gen_timeout=999999):
if rank == 0: print('enter generate mode')
tic = time.time()
data_cache = []
if os.path.exists(cache_file):
with open(cache_file, "rb") as f:
data_cache = pickle.load(f)
for ii in range(num):
if time.time() - tic > gen_timeout: break
inputs = random.sample(QAs, 1)
prompts_text, rewards, answers = gen_samples(inputs)
if prompts_text is None: continue
if rank == 0:
# 为每个样本生成输入 token 和伪参考概率
input_ids = \
tokenizer([prompts_text[0] + a for a in answers], return_tensors="pt", padding=True, padding_side="left")[
"input_ids"]
data_cache.append({
"plen": len(tokenizer(prompts_text[0], return_tensors="pt")["input_ids"][0]),
"inputs": input_ids,
"rewards": torch.tensor(rewards, dtype=torch.float32),
"refs": torch.zeros_like(input_ids, dtype=torch.float32) # 伪参考概率
})
print('gen:', rewards)
if rank == 0:
with open(cache_file, "wb") as f:
pickle.dump(data_cache, f)
print('exit generate mode')
dist.barrier()
print(f'{rank}: {time.time() - tic:.3f}s')
if 'genonly' in sys.argv:
model.to('cuda')
generate_mode(999999)
sys.exit()
import deepspeed
engine, optimizer, _, _ = deepspeed.initialize(config=ds_config, model=model,
model_parameters=model.parameters())
gen_model = engine
def get_per_token_logps(logits, input_ids):
per_token_logps = []
for logits_row, input_ids_row in zip(logits, input_ids):
log_probs = logits_row.log_softmax(dim=-1)
token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
per_token_logps.append(token_log_prob)
return torch.stack(per_token_logps)
from kernel.ce_kernel import fast_log_softmax_gather
get_per_token_logps = fast_log_softmax_gather
def GRPO_step(batch):
prompt_length = batch['plen']
inputs = batch['inputs'].to(engine.device)
advantages = batch['rewards'].to(engine.device)
logits = engine(inputs).logits
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit
input_ids = inputs[:, 1:] # (B, L-1), exclude the first input ID
per_token_logps = get_per_token_logps(logits, input_ids)
per_token_logps = per_token_logps[:, prompt_length - 1:]
ref_per_token_logps = batch['refs'].to(engine.device)
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
completion_mask = (inputs[:, prompt_length:] != tokenizer.pad_token_id).int()
per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
per_token_loss = -(per_token_loss - beta * per_token_kl)
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
return loss
from tqdm import tqdm
progress = range(1, all_steps + 1)
if torch.distributed.get_rank() == 0: progress = tqdm(progress)
for step in progress:
batch = get_batch()
while batch is None:
generate_mode(rank=torch.distributed.get_rank())
batch = get_batch()
if step % gen_steps == 0:
generate_mode(rank=torch.distributed.get_rank(), gen_timeout=100)
B = batch['inputs'].shape[0]
assert B % micro_batch_size == 0
for sub in range(0, B, micro_batch_size):
sub_batch = {'plen': batch['plen']}
sub_batch['inputs'] = batch['inputs'][sub:sub + micro_batch_size]
sub_batch['rewards'] = batch['rewards'][sub:sub + micro_batch_size]
sub_batch['refs'] = batch['refs'][sub:sub + micro_batch_size]
try:
loss = GRPO_step(sub_batch)
engine.backward(loss)
engine.step()
except:
traceback.print_exc()
print('batch:', sub_batch['inputs'].shape)
torch.distributed.destroy_process_group()
sys.exit()
if torch.distributed.get_rank() == 0:
progress.set_description(f"Loss: {loss.item():.6f}")
if step % save_steps == 0:
dist.barrier()
if torch.distributed.get_rank() == 0:
print('saving model')
save_name = f"./step_{step}"
state_dict = engine.module.state_dict()
state_dict = type(state_dict)({k: v.cpu() for k, v in state_dict.items()})
engine.module.save_pretrained(save_name, state_dict=state_dict)
tokenizer.save_pretrained(save_name)
dist.barrier()
训练过程
总结
- 数据集选择的针对性:
- GSM8K 的数学推理任务明确且可验证,适合测试和提升模型的逻辑推理能力。
- 数据集的适中难度和标准答案格式便于奖励函数设计和模型评估。
- GRPO 训练的创新性:
- 结合正确性和格式的双重奖励,平衡任务目标和输出质量。
- KL 散度正则化防止模型过度偏离原始分布,保持生成能力。
- 动态数据生成和在线训练相结合,高效利用计算资源。
- 分布式训练的高效性:
- DeepSpeed 优化(ZeRO-1、通信重叠等)显著降低显存和通信开销,适合大型模型训练。
- 梯度累积和微批次处理提升训练稳定性,适应不同硬件环境。
通俗来讲通过给模型“出题”(GSM8K 数据集)、“批改作业”(奖励函数)和“指导改进”(GRPO 优化),让模型学会如何正确且清晰地解决数学问题。它的聪明之处在于:
- 选了一个简单但有挑战性的“教材”(GSM8K),让模型既能学到东西又不至于太难。
- 用奖励机制“表扬”正确答案和规范格式,“批评”错误或乱七八糟的回答。
- 通过 GRPO 让模型在学习新知识时不忘“老本行”(语言生成能力)。
- 借助 DeepSpeed 就像请了个“高效助教”,让大模型在多台电脑上快速学习。
- 定期“考试”并保存成绩(模型检查点),确保学习成果不丢失。