LLM - Qwen functioncalling 预测与训练

微信公众号:NLP万事通

环境准备

  • 本文使用cuda11.8,使用的python包如下
transformers==4.39.0
torch==2.0.1
deepspeed==0.14.4

主要代码

  1. 模型加载
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
device = "cuda" 
# 加载模型
model = AutoModelForCausalLM.from_pretrained(
    "Qwen2-7B-Instruct",
    torch_dtype="auto",
    device_map="auto"
)
# 加载tokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen2-7B-Instruct")

2.封装调用qwen2的方法,支持输入user_prompt 或 ReAct模版

def get_model_result(prompt, messages=None):
    if 'im' not in prompt:
        if messages is None:
            messages = [
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": prompt}
            ]
        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
    else:
        text = prompt
    print("text", text)
    model_inputs = tokenizer([text], return_tensors="pt").to(device)
    

    generated_ids1 = model.generate(
        model_inputs.input_ids,
        max_new_tokens=512
    )
    
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids1)
    ]
   

    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return response

tokenizer.apply_chat_template 方法可以直接生成模版

<|im_start|>system 
You are a helpful assistant.<|im_end|> 
<|im_start|>user 
介绍下中国<|im_end|> 
<|im_start|>assistant

想生成ReAct模版,需要使用openai的接口传入Function,由于vllm部署原因,本文我们自己构建

2.1 测试

response = get_model_result("介绍下中国")

response

中国,全称为中华人民共和国,是一个位于亚洲东部、太平洋西岸的国家。以下是关于中国的几个重要方面:

1. **地理**:中国领土面积约为960万平方公里,仅次于俄罗斯和加拿大,是世界上第三大国家。它东临太平洋,与日本、韩国、朝鲜等国隔海相望;北与蒙古接壤,东北与俄罗斯相邻;西北部与哈萨克斯坦、吉尔吉斯斯坦、塔吉克斯坦、阿富汗、巴基斯坦等国为邻;西南部与印度、尼泊尔、不丹、缅甸、老挝、越南等国相邻。中国地形复杂多样,包括高原、平原、山地、丘陵和盆地等多种地貌。

2. **人口**:截至2023年,中国的人口总数超过14亿,是世界上人口最多的国家。人口分布不均,主要集中在东部沿海地区。

3. **文化**:中国文化源远流长,拥有超过5000年的历史,深受儒家、道家、佛教等哲学思想的影响。中国有着丰富的文化遗产,包括长城、故宫、兵马俑等世界著名的历史遗迹。此外,中国的文学、艺术、音乐、书法等领域也成就显著,对中国乃至世界文化产生了深远影响。

4. **经济**:改革开放以来,中国经济实现了快速增长,成为全球第二大经济体。中国是世界上最大的制造业国家之一,同时也是农产品生产大国。近年来,服务业在中国经济中的比重持续上升,成为经济增长的重要驱动力。中国还是世界上最大的出口国之一,并在科技创新、基础设施建设等方面取得了显著成就。

5. **政治**:中国实行社会主义制度,由共产党领导的多党合作和政治协商制度是中国特色的政治体制。中华人民共和国主席、国务院总理等高级官员由全国人民代表大会选举产生。中国坚持走和平发展道路,致力于维护世界和平、促进共同发展。

6. **社会**:中国在教育、医疗、科技、环保等领域不断进步,同时也在努力解决贫困问题,推动城乡一体化发展,提高人民生活水平。随着城镇化进程加快,城市化进程显著,但同时也面临诸如环境污染、交通拥堵等问题。

7. **国际地位**:作为联合国安理会常任理事国之一,中国在国际事务中发挥着重要作用。中国积极参与全球治理,推动构建人类命运共同体,倡导多边主义,促进国际合作与发展。

3.定义两个function,查询和操作数据库

"""
用JSON格式模拟数据库
"""
class CourseDatabase:
    def __init__(self):
        self.database = {
            "大模型技术实战":{
                "课时": 200,
                "每周更新次数": 3,
                "每次更新小时": 2
            },
             "机器学习实战":{
                "课时": 230,
                "每周更新次数": 2,
                "每次更新小时": 1.5
            },
            "深度学习实战":{
                "课时": 150,
                "每周更新次数": 1,
                "每次更新小时": 3
            },
            "AI数据分析":{
                "课时": 10,
                "每周更新次数": 1,
                "每次更新小时": 1
            },
        }
    def course_query(self, course_name):
        return self.database.get(course_name, "目前没有该课程信息")

