# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from huggingface/trl.
import concurrent.futures
import inspect
import os
import re
import time
from collections import defaultdict, deque
from concurrent.futures import Future
from contextlib import contextmanager, nullcontext
from copy import copy, deepcopy
from dataclasses import asdict, dataclass, field
from functools import partial
from math import ceil
from queue import Queue
from types import MethodType
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import datasets
import torch
import torch.nn as nn
import transformers
from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed
from packaging import version
from torch.nn import ModuleList
from torch.utils.data import DataLoader
from transformers import PreTrainedModel, TrainerCallback
from transformers.trainer import Trainer
from trl import GRPOTrainer as HFGRPOTrainer
from trl.extras.profiling import profiling_context, profiling_decorator
from trl.models import prepare_deepspeed
from trl.trainer.callbacks import SyncRefModelCallback
from trl.trainer.grpo_trainer import nanmax, nanmin, nanstd
from swift.llm import (InferRequest, MultiModelKeys, RequestConfig, RolloutInferRequest, RowPreprocessor, Template,
get_model_arch, to_device)
from swift.llm.infer.protocol import ChatCompletionResponse
from swift.llm.model.utils import get_llm_model
from swift.llm.template.template_inputs import StdTemplateInputs
from swift.plugin import loss_scale_map, multi_turns, orms, rm_plugins
from swift.plugin.multi_turn import MultiTurnScheduler
from swift.utils import (JsonlWriter, empty_cache, get_current_device, get_device, get_logger, is_vllm_available,
is_wandb_available, seed_worker, unwrap_model_for_generation)
from ..mixin import SwiftMixin
from .rlhf_mixin import RLHFTrainerMixin
from .utils import _ForwardRedirection, patch_lora_merge, patch_lora_unmerge
from .vllm_client import VLLMClient
del HFGRPOTrainer.__init__
del HFGRPOTrainer.log
logger = get_logger()
if is_wandb_available():
import wandb
InputsType = List[Dict[str, Union[torch.Tensor, Any]]]
# tuple: (messages, finish_reason)
OutputsType = List[Tuple[List[Dict], str]]
class GRPOCallback(TrainerCallback):
def __init__(self, trainer):
self.trainer = trainer
# offload original_modules to cpu, to save memory
def on_train_begin(self, args, state, control, **kwargs):
self.trainer.queue = self.trainer.train_queue
train_dataloader = getattr(state, 'train_dataloader', None) or kwargs.get('train_dataloader')
self.trainer._prefetch(train_dataloader)
@dataclass
class DataCache:
inputs: List[Dict] = field(default_factory=list)
outputs: List[Dict] = field(default_factory=list)
def identity_data_collator(features):
return features
class GRPOTrainer(RLHFTrainerMixin, SwiftMixin, HFGRPOTrainer):
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
def __init__(self,
model: Optional[Union[PreTrainedModel, nn.Module]] = None,
ref_model: Optional[Union[PreTrainedModel, nn.Module]] = None,
reward_model: Optional[List[Union[PreTrainedModel, nn.Module]]] = None,
reward_funcs: Optional[List[Union[str, Callable]]] = None,
*_args,
**kwargs):
from swift.trainers.rlhf_arguments import GRPOConfig
args: GRPOConfig = kwargs['args']
self.args = args
# for async generate
self.train_queue = Queue()
self.eval_queue = Queue()
self.processing_class = kwargs.get('template').tokenizer
if not isinstance(reward_funcs, list):
reward_funcs = [reward_funcs]
if reward_funcs:
for i, reward_func in enumerate(reward_funcs):
if reward_func in orms:
reward_func_class = orms[reward_func]
reward_func_args = list(inspect.signature(reward_func_class.__init__).parameters)
reward_func_kwargs = {
key: getattr(args, key)
for key in reward_func_args if key not in ['self', 'args', 'kwargs'] and hasattr(args, key)
}
if 'tokenizer' in reward_func_args:
reward_func_kwargs['tokenizer'] = self.processing_class
reward_funcs[i] = reward_func_class(**reward_func_kwargs)
elif not callable(reward_func):
raise ValueError(f'reward_function {reward_func} is not implemented in swift.llm.plugin')
self.reward_funcs = reward_funcs
self.reward_func_names = []
for reward_func in reward_funcs:
if inspect.isfunction(reward_func):
reward_func_name = reward_func.__name__
else:
reward_func_name = reward_func.__class__.__name__
self.reward_func_names.append(reward_func_name)
self.reward_model_plugins = [None] * len(self.reward_funcs)
if reward_model is not None:
reward_template = kwargs.pop('reward_template')
reward_plugins = args.reward_model_plugin
if reward_plugins is None:
reward_plugins = ['default'] * len(reward_model)
assert len(reward_plugins) == len(reward_model), (
f"The number of 'reward_model_plugin' ({len(reward_plugins)}) does not match "
f"the number of 'reward_model' ({len(reward_model)}). "
"Please provide a corresponding 'reward_model_plugin' for each 'reward_model'.")
for rm, rm_plugin, rm_template in zip(reward_model, reward_plugins, reward_template):
# Set encoding mode train(see details in Template.encode).
# Set max_length to None to disable truncation, as the input length has already been truncated earlier.
rm_template.set_mode('train')
rm_template.max_length = None
if rm_plugin not in rm_plugins:
raise ValueError(f'rm_plugin {rm_plugin} is not implemented in swift.llm.plugin')
self.reward_model_plugins.append(rm_plugins[rm_plugin](model=rm, template=rm_template))
self.reward_funcs.append(rm)
self.reward_func_names.append(rm.config._name_or_path.split('/')[-1])
if not self.reward_funcs:
raise ValueError('You must specify reward_funcs or reward_model')
# Reward weights
if args.reward_weights is not None:
if len(args.reward_weights) != len(reward_funcs):
raise ValueError(f'Number of reward weights ({len(args.reward_weights)}) must match number of reward '
f'functions ({len(reward_funcs)})')
self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32)
else:
self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32)
self.multi_turn_scheduler = None
if self.args.multi_turn_scheduler:
if isinstance(self.args.multi_turn_scheduler, str):
assert self.args.multi_turn_scheduler in multi_turns
multi_turn_scheduler = multi_turns[self.args.multi_turn_scheduler](max_turns=self.args.max_turns)
self.multi_turn_scheduler: MultiTurnScheduler = multi_turn_scheduler
else:
assert isinstance(multi_turn_scheduler, MultiTurnScheduler)
self.multi_turn_scheduler: MultiTurnScheduler = self.args.multi_turn_scheduler
self.num_generations = args.num_generations
self.temperature = args.temperature
self.vllm_mode = args.vllm_mode
self.vllm_gpu_memory_utilization = args.vllm_gpu_memory_utilization # only applies to colocation