集成SwanLab与HuggingFace TRL:跟踪与优化强化学习实验

TRL (Transformers Reinforcement Learning,用强化学习训练Transformers模型) 是一个领先的Python库,旨在通过监督微调(SFT)、近端策略优化(PPO)和直接偏好优化(DPO)等先进技术,对基础模型进行训练后优化。TRL 建立在 🤗 Transformers 生态系统之上,支持多种模型架构和模态,并且能够在各种硬件配置上进行扩展。

logo

你可以使用Trl快速进行模型训练,同时使用SwanLab进行实验跟踪与可视化。

Demo

1. 引入SwanLabCallback

from swanlab.integration.transformers import SwanLabCallback

SwanLabCallback是适配于Transformers的日志记录类。

SwanLabCallback可以定义的参数有:

  • project、experiment_name、description 等与 swanlab.init 效果一致的参数, 用于SwanLab项目的初始化。
  • 你也可以在外部通过swanlab.init创建项目,集成会将实验记录到你在外部创建的项目中。

2. 传入Trainer

from swanlab.integration.transformers import SwanLabCallback
from trl import SFTConfig, SFTTrainer

...

# 实例化SwanLabCallback
swanlab_callback = SwanLabCallback(project="trl-visualization")

trainer = SFTTrainer(
    ...
    # 传入callbacks参数
    callbacks=[swanlab_callback],
)

trainer.train()

3. 完整案例代码

使用Qwen2.5-0.5B-Instruct模型,使用Capybara数据集进行SFT训练:

from trl import SFTConfig, SFTTrainer
from datasets import load_dataset
from swanlab.integration.transformers import SwanLabCallback

dataset = load_dataset("trl-lib/Capybara", split="train")

swanlab_callback = SwanLabCallback(
    project="trl-visualization",
    experiment_name="Qwen2.5-0.5B-SFT",
    description="测试使用trl框架sft训练"
)

training_args = SFTConfig(
    output_dir="Qwen/Qwen2.5-0.5B-SFT",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    num_train_epochs=1,
    logging_steps=20,
    learning_rate=2e-5,
    )

trainer = SFTTrainer(
    args=training_args,
    model="Qwen/Qwen2.5-0.5B-Instruct",
    train_dataset=dataset,
    callbacks=[swanlab_callback]
)

trainer.train()

DPO、GRPO、PPO等同理,只需要将SwanLabCallback传入对应的Trainer即可。

4. GUI效果展示

超参数自动记录:

ig-hf-trl-gui-2

指标记录:

ig-hf-trl-gui-1

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值