"""
定义数据库操作工具
"""
class CourseOperations:
    def __init__(self):
        self.db = CourseDatabase()
 
    def add_hours_to_course(self, param):
        course_name = param.get("course_name", "None")
        additional_hours = int(param.get("additional_hours", "None"))
        if course_name in self.db.database:
            self.db.database[course_name]['课时'] += additional_hours
            return f"课程 {course_name} 的课时已增加{additional_hours}小时。"
        else:
            return "课程不存在,无法添加课时"

3.1 function测试

course_ops = CourseOperations()
# 给某个课程增加课时
course_insert_test = course_ops.add_hours_to_course({"course_name": "大模型技术实战", "additional_hours":20 })

course_insert_test

课程 大模型技术实战 的课时已增加20小时。

3.2 定义大模型需要调用的工具库

  • 这是Qwen风格的function格式,需要严格按照此格式,否则后续容易报错
TOOLS = [
    {
        'name_for_human': '课程信息数据库',
        'name_for_model': 'CourseDatabase',
        'description_for_model': '课程信息数据库存储有各课程的详细信息,包括目前的上线课时,每周更新次数以及每次更新的小时数。通过输入课程名称,可以返回该课程的详细信息。',
        'parameters': [{
            'name': 'course_query',
            'description': '课程名称,所需查询信息的课程名称',
            'required': True,
            'schema': {
                'type': 'string'
            },
        }],
    },
        {
        'name_for_human': '课程操作工具',
        'name_for_model': 'CourseOperations',
        'description_for_model': '课程操作工具提供了对课程信息的添加操作,可以添加课程的详细信息,如每周更新次数,更新课时',
        'parameters': [{
            'name': 'add_hours_to_course',
            'description': '给指定的课程增加课时,需要课程名称和增加的课时数',
            'required': True,
            'schema': {
                'type': 'string',
                'properties': {
                    'course_name': {'type': 'string'},
                    'additional_hours': {'type': 'string'}
                },
                'required': ['course_name', 'additional_hours']
            },
        }],
    },
    # 其他工具的定义可以在这里继续添加
] 

3.3 将一个插件的关键信息拼接成一段文本的模板

TOOL_DESC = """{name_for_model}: Call this tool to interact with the {name_for_human} API. What is the {name_for_human} API useful for? {description_for_model} Parameters:{parameters}
"""
 
PROMPT_REACT = """
<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
Answer the following questions as best you con. You have access to the following
{tool_descs}
Use the following format:
Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [{tool_names}]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question
Begin!
Question: {query}<|im_end|>
<|im_start|>assistant
"""

4 生成function描述的方法

import json
def generate_action_prompt(query):
    """
    根据用户查询生成最终的动作提示字符串。
    函数内部直接引用全局变量 TOOLS, TOOL_DESC, 和 PROMPT_REACT.
    参数:
    - query: 用户的查询字符串。
    返回:
    - action_prompt: 格式化后的动作提示字符串。
    """
 
    tool_descs = []
    tool_names = []
 
    for info in TOOLS:
        tool_descs.append(
            TOOL_DESC.format(
                name_for_model = info['name_for_model'],
                name_for_human = info['name_for_human'],
                description_for_model = info['description_for_model'],
                parameters = json.dumps(info['parameters'], ensure_ascii=False),
            )
        )
        tool_names.append(info['name_for_model'])
 
    tool_descs_str = '\\n\\n'.join(tool_descs)
    tool_names_str = ','.join(tool_names)
 
    action_prompt = PROMPT_REACT.format(tool_descs=tool_descs_str, tool_names=tool_names_str, query=query)
    return action_prompt

4.1 测试function描述生成

query = "先帮我查询一下大模型技术实战这个课程目前更新了多少节,今晚我直播了一节新课,请你帮我更新一下"
prompt = generate_action_prompt(query)

prompt

