123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289 |
- """
- 基于 StateGraph 的、具备上下文感知能力的 React Agent 核心实现
- """
- import logging
- import json
- import pandas as pd
- from typing import List, Optional, Dict, Any, Tuple
- from contextlib import AsyncExitStack
- from langchain_openai import ChatOpenAI
- from langchain_core.messages import HumanMessage, ToolMessage, BaseMessage
- from langgraph.graph import StateGraph, END
- from langgraph.prebuilt import ToolNode
- from redis.asyncio import Redis
- try:
- from langgraph.checkpoint.redis import AsyncRedisSaver
- except ImportError:
- AsyncRedisSaver = None
- # 从新模块导入配置、状态和工具
- from . import config
- from .state import AgentState
- from .sql_tools import sql_tools
- logger = logging.getLogger(__name__)
- class CustomReactAgent:
- """
- 一个使用 StateGraph 构建的、具备上下文感知和持久化能力的 Agent。
- """
- def __init__(self):
- """私有构造函数,请使用 create() 类方法来创建实例。"""
- self.llm = None
- self.tools = None
- self.agent_executor = None
- self.checkpointer = None
- self._exit_stack = None
- @classmethod
- async def create(cls):
- """异步工厂方法,创建并初始化 CustomReactAgent 实例。"""
- instance = cls()
- await instance._async_init()
- return instance
- async def _async_init(self):
- """异步初始化所有组件。"""
- logger.info("🚀 开始初始化 CustomReactAgent...")
- # 1. 初始化 LLM
- self.llm = ChatOpenAI(
- api_key=config.QWEN_API_KEY,
- base_url=config.QWEN_BASE_URL,
- model=config.QWEN_MODEL,
- temperature=0.1,
- model_kwargs={
- "extra_body": {
- "enable_thinking": False,
- "misc": {
- "ensure_ascii": False
- }
- }
- }
- )
- logger.info(f" LLM 已初始化,模型: {config.QWEN_MODEL}")
- # 2. 绑定工具
- self.tools = sql_tools
- self.llm_with_tools = self.llm.bind_tools(self.tools)
- logger.info(f" 已绑定 {len(self.tools)} 个工具。")
- # 3. 初始化 Redis Checkpointer
- if config.REDIS_ENABLED and AsyncRedisSaver is not None:
- try:
- self._exit_stack = AsyncExitStack()
- checkpointer_manager = AsyncRedisSaver.from_conn_string(config.REDIS_URL)
- self.checkpointer = await self._exit_stack.enter_async_context(checkpointer_manager)
- await self.checkpointer.asetup()
- logger.info(f" AsyncRedisSaver 持久化已启用: {config.REDIS_URL}")
- except Exception as e:
- logger.error(f" ❌ RedisSaver 初始化失败: {e}", exc_info=True)
- if self._exit_stack:
- await self._exit_stack.aclose()
- self.checkpointer = None
- else:
- logger.warning(" Redis 持久化功能已禁用。")
- # 4. 构建 StateGraph
- self.agent_executor = self._create_graph()
- logger.info(" StateGraph 已构建并编译。")
- logger.info("✅ CustomReactAgent 初始化完成。")
- async def close(self):
- """清理资源,关闭 Redis 连接。"""
- if self._exit_stack:
- await self._exit_stack.aclose()
- self._exit_stack = None
- self.checkpointer = None
- logger.info("✅ RedisSaver 资源已通过 AsyncExitStack 释放。")
- def _create_graph(self):
- """定义并编译 StateGraph。"""
- builder = StateGraph(AgentState)
- # 定义节点
- builder.add_node("agent", self._agent_node)
- builder.add_node("prepare_tool_input", self._prepare_tool_input_node)
- builder.add_node("tools", ToolNode(self.tools))
- builder.add_node("update_state_after_tool", self._update_state_after_tool_node)
- builder.add_node("format_final_response", self._format_final_response_node)
- # 定义边
- builder.set_entry_point("agent")
- builder.add_conditional_edges(
- "agent",
- self._should_continue,
- {
- "continue": "prepare_tool_input",
- "end": "format_final_response"
- }
- )
- builder.add_edge("prepare_tool_input", "tools")
- builder.add_edge("tools", "update_state_after_tool")
- builder.add_edge("update_state_after_tool", "agent")
- builder.add_edge("format_final_response", END)
- # 编译图,并传入 checkpointer
- return builder.compile(checkpointer=self.checkpointer)
- def _should_continue(self, state: AgentState) -> str:
- """判断是继续调用工具还是结束。"""
- last_message = state["messages"][-1]
- if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
- return "end"
- return "continue"
- def _agent_node(self, state: AgentState) -> Dict[str, Any]:
- """Agent 节点:调用 LLM 进行思考和决策。"""
- logger.info(f"🧠 [Node] agent - Thread: {state['thread_id']}")
-
- messages_for_llm = list(state["messages"])
- if state.get("suggested_next_step"):
- instruction = f"基于之前的步骤,强烈建议你下一步执行 '{state['suggested_next_step']}' 操作。"
- # 为了避免污染历史,可以考虑不同的注入方式,但这里为了简单直接添加
- messages_for_llm.append(HumanMessage(content=instruction, name="system_instruction"))
- response = self.llm_with_tools.invoke(messages_for_llm)
- logger.info(f" LLM 返回: {response.pretty_print()}")
- return {"messages": [response]}
-
- def _prepare_tool_input_node(self, state: AgentState) -> Dict[str, Any]:
- """信息组装节点:为需要上下文的工具注入历史消息。"""
- logger.info(f"🛠️ [Node] prepare_tool_input - Thread: {state['thread_id']}")
-
- last_message = state["messages"][-1]
- if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
- return {}
- # 创建一个新的 AIMessage 来替换,避免直接修改 state 中的对象
- new_tool_calls = []
- for tool_call in last_message.tool_calls:
- if tool_call["name"] == "generate_sql":
- logger.info(" 检测到 generate_sql 调用,注入可序列化的历史消息。")
- # 复制一份以避免修改原始 tool_call
- modified_args = tool_call["args"].copy()
-
- # 将消息对象列表转换为可序列化的字典列表
- serializable_history = []
- for msg in state["messages"]:
- serializable_history.append({
- "type": msg.type,
- "content": msg.content
- })
-
- modified_args["history_messages"] = serializable_history
- new_tool_calls.append({
- "name": tool_call["name"],
- "args": modified_args,
- "id": tool_call["id"],
- })
- else:
- new_tool_calls.append(tool_call)
-
- # 用包含修改后参数的新消息替换掉原来的
- last_message.tool_calls = new_tool_calls
- return {"messages": [last_message]}
- def _update_state_after_tool_node(self, state: AgentState) -> Dict[str, Any]:
- """流程建议与错误处理节点:在工具执行后更新状态。"""
- logger.info(f"📝 [Node] update_state_after_tool - Thread: {state['thread_id']}")
-
- last_tool_message = state['messages'][-1]
- tool_name = last_tool_message.name
- tool_output = last_tool_message.content
- next_step = None
- if tool_name == 'generate_sql':
- if "失败" in tool_output or "无法生成" in tool_output:
- next_step = 'answer_with_common_sense'
- logger.warning(f" generate_sql 失败,建议下一步: {next_step}")
- else:
- next_step = 'valid_sql'
- logger.info(f" generate_sql 成功,建议下一步: {next_step}")
-
- elif tool_name == 'valid_sql':
- if "失败" in tool_output:
- next_step = 'analyze_validation_error'
- logger.warning(f" valid_sql 失败,建议下一步: {next_step}")
- else:
- next_step = 'run_sql'
- logger.info(f" valid_sql 成功,建议下一步: {next_step}")
- elif tool_name == 'run_sql':
- next_step = 'summarize_final_answer'
- logger.info(f" run_sql 执行完毕,建议下一步: {next_step}")
- return {"suggested_next_step": next_step}
- def _format_final_response_node(self, state: AgentState) -> Dict[str, Any]:
- """最终输出格式化节点(当前为占位符)。"""
- logger.info(f"🎨 [Node] format_final_response - Thread: {state['thread_id']} - 准备格式化最终输出...")
- # 这里可以添加一个标记,表示这是格式化后的输出
- last_message = state['messages'][-1]
- formatted_content = f"[Formatted Output]\n{last_message.content}"
- last_message.content = formatted_content
- return {"messages": [last_message]}
- async def chat(self, message: str, user_id: str, thread_id: Optional[str] = None) -> Dict[str, Any]:
- """
- 处理用户聊天请求。
- """
- if not thread_id:
- thread_id = f"{user_id}:{pd.Timestamp.now().strftime('%Y%m%d%H%M%S%f')}"
- logger.info(f"🆕 新建会话,Thread ID: {thread_id}")
- config = {"configurable": {"thread_id": thread_id}}
-
- # 定义输入
- inputs = {
- "messages": [HumanMessage(content=message)],
- "user_id": user_id,
- "thread_id": thread_id,
- "suggested_next_step": None, # 初始化建议
- }
- final_state = None
- try:
- logger.info(f"🔄 开始处理 - Thread: {thread_id}, User: {user_id}, Message: '{message}'")
- # 使用 ainvoke 来执行完整的图流程
- final_state = await self.agent_executor.ainvoke(inputs, config)
-
- if final_state and final_state.get("messages"):
- answer = final_state["messages"][-1].content
- logger.info(f"✅ 处理完成 - Thread: {thread_id}, Final Answer: '{answer}'")
- return {"success": True, "answer": answer, "thread_id": thread_id}
- else:
- logger.error(f"❌ 处理异常结束,最终状态为空 - Thread: {thread_id}")
- return {"success": False, "error": "Agent failed to produce a final answer.", "thread_id": thread_id}
- except Exception as e:
- logger.error(f"❌ 处理过程中发生严重错误 - Thread: {thread_id}: {e}", exc_info=True)
- return {"success": False, "error": str(e), "thread_id": thread_id}
- async def get_conversation_history(self, thread_id: str) -> List[Dict[str, Any]]:
- """从 checkpointer 获取指定线程的对话历史。"""
- if not self.checkpointer:
- return []
-
- config = {"configurable": {"thread_id": thread_id}}
- conversation_state = await self.checkpointer.get(config)
-
- if not conversation_state:
- return []
-
- history = []
- for msg in conversation_state['values'].get('messages', []):
- if isinstance(msg, HumanMessage):
- role = "human"
- elif isinstance(msg, ToolMessage):
- role = "tool"
- else: # AIMessage
- role = "ai"
-
- history.append({
- "type": role,
- "content": msg.content,
- "tool_calls": getattr(msg, 'tool_calls', None)
- })
- return history
|