sync_agent.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. """
  2. 同步版本的React Agent - 解决Vector搜索异步冲突问题
  3. 基于原有CustomReactAgent,但使用完全同步的实现
  4. """
  5. import json
  6. import sys
  7. import os
  8. from pathlib import Path
  9. from typing import List, Optional, Dict, Any
  10. import redis
  11. # 添加项目根目录到sys.path
  12. try:
  13. project_root = Path(__file__).parent.parent
  14. if str(project_root) not in sys.path:
  15. sys.path.insert(0, str(project_root))
  16. except Exception as e:
  17. pass
  18. from core.logging import get_react_agent_logger
  19. from langchain_openai import ChatOpenAI
  20. from langchain_core.messages import HumanMessage, ToolMessage, BaseMessage, SystemMessage, AIMessage
  21. from langgraph.graph import StateGraph, END
  22. from langgraph.prebuilt import ToolNode
  23. # 导入同步版本的依赖
  24. try:
  25. from . import config
  26. from .state import AgentState
  27. from .sql_tools import sql_tools
  28. except ImportError:
  29. import config
  30. from state import AgentState
  31. from sql_tools import sql_tools
  32. logger = get_react_agent_logger("SyncCustomReactAgent")
  33. class SyncCustomReactAgent:
  34. """
  35. 同步版本的React Agent
  36. 专门解决Vector搜索的异步事件循环冲突问题
  37. """
  38. def __init__(self):
  39. """私有构造函数,请使用 create() 类方法来创建实例。"""
  40. self.llm = None
  41. self.tools = None
  42. self.agent_executor = None
  43. self.checkpointer = None
  44. self.redis_client = None
  45. @classmethod
  46. def create(cls):
  47. """同步工厂方法,创建并初始化 SyncCustomReactAgent 实例。"""
  48. instance = cls()
  49. instance._sync_init()
  50. return instance
  51. def _sync_init(self):
  52. """同步初始化所有组件。"""
  53. logger.info("🚀 开始初始化 SyncCustomReactAgent...")
  54. # 1. 初始化同步Redis客户端(如果需要)
  55. try:
  56. self.redis_client = redis.from_url(config.REDIS_URL, decode_responses=True)
  57. self.redis_client.ping()
  58. logger.info(f" ✅ Redis连接成功: {config.REDIS_URL}")
  59. except Exception as e:
  60. logger.warning(f" ⚠️ Redis连接失败,将不使用checkpointer: {e}")
  61. self.redis_client = None
  62. # 2. 初始化 LLM(同步版本)
  63. self.llm = ChatOpenAI(
  64. api_key=config.QWEN_API_KEY,
  65. base_url=config.QWEN_BASE_URL,
  66. model=config.QWEN_MODEL,
  67. temperature=0.1,
  68. timeout=config.NETWORK_TIMEOUT,
  69. max_retries=0,
  70. streaming=False, # 关键:禁用流式处理
  71. extra_body={
  72. "enable_thinking": False, # 明确设置为False:非流式调用必须设为false
  73. "misc": {
  74. "ensure_ascii": False
  75. }
  76. }
  77. )
  78. logger.info(f" ✅ 同步LLM已初始化,模型: {config.QWEN_MODEL}")
  79. # 3. 绑定工具
  80. self.tools = sql_tools
  81. self.llm_with_tools = self.llm.bind_tools(self.tools)
  82. logger.info(f" ✅ 已绑定 {len(self.tools)} 个工具")
  83. # 4. 创建StateGraph(不使用checkpointer避免异步依赖)
  84. self.agent_executor = self._create_sync_graph()
  85. logger.info(" ✅ 同步StateGraph已创建")
  86. logger.info("✅ SyncCustomReactAgent 初始化完成")
  87. def _create_sync_graph(self):
  88. """创建同步的StateGraph"""
  89. graph = StateGraph(AgentState)
  90. # 添加同步节点
  91. graph.add_node("agent", self._sync_agent_node)
  92. graph.add_node("tools", ToolNode(self.tools))
  93. graph.add_node("prepare_tool_input", self._sync_prepare_tool_input_node)
  94. graph.add_node("update_state_after_tool", self._sync_update_state_after_tool_node)
  95. graph.add_node("format_final_response", self._sync_format_final_response_node)
  96. # 设置入口点
  97. graph.set_entry_point("agent")
  98. # 添加条件边
  99. graph.add_conditional_edges(
  100. "agent",
  101. self._sync_should_continue,
  102. {
  103. "tools": "prepare_tool_input",
  104. "end": "format_final_response"
  105. }
  106. )
  107. # 添加普通边
  108. graph.add_edge("prepare_tool_input", "tools")
  109. graph.add_edge("tools", "update_state_after_tool")
  110. graph.add_edge("update_state_after_tool", "agent")
  111. graph.add_edge("format_final_response", END)
  112. # 关键:使用同步编译,不传入checkpointer
  113. return graph.compile()
  114. def _sync_agent_node(self, state: AgentState) -> Dict[str, Any]:
  115. """同步Agent节点"""
  116. logger.info(f"🧠 [Sync Node] agent - Thread: {state.get('thread_id', 'unknown')}")
  117. messages_for_llm = state["messages"].copy()
  118. # 添加数据库范围提示词
  119. if isinstance(state["messages"][-1], HumanMessage):
  120. db_scope_prompt = self._get_database_scope_prompt()
  121. if db_scope_prompt:
  122. messages_for_llm.insert(0, SystemMessage(content=db_scope_prompt))
  123. logger.info(" ✅ 已添加数据库范围判断提示词")
  124. # 同步LLM调用
  125. response = self.llm_with_tools.invoke(messages_for_llm)
  126. return {"messages": [response]}
  127. def _sync_should_continue(self, state: AgentState):
  128. """同步条件判断"""
  129. messages = state["messages"]
  130. last_message = messages[-1]
  131. if not last_message.tool_calls:
  132. return "end"
  133. else:
  134. return "tools"
  135. def _sync_prepare_tool_input_node(self, state: AgentState) -> Dict[str, Any]:
  136. """同步准备工具输入节点"""
  137. logger.info(f"🔧 [Sync Node] prepare_tool_input - Thread: {state.get('thread_id', 'unknown')}")
  138. last_message = state["messages"][-1]
  139. if hasattr(last_message, 'tool_calls') and last_message.tool_calls:
  140. for tool_call in last_message.tool_calls:
  141. if tool_call.get('name') == 'generate_sql':
  142. # 注入历史消息
  143. history_messages = self._filter_and_format_history(state["messages"])
  144. if 'args' not in tool_call:
  145. tool_call['args'] = {}
  146. tool_call['args']['history_messages'] = history_messages
  147. logger.info(f" ✅ 为generate_sql注入了 {len(history_messages)} 条历史消息")
  148. return {"messages": [last_message]}
  149. def _sync_update_state_after_tool_node(self, state: AgentState) -> Dict[str, Any]:
  150. """同步更新工具执行后的状态"""
  151. logger.info(f"📝 [Sync Node] update_state_after_tool - Thread: {state.get('thread_id', 'unknown')}")
  152. last_message = state["messages"][-1]
  153. tool_name = last_message.name
  154. tool_output = last_message.content
  155. next_step = None
  156. if tool_name == 'generate_sql':
  157. tool_output_lower = tool_output.lower()
  158. if "failed" in tool_output_lower or "无法生成" in tool_output_lower or "失败" in tool_output_lower:
  159. next_step = 'answer_with_common_sense'
  160. else:
  161. next_step = 'valid_sql'
  162. elif tool_name == 'valid_sql':
  163. if "失败" in tool_output:
  164. next_step = 'analyze_validation_error'
  165. else:
  166. next_step = 'run_sql'
  167. elif tool_name == 'run_sql':
  168. next_step = 'summarize_final_answer'
  169. logger.info(f" Tool '{tool_name}' executed. Suggested next step: {next_step}")
  170. return {"suggested_next_step": next_step}
  171. def _sync_format_final_response_node(self, state: AgentState) -> Dict[str, Any]:
  172. """同步格式化最终响应节点"""
  173. logger.info(f"📄 [Sync Node] format_final_response - Thread: {state.get('thread_id', 'unknown')}")
  174. messages = state["messages"]
  175. last_message = messages[-1]
  176. # 构建最终响应
  177. final_response = last_message.content
  178. logger.info(f" ✅ 最终响应已准备完成")
  179. return {"final_answer": final_response}
  180. def _filter_and_format_history(self, messages: list) -> list:
  181. """过滤和格式化历史消息"""
  182. clean_history = []
  183. for msg in messages[:-1]: # 排除最后一条消息
  184. if isinstance(msg, HumanMessage):
  185. clean_history.append({"type": "human", "content": msg.content})
  186. elif isinstance(msg, AIMessage):
  187. clean_content = msg.content if not hasattr(msg, 'tool_calls') or not msg.tool_calls else ""
  188. if clean_content.strip():
  189. clean_history.append({"type": "ai", "content": clean_content})
  190. return clean_history
  191. def _get_database_scope_prompt(self) -> str:
  192. """获取数据库范围判断提示词"""
  193. return """你是一个专门处理高速公路收费数据查询的AI助手。在回答用户问题时,请首先判断这个问题是否可以通过查询数据库来回答。
  194. 数据库包含以下类型的数据:
  195. - 服务区信息(名称、位置、档口数量等)
  196. - 收费站数据
  197. - 车流量统计
  198. - 业务数据分析
  199. 如果用户的问题与这些数据相关,请使用工具生成SQL查询。
  200. 如果问题与数据库内容无关(如常识性问题、天气、新闻等),请直接用你的知识回答,不要尝试生成SQL。"""
  201. def chat(self, message: str, user_id: str, thread_id: Optional[str] = None) -> Dict[str, Any]:
  202. """
  203. 同步聊天方法 - 关键:使用 graph.invoke() 而不是 ainvoke()
  204. """
  205. if thread_id is None:
  206. import uuid
  207. thread_id = str(uuid.uuid4())
  208. # 构建输入
  209. inputs = {
  210. "messages": [HumanMessage(content=message)],
  211. "user_id": user_id,
  212. "thread_id": thread_id,
  213. "suggested_next_step": None
  214. }
  215. # 构建运行配置(不使用checkpointer)
  216. run_config = {
  217. "recursion_limit": config.RECURSION_LIMIT,
  218. }
  219. try:
  220. logger.info(f"🚀 开始同步处理用户消息: {message[:50]}...")
  221. # 关键:使用同步的 invoke() 方法
  222. final_state = self.agent_executor.invoke(inputs, run_config)
  223. logger.info(f"🔍 Final state keys: {list(final_state.keys())}")
  224. # 提取答案
  225. if final_state["messages"]:
  226. answer = final_state["messages"][-1].content
  227. else:
  228. answer = "抱歉,无法处理您的请求。"
  229. # 提取SQL数据(如果有)
  230. sql_data = self._extract_latest_sql_data(final_state["messages"])
  231. logger.info(f"✅ 同步处理完成 - Final Answer: '{answer[:100]}...'")
  232. # 构建返回结果
  233. result = {
  234. "success": True,
  235. "answer": answer,
  236. "thread_id": thread_id
  237. }
  238. # 只有当存在SQL数据时才添加到返回结果中
  239. if sql_data:
  240. try:
  241. # 尝试解析SQL数据
  242. sql_parsed = json.loads(sql_data)
  243. # 检查数据格式:run_sql工具返回的是数组格式 [{"col1":"val1"}]
  244. if isinstance(sql_parsed, list):
  245. # 数组格式:直接作为records使用
  246. result["api_data"] = {
  247. "response": answer,
  248. "records": sql_parsed,
  249. "react_agent_meta": {
  250. "thread_id": thread_id,
  251. "agent_version": "sync_react_v1"
  252. }
  253. }
  254. elif isinstance(sql_parsed, dict):
  255. # 字典格式:按原逻辑处理
  256. result["api_data"] = {
  257. "response": answer,
  258. "sql": sql_parsed.get("sql", ""),
  259. "records": sql_parsed.get("records", []),
  260. "react_agent_meta": {
  261. "thread_id": thread_id,
  262. "agent_version": "sync_react_v1"
  263. }
  264. }
  265. else:
  266. logger.warning(f"SQL数据格式未知: {type(sql_parsed)}")
  267. raise ValueError("Unknown SQL data format")
  268. except (json.JSONDecodeError, AttributeError, ValueError) as e:
  269. logger.warning(f"SQL数据格式处理失败: {str(e)}, 跳过API数据构建")
  270. else:
  271. result["api_data"] = {
  272. "response": answer,
  273. "react_agent_meta": {
  274. "thread_id": thread_id,
  275. "agent_version": "sync_react_v1"
  276. }
  277. }
  278. return result
  279. except Exception as e:
  280. logger.error(f"❌ 同步处理失败: {str(e)}", exc_info=True)
  281. return {
  282. "success": False,
  283. "error": f"同步处理失败: {str(e)}",
  284. "thread_id": thread_id,
  285. "retry_suggested": True
  286. }
  287. def _extract_latest_sql_data(self, messages: List[BaseMessage]) -> Optional[str]:
  288. """从消息历史中提取最近的run_sql执行结果(同步版本)"""
  289. logger.info("🔍 提取最新的SQL执行结果...")
  290. # 查找最后一个HumanMessage之后的SQL执行结果
  291. last_human_index = -1
  292. for i in range(len(messages) - 1, -1, -1):
  293. if isinstance(messages[i], HumanMessage):
  294. last_human_index = i
  295. break
  296. if last_human_index == -1:
  297. logger.info(" 未找到用户消息,跳过SQL数据提取")
  298. return None
  299. # 只在当前对话轮次中查找SQL结果
  300. current_conversation = messages[last_human_index:]
  301. logger.info(f" 当前对话轮次包含 {len(current_conversation)} 条消息")
  302. for msg in reversed(current_conversation):
  303. if isinstance(msg, ToolMessage) and msg.name == 'run_sql':
  304. logger.info(f" 找到当前对话轮次的run_sql结果: {msg.content[:100]}...")
  305. try:
  306. # 尝试解析JSON以验证格式
  307. parsed_data = json.loads(msg.content)
  308. # 重新序列化,确保中文字符正常显示
  309. formatted_content = json.dumps(parsed_data, ensure_ascii=False, separators=(',', ':'))
  310. logger.info(f" 已转换Unicode转义序列为中文字符")
  311. return formatted_content
  312. except json.JSONDecodeError:
  313. # 如果不是有效JSON,直接返回原内容
  314. logger.warning(f" SQL结果不是有效JSON格式,返回原始内容")
  315. return msg.content
  316. logger.info(" 当前对话轮次中未找到run_sql执行结果")
  317. return None