python import asyncio import os from openai import OpenAI from dotenv import load_dotenv from contextlib import AsyncExitStack from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client import json import aiohttp # 用于云端服务的 HTTP/WebSocket 连接 # 加载 .env 文件 load_dotenv() class MCPClient: def __init__(self): """初始化 MCP 客户端""" self.exit_stack = AsyncExitStack() self.api_key = os.getenv("API_KEY") # 读取 OpenAI API Key self.base_url = os.getenv("BASE_URL") # 读取 BASE URL self.model = os.getenv("MODEL") # 读取 model if not self.api_key: raise ValueError("未找到 API KEY. 请在 .env 文件中配置 API_KEY") self.client = OpenAI(api_key=self.api_key, base_url=self.base_url) self.sessions = {} # 存储多个服务端会话 self.tools_map = {} # 工具映射:工具名称 -> 服务端 ID async def load_servers_from_config(self, config_path: str): """ 从配置文件加载 MCP 服务 :param config_path: 配置文件路径 """ with open(config_path, 'r', encoding='utf-8') as f: config = json.load(f) servers = config.get("mcpServers", {}) for server_id, server_config in servers.items(): service_type = server_config.get("type", "local") if service_type in ["sse", "websocket"]: # 云端服务 await self.connect_to_cloud_server(server_id, server_config["url"], service_type) elif service_type == "local": # 本地服务 command = server_config["command"] args = server_config.get("args", []) env = server_config.get("env", None) await self.connect_to_local_server(server_id, command, args, env) async def connect_to_cloud_server(self, server_id: str, url: str, service_type: str): """ 连接到云端 MCP 服务 :param server_id: 服务端标识符 :param url: 云端服务的 URL :param service_type: 服务类型(sse 或 websocket) """ if server_id in self.sessions: raise ValueError(f"服务端 {server_id} 已经连接") if service_type == "websocket": # 使用 aiohttp 建立 WebSocket 连接 session = aiohttp.ClientSession() ws = await session.ws_connect(url) self.sessions[server_id] = {"session": ws, "type": "cloud-websocket"} print(f"已连接到云端 WebSocket 服务: {server_id}") elif service_type == "sse": # 使用 aiohttp 建立 SSE 连接 session = aiohttp.ClientSession() response = await session.get(url) if response.status != 200: raise ValueError(f"无法连接到 SSE 服务 {server_id},状态码: {response.status}") self.sessions[server_id] = {"session": response, "type": "cloud-sse"} print(f"已连接到云端 SSE 服务: {server_id}") else: raise ValueError(f"未知的服务类型: {service_type}") async def connect_to_local_server(self, server_id: str, command: str, args: list, env: dict): """ 连接到本地 MCP 服务 :param server_id: 服务端标识符 :param command: 本地服务的启动命令 :param args: 启动命令的参数 :param env: 环境变量 """ if server_id in self.sessions: raise ValueError(f"服务端 {server_id} 已经连接") print(f"正在连接本地服务: {server_id}, 命令: {command}, 参数: {args}, 环境变量: {env}") server_params = StdioServerParameters(command=command, args=args, env=env) # 启动 MCP 服务器并建立通信 stdio_transport = await self.exit_stack.enter_async_context( stdio_client(server_params)) stdio, write = stdio_transport session = await self.exit_stack.enter_async_context( ClientSession(stdio, write)) await session.initialize() self.sessions[server_id] = {"session": session, "type": "local"} print(f"已连接到本地 MCP 服务: {server_id}") async def handle_sse_messages(self, server_id: str): """ 处理 SSE 消息流 :param server_id: 服务端标识符 """ session_info = self.sessions.get(server_id) if not session_info or session_info["type"] != "cloud-sse": raise ValueError(f"服务端 {server_id} 不是有效的 SSE 服务") response = session_info["session"] try: async for line in response.content: # SSE 消息以换行符分隔 if line: message = line.decode("utf-8").strip() print(f"[SSE Message from {server_id}]: {message}") except Exception as e: print(f"SSE 消息处理错误: {e}") async def list_tools(self): """列出所有服务端的工具""" if not self.sessions: print("没有已连接的服务端") return print("已连接的服务端工具列表:") for server_id, session_info in self.sessions.items(): if session_info["type"] == "local": session = session_info["session"] response = await session.list_tools() for tool in response.tools: self.tools_map[tool.name] = server_id print(f"工具: {tool.name}, 来源服务端: {server_id}") async def process_query(self, query: str) -> str: """ 调用大模型处理用户查询,并根据返回的 tools 列表调用对应工具。 支持多次工具调用,直到所有工具调用完成。 """ messages = [{"role": "user", "content": query}] # 构建统一的工具列表 available_tools = [] for tool_name, server_id in self.tools_map.items(): session = self.sessions[server_id]["session"] response = await session.list_tools() for tool in response.tools: if tool.name == tool_name: available_tools.append({ "type": "function", "function": { "name": tool.name, "description": tool.description, "input_schema": tool.inputSchema } }) # print('整合的服务端工具列表:', available_tools) # 循环处理工具调用 while True: # 请求 OpenAI 模型处理 response = self.client.chat.completions.create( model=self.model, messages=messages, tools=available_tools ) # 处理返回的内容 content = response.choices[0] if content.finish_reason == "tool_calls": # 执行工具调用 for tool_call in content.message.tool_calls: tool_name = tool_call.function.name tool_args = json.loads(tool_call.function.arguments) # 根据工具名称找到对应的服务端 server_id = self.tools_map.get(tool_name) if not server_id: raise ValueError(f"未找到工具 {tool_name} 对应的服务端") session = self.sessions[server_id]["session"] result = await session.call_tool(tool_name, tool_args) print(f"\n\n[Calling tool {tool_name} on server {server_id} with args {tool_args}]\n\n") print(f"[Tool {tool_name} Result]: {result.content[0].text}") # 将工具调用的结果添加到 messages 中 messages.append({ "role": "tool", "content": result.content[0].text, "tool_call_id": tool_call.id, }) else: # 如果没有工具调用,返回最终的回复 return content.message.content async def call_tool(self, server_id: str, tool_name: str, tool_args: dict): """ 调用指定服务端的工具 :param server_id: 服务端标识符 :param tool_name: 工具名称 :param tool_args: 工具参数 """ session_info = self.sessions.get(server_id) if not session_info: raise ValueError(f"服务端 {server_id} 未连接") session_type = session_info["type"] if session_type == "cloud-websocket": # WebSocket 工具调用 ws = session_info["session"] await ws.send_json({"tool": tool_name, "args": tool_args}) response = await ws.receive_json() return response elif session_type == "cloud-sse": # SSE 工具调用(通过 HTTP 请求触发) response = await session_info["session"].post( f"{session_info['url']}/{tool_name}", json=tool_args ) if response.status != 200: raise ValueError(f"SSE 工具调用失败,状态码: {response.status}") return await response.json() elif session_type == "local": # 本地工具调用 session = session_info["session"] return await session.call_tool(tool_name, tool_args) else: raise ValueError(f"未知的会话类型: {session_type}") async def chat_loop(self): """运行交互式聊天循环""" print("MCP 客户端已启动!输入 'exit' 退出") while True: try: query = input("问: ").strip() if query.lower() == 'exit': break response = await self.process_query(query) print(f"AI回复: {response}") except Exception as e: print(f"发生错误: {str(e)}") async def clean(self): """清理所有资源""" await self.exit_stack.aclose() self.sessions.clear() self.tools_map.clear() async def main(): # 启动并初始化 MCP 客户端 client = MCPClient() try: # 从配置文件加载 MCP 服务 await client.load_servers_from_config("mcp_servers.json") # 列出 MCP 服务器上的工具 await client.list_tools() # 运行交互式聊天循环,处理用户对话 await client.chat_loop() finally: # 清理资源 await client.clean() if __name__ == "__main__": asyncio.run(main()) |