"""
<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
Answer the following questions as best you con. You have access to the following
CourseDatabase: Call this tool to interact with the 课程信息数据库 API. What is the 课程信息数据库 API useful for? 课程信息数据库存储有各课程的详细信息,包括目前的上线课时,每周更新次数以及每次更新的小时数。通过输入课程名称,可以返回该课程的详细信息。 Parameters:[{"name": "course_query", "description": "课程名称,所需查询信息的课程名称", "required": true, "schema": {"type": "string"}}]

CourseOperations: Call this tool to interact with the 课程操作工具 API. What is the 课程操作工具 API useful for? 课程操作工具提供了对课程信息的添加操作,可以添加课程的详细信息,如每周更新次数,更新课时 Parameters:[{"name": "add_hours_to_course", "description": "给指定的课程增加课时,需要课程名称和增加的课时数", "required": true, "schema": {"type": "string", "properties": {"course_name": {"type": "string"}, "additional_hours": {"type": "string"}}, "required": ["course_name", "additional_hours"]}}]

Use the following format:
Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [CourseDatabase,CourseOperations]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question
Begin!
Question: 先帮我查询一下大模型技术实战这个课程目前更新了多少节,今晚我直播了一节新课,请你帮我更新一下<|im_end|>
<|im_start|>assistant
"""

5 todo: 插入stopwords,Qwen会在Observation之后出现幻觉,希望在Observation处停止,去调用call。

  • 用vllm调用可以传入stop_words,但是cuda11.8没有安装成功,这里直接去掉Observation后的内容
custom_stop_words = ["Observation:", "Observation:\\n", "..."]
stop_words_ids = [tokenizer.add_tokens([word]) for word in custom_stop_words]
encoded_word = [tokenizer.encode(word, add_special_tokens=False) for word in custom_stop_words]
encoded_stop_words_ids = sum(encoded_word, [])

encoded_stop_words_ids

[151646, 151647, 1112]

6 加入ReAct的prompt请求Qwen模型

response = get_model_result(prompt)

response

Thought: 需要使用CourseDatabase API查询课程信息,并使用CourseOperations API更新课程信息。
Action: CourseDatabase
Action Input: {"course_query": "大模型技术实战"}
Observation:

6.1 解析模型的ReAct输出文本提取名称及其参数。

def parse_plugin_action(text: str):
    """
    解析模型的ReAct输出文本提取名称及其参数。
    参数:
    - text: 模型ReAct提示的输出文本
    返回值:
    - action_name: 要调用的动作(方法)名称。
    - action_arguments: 动作(方法)的参数。
    """
    # 查找“Action:”和“Action Input:”的最后出现位置
    action_index = text.find('\\nAction:')
    action_input_index = text.find('\\nAction Input:')
    observation_index = text.find('\\nObservation:')
 
    # 如果文本中有“Action:”和“Action Input:”
    if 0 <= action_index < action_input_index:
        if observation_index < action_input_index:
            text = text.rstrip() + '\\nObservation:'
            observation_index = text.rfind('\\nObservation:')
 
    # 确保文本中同时存在“Action:”和“Action Input:”
    if 0 <= action_index < action_input_index < observation_index:
        # 提取“Action:”和“Action Input:”之间的文本为动作名称
        action_name = text[action_index + len('\\nAction:'):action_input_index].strip()
        # 提取“Action Input:”之后的文本为动作参数
        action_arguments = text[action_input_index + len('\\nAction Input:'):observation_index].strip()
        return action_name, action_arguments
 
    # 如果没有找到符合条件的文本,返回空字符串
    return '', ''

6.2 根据模型的ReAct输出执行相应的插件调用,并返回调用结果。

