|
@@ -0,0 +1,190 @@
|
|
|
+"""
|
|
|
+基于 StateGraph 的、具备上下文感知能力的 React Agent 核心实现
|
|
|
+"""
|
|
|
+import logging
|
|
|
+import json
|
|
|
+import pandas as pd
|
|
|
+from typing import List, Optional, Dict, Any, Tuple
|
|
|
+from contextlib import AsyncExitStack
|
|
|
+
|
|
|
+from langchain_openai import ChatOpenAI
|
|
|
+from langchain_core.messages import HumanMessage, ToolMessage, BaseMessage
|
|
|
+from langgraph.graph import StateGraph, END
|
|
|
+from langgraph.prebuilt import ToolNode
|
|
|
+from redis.asyncio import Redis
|
|
|
+try:
|
|
|
+ from langgraph.checkpoint.redis import AsyncRedisSaver
|
|
|
+except ImportError:
|
|
|
+ AsyncRedisSaver = None
|
|
|
+
|
|
|
+# 从新模块导入配置、状态和工具
|
|
|
+from . import config
|
|
|
+from .state import AgentState
|
|
|
+from .sql_tools import sql_tools
|
|
|
+
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
+
|
|
|
+class CustomReactAgent:
|
|
|
+ """
|
|
|
+ 一个使用 StateGraph 构建的、具备上下文感知和持久化能力的 Agent。
|
|
|
+ """
|
|
|
+ def __init__(self):
|
|
|
+ """私有构造函数,请使用 create() 类方法来创建实例。"""
|
|
|
+ self.llm = None
|
|
|
+ self.tools = None
|
|
|
+ self.agent_executor = None
|
|
|
+ self.checkpointer = None
|
|
|
+ self._exit_stack = None
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ async def create(cls):
|
|
|
+ """异步工厂方法,创建并初始化 CustomReactAgent 实例。"""
|
|
|
+ instance = cls()
|
|
|
+ await instance._async_init()
|
|
|
+ return instance
|
|
|
+
|
|
|
+ async def _async_init(self):
|
|
|
+ """异步初始化所有组件。"""
|
|
|
+ logger.info("🚀 开始初始化 CustomReactAgent...")
|
|
|
+
|
|
|
+ # 1. 初始化 LLM
|
|
|
+ self.llm = ChatOpenAI(
|
|
|
+ api_key=config.QWEN_API_KEY,
|
|
|
+ base_url=config.QWEN_BASE_URL,
|
|
|
+ model=config.QWEN_MODEL,
|
|
|
+ temperature=0.1,
|
|
|
+ extra_body={"enable_thinking": False}
|
|
|
+ )
|
|
|
+ logger.info(f" LLM 已初始化,模型: {config.QWEN_MODEL}")
|
|
|
+
|
|
|
+ # 2. 绑定工具
|
|
|
+ self.tools = sql_tools
|
|
|
+ self.llm_with_tools = self.llm.bind_tools(self.tools)
|
|
|
+ logger.info(f" 已绑定 {len(self.tools)} 个工具。")
|
|
|
+
|
|
|
+ # 3. 初始化 Redis Checkpointer
|
|
|
+ if config.REDIS_ENABLED and AsyncRedisSaver is not None:
|
|
|
+ try:
|
|
|
+ self._exit_stack = AsyncExitStack()
|
|
|
+ checkpointer_manager = AsyncRedisSaver.from_conn_string(config.REDIS_URL)
|
|
|
+ self.checkpointer = await self._exit_stack.enter_async_context(checkpointer_manager)
|
|
|
+ await self.checkpointer.asetup()
|
|
|
+ logger.info(f" AsyncRedisSaver 持久化已启用: {config.REDIS_URL}")
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f" ❌ RedisSaver 初始化失败: {e}", exc_info=True)
|
|
|
+ if self._exit_stack:
|
|
|
+ await self._exit_stack.aclose()
|
|
|
+ self.checkpointer = None
|
|
|
+ else:
|
|
|
+ logger.warning(" Redis 持久化功能已禁用。")
|
|
|
+
|
|
|
+ # 4. 构建 StateGraph
|
|
|
+ self.agent_executor = self._create_graph()
|
|
|
+ logger.info(" StateGraph 已构建并编译。")
|
|
|
+ logger.info("✅ CustomReactAgent 初始化完成。")
|
|
|
+
|
|
|
+ async def close(self):
|
|
|
+ """清理资源,关闭 Redis 连接。"""
|
|
|
+ if self._exit_stack:
|
|
|
+ await self._exit_stack.aclose()
|
|
|
+ self._exit_stack = None
|
|
|
+ self.checkpointer = None
|
|
|
+ logger.info("✅ RedisSaver 资源已通过 AsyncExitStack 释放。")
|
|
|
+
|
|
|
+ def _create_graph(self):
|
|
|
+ """定义并编译 StateGraph。"""
|
|
|
+ builder = StateGraph(AgentState)
|
|
|
+
|
|
|
+ # 定义节点
|
|
|
+ builder.add_node("agent", self._agent_node)
|
|
|
+ builder.add_node("tools", ToolNode(self.tools))
|
|
|
+
|
|
|
+ # 定义边
|
|
|
+ builder.set_entry_point("agent")
|
|
|
+ builder.add_conditional_edges(
|
|
|
+ "agent",
|
|
|
+ self._should_continue,
|
|
|
+ {"continue": "tools", "end": END}
|
|
|
+ )
|
|
|
+ builder.add_edge("tools", "agent")
|
|
|
+
|
|
|
+ # 编译图,并传入 checkpointer
|
|
|
+ return builder.compile(checkpointer=self.checkpointer)
|
|
|
+
|
|
|
+ def _should_continue(self, state: AgentState) -> str:
|
|
|
+ """判断是否需要继续调用工具。"""
|
|
|
+ last_message = state["messages"][-1]
|
|
|
+ if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
|
|
|
+ return "end"
|
|
|
+ return "continue"
|
|
|
+
|
|
|
+ def _agent_node(self, state: AgentState) -> Dict[str, Any]:
|
|
|
+ """Agent 节点:调用 LLM 进行思考和决策。"""
|
|
|
+ logger.info(f"🧠 [Node] agent - Thread: {state['thread_id']}")
|
|
|
+
|
|
|
+ # LangChain 会自动从 state 中提取 messages 传递给 LLM
|
|
|
+ response = self.llm_with_tools.invoke(state["messages"])
|
|
|
+ logger.info(f" LLM 返回: {response.pretty_print()}")
|
|
|
+ return {"messages": [response]}
|
|
|
+
|
|
|
+ async def chat(self, message: str, user_id: str, thread_id: Optional[str] = None) -> Dict[str, Any]:
|
|
|
+ """
|
|
|
+ 处理用户聊天请求。
|
|
|
+ """
|
|
|
+ if not thread_id:
|
|
|
+ thread_id = f"{user_id}:{pd.Timestamp.now().strftime('%Y%m%d%H%M%S%f')}"
|
|
|
+ logger.info(f"🆕 新建会话,Thread ID: {thread_id}")
|
|
|
+
|
|
|
+ config = {"configurable": {"thread_id": thread_id}}
|
|
|
+
|
|
|
+ # 定义输入
|
|
|
+ inputs = {
|
|
|
+ "messages": [HumanMessage(content=message)],
|
|
|
+ "user_id": user_id,
|
|
|
+ "thread_id": thread_id,
|
|
|
+ }
|
|
|
+
|
|
|
+ final_state = None
|
|
|
+ try:
|
|
|
+ logger.info(f"🔄 开始处理 - Thread: {thread_id}, User: {user_id}, Message: '{message}'")
|
|
|
+ # 使用 ainvoke 来执行完整的图流程
|
|
|
+ final_state = await self.agent_executor.ainvoke(inputs, config)
|
|
|
+
|
|
|
+ if final_state and final_state.get("messages"):
|
|
|
+ answer = final_state["messages"][-1].content
|
|
|
+ logger.info(f"✅ 处理完成 - Thread: {thread_id}, Final Answer: '{answer}'")
|
|
|
+ return {"success": True, "answer": answer, "thread_id": thread_id}
|
|
|
+ else:
|
|
|
+ logger.error(f"❌ 处理异常结束,最终状态为空 - Thread: {thread_id}")
|
|
|
+ return {"success": False, "error": "Agent failed to produce a final answer.", "thread_id": thread_id}
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"❌ 处理过程中发生严重错误 - Thread: {thread_id}: {e}", exc_info=True)
|
|
|
+ return {"success": False, "error": str(e), "thread_id": thread_id}
|
|
|
+
|
|
|
+ async def get_conversation_history(self, thread_id: str) -> List[Dict[str, Any]]:
|
|
|
+ """从 checkpointer 获取指定线程的对话历史。"""
|
|
|
+ if not self.checkpointer:
|
|
|
+ return []
|
|
|
+
|
|
|
+ config = {"configurable": {"thread_id": thread_id}}
|
|
|
+ conversation_state = await self.checkpointer.get(config)
|
|
|
+
|
|
|
+ if not conversation_state:
|
|
|
+ return []
|
|
|
+
|
|
|
+ history = []
|
|
|
+ for msg in conversation_state['values'].get('messages', []):
|
|
|
+ if isinstance(msg, HumanMessage):
|
|
|
+ role = "human"
|
|
|
+ elif isinstance(msg, ToolMessage):
|
|
|
+ role = "tool"
|
|
|
+ else: # AIMessage
|
|
|
+ role = "ai"
|
|
|
+
|
|
|
+ history.append({
|
|
|
+ "type": role,
|
|
|
+ "content": msg.content,
|
|
|
+ "tool_calls": getattr(msg, 'tool_calls', None)
|
|
|
+ })
|
|
|
+ return history
|