【代码详解】大模型微调入门:SFT Qwen2-7B,基于 Hugging Face Transformers库

这篇博客详细记录了基于 Hugging Face Transformers库 SFT (Supervised Fine-Tuning,有监督微调) Qwen2-7B-Instruct 的代码和相关细节,适用于大模型微调入门实践的读者。

1. SFT 代码

code:

# sft Qwen2-7B-Instruct

import os
import datasets
from datasets import load_dataset
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GenerationConfig, DataCollatorForSeq2Seq, Trainer, TrainingArguments

PaddingID = -100

def preprocess_inputs(examples, max_len=8192, overflow_strategy='truncate'):
    """
    @function:预处理输入数据
    examples:数据集
    max_len:尽管 Qwen2-7B-Instruct 支持 131072 tokens,但最好不要设置为最大长度,否则显存占用将会非常大。
            max_len 的设置可以通过统计数据集的 token 长度得到。具体方法:将所有数据输入到 qwen2 模型的 tokenizer,统计 tokenizer 的输出长度(最大,最小,平均)
    overflow_strategy:'drop'表示丢弃,'truncate'表示截断
    """

    # prompt_template:可以在 qwen2 的 huggingface 官方库的 demo 中使用 tokenizer.apply_chat_template 函数打印模型的 prompt template
    # Qwen2-7B-Instruct huggingface 地址:https://huggingface.co/Qwen/Qwen2-7B-Instruct
    prompt_template = '<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n'
    system_prompt = "你是一个知识渊博的人,请根据问题做出全面且正确的回答。"

    model_inputs = {
   
   'input_ids': [], 'labels': [], 'input_len': [], 'output_len': []}
    for i in range(len(examples['query'])):
        prompt = prompt_template.format(
                    system_prompt=system_prompt,
                    user_prompt=examples['query'][i]
                )
        a_ids = tokenizer.encode(prompt)
        b_ids = tokenizer.encode(f"{
     
     examples['answer'][i]}", add_special_tokens=False) + [tokenizer.eos_token_id]
        context_length = len(a_ids)
        input_ids = a_ids + b_ids

        if len(input_ids) > max_len and overflow_strategy == 'drop':
            # 丢弃样本
            input_ids = []
            labels = []
        else:
            if max_len > len(input_ids):
                """
                使用 -100 填充, 因为 torch.nn.CrossEntropyLoss 的 ignore_index=-100, 即 CrossEntropyLoss 会忽略标签为 -100 的值的 loss,只计算非填充部分的 loss
                torch.nn.CrossEntropyLoss 官方文档:https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
                """
                pad_length = max_len - len(input_ids)
                labels = [PaddingID] * context_length + b_ids + [PaddingID] * pad_length
                input_ids = input_ids + [tokenizer.pad_token_id] * pad_length
         
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

ctrl A_ctrl C_ctrl V

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值