前言
之前发布的博客LangGraph开发Agent智能体应用【NL2SQL】-CSDN博客,留了一个问题,对于相对复杂的sql(leetcode中等难度的sql题),gpt4o就力不从心了。这篇文章来讲一下优化
什么是few-shot
使用这些少量的、调整后的样本对预训练模型进行微调
其实就是给LLM少量示例
关于few-shot的研究:
实现few-shot的方式
1.prompt:最简单的当然是在prompt上写几个例子,作为上下文,当LLM被问到类似的问题的时候,就会参照你的上下文中的例子。
2.RAG:如果你觉得上下文的token数量有限,不可能吧所有例子写在prompt中,可以通过RAG的形式,把各种场景的sql案例做成Wiki文档,通过嵌入模型转换成向量表示,存储在向量数据库中,用户提问的时候通过向量召回策略找到相应的知识作为上下文,同样也可以实现优化。
我们一般在测试环境用prompt优化,在生产环境用prompt+RAG的方式
如果在测试中,确认了prompt能实现优化,那在生成环境中只是对应的加了一层向量化召回操作而已,所以本文也只讲prompt优化的操作案例
PS:使用LangChain实现RAG,这篇文章中有完整代码:LangChain开发LLM应用【入门指南】_langchain 开发社区-CSDN博客
代码:用prompt实现few-shot优化
PS:下文代码,是对LangGraph开发Agent智能体应用【NL2SQL】-CSDN博客的改进优化,可能存在重复内容。
第一步:定义工具集合
LangChain 和 LangGraph是打通的(准确的说,LangGraph是LangChain生态的高级框架)
所以我们可以直接使用LangChain的工具集 SQLDatabaseToolkit
如果你愿意深入看看源码,就知道这个工具集里有四个工具:
执行sql:QuerySQLDataBaseTool
查看表详情:InfoSQLDatabaseTool
sql语法检查:QuerySQLCheckerTool
查看所有表:ListSQLDatabaseTool
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_openai import ChatOpenAI
from sqlalchemy import create_engine
from langchain_community.utilities import SQLDatabase
# 数据库连接信息
username = 'root'
password = 'password'
host = 'hostname'
port = '3306'
database = 'test'
engine = create_engine(f'mysql+mysqlconnector://{username}:{password}@{host}:{port}/{database}')
db = SQLDatabase(engine)
toolkit = SQLDatabaseToolkit(db=db, llm=ChatOpenAI(temperature=0))
context = toolkit.get_context()
tools = toolkit.get_tools()
第二步:定义LLM节点,并加入到图中
让LLM绑定工具,一定要绑定,就像你需要告诉LLM,可以使用哪些工具,LLM才会生成调用计划
prompt优化内容:
1.首先告诉agent它的定位是一个SQL编码助手
2.按照 问题、思路、答案 给他相应提示(我试了很多方式,这种方式效果最好,问题部分要包含表的DDL最佳)
3.告诉agent,你希望的输出形式
格式如下:(当然还有优化空间,期待你自己尝试更丰富的提示语)
"system",
"""你是一名精通 SQL 的编码助理。\n
这是参考文档: \n ------- \n
# 如何找到各部门最高工资的员工
## 问题:
{...}
## 思路:
{...}
## 答案:
{...}
\n ------- \n
根据以上提供的文档,作为参考生成sql查询数据以回复用户的问题 \n
用中文回复,并且最后以表格形式输出。\n
以下是用户问题:""",
代码如下:
from typing import Annotated
from langchain_openai import ChatOpenAI
from typing_extensions import TypedDict
from langgraph.graph import StateGraph
from langgraph.graph.message import add_messages,AnyMessage
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.prompts import ChatPromptTemplate
class State(TypedDict):
messages: Annotated[list[AnyMessage], add_messages]
graph_builder = StateGraph(State)
# expt_llm = "gpt-4-1106-preview"
expt_llm = "gpt-4o"
llm = ChatOpenAI(temperature=0, model=expt_llm)
class Assistant:
def __init__(self, runnable: Runnable):
self.runnable = runnable
def __call__(self, state: State, config: RunnableConfig):
while True:
passenger_id = config.get("passenger_id", None)
state = {**state, "user_info": passenger_id}
result = self.runnable.invoke(state)
# If the LLM happens to return an empty response, we will re-prompt it
# for an actual response.
if not result.tool_calls and (
not result.content
or isinstance(result.content, list)
and not result.content[0].get("text")
):
messages = state["messages"] + [("user", "Respond with a real output.")]
state = {**state, "messages": messages}
else:
break