import json
import ipdb
def execute_plugin_from_react_output(response):
    """
    根据模型的ReAct输出执行相应的插件调用,并返回调用结果。
    参数:
    - response: 模型的ReAct输出字符串。
    返回:
    - result_dict: 包括状态码和插件调用结果的字典。
    """
    # 从模型的ReAct输出中提取函数名称及函数入参
    # ipdb.set_trace()
    plugin_configuration = parse_plugin_action(response)
    first_config_line = plugin_configuration[1:][0].split('\\n')[0]
    config_parameters = json.loads(first_config_line)
    result_dict = {"status_code": 200}
 
    for k, v in config_parameters.items():
        for TOOL in TOOLS:
            if k in TOOL["parameters"][0]['name']:
                # 通过eval函数执行存储在字符串中的python表达式,并返回表达式计算结果。其执行过程实质上是实例化类
                tool_instance = eval(TOOL["name_for_model"])()
                # 然后通过getattr函数传递对象和字符串形式的属性或方法名来动态的访问该属性和方法h
                tool_func = getattr(tool_instance, k)
                # 这一步实际上执行的过程就是:course_db,course_query('大模型技术实战')
                tool_result = tool_func(v)
                result_dict["result"] = tool_result
                return result_dict
 
    result_dict["status_code"] = 404
    result_dict["result"] = "未找到匹配的插件配置"
    return result_dict

-------------------------------------------------
测试:
tool_result = execute_plugin_from_react_output(response)
print(tool_result)

output:
{'status_code': 200, 'result': {'课时': 200, '每周更新次数': 3, '每次更新小时': 2}}

6.2.1 调用function测试

tool_result = execute_plugin_from_react_output(response)

tool_result

{'status_code': 200, 'result': {'课时': 200, '每周更新次数': 3, '每次更新小时': 2}}

6.3 生成新的prompt

response = response + " " + str(tool_result)
prompt = prompt + "\\n" + response

prompt

<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
Answer the following questions as best you con. You have access to the following
CourseDatabase: Call this tool to interact with the 课程信息数据库 API. What is the 课程信息数据库 API useful for? 课程信息数据库存储有各课程的详细信息,包括目前的上线课时,每周更新次数以及每次更新的小时数。通过输入课程名称,可以返回该课程的详细信息。 Parameters:[{"name": "course_query", "description": "课程名称,所需查询信息的课程名称", "required": true, "schema": {"type": "string"}}]

CourseOperations: Call this tool to interact with the 课程操作工具 API. What is the 课程操作工具 API useful for? 课程操作工具提供了对课程信息的添加操作,可以添加课程的详细信息,如每周更新次数,更新课时 Parameters:[{"name": "add_hours_to_course", "description": "给指定的课程增加课时,需要课程名称和增加的课时数", "required": true, "schema": {"type": "string", "properties": {"course_name": {"type": "string"}, "additional_hours": {"type": "string"}}, "required": ["course_name", "additional_hours"]}}]

Use the following format:
Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [CourseDatabase,CourseOperations]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question
Begin!
Question: 先帮我查询一下大模型技术实战这个课程目前更新了多少节,今晚我直播了一节新课,请你帮我更新一下<|im_end|>
<|im_start|>assistant

Thought: 需要使用 CourseDatabase API 查询课程信息,然后使用 CourseOperations API 更新课程信息。
Action: CourseDatabase
Action Input: {"course_query": "大模型技术实战"}
Observation: {'status_code': 200, 'result': {'课时': 200, '每周更新次数': 3, '每次更新小时': 2}}

6.4 再次请求Qwen

response = get_model_result(prompt)

response

Thought: 课程当前更新了200节,每周更新3次,每次更新2小时。
Action: CourseOperations
Action Input: {"add_hours_to_course": {"course_name": "大模型技术实战", "additional_hours": "1"}}
Observation: 

6.5 请求function

tool_result = execute_plugin_from_react_output(response)

tool_result

