【复现DeepSeek-R1之Open R1实战】系列6:GRPO源码结构解析

【复现DeepSeek-R1之Open R1实战】系列博文链接:
【复现DeepSeek-R1之Open R1实战】系列1:跑通SFT(一步步操作,手把手教学)
【复现DeepSeek-R1之Open R1实战】系列2:没有卡也能训模型!Colab跑OpenR1(附源码)
【复现DeepSeek-R1之Open R1实战】系列3:基础知识介绍
【复现DeepSeek-R1之Open R1实战】系列4:跑通GRPO!
【复现DeepSeek-R1之Open R1实战】系列5:SFT源码逐行深度解析
【复现DeepSeek-R1之Open R1实战】系列6:GRPO源码结构解析
【复现DeepSeek-R1之Open R1实战】系列7:GRPO原理介绍、训练流程和源码深度解析
【复现DeepSeek-R1之Open R1实战】系列8:混合精度训练、DeepSpeed、vLLM和LightEval介绍
【复现DeepSeek-R1之Open R1实战】系列9:有趣的现象——GRPO训练过程Loss从0开始慢慢变大

4 GRPO源码分析

前面两篇博文已经详细介绍了一些基础知识和SFT源码,本文继续解读GRPO源码。与SFT源码差不多的部分,我们就不展开细说了,这里只解析GRPO独特的部分。

4.1 数据类 GRPOScriptArguments

该类使用了 Python 的 dataclass 装饰器,这是一种简化类定义的方式,特别是对于那些主要用来存储数据的类。它继承自 ScriptArguments 类。

  • reward_funcs: 这是一个列表,包含了一系列可能的奖励函数名称,默认值为 ["accuracy", "format"]。这些奖励函数可能是用于评估模型性能的不同标准。

    reward_funcs: list[str] = field(
        default_factory=lambda: ["accuracy", "format"],
        metadata={
            "help": "List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length'"
        },
    )
    
  • cosine_min_value_wrongcosine_max_value_wrong: 分别表示错误答案在余弦相似度尺度上的最小和最大奖励值,默认分别为 0.0-0.5

  • cosine_min_value_correctcosine_max_value_correct: 分别表示正确答案在余弦相似度尺度上的最小和最大奖励值,默认分别为 0.51.0

  • cosine_max_len: 表示余弦相似度尺度的最大长度,默认值为 1000

  • repetition_n_grams: 表示用于重复惩罚奖励的n-gram数量,默认值为 3

  • repetition_max_penalty: 表示重复惩罚奖励的最大负值,默认值为 -1.0

每个字段都使用了 field() 函数来定义其默认值和元数据(如帮助信息)。这有助于工具和库更好地理解和处理这些字段,例如生成命令行解析器时。

4.2 系统提示字符串 SYSTEM_PROMPT

SYSTEM_PROMPT = (
    "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
    "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
    "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
    "<think> reasoning process here </think><answer> answer here </answer>"
)

字符串描述了一个对话场景,用户先提问,助手首先思考推理过程,然后提供答案。推理过程和答案分别用 <think><answer> 标签包裹,这种格式化有助于区分和识别不同的部分,和DeepSeek-R1的思考过程格式一致。

4.3 奖励函数

奖励函数的定义如下,GRPO默认用到了accuracy_reward和format_reward这两个函数。

# Get reward functions
    REWARD_FUNCS_REGISTRY = {
        "accuracy": accuracy_reward,
        "format": format_reward,
        "reasoning_steps": reasoning_steps_reward,
        "cosine": get_cosine_scaled_reward(
            min_value_wrong=script_args.cosine_min_value_wrong,
            max_value_wrong=script_args.cosine_max_value_wrong,
            min_value_correct=script_args.cosine_min_value_correct,
            max_value_correct=script_args.cosine_max_value_correct,
            max_len=script_args.cosine_max_len,
        ),
        "repetition_penalty": get_repetition_penalty_reward(
            ngram_size=script_args.repetition_n_grams,
            max_penalty=script_args.repetition_max_penalty,
        ),
        "length": len_reward,
    }
    reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]

