1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
| import os import asyncio from typing import Optional from contextlib import AsyncExitStack import json import sys from urllib.parse import urlparse
from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client from mcp.client.sse import sse_client
from openai import OpenAI from dotenv import load_dotenv
load_dotenv()
class MCPClient: def __init__(self): self.session: Optional[ClientSession] = None self.exit_stack = AsyncExitStack() self.openai = OpenAI( api_key=os.getenv("OPENAI_API_KEY"), base_url=os.getenv("OPENAI_BASE_URL") ) self.model = "gpt-4o"
def get_response(self, messages: list, tools: list): response = self.openai.chat.completions.create( model=self.model, max_tokens=1000, messages=messages, tools=tools, ) return response
async def get_tools(self): response = await self.session.list_tools() available_tools = [ { "type": "function", "function": { "name": tool.name, "description": tool.description, "parameters": tool.inputSchema, }, } for tool in response.tools ]
return available_tools
async def connect_to_server(self, server_path: str = None): """连接到 MCP 服务器 参数: server_path: 可以是以下三种形式之一: 1. HTTP(S) URL - 使用 SSE 客户端连接 2. 服务器脚本路径 (.py 或 .js) 3. None - 使用默认的 mcp_server_fetch """ try: if server_path and urlparse(server_path).scheme in ("http", "https"): print(f"正在连接到 SSE 服务器: {server_path}") sse_transport = await self.exit_stack.enter_async_context( sse_client(server_path) ) self.stdio, self.write = sse_transport
else: if server_path: is_python = server_path.endswith(".py") is_js = server_path.endswith(".js") if not (is_python or is_js): raise ValueError("服务器脚本必须是 .py 或 .js 文件")
command = "python" if is_python else "node" print(f"正在启动服务器: {command} {server_path}") server_params = StdioServerParameters( command=command, args=[server_path], env=None ) else: print("正在启动默认 MCP 服务器…") server_params = StdioServerParameters( command="uv", args=[ "run", "--with", "mcp[cli]", "mcp", "run", "/exports/git/baidu-map-mcp/src/baidu-map/python/map.py", ], env={"BAIDU_MAPS_API_KEY": "xxx"}, )
stdio_transport = await self.exit_stack.enter_async_context( stdio_client(server_params) ) self.stdio, self.write = stdio_transport
self.session = await self.exit_stack.enter_async_context( ClientSession(self.stdio, self.write) ) await self.session.initialize()
response = await self.session.list_tools() tools = response.tools print("\n连接到服务器,工具列表:", [tool.name for tool in tools]) print("服务器初始化完成")
except Exception as e: print(f"连接服务器时出错: {str(e)}") raise
async def process_query(self, query: str) -> str: """使用 OpenAI 和可用工具处理查询"""
messages = [{"role": "user", "content": query}]
available_tools = await self.get_tools() response = self.get_response(messages, available_tools)
tool_results = [] final_text = [] for choice in response.choices: message = choice.message is_function_call = message.tool_calls if not is_function_call: final_text.append(message.content) else: tool_name = message.tool_calls[0].function.name tool_args = json.loads(message.tool_calls[0].function.arguments) print(f"准备调用工具: {tool_name}") print(f"参数: {json.dumps(tool_args, ensure_ascii=False, indent=2)}") result = await self.session.call_tool(tool_name, tool_args) tool_results.append({"call": tool_name, "result": result}) final_text.append(f"[Calling tool {tool_name} with args {tool_args}]") if message.content and hasattr(message.content, "text"): messages.append({"role": "assistant", "content": message.content}) messages.append({"role": "user", "content": result.content}) response = self.get_response(messages, available_tools) if response.choices[0].message.content: final_text.append(response.choices[0].message.content)
return "\\n".join(final_text)
async def chat_loop(self): """运行交互式聊天循环(没有记忆)""" print("\\nMCP Client 启动!") print("输入您的查询或 'quit' 退出.")
while True: try: query = input("\\nQuery: ").strip()
if query.lower() == "quit": break
response = await self.process_query(query) print("\\n" + response)
except Exception as e: import traceback
traceback.print_exc() print(f"\\n错误: {str(e)}")
async def cleanup(self): """清理资源""" await self.exit_stack.aclose()
async def main(): """ 主函数:初始化并运行 MCP 客户端 支持三种模式: 1. python client.py <url> # 使用 SSE 连接 2. python client.py <path_to_server_script> # 使用自定义服务器脚本 3. python client.py # 使用默认服务器 """ client = MCPClient() try: server_path = sys.argv[1] if len(sys.argv) > 1 else None await client.connect_to_server(server_path) await client.chat_loop() finally: await client.cleanup()
if __name__ == "__main__": asyncio.run(main())
|