{'status_code': 200, 'result': '课程 大模型技术实战的课时已增加1小时。'

6.6 再次请求Qwen

response = response + " " + str(tool_result)
prompt = prompt + "\\n" + response
response = get_model_result(prompt)

response

Thought: 已成功更新课程“大模型技术实战”的课时。
Final Answer: 大模型技术实战这个课程目前更新了201节。您今晚直播的新课已经成功更新到课程中。

最后将Final Answer:后面的内容提取出来回复给用户作为输出

结论

      至此,我们已经知道了ReAct的构造方式及function的调度流程。Function的训练数据也同样遵循ReAct模版的构建方式,至于,和chatML模版的配比需要小伙伴自己尝试。

<think>我们正在讨论如何微调 deepseek-r1-distill-qwen模型。根据引用[1]和引用[2],我们可以使用LLaMA-Factory来微调这个模型,并且可以使用vllm进行推理。微调完成后,我们需要合并LoRA权重才能用vllm进行推理。步骤概述: 1.准备环境:安装必要的库,包括LLaMA-Factory和vllm2.准备数据集:根据任务需求准备训练和验证数据集。 3.配置微调参数:使用LLaMA-Factory提供的配置模板,设置模型路径、数据集路径、训练参数等。 4.开始微调:运行训练脚本。 5.合并权重:微调完成后,使用提供的脚本合并LoRA权重到基座模型中。6.使用vllm进行推理:加载合并后的模型进行推理服务。 详细步骤:###1.环境准备首先,确保已经安装了LLaMA-Factory和vllm。如果没有,可以通过以下命令安装:```bash#安装LLaMA-Factorygit clonehttps://github.com/hiyouga/LLaMA-Factory.gitcd LLaMA-Factorypipinstall -e .[metrics] #安装vllmpip installvllm``` ###2.准备数据集LLaMA-Factory支持多种数据集格式,如alpaca、json等。假设我们有一个名为`dataset.json`的文件,格式如下(以alpaca格式为例):```json[{"instruction":"解释深度学习。","input": "","output":"深度学习是机器学习的一个分支..."},...] ```将数据集放在合适的目录下,例如`data/dataset.json`。###3.配置微调参数LLaMA-Factory提供了训练脚本`src/train_bash.py`。我们可以通过命令行参数或配置文件来设置训练参数。以下是一个示例命令:```bashCUDA_VISIBLE_DEVICES=0 pythonsrc/train_bash.py\--stagesft \--do_train \--model_name_or_path/workspace/luoshiyong/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B \--dataset dataset\--templatedefault \--finetuning_typelora\--lora_targetq_proj,v_proj\--output_dir saves/deepseek/lora/sft \--overwrite_cache \--per_device_train_batch_size4\--gradient_accumulation_steps4\--lr_scheduler_type cosine\--logging_steps10\--save_steps1000 \--learning_rate5e-5 \--num_train_epochs3.0 \--plot_loss\--fp16``` 参数说明:- `--stagesft`:使用监督微调。- `--model_name_or_path`:基座模型路径。- `--dataset`:数据集名称(对应`data`目录下的文件名,不带后缀)。 -`--template`:模板,这里使用default。 -`--finetuning_typelora`:使用LoRA进行微调。 -`--lora_target`:指定要应用LoRA的模块,这里以QProj和VProj为例(具体目标模块需根据模型结构确定,对于DeepSeek-R1-Distill-Qwen-7B,可能需要参考其模型结构)。- `--output_dir`:输出目录,用于保存微调后的LoRA权重。 -其他训练参数如批次大小、学习率等可根据实际情况调整。注意:`lora_target`的设置需要根据模型结构来确定。对于DeepSeek-R1-Distill-Qwen-7B,我们可以查看其模型结构或参考类似模型(如Qwen)的设置。常见的设置是`q_proj, v_proj`,但具体可能需要尝试或查阅文档。 ###4.开始微调运行上述命令开始训练训练过程中会输出日志,包括损失值等。 ###5.合并权重训练完成后,LoRA权重保存在`output_dir`指定的目录中。为了使用vllm进行推理,我们需要将LoRA权重合并到基座模型中。使用引用[2]中提供的脚本: ```pythonfrom peft importPeftModelfrom transformersimport AutoModelForCausalLM#加载基座模型base_model =AutoModelForCausalLM.from_pretrained("/workspace/luoshiyong/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B")#加载LoRA适配器并合并merged_model= PeftModel.from_pretrained(base_model, "/workspace/luoshiyong/LLaMA-Factory/saves/deepseek/lora/sft") merged_model =merged_model.merge_and_unload()#保存合并后的模型merged_model.save_pretrained("/workspace/luoshiyong/deepseek-ai/lora_sft") ```注意:合并后的模型将保存为一个完整的模型,可以直接加载。 ###6.使用vllm进行推理合并后的模型可以使用vllm进行推理服务。启动服务命令如下: ```bashpython -m vllm.entrypoints.openai.api_server \--model /workspace/luoshiyong/deepseek-ai/lora_sft\--tensor-parallel-size2\--max-model-len32768``` 参数说明:- `--model`:合并后的模型路径。 -`--tensor-parallel-size`:张量并行数,根据GPU数量设置。- `--max-model-len`:最大模型长度。然后就可以通过OpenAIAPI的格式进行调用。 ###注意事项-微调需要大量计算资源,确保有足够的GPU内存。 -在微调前,建议先在小批量数据上测试,确保配置正确。-对于`lora_target`的设置,如果不确定,可以尝试使用`all`(如果LLaMA-Factory支持)或者参考类似模型的设置。 以上步骤提供了一个完整的微调流程。根据具体任务,可能需要调整训练参数和数据集。§§相关问题§§1.如何为DeepSeek-R1-Distill-Qwen-7B模型准备数据集? 2.微调DeepSeek-R1-Distill-Qwen-7B模型时,如何选择合适的学习率和批次大小?3.除了LoRA,还有哪些微调方法适用于DeepSeek-R1-Distill-Qwen-7B模型? 4.如何评估微调后的DeepSeek-R1-Distill-Qwen-7B模型性能?</think>### 微调 DeepSeek-R1-Distill-Qwen-7B 模型指南 #### 1. 环境准备 - 安装核心库: ```bash pip install vllm transformers peft accelerate datasets ``` - 确保 GPU 环境(建议 NVIDIA A100 80GB 或更高) - 准备模型权重: ```python from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-7B") ``` #### 2. 数据准备(示例) ```python from datasets import load_dataset dataset = load_dataset("json", data_files="your_data.json") # 自定义数据集格式 dataset = dataset.map(lambda x: {"text": f"指令: {x['instruction']}\n响应: {x['output']}"}) ``` #### 3. LoRA 微调(使用 LLaMA-Factory) ```bash # 克隆仓库 git clone https://github.com/hiyouga/LLaMA-Factory.git cd LLaMA-Factory ``` ```bash # 启动微调 (单卡示例) CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --stage sft \ --model_name_or_path deepseek-ai/DeepSeek-R1-Distill-Qwen-7B \ --do_train \ --dataset your_dataset \ --template default \ --finetuning_type lora \ --lora_target q_proj,v_proj \ --output_dir saves/lora_sft \ --per_device_train_batch_size 4 \ --gradient_accumulation_steps 4 \ --lr_scheduler_type cosine \ --logging_steps 10 \ --save_steps 1000 \ --learning_rate 5e-5 \ --num_train_epochs 3.0 \ --fp16 ``` **关键参数说明**: - `lora_target`:指定适配层(根据模型结构调整) - `per_device_train_batch_size`:根据 GPU 显存调整 - `learning_rate`:推荐 1e-5 到 5e-5 #### 4. 权重合并(微调后必需步骤) ```python from peft import PeftModel from transformers import AutoModelForCausalLM base_model = AutoModelForCausalLM.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-7B") merged_model = PeftModel.from_pretrained(base_model, "saves/lora_sft") merged_model = merged_model.merge_and_unload() # 合并LoRA权重 merged_model.save_pretrained("merged_model") # 保存完整模型[^2] ``` #### 5. 使用 vLLM 部署 ```bash python -m vllm.entrypoints.openai.api_server \ --model merged_model \ --tensor-parallel-size 2 \ --max-model-len 32768 \ --enforce-eager ``` **部署参数**: - `tensor-parallel-size`:GPU 并行数量 - `max-model-len`:最大上下文长度(需匹配训练长度) - `enforce-eager`:优化内存使用[^1] #### 6. 验证微调效果 ```python from vllm import LLM llm = LLM("merged_model") outputs = llm.generate(["你的测试指令"]) print(outputs[0].outputs[0].text) ``` ### 注意事项 1. **显存优化**: - 使用 `fp16` 或 `bf16` 精度 - 启用梯度检查点:`--gradient_checkpointing` 2. **数据格式**: - 确保指令-响应对格式统一 - 推荐使用 Alpaca 格式 3. **灾难性遗忘**: - 添加 10-20% 原始预训练数据 - 控制学习率不超过 5e-5 > 微调效果取决于数据质量和超参数调整,建议先用 5% 数据做测试训练[^2]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值