这段代码定义了一个奖励函数注册表 REWARD_FUNCS_REGISTRY,并根据用户提供的配置动态生成一个奖励函数列表 reward_funcs。每个奖励函数用于评估模型输出的不同方面,如准确性、格式、推理步骤等。

  1. 注册表定义
  • accuracy: 使用 accuracy_reward 函数评估模型输出的准确性。
  • format: 使用 format_reward 函数评估模型输出的格式。
  • reasoning_steps: 使用 reasoning_steps_reward 函数评估模型输出的推理步骤。
  • cosine: 使用 get_cosine_scaled_reward 函数计算余弦相似度奖励,参数包括:
    • min_value_wrong: 错误情况下的最小值。
    • max_value_wrong: 错误情况下的最大值。
    • min_value_correct: 正确情况下的最小值。
    • max_value_correct: 正确情况下的最大值。
    • max_len: 最大长度。
  • repetition_penalty: 使用 get_repetition_penalty_reward 函数计算重复惩罚奖励,参数包括:
    • ngram_size: n-gram 的大小。
    • max_penalty: 最大惩罚值。
  • length: 使用 len_reward 函数评估模型输出的长度。
  1. 动态生成奖励函数列表
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]
  • 根据 script_args.reward_funcs 中指定的奖励函数名称,从 REWARD_FUNCS_REGISTRY 中获取相应的奖励函数,并生成一个列表 reward_funcs

4.3.1 accuracy_reward函数

该函数用于计算模型生成的补全与真实答案之间的准确性奖励。它通过解析和验证生成的内容与真实答案来确定奖励值。

def accuracy_reward(completions, solution, **kwargs):
    """Reward function that checks if the completion is the same as the ground truth."""
    contents = [completion[0]["content"] for completion in completions]
    rewards = []
    for content, sol in zip(contents, solution):
        gold_parsed = parse(
            sol,
            extraction_mode="first_match",
            extraction_config=[LatexExtractionConfig()],
        )
        if len(gold_parsed) != 0:
            # We require the answer to be provided in correct latex (no malformed operators)
            answer_parsed = parse(
                content,
                extraction_config=[
                    LatexExtractionConfig(
                        normalization_config=NormalizationConfig(
                            nits=False,
                            malformed_operators=False,
                            basic_latex=True,
                            equations=True,
                            boxed="all",
                            units=True,
                        ),
                        # Ensures that boxed is tried first
                        boxed_match_priority=0,
                        try_extract_without_anchor=False,
                    )
                ],
                extraction_mode="first_match",
            )
            # Reward 1 if the content is the same as the ground truth, 0 otherwise
            reward = float(verify(answer_parsed, gold_parsed))
        else:
            # If the gold solution is not parseable, we reward 1 to skip this example
            reward = 1.0
            print("Failed to parse gold solution: ", sol)
        rewards.append(reward)

    return rewards
  • completions (list): 包含多个补全结果的列表,每个补全结果是一个包含内容的字典列表。
  • solution (list): 真实答案的列表。
  • kwargs: 其他可选参数(在本函数中未使用)。
  1. 提取补全内容

    contents = [completion[0]["content"] for completion in completions]
    
    • completions 列表中提取每个补全的第一个内容(假设每个补全是单个元素的列表),形成一个新的 contents 列表。
  2. 初始化奖励列表

    rewards = []
    
  3. 遍历每个补全和对应的真实答案

    for content, sol in zip(contents, solution):
        gold_parsed = parse(
            sol,
            extraction_mode="first_match",
            extraction_config=[LatexExtractionConfig()],
        )
    
    • 使用 zip 函数将 contentssolution 配对。
    • 对于每一对补全内容和真实答案,首先解析真实答案 sol,使用 parse 函数提取其中的信息。
  4. 处理解析结果

    if len(gold_parsed) != 0:
        answer_parsed = parse(
            content,
            extraction_config=[
                LatexExtractionConfig(
                    normalization_config=NormalizationConfig(
                        nits=False,
                        malformed_operators=False,
                        basic_latex=True,
                        equations=True,
                        boxed="all",
                        units=True,
                    ),
                    # Ensures that boxed is tried first
                    boxed_match_priority=0,
                    try_extract_without_anchor=False,
                )
            ],
            extraction_mode="first_match",
        )
    
    • 如果解析得到的真实答案 gold_parsed 非空,则继续解析生成的补全内容 content
    • 使用 LatexExtractionConfigNormalizationConfig 进行详细配置,确保解析过程中考虑了各种格式要求(如方程、单位等)。
  5. 计算奖励

    reward = float(verify(answer_parsed, gold_parsed))
    
    • 使用 verify 函数比较生成的补全解析结果和真实答案的解析结果。
    • 如果两者匹配,则返回 1.0,否则返回 0.0
  6. 处理无法解析的情况

    else:
        reward = 1.0
        print("Failed to parse gold solution: ", sol)
    
    • 如果真实答案无法解析,则默认给予奖励 1.0 并打印警告信息。
  7. 添加奖励到列表

    rewards.append(reward)
    
  8. 返回所有奖励

    return rewards
    

