文章目录
这篇博客详细记录了基于 Hugging Face Transformers库 SFT (Supervised Fine-Tuning,有监督微调) Qwen2-7B-Instruct 的代码和相关细节,适用于大模型微调入门实践的读者。
1. SFT 代码
code:
# sft Qwen2-7B-Instruct
import os
import datasets
from datasets import load_dataset
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GenerationConfig, DataCollatorForSeq2Seq, Trainer, TrainingArguments
PaddingID = -100
def preprocess_inputs(examples, max_len=8192, overflow_strategy='truncate'):
"""
@function:预处理输入数据
examples:数据集
max_len:尽管 Qwen2-7B-Instruct 支持 131072 tokens,但最好不要设置为最大长度,否则显存占用将会非常大。
max_len 的设置可以通过统计数据集的 token 长度得到。具体方法:将所有数据输入到 qwen2 模型的 tokenizer,统计 tokenizer 的输出长度(最大,最小,平均)
overflow_strategy:'drop'表示丢弃,'truncate'表示截断
"""
# prompt_template:可以在 qwen2 的 huggingface 官方库的 demo 中使用 tokenizer.apply_chat_template 函数打印模型的 prompt template
# Qwen2-7B-Instruct huggingface 地址:https://huggingface.co/Qwen/Qwen2-7B-Instruct
prompt_template = '<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n'
system_prompt = "你是一个知识渊博的人,请根据问题做出全面且正确的回答。"
model_inputs = {
'input_ids': [], 'labels': [], 'input_len': [], 'output_len': []}
for i in range(len(examples['query'])):
prompt = prompt_template.format(
system_prompt=system_prompt,
user_prompt=examples['query'][i]
)
a_ids = tokenizer.encode(prompt)
b_ids = tokenizer.encode(f"{
examples['answer'][i]}", add_special_tokens=False) + [tokenizer.eos_token_id]
context_length = len(a_ids)
input_ids = a_ids + b_ids
if len(input_ids) > max_len and overflow_strategy == 'drop':
# 丢弃样本
input_ids = []
labels = []
else:
if max_len > len(input_ids):
"""
使用 -100 填充, 因为 torch.nn.CrossEntropyLoss 的 ignore_index=-100, 即 CrossEntropyLoss 会忽略标签为 -100 的值的 loss,只计算非填充部分的 loss
torch.nn.CrossEntropyLoss 官方文档:https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
"""
pad_length = max_len - len(input_ids)
labels = [PaddingID] * context_length + b_ids + [PaddingID] * pad_length
input_ids = input_ids + [tokenizer.pad_token_id] * pad_length