agent.py.backup 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. """
  2. 基于 StateGraph 的、具备上下文感知能力的 React Agent 核心实现
  3. """
  4. import logging
  5. import json
  6. import pandas as pd
  7. from typing import List, Optional, Dict, Any, Tuple
  8. from contextlib import AsyncExitStack
  9. from langchain_openai import ChatOpenAI
  10. from langchain_core.messages import HumanMessage, ToolMessage, BaseMessage, SystemMessage
  11. from langgraph.graph import StateGraph, END
  12. from langgraph.prebuilt import ToolNode
  13. from redis.asyncio import Redis
  14. try:
  15. from langgraph.checkpoint.redis import AsyncRedisSaver
  16. except ImportError:
  17. AsyncRedisSaver = None
  18. # 从新模块导入配置、状态和工具
  19. from . import config
  20. from .state import AgentState
  21. from .sql_tools import sql_tools
  22. from langchain_core.runnables import RunnablePassthrough
  23. logger = logging.getLogger(__name__)
  24. class CustomReactAgent:
  25. """
  26. 一个使用 StateGraph 构建的、具备上下文感知和持久化能力的 Agent。
  27. """
  28. def __init__(self):
  29. """私有构造函数,请使用 create() 类方法来创建实例。"""
  30. self.llm = None
  31. self.tools = None
  32. self.agent_executor = None
  33. self.checkpointer = None
  34. self._exit_stack = None
  35. @classmethod
  36. async def create(cls):
  37. """异步工厂方法,创建并初始化 CustomReactAgent 实例。"""
  38. instance = cls()
  39. await instance._async_init()
  40. return instance
  41. async def _async_init(self):
  42. """异步初始化所有组件。"""
  43. logger.info("🚀 开始初始化 CustomReactAgent...")
  44. # 1. 初始化 LLM
  45. self.llm = ChatOpenAI(
  46. api_key=config.QWEN_API_KEY,
  47. base_url=config.QWEN_BASE_URL,
  48. model=config.QWEN_MODEL,
  49. temperature=0.1,
  50. model_kwargs={
  51. "extra_body": {
  52. "enable_thinking": False,
  53. "misc": {
  54. "ensure_ascii": False
  55. }
  56. }
  57. }
  58. )
  59. logger.info(f" LLM 已初始化,模型: {config.QWEN_MODEL}")
  60. # 2. 绑定工具
  61. self.tools = sql_tools
  62. self.llm_with_tools = self.llm.bind_tools(self.tools)
  63. logger.info(f" 已绑定 {len(self.tools)} 个工具。")
  64. # 3. 初始化 Redis Checkpointer
  65. if config.REDIS_ENABLED and AsyncRedisSaver is not None:
  66. try:
  67. self._exit_stack = AsyncExitStack()
  68. checkpointer_manager = AsyncRedisSaver.from_conn_string(config.REDIS_URL)
  69. self.checkpointer = await self._exit_stack.enter_async_context(checkpointer_manager)
  70. await self.checkpointer.asetup()
  71. logger.info(f" AsyncRedisSaver 持久化已启用: {config.REDIS_URL}")
  72. except Exception as e:
  73. logger.error(f" ❌ RedisSaver 初始化失败: {e}", exc_info=True)
  74. if self._exit_stack:
  75. await self._exit_stack.aclose()
  76. self.checkpointer = None
  77. else:
  78. logger.warning(" Redis 持久化功能已禁用。")
  79. # 4. 构建 StateGraph
  80. self.agent_executor = self._create_graph()
  81. logger.info(" StateGraph 已构建并编译。")
  82. logger.info("✅ CustomReactAgent 初始化完成。")
  83. async def close(self):
  84. """清理资源,关闭 Redis 连接。"""
  85. if self._exit_stack:
  86. await self._exit_stack.aclose()
  87. self._exit_stack = None
  88. self.checkpointer = None
  89. logger.info("✅ RedisSaver 资源已通过 AsyncExitStack 释放。")
  90. def _create_graph(self):
  91. """定义并编译最终的、正确的 StateGraph 结构。"""
  92. builder = StateGraph(AgentState)
  93. # 定义所有需要的节点
  94. builder.add_node("agent", self._agent_node)
  95. builder.add_node("prepare_tool_input", self._prepare_tool_input_node)
  96. builder.add_node("tools", ToolNode(self.tools))
  97. builder.add_node("update_state_after_tool", self._update_state_after_tool_node)
  98. builder.add_node("format_final_response", self._format_final_response_node)
  99. # 建立正确的边连接
  100. builder.set_entry_point("agent")
  101. builder.add_conditional_edges(
  102. "agent",
  103. self._should_continue,
  104. {
  105. "continue": "prepare_tool_input",
  106. "end": "format_final_response"
  107. }
  108. )
  109. builder.add_edge("prepare_tool_input", "tools")
  110. builder.add_edge("tools", "update_state_after_tool")
  111. builder.add_edge("update_state_after_tool", "agent")
  112. builder.add_edge("format_final_response", END)
  113. return builder.compile(checkpointer=self.checkpointer)
  114. def _should_continue(self, state: AgentState) -> str:
  115. """判断是继续调用工具还是结束。"""
  116. last_message = state["messages"][-1]
  117. if hasattr(last_message, "tool_calls") and last_message.tool_calls:
  118. return "continue"
  119. return "end"
  120. def _agent_node(self, state: AgentState) -> Dict[str, Any]:
  121. """Agent 节点:只负责调用 LLM 并返回其输出。"""
  122. logger.info(f"🧠 [Node] agent - Thread: {state['thread_id']}")
  123. messages_for_llm = list(state["messages"])
  124. if state.get("suggested_next_step"):
  125. instruction = f"提示:建议下一步使用工具 '{state['suggested_next_step']}'。"
  126. messages_for_llm.append(SystemMessage(content=instruction))
  127. response = self.llm_with_tools.invoke(messages_for_llm)
  128. logger.info(f" LLM Response: {response.pretty_print()}")
  129. # 只返回消息,不承担其他职责
  130. return {"messages": [response]}
  131. def _prepare_tool_input_node(self, state: AgentState) -> Dict[str, Any]:
  132. """
  133. 信息组装节点:为需要上下文的工具注入历史消息。
  134. """
  135. logger.info(f"🛠️ [Node] prepare_tool_input - Thread: {state['thread_id']}")
  136. last_message = state["messages"][-1]
  137. if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
  138. return {"messages": [last_message]}
  139. # 创建一个新的 AIMessage 来替换,避免直接修改 state 中的对象
  140. new_tool_calls = []
  141. for tool_call in last_message.tool_calls:
  142. if tool_call["name"] == "generate_sql":
  143. logger.info(" 检测到 generate_sql 调用,注入历史消息。")
  144. # 复制一份以避免修改原始 tool_call
  145. modified_args = tool_call["args"].copy()
  146. # 将消息对象列表转换为可序列化的字典列表
  147. serializable_history = []
  148. for msg in state["messages"]:
  149. serializable_history.append({
  150. "type": msg.type,
  151. "content": msg.content
  152. })
  153. modified_args["history_messages"] = serializable_history
  154. logger.info(f" 注入了 {len(serializable_history)} 条历史消息")
  155. new_tool_calls.append({
  156. "name": tool_call["name"],
  157. "args": modified_args,
  158. "id": tool_call["id"],
  159. })
  160. else:
  161. new_tool_calls.append(tool_call)
  162. # 用包含修改后参数的新消息替换掉原来的
  163. last_message.tool_calls = new_tool_calls
  164. return {"messages": [last_message]}
  165. def _update_state_after_tool_node(self, state: AgentState) -> Dict[str, Any]:
  166. """在工具执行后,更新 suggested_next_step 并清理参数。"""
  167. logger.info(f"📝 [Node] update_state_after_tool - Thread: {state['thread_id']}")
  168. last_tool_message = state['messages'][-1]
  169. tool_name = last_tool_message.name
  170. tool_output = last_tool_message.content
  171. next_step = None
  172. if tool_name == 'generate_sql':
  173. if "失败" in tool_output or "无法生成" in tool_output:
  174. next_step = 'answer_with_common_sense'
  175. else:
  176. next_step = 'valid_sql'
  177. # 🎯 清理 generate_sql 的 history_messages 参数,设置为空字符串
  178. self._clear_history_messages_parameter(state['messages'])
  179. elif tool_name == 'valid_sql':
  180. if "失败" in tool_output:
  181. next_step = 'analyze_validation_error'
  182. else:
  183. next_step = 'run_sql'
  184. elif tool_name == 'run_sql':
  185. next_step = 'summarize_final_answer'
  186. logger.info(f" Tool '{tool_name}' executed. Suggested next step: {next_step}")
  187. return {"suggested_next_step": next_step}
  188. def _clear_history_messages_parameter(self, messages: List[BaseMessage]) -> None:
  189. """
  190. 将 generate_sql 工具的 history_messages 参数设置为空字符串
  191. """
  192. for message in messages:
  193. if hasattr(message, "tool_calls") and message.tool_calls:
  194. for tool_call in message.tool_calls:
  195. if tool_call["name"] == "generate_sql" and "history_messages" in tool_call["args"]:
  196. tool_call["args"]["history_messages"] = ""
  197. logger.info(f" 已将 generate_sql 的 history_messages 设置为空字符串")
  198. def _format_final_response_node(self, state: AgentState) -> Dict[str, Any]:
  199. """最终输出格式化节点。"""
  200. logger.info(f"🎨 [Node] format_final_response - Thread: {state['thread_id']}")
  201. last_message = state['messages'][-1]
  202. last_message.content = f"[Formatted Output]\n{last_message.content}"
  203. return {"messages": [last_message]}
  204. async def chat(self, message: str, user_id: str, thread_id: Optional[str] = None) -> Dict[str, Any]:
  205. """
  206. 处理用户聊天请求。
  207. """
  208. if not thread_id:
  209. thread_id = f"{user_id}:{pd.Timestamp.now().strftime('%Y%m%d%H%M%S%f')}"
  210. logger.info(f"🆕 新建会话,Thread ID: {thread_id}")
  211. config = {
  212. "configurable": {
  213. "thread_id": thread_id,
  214. }
  215. }
  216. inputs = {
  217. "messages": [HumanMessage(content=message)],
  218. "user_id": user_id,
  219. "thread_id": thread_id,
  220. "suggested_next_step": None,
  221. }
  222. try:
  223. final_state = await self.agent_executor.ainvoke(inputs, config)
  224. answer = final_state["messages"][-1].content
  225. logger.info(f"✅ 处理完成 - Final Answer: '{answer}'")
  226. return {"success": True, "answer": answer, "thread_id": thread_id}
  227. except Exception as e:
  228. logger.error(f"❌ 处理过程中发生严重错误 - Thread: {thread_id}: {e}", exc_info=True)
  229. return {"success": False, "error": str(e), "thread_id": thread_id}
  230. async def get_conversation_history(self, thread_id: str) -> List[Dict[str, Any]]:
  231. """从 checkpointer 获取指定线程的对话历史。"""
  232. if not self.checkpointer:
  233. return []
  234. config = {"configurable": {"thread_id": thread_id}}
  235. conversation_state = await self.checkpointer.get(config)
  236. if not conversation_state:
  237. return []
  238. history = []
  239. for msg in conversation_state['values'].get('messages', []):
  240. if isinstance(msg, HumanMessage):
  241. role = "human"
  242. elif isinstance(msg, ToolMessage):
  243. role = "tool"
  244. else: # AIMessage
  245. role = "ai"
  246. history.append({
  247. "type": role,
  248. "content": msg.content,
  249. "tool_calls": getattr(msg, 'tool_calls', None)
  250. })
  251. return history