4.3.2 verify函数

该函数用于验证目标表达式是否与参考表达式匹配,它通过多种比较策略来处理不同的数学对象(如数字、表达式、集合、矩阵等),并提供灵活的配置选项以适应不同的需求。

def verify(
    gold: list[Basic | MatrixBase | str] | Basic | MatrixBase | str, 
    target: list[Basic | MatrixBase | str] | Basic | MatrixBase | str, 
    float_rounding: int=6,
    numeric_precision: int=15,
    strict: bool=True,
    timeout_seconds: int=3
) -> bool:
  • gold: 参考或正确的表达式,可以是单个 SymPy 表达式(BasicMatrixBase)、字符串或这些类型的列表。
  • target: 需要验证的表达式,类型同 gold
  • float_rounding: 浮点数舍入的小数位数,默认为 6。
  • numeric_precision: 数值比较时考虑的小数位数,默认为 15。
  • strict: 是否启用严格比较模式,默认为 True
    • 在严格模式下:变量很重要,集合不可与元组比较。
    • 在非严格模式下:变量按位置匹配,集合可与元组比较。
  • timeout_seconds: 单次比较操作的最大超时时间(秒),默认为 3 秒。
  1. 定义内部比较函数 compare_single_extraction

    @timeout(timeout_seconds=timeout_seconds)
    def compare_single_extraction(gold: Basic | MatrixBase | str, target: Basic | MatrixBase | str) -> bool:
        ...
    
    • 使用装饰器 @timeout 设置超时保护,默认超时时间为 timeout_seconds
    • 比较两个表达式:
      • 如果两者都是 SymPy 表达式(BasicMatrixBase),则调用 sympy_expr_eq 进行比较。
      • 如果两者都是字符串,则进行简单的字符串比较。
  2. 定义包装函数 compare_single_extraction_wrapper

    def compare_single_extraction_wrapper(g, t):
        try:
            return compare_single_extraction(g, t)
        except Exception as e:
            logger.exception(f"Error comparing {g} and {t}")
            return False
    
    • 包装 compare_single_extraction,捕获并记录任何异常,返回 False 以避免程序中断。
  3. 处理输入列表

    if not isinstance(gold, list):
        gold = [gold]
    if not isinstance(target, list):
        target = [target]
    
    • 如果 goldtarget 不是列表,则将其转换为单元素列表,以便统一处理。
  4. 组合所有可能的比较

    return any(compare_single_extraction_wrapper(g, t) for g, t in product(gold, target))
    
    • 使用 itertools.product 生成所有可能的 goldtarget 组合。
    • 对每个组合调用 compare_single_extraction_wrapper,如果任意一对匹配成功,则返回 True

4.3.3 format_reward函数

函数用于检查给定的完成文本是否符合特定的格式,它验证完成文本是否包含 <think><answer> 标签,并且这两个标签的内容是非空的。

def format_reward(completions, **kwargs):
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<think>.*?</think>\s*<answer>.*?</answer>$"
    completion_contents = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]
    return [1.0 if match else 0.0 for match in matches]
  • completions: 这是一个列表,其中每个元素都是一个包含完成内容的对象(通常是字典)。假设每个完成对象的第一个元素包含一个键 "content",其值是需要检查的文本。
  • kwargs: 其他关键字参数,这里没有使用,但可以为未来的扩展提供灵活性。
  1. 正则表达式模式定义

    pattern = r"^<think>.*?</think>\s*<answer>.*?</answer>$"
    
    • 这个正则表达式用于匹配字符串是否以 <think> 开始,紧接着是任意字符(非贪婪匹配),然后是 </think>,接着可能有任意数量的空白字符(包括换行符),最后是以 <answer> 开始并以 </answer> 结束。
    • .*? 是非贪婪匹配,确保尽可能少地匹配字符。
    • \s* 匹配零个或多个空白字符(包括换行符)。
    • re.DOTALL | re.MULTILINE 标志允许点号 . 匹配所有字符(包括换行符),并且使多行文本中的每一行都可以独立匹配。
  2. 提取完成内容

    completion_contents = [completion[0]["content"] for completion in completions]
    
    • 这里通过列表推导式从 completions 列表中提取每个完成对象的第一个元素的 "content" 字段,形成一个新的列表 completion_contents
  3. 匹配正则表达式

    matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]
    
    • 使用 re.match 函数对 completion_contents 中的每个内容应用正则表达式模式。
    • matches 列表将包含 re.Match 对象(如果匹配成功)或 None(如果匹配失败)。
  4. 生成奖励分数

    return [1.0 if match else 0.0 for match in matches]
    
    • 最后一步是根据匹配结果生成奖励分数。如果匹配成功(即 match 不是 None),则返回 1.0;否则返回 0.0

