agent.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. """
  2. 基于 StateGraph 的、具备上下文感知能力的 React Agent 核心实现
  3. """
  4. import logging
  5. import json
  6. import pandas as pd
  7. from typing import List, Optional, Dict, Any, Tuple
  8. from contextlib import AsyncExitStack
  9. from langchain_openai import ChatOpenAI
  10. from langchain_core.messages import HumanMessage, ToolMessage, BaseMessage
  11. from langgraph.graph import StateGraph, END
  12. from langgraph.prebuilt import ToolNode
  13. from redis.asyncio import Redis
  14. try:
  15. from langgraph.checkpoint.redis import AsyncRedisSaver
  16. except ImportError:
  17. AsyncRedisSaver = None
  18. # 从新模块导入配置、状态和工具
  19. from . import config
  20. from .state import AgentState
  21. from .sql_tools import sql_tools
  22. logger = logging.getLogger(__name__)
  23. class CustomReactAgent:
  24. """
  25. 一个使用 StateGraph 构建的、具备上下文感知和持久化能力的 Agent。
  26. """
  27. def __init__(self):
  28. """私有构造函数,请使用 create() 类方法来创建实例。"""
  29. self.llm = None
  30. self.tools = None
  31. self.agent_executor = None
  32. self.checkpointer = None
  33. self._exit_stack = None
  34. @classmethod
  35. async def create(cls):
  36. """异步工厂方法,创建并初始化 CustomReactAgent 实例。"""
  37. instance = cls()
  38. await instance._async_init()
  39. return instance
  40. async def _async_init(self):
  41. """异步初始化所有组件。"""
  42. logger.info("🚀 开始初始化 CustomReactAgent...")
  43. # 1. 初始化 LLM
  44. self.llm = ChatOpenAI(
  45. api_key=config.QWEN_API_KEY,
  46. base_url=config.QWEN_BASE_URL,
  47. model=config.QWEN_MODEL,
  48. temperature=0.1,
  49. extra_body={"enable_thinking": False}
  50. )
  51. logger.info(f" LLM 已初始化,模型: {config.QWEN_MODEL}")
  52. # 2. 绑定工具
  53. self.tools = sql_tools
  54. self.llm_with_tools = self.llm.bind_tools(self.tools)
  55. logger.info(f" 已绑定 {len(self.tools)} 个工具。")
  56. # 3. 初始化 Redis Checkpointer
  57. if config.REDIS_ENABLED and AsyncRedisSaver is not None:
  58. try:
  59. self._exit_stack = AsyncExitStack()
  60. checkpointer_manager = AsyncRedisSaver.from_conn_string(config.REDIS_URL)
  61. self.checkpointer = await self._exit_stack.enter_async_context(checkpointer_manager)
  62. await self.checkpointer.asetup()
  63. logger.info(f" AsyncRedisSaver 持久化已启用: {config.REDIS_URL}")
  64. except Exception as e:
  65. logger.error(f" ❌ RedisSaver 初始化失败: {e}", exc_info=True)
  66. if self._exit_stack:
  67. await self._exit_stack.aclose()
  68. self.checkpointer = None
  69. else:
  70. logger.warning(" Redis 持久化功能已禁用。")
  71. # 4. 构建 StateGraph
  72. self.agent_executor = self._create_graph()
  73. logger.info(" StateGraph 已构建并编译。")
  74. logger.info("✅ CustomReactAgent 初始化完成。")
  75. async def close(self):
  76. """清理资源,关闭 Redis 连接。"""
  77. if self._exit_stack:
  78. await self._exit_stack.aclose()
  79. self._exit_stack = None
  80. self.checkpointer = None
  81. logger.info("✅ RedisSaver 资源已通过 AsyncExitStack 释放。")
  82. def _create_graph(self):
  83. """定义并编译 StateGraph。"""
  84. builder = StateGraph(AgentState)
  85. # 定义节点
  86. builder.add_node("agent", self._agent_node)
  87. builder.add_node("tools", ToolNode(self.tools))
  88. # 定义边
  89. builder.set_entry_point("agent")
  90. builder.add_conditional_edges(
  91. "agent",
  92. self._should_continue,
  93. {"continue": "tools", "end": END}
  94. )
  95. builder.add_edge("tools", "agent")
  96. # 编译图,并传入 checkpointer
  97. return builder.compile(checkpointer=self.checkpointer)
  98. def _should_continue(self, state: AgentState) -> str:
  99. """判断是否需要继续调用工具。"""
  100. last_message = state["messages"][-1]
  101. if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
  102. return "end"
  103. return "continue"
  104. def _agent_node(self, state: AgentState) -> Dict[str, Any]:
  105. """Agent 节点:调用 LLM 进行思考和决策。"""
  106. logger.info(f"🧠 [Node] agent - Thread: {state['thread_id']}")
  107. # LangChain 会自动从 state 中提取 messages 传递给 LLM
  108. response = self.llm_with_tools.invoke(state["messages"])
  109. logger.info(f" LLM 返回: {response.pretty_print()}")
  110. return {"messages": [response]}
  111. async def chat(self, message: str, user_id: str, thread_id: Optional[str] = None) -> Dict[str, Any]:
  112. """
  113. 处理用户聊天请求。
  114. """
  115. if not thread_id:
  116. thread_id = f"{user_id}:{pd.Timestamp.now().strftime('%Y%m%d%H%M%S%f')}"
  117. logger.info(f"🆕 新建会话,Thread ID: {thread_id}")
  118. config = {"configurable": {"thread_id": thread_id}}
  119. # 定义输入
  120. inputs = {
  121. "messages": [HumanMessage(content=message)],
  122. "user_id": user_id,
  123. "thread_id": thread_id,
  124. }
  125. final_state = None
  126. try:
  127. logger.info(f"🔄 开始处理 - Thread: {thread_id}, User: {user_id}, Message: '{message}'")
  128. # 使用 ainvoke 来执行完整的图流程
  129. final_state = await self.agent_executor.ainvoke(inputs, config)
  130. if final_state and final_state.get("messages"):
  131. answer = final_state["messages"][-1].content
  132. logger.info(f"✅ 处理完成 - Thread: {thread_id}, Final Answer: '{answer}'")
  133. return {"success": True, "answer": answer, "thread_id": thread_id}
  134. else:
  135. logger.error(f"❌ 处理异常结束,最终状态为空 - Thread: {thread_id}")
  136. return {"success": False, "error": "Agent failed to produce a final answer.", "thread_id": thread_id}
  137. except Exception as e:
  138. logger.error(f"❌ 处理过程中发生严重错误 - Thread: {thread_id}: {e}", exc_info=True)
  139. return {"success": False, "error": str(e), "thread_id": thread_id}
  140. async def get_conversation_history(self, thread_id: str) -> List[Dict[str, Any]]:
  141. """从 checkpointer 获取指定线程的对话历史。"""
  142. if not self.checkpointer:
  143. return []
  144. config = {"configurable": {"thread_id": thread_id}}
  145. conversation_state = await self.checkpointer.get(config)
  146. if not conversation_state:
  147. return []
  148. history = []
  149. for msg in conversation_state['values'].get('messages', []):
  150. if isinstance(msg, HumanMessage):
  151. role = "human"
  152. elif isinstance(msg, ToolMessage):
  153. role = "tool"
  154. else: # AIMessage
  155. role = "ai"
  156. history.append({
  157. "type": role,
  158. "content": msg.content,
  159. "tool_calls": getattr(msg, 'tool_calls', None)
  160. })
  161. return history