TRL的PPOTrainer
实现存在一个很严重的问题,它的model
和reward_model
两个参数所使用的分词器是必须相同的,否则一定会报错。
之前已经提过,PPOTrainer
要求训练数据(train_dataset
参数)必须包含input_ids
字段,这个跟SFTTrainer
,DPOTrainer
,GRPOTrainer
都不同,查了一下源码(trl/trainer/ppo_trainer.py
),发现只有PPOTrainer
重写了父类transformers.Trainer
的.train
方法,其他三个都是直接继承,因此它们三个理论上适配相同的数据集参数,即text
或者input
+target
或者prompt+completion
,这个就很神奇,因为DPOTrainer
也没有重写.train
方法,但是显然DPOTrainer
的字段是要与其他不同的。
简单看一下PPOTrainer
的.train
方法:
def train(self):
args = self.args
accelerator = self.accelerator
optimizer = self.optimizer
model = self.model
ref_policy = self.ref_model
reward_model = self.reward_model
processing_class = self.processing_class
dataloader = self.dataloader
device = accelerator.device
def repeat_generator():
while True:
yield from dataloader
iter_dataloader = iter(repeat_generator())
generation_config = GenerationConfig(
max_new_tokens=args.response_length,
temperature=(args.temperature + 1e-7),
top_k=0.0,
top_p=1.0,
do_sample=True,
)
accelerator.print("===training policy===")
start_time = time.time()
stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps)
approxkl_stats = torch.zeros(stats_shape, device=device)
pg_clipfrac_stats = torch.zeros(stats_shape, device=device)
pg_loss_stats = torch.zeros(stats_shape, device=device)
vf_loss_stats = torch.zeros(stats_shape, device=device)
vf_clipfrac_stats = torch.zeros(stats_shape, device=device)
entropy_stats = torch.zeros(stats_shape, device=device)
ratio_stats = torch.zeros(stats_shape, device=device)
model.train()
# trainer state initialization
self.state.global_step = 0
self.state.episode = 0
self.state.max_steps = args.num_total_batches
self.state.num_train_epochs = args.total_episodes / self.train_dataset_len
# Compute absolute values for logging, eval, and save if given as ratio
if args.logging_steps is not None:
if args.logging_steps < 1:
self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps)
else:
self.state.logging_steps = args.logging_steps
if args.eval_steps is not None:
if args.eval_steps < 1:
self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps)
else:
self.state.eval_steps = args.eval_steps
if args.save_steps is not None:
if args.save_steps < 1:
self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps)
else:
self.state.save_steps = args.save_steps
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
# backward compatibility
if self.is_deepspeed_enabled:
self.deepspeed = self.model
self.model_wrapped = self.model
for update in range(1, args.num_total_batches + 1):
self.state.episode += 1 * args.batch_size
data = next(iter_dataloader)
with torch.no_grad():
queries = data["input_ids"].to(device)
context_length = queries.shape[1]
responses = []
postprocessed_responses = []
logprobs = []
ref_logprobs = []
scores = []
sequence_lengths = []
values = []
with unwrap_model_for_generation(
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
) as unwrapped_model:
query_responses, logitss = batch_generation(
unwrapped_model.policy,
queries,
args.local_rollout_forward_batch_size,
processing_class.pad_token_id,
generation_config,
)
for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
query = queries[i : i + args.local_rollout_forward_batch_size]
query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
response = query_response[:, context_length:]
logits = logitss[i : i + args.local_rollout_forward_batch_size]
logprob = selective_log_softmax(logits, response)
del logits
torch.cuda.empty_cache()
if ref_policy is None:
with self.null_ref_context():
ref_output = forward(model.policy, query_response, processing_class.pad_token_id)
else:
ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
ref_logits = ref_output.logits[:, context_length - 1 : -1]
ref_logits /= args.temperature + 1e-7
ref_logprob = selective_log_softmax(ref_logits, response)
del ref_output, ref_logits
torch.cuda.empty_cache()
# Response Processing 1. truncate response after the first occurrence of `stop_token_id`
postprocessed_response = response
if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
postprocessed_response = truncate_response(
self.stop_token_id, processing_class.pad_token_id, response
)
# Response Processing 2. run reward model on the truncated responses
postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1
unwrapped_value_model = accelerator.unwrap_model(model).value_model
full_value, _, _ = get_reward(
unwrapped_value_model, query_response, processing_class.pad_token_id, context_length
)
value = full_value[:, context_length - 1 : -1].squeeze(-1)
_, score, _ = get_reward(
reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
)
responses.append(response)
postprocessed_responses.append(postprocessed_response)
logprobs.append(logprob)
ref_logprobs.append(ref_logprob)
sequence_lengths.append(sequence_length)
scores.append(score)
values.append(value)
responses = torch.cat(responses, 0)
postprocessed_responses = torch.cat(postprocessed_responses, 0)
logprobs = torch.cat(logprobs, 0)
ref_logprobs = torch.cat(ref_logprobs, 0)
sequence_lengths = torch.cat(sequence_lengths, 0)
scores = torch.cat(scores, 0)
values = torch.cat(values, 0)
del (logprob, ref_logprob, full_value, value, score, unwrapped_model)
torch.cuda.empty_cache()
gc.collect()
# Response Processing 3. Filter completion. Ensure that the sample contains stop_token_id
# Completions not passing that filter will receive a lower score.
contain_eos_token = torch.any(postprocessed_responses == self.processing_class.eos_token_id, dim=-1)
if self.args.missing_eos_penalty is not None:
scores[~contain_eos_token] -= self.args.missing_eos_penalty
# accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}")
# be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw
response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1)
padding_mask = response_idxs > sequence_lengths.unsqueeze(1)
logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB)
ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB)
sequence_lengths_p1 = sequence_lengths + 1
padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1))
values = torch.masked_fill(values, padding_mask_p1, 0)
# 4. compute rewards
# Formula used by http://joschu.net/blog/kl-approx.html for the k1 and k3 estimators
logr = ref_logprobs - logprobs
kl = -logr if args.kl_estimator == "k1" else (logr.exp() - 1) - logr # Else statement is k3
non_score_reward = -args.kl_coef * kl
rewards = non_score_reward.clone()
actual_start = torch.arange(rewards.size(0), device=rewards.device)
actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths)
rewards[[actual_start, actual_end]] += scores
# 5. whiten rewards
if args.whiten_rewards:
rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False)
rewards = torch.masked_fill(rewards, padding_mask_p1, 0)
# 6. compute advantages and returns
lastgaelam = 0
advantages_reversed = []
gen_length = responses.shape[1]
for t in reversed(range(gen_length)):
nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0
delta = rewards[:, t] + args.gamma * nextvalues - values[:, t]
lastgaelam = delta + args.gamma * args.lam * lastgaelam
advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1], axis=1)
returns = advantages + values
advantages = masked_whiten(advantages, ~padding_mask)
advantages = torch.masked_fill(advantages, padding_mask, 0)
torch.cuda.empty_cache()
# Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
for ppo_epoch_idx in range(args.num_ppo_epochs):
b_inds = np.random.permutation(args.local_batch_size)
minibatch_idx = 0
for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size):
mini_batch_end = mini_batch_start + args.local_mini_batch_size
mini_batch_inds = b_inds[mini_batch_start:mini_batch_end]
gradient_accumulation_idx = 0
for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size):
with accelerator.accumulate(model):
micro_batch_end = micro_batch_start + args.per_device_train_batch_size
micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end]
mb_advantage = advantages[micro_batch_inds]
mb_responses = responses[micro_batch_inds]
mb_query_responses = query_responses[micro_batch_inds]
mb_logprobs = logprobs[micro_batch_inds]
mb_return = returns[micro_batch_inds]
mb_values = values[micro_batch_inds]
output, vpred_temp = forward(model, mb_query_responses, processing_class.pad_token_id)
logits = output.logits[:, context_length - 1 : -1]
logits /= args.temperature + 1e-7
new_logprobs = selective_log_softmax(logits, mb_responses)
new_logprobs = torch.masked_fill(
new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB
)
vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1)
vpred = torch.masked_fill(vpred, padding_mask_p1[micro_batch_inds], 0)
vpredclipped = torch.clamp(
vpred,
mb_values - args.cliprange_value,
mb_values + args.cliprange_value,
)
vf_losses1 = torch.square(vpred - mb_return)
vf_losses2 = torch.square(vpredclipped - mb_return)
vf_loss_max = torch.max(vf_losses1, vf_losses2)
vf_loss = 0.5 * masked_mean(vf_loss_max, ~padding_mask_p1[micro_batch_inds])
vf_clipfrac = masked_mean(
(vf_losses2 > vf_losses1).float(), ~padding_mask_p1[micro_batch_inds]
)
logprobs_diff = new_logprobs - mb_logprobs
ratio = torch.exp(logprobs_diff)
pg_losses = -mb_advantage * ratio
pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
pg_loss_max = torch.max(pg_losses, pg_losses2)
pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds])
loss = pg_loss + args.vf_coef * vf_loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
with torch.no_grad():
pg_clipfrac = masked_mean(
(pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds]
)
prob_dist = torch.nn.functional.softmax(logits, dim=-1)
entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
approxkl = 0.5 * (logprobs_diff**2).mean()
approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
pg_clipfrac
)
pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss
vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss
vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
vf_clipfrac
)
entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean()
ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean()
gradient_accumulation_idx += 1
minibatch_idx += 1
# del everything and empty cache
# fmt: off
del (
output, vpred_temp, logits, new_logprobs, vpred, vpredclipped,
vf_losses1, vf_losses2, vf_loss, vf_clipfrac, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max,
pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_return,
mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs,
)
# fmt: on
torch.cuda.empty_cache()
with torch.no_grad():
mean_kl = kl.sum(1).mean()
mean_entropy = (-logprobs).sum(1).mean()
mean_non_score_reward = non_score_reward.sum(1).mean()
rlhf_reward = mean_non_score_reward + scores.mean()
eps = int(self.state.episode / (time.time() - start_time))
metrics = {}
metrics["eps"] = eps
metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item()
metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item()
metrics["objective/non_score_reward"] = (
self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
)
metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item()
metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item()
metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item()
metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item()
metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item()
metrics["loss/value_avg"] = self.accelerator.gather_for_metrics(vf_loss_stats).mean().item()
metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item()
metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item()
metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item()
metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item()
metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item()
metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
metrics["episode"] = self.state.episode
self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log
self.state.global_step += 1
self.log(metrics)
self.lr_scheduler.step()
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
if self.control.should_save:
self._save_checkpoint(model, trial=None)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward
torch.cuda.empty_cache()
gc.collect()
if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0:
self.generate_completions(sampling=True)
torch.cuda.empty_cache()
del (
query_responses,
responses,
postprocessed_responses,
logprobs,
ref_logprobs,
values,
sequence_lengths,
contain_eos_token,
sequence_lengths_p1,
response_idxs,
padding_mask,
padding_mask_p1,
rewards,
actual_start,
actual_end,
advantages,
returns,
)
torch.cuda.empty_cache()
# HF trainer specifics
self.control = self.callback_handler.on_train_end(args, self.state, self.control)
if self.control.should_save:
self._save_checkpoint(model, trial=None, metrics=None)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
逻辑很清晰(这里很明显看到了queries = data["input_ids"].to(device)
,即要求有input_ids
字段):
-
首先在policy上做生成,即采样得到
query_response
:query_responses, logitss = batch_generation( unwrapped_model.policy, queries, args.local_rollout_forward_batch_size, processing_class.pad_token_id, generation_config, )
-
然后把
query_response
拿去计算奖励和价值:full_value, _, _ = get_reward( unwrapped_value_model, query_response, processing_class.pad_token_id, context_length ) value = full_value[:, context_length - 1 : -1].squeeze(-1) _, score, _ = get_reward( reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length )
问题就出在这儿了,reward和policy使用了相同的processing_class
,即分词器,这个问题很难修改,除非你先把query_responses
还原成文本,然后再用reward_model
的分词器重新分一次词,否则这里就是强制要求它们的分词器相同的。
在PPO官方示例中,使用的策略模型是EleutherAI/pythia-1b-deduped
,而PPOConfig
默认的奖励模型是EleutherAI/pythia-160m
,这两个的分词器刚好一样,所以没有出问题。
如果现在想用其他策略模型(即换个LLM来训练),那么就必须找到和它相同基座的奖励模型,比如对于Qwen系列,TRL是提供了一个Qwen基座的的奖励模型的:trl-lib/Qwen2-0.5B-Reward
(https://huggingface.co/trl-lib/Qwen2-0.5B-Reward)
然后推荐一本XAI的好书:Explainable AI with Python,电子版挂在下面了:
通过网盘分享的文件:explainable-ai-with-python.pdf
链接: https://pan.baidu.com/s/13nS8mNMhif62o0F3cG0X5A?pwd=avdu 提取码: avdu 复制这段内容后打开百度网盘手机App,操作更方便哦
这个对XAI概括的很专业,之前提过的那篇ICLR做Learning Dynamics的工作arXiv:2407.10490,我一直觉得他没做完,它也是从MNIST入手,明明开头说要做来了一个新的sample后对老的sample有何影响,但是做到后面变成了来了一个新的sample后对这个模型新的sample的预测有何变化?这个就很奇怪,前后说不通了。
在这个书中DeepLift的例子中,给了一个MNIST非常好的一个解释案例:
利用这种可视化的图,来说明图片的每个像素对各个预测类别标签的贡献,这个就很有说服力,代码:
# -*- coding: utf8 -*-
from __future__ import print_function
import shape
import numpy
import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten, Conv2D, MaxPooling2D
from keras import backend as K
# DeepShap using DeepExplainer
# ... include code from https://github.com/keras-team/keras/blob/master/examples/mnist_cnn.py
batch_size = 128
num_classes = 10
epochs = 1
# input image dimensions
img_rows, img_cols = 28, 28
# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
if K.image_data_format() == "channel_first":
x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
input_shape = (1, img_rows, img_cols)
else:
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)
x_train = x_train.astype("float32")
x_test = x_test.astype("float32")
x_train /= 255
x_test /= 255
print("x_train shape:", x_train.shape)
print("x_test shape:", x_test.shape)
# convert class vectors to binary class metrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3), activation="relu", input_shape=input_shape))
model.add(Conv2D(64, (3, 3), activation="relu"))
model.add(MaxPooling2D(pool_size=(2, 2)))
# model.add(Dropout(.25))
model.add(Flatten())
model.add(Dense(128, activation="relu"))
model.add(Dropout(.5))
model.add(Dense(num_classes, activation="softmax"))
model.compile(loss=keras.losses.categorical_crossentropy, optimizer=keras.optimizers.Adadelta(), metrics=["accuracy"])
model.fix(x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=1, validation_data=(x_test, y_test))
score = model.evaluate(x_test, y_test, verbose=0)
print("Test loss:", score[0])
print("Test accuracy:", score[1])
# ----------------------------------------------------------------------
# DeepShap using DeepExplainer
# select a set of background examples to take an expectation over
background = x_train[numpy.random.choice(x_train.shape[0], 100, replace=False)]
# explain predictions of the model on four images
e = shap.DeepExplainer(model, background)
# ... or pass tensors directly
# e = shap.DeepExplainer((model.layers[0].input, model.layers[-1].output), background)
shap_values = e.shap_values(x_test[1:5])
# plot the feature attributions
shap.image_plot(shap_values, -x_test[1:5])