示例代码:

completions = [
    [{"content": "<think>This is reasoning.</think><answer>This is answer.</answer>"}],
    [{"content": "<think>This is reasoning.</think>"}],
    [{"content": "<answer>This is answer.</answer>"}],
    [{"content": "This does not match."}]
]

reward_scores = format_reward(completions)
print(reward_scores)  # 输出: [1.0, 0.0, 0.0, 0.0]

在这个例子中:

  • 第一个完成内容完全匹配正则表达式,因此得分为 1.0
  • 后三个完成内容不符合要求,因此得分均为 0.0

4.4 将数据集格式化为对话形式

# Format into conversation
    def make_conversation(example):
        return {
            "prompt": [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": example["problem"]},
            ],
        }

    dataset = dataset.map(make_conversation)
    for split in dataset:
        if "messages" in dataset[split].column_names:
            dataset[split] = dataset[split].remove_columns("messages")

将一个数据集中的每个示例转换为对话格式,并确保数据集中没有多余的列(如 messages)。

  • 输入example 是一个字典,包含单个数据样本的信息,其中 "problem" 键对应的值是用户的问题或任务描述。
  • 输出:返回一个新的字典,包含一个 "prompt" 键,其值是一个对话列表:
    • 第一条消息是系统消息,内容由 SYSTEM_PROMPT 定义。
    • 第二条消息是用户消息,内容是 example["problem"]
  • dataset.map(make_conversation):使用 map 方法将 make_conversation 函数应用到数据集的每个示例上,生成新的对话格式。
  • 移除多余列:遍历数据集的每个拆分(split),如果存在 "messages" 列,则将其移除。

4.5 初始化GRPO Trainer

trainer = GRPOTrainer(
        model=model_args.model_name_or_path,
        reward_funcs=reward_funcs,
        args=training_args,
        train_dataset=dataset[script_args.dataset_train_split],
        eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
        peft_config=get_peft_config(model_args),
        callbacks=get_callbacks(training_args, model_args),
    )

篇幅有限,训练过程的代码我们放到下一篇博文详细解读, 博文链接:【复现DeepSeek-R1之Open R1实战】系列7:GRPO原理介绍、训练流程和源码深度解析

### 复现 Open-R1 项目的指南 #### 环境准备 为了成功复现 Hugging Face 的 Open-R1 项目,需配置合适的硬件环境。建议使用至少两个计算节点,每节点应装备有8块NVIDIA H100 GPU 显卡[^4]。 #### 获取源码与资源 访问 GitHub 上官方维护的 Open-R1 仓库地址 `https://github.com/huggingface/open-r1` 下载最新版本代码库。此存储库包含了构建和运行该项目所需的一切文件以及详细的文档说明[^1]。 #### 数据集准备 依据描述,Open-R1 使用了特定的数据生成方法来创建用于训练的数据集。这包括但不限于GRPO (Generative Reasoning Process Optimization) 实现及相应数据生成器的设计[^2]。对于初学者来说,可以从现有公开可用的小规模样例入手练习,在熟悉整个流程后再逐步扩展到更大更复杂的真实场景应用上去。 #### 训练过程概述 采用 Qwen2.5-1.5B 模型作为基础架构,并在此之上集成多种先进的开源工具和技术栈完成最终系统的搭建工作。值得注意的是,初期可以通过提供少量精心设计好的示范案例来进行预热(即所谓的“冷启动”),之后再利用这些初步结果进一步指导后续迭代优化的方向[^3]。 ```bash # 假设已经克隆了GitHub上的open-r1仓库 cd path/to/your/cloned/repo pip install -r requirements.txt python train.py --config_path ./configs/default.yaml ``` 上述命令展示了安装依赖项并执行默认设置下的训练脚本的方式;实际操作时可根据个人需求调整参数配置以适应不同任务目标的要求。 #### 验证效果 最后一步是对所得到的结果进行全面评估测试,确保其性能指标达到预期水平。考虑到该研究方向尚处于快速发展阶段,鼓励参与者积极贡献自己的见解和改进建议给社区,共同推动领域内技术创新与发展进步。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值