|
@@ -1,382 +0,0 @@
|
|
-"""
|
|
|
|
-同步版本的React Agent - 解决Vector搜索异步冲突问题
|
|
|
|
-基于原有CustomReactAgent,但使用完全同步的实现
|
|
|
|
-"""
|
|
|
|
-import json
|
|
|
|
-import sys
|
|
|
|
-import os
|
|
|
|
-from pathlib import Path
|
|
|
|
-from typing import List, Optional, Dict, Any
|
|
|
|
-import redis
|
|
|
|
-
|
|
|
|
-# 添加项目根目录到sys.path
|
|
|
|
-try:
|
|
|
|
- project_root = Path(__file__).parent.parent
|
|
|
|
- if str(project_root) not in sys.path:
|
|
|
|
- sys.path.insert(0, str(project_root))
|
|
|
|
-except Exception as e:
|
|
|
|
- pass
|
|
|
|
-
|
|
|
|
-from core.logging import get_react_agent_logger
|
|
|
|
-from langchain_openai import ChatOpenAI
|
|
|
|
-from langchain_core.messages import HumanMessage, ToolMessage, BaseMessage, SystemMessage, AIMessage
|
|
|
|
-from langgraph.graph import StateGraph, END
|
|
|
|
-from langgraph.prebuilt import ToolNode
|
|
|
|
-
|
|
|
|
-# 导入同步版本的依赖
|
|
|
|
-try:
|
|
|
|
- from . import config
|
|
|
|
- from .state import AgentState
|
|
|
|
- from .sql_tools import sql_tools
|
|
|
|
-except ImportError:
|
|
|
|
- import config
|
|
|
|
- from state import AgentState
|
|
|
|
- from sql_tools import sql_tools
|
|
|
|
-
|
|
|
|
-logger = get_react_agent_logger("SyncCustomReactAgent")
|
|
|
|
-
|
|
|
|
-class SyncCustomReactAgent:
|
|
|
|
- """
|
|
|
|
- 同步版本的React Agent
|
|
|
|
- 专门解决Vector搜索的异步事件循环冲突问题
|
|
|
|
- """
|
|
|
|
-
|
|
|
|
- def __init__(self):
|
|
|
|
- """私有构造函数,请使用 create() 类方法来创建实例。"""
|
|
|
|
- self.llm = None
|
|
|
|
- self.tools = None
|
|
|
|
- self.agent_executor = None
|
|
|
|
- self.checkpointer = None
|
|
|
|
- self.redis_client = None
|
|
|
|
-
|
|
|
|
- @classmethod
|
|
|
|
- def create(cls):
|
|
|
|
- """同步工厂方法,创建并初始化 SyncCustomReactAgent 实例。"""
|
|
|
|
- instance = cls()
|
|
|
|
- instance._sync_init()
|
|
|
|
- return instance
|
|
|
|
-
|
|
|
|
- def _sync_init(self):
|
|
|
|
- """同步初始化所有组件。"""
|
|
|
|
- logger.info("🚀 开始初始化 SyncCustomReactAgent...")
|
|
|
|
-
|
|
|
|
- # 1. 初始化同步Redis客户端(如果需要)
|
|
|
|
- try:
|
|
|
|
- self.redis_client = redis.from_url(config.REDIS_URL, decode_responses=True)
|
|
|
|
- self.redis_client.ping()
|
|
|
|
- logger.info(f" ✅ Redis连接成功: {config.REDIS_URL}")
|
|
|
|
- except Exception as e:
|
|
|
|
- logger.warning(f" ⚠️ Redis连接失败,将不使用checkpointer: {e}")
|
|
|
|
- self.redis_client = None
|
|
|
|
-
|
|
|
|
- # 2. 初始化 LLM(同步版本)
|
|
|
|
- self.llm = ChatOpenAI(
|
|
|
|
- api_key=config.QWEN_API_KEY,
|
|
|
|
- base_url=config.QWEN_BASE_URL,
|
|
|
|
- model=config.QWEN_MODEL,
|
|
|
|
- temperature=0.1,
|
|
|
|
- timeout=config.NETWORK_TIMEOUT,
|
|
|
|
- max_retries=0,
|
|
|
|
- streaming=False, # 关键:禁用流式处理
|
|
|
|
- extra_body={
|
|
|
|
- "enable_thinking": False, # 明确设置为False:非流式调用必须设为false
|
|
|
|
- "misc": {
|
|
|
|
- "ensure_ascii": False
|
|
|
|
- }
|
|
|
|
- }
|
|
|
|
- )
|
|
|
|
- logger.info(f" ✅ 同步LLM已初始化,模型: {config.QWEN_MODEL}")
|
|
|
|
-
|
|
|
|
- # 3. 绑定工具
|
|
|
|
- self.tools = sql_tools
|
|
|
|
- self.llm_with_tools = self.llm.bind_tools(self.tools)
|
|
|
|
- logger.info(f" ✅ 已绑定 {len(self.tools)} 个工具")
|
|
|
|
-
|
|
|
|
- # 4. 创建StateGraph(不使用checkpointer避免异步依赖)
|
|
|
|
- self.agent_executor = self._create_sync_graph()
|
|
|
|
- logger.info(" ✅ 同步StateGraph已创建")
|
|
|
|
-
|
|
|
|
- logger.info("✅ SyncCustomReactAgent 初始化完成")
|
|
|
|
-
|
|
|
|
- def _create_sync_graph(self):
|
|
|
|
- """创建同步的StateGraph"""
|
|
|
|
- graph = StateGraph(AgentState)
|
|
|
|
-
|
|
|
|
- # 添加同步节点
|
|
|
|
- graph.add_node("agent", self._sync_agent_node)
|
|
|
|
- graph.add_node("tools", ToolNode(self.tools))
|
|
|
|
- graph.add_node("prepare_tool_input", self._sync_prepare_tool_input_node)
|
|
|
|
- graph.add_node("update_state_after_tool", self._sync_update_state_after_tool_node)
|
|
|
|
- graph.add_node("format_final_response", self._sync_format_final_response_node)
|
|
|
|
-
|
|
|
|
- # 设置入口点
|
|
|
|
- graph.set_entry_point("agent")
|
|
|
|
-
|
|
|
|
- # 添加条件边
|
|
|
|
- graph.add_conditional_edges(
|
|
|
|
- "agent",
|
|
|
|
- self._sync_should_continue,
|
|
|
|
- {
|
|
|
|
- "tools": "prepare_tool_input",
|
|
|
|
- "end": "format_final_response"
|
|
|
|
- }
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- # 添加普通边
|
|
|
|
- graph.add_edge("prepare_tool_input", "tools")
|
|
|
|
- graph.add_edge("tools", "update_state_after_tool")
|
|
|
|
- graph.add_edge("update_state_after_tool", "agent")
|
|
|
|
- graph.add_edge("format_final_response", END)
|
|
|
|
-
|
|
|
|
- # 关键:使用同步编译,不传入checkpointer
|
|
|
|
- return graph.compile()
|
|
|
|
-
|
|
|
|
- def _sync_agent_node(self, state: AgentState) -> Dict[str, Any]:
|
|
|
|
- """同步Agent节点"""
|
|
|
|
- logger.info(f"🧠 [Sync Node] agent - Thread: {state.get('thread_id', 'unknown')}")
|
|
|
|
-
|
|
|
|
- messages_for_llm = state["messages"].copy()
|
|
|
|
-
|
|
|
|
- # 添加数据库范围提示词
|
|
|
|
- if isinstance(state["messages"][-1], HumanMessage):
|
|
|
|
- db_scope_prompt = self._get_database_scope_prompt()
|
|
|
|
- if db_scope_prompt:
|
|
|
|
- messages_for_llm.insert(0, SystemMessage(content=db_scope_prompt))
|
|
|
|
- logger.info(" ✅ 已添加数据库范围判断提示词")
|
|
|
|
-
|
|
|
|
- # 同步LLM调用
|
|
|
|
- response = self.llm_with_tools.invoke(messages_for_llm)
|
|
|
|
-
|
|
|
|
- return {"messages": [response]}
|
|
|
|
-
|
|
|
|
- def _sync_should_continue(self, state: AgentState):
|
|
|
|
- """同步条件判断"""
|
|
|
|
- messages = state["messages"]
|
|
|
|
- last_message = messages[-1]
|
|
|
|
-
|
|
|
|
- if not last_message.tool_calls:
|
|
|
|
- return "end"
|
|
|
|
- else:
|
|
|
|
- return "tools"
|
|
|
|
-
|
|
|
|
- def _sync_prepare_tool_input_node(self, state: AgentState) -> Dict[str, Any]:
|
|
|
|
- """同步准备工具输入节点"""
|
|
|
|
- logger.info(f"🔧 [Sync Node] prepare_tool_input - Thread: {state.get('thread_id', 'unknown')}")
|
|
|
|
-
|
|
|
|
- last_message = state["messages"][-1]
|
|
|
|
-
|
|
|
|
- if hasattr(last_message, 'tool_calls') and last_message.tool_calls:
|
|
|
|
- for tool_call in last_message.tool_calls:
|
|
|
|
- if tool_call.get('name') == 'generate_sql':
|
|
|
|
- # 注入历史消息
|
|
|
|
- history_messages = self._filter_and_format_history(state["messages"])
|
|
|
|
- if 'args' not in tool_call:
|
|
|
|
- tool_call['args'] = {}
|
|
|
|
- tool_call['args']['history_messages'] = history_messages
|
|
|
|
- logger.info(f" ✅ 为generate_sql注入了 {len(history_messages)} 条历史消息")
|
|
|
|
-
|
|
|
|
- return {"messages": [last_message]}
|
|
|
|
-
|
|
|
|
- def _sync_update_state_after_tool_node(self, state: AgentState) -> Dict[str, Any]:
|
|
|
|
- """同步更新工具执行后的状态"""
|
|
|
|
- logger.info(f"📝 [Sync Node] update_state_after_tool - Thread: {state.get('thread_id', 'unknown')}")
|
|
|
|
-
|
|
|
|
- last_message = state["messages"][-1]
|
|
|
|
- tool_name = last_message.name
|
|
|
|
- tool_output = last_message.content
|
|
|
|
- next_step = None
|
|
|
|
-
|
|
|
|
- if tool_name == 'generate_sql':
|
|
|
|
- tool_output_lower = tool_output.lower()
|
|
|
|
- if "failed" in tool_output_lower or "无法生成" in tool_output_lower or "失败" in tool_output_lower:
|
|
|
|
- next_step = 'answer_with_common_sense'
|
|
|
|
- else:
|
|
|
|
- next_step = 'valid_sql'
|
|
|
|
- elif tool_name == 'valid_sql':
|
|
|
|
- if "失败" in tool_output:
|
|
|
|
- next_step = 'analyze_validation_error'
|
|
|
|
- else:
|
|
|
|
- next_step = 'run_sql'
|
|
|
|
- elif tool_name == 'run_sql':
|
|
|
|
- next_step = 'summarize_final_answer'
|
|
|
|
-
|
|
|
|
- logger.info(f" Tool '{tool_name}' executed. Suggested next step: {next_step}")
|
|
|
|
- return {"suggested_next_step": next_step}
|
|
|
|
-
|
|
|
|
- def _sync_format_final_response_node(self, state: AgentState) -> Dict[str, Any]:
|
|
|
|
- """同步格式化最终响应节点"""
|
|
|
|
- logger.info(f"📄 [Sync Node] format_final_response - Thread: {state.get('thread_id', 'unknown')}")
|
|
|
|
-
|
|
|
|
- messages = state["messages"]
|
|
|
|
- last_message = messages[-1]
|
|
|
|
-
|
|
|
|
- # 构建最终响应
|
|
|
|
- final_response = last_message.content
|
|
|
|
-
|
|
|
|
- logger.info(f" ✅ 最终响应已准备完成")
|
|
|
|
- return {"final_answer": final_response}
|
|
|
|
-
|
|
|
|
- def _filter_and_format_history(self, messages: list) -> list:
|
|
|
|
- """过滤和格式化历史消息"""
|
|
|
|
- clean_history = []
|
|
|
|
- for msg in messages[:-1]: # 排除最后一条消息
|
|
|
|
- if isinstance(msg, HumanMessage):
|
|
|
|
- clean_history.append({"type": "human", "content": msg.content})
|
|
|
|
- elif isinstance(msg, AIMessage):
|
|
|
|
- clean_content = msg.content if not hasattr(msg, 'tool_calls') or not msg.tool_calls else ""
|
|
|
|
- if clean_content.strip():
|
|
|
|
- clean_history.append({"type": "ai", "content": clean_content})
|
|
|
|
-
|
|
|
|
- return clean_history
|
|
|
|
-
|
|
|
|
- def _get_database_scope_prompt(self) -> str:
|
|
|
|
- """获取数据库范围判断提示词"""
|
|
|
|
- return """你是一个专门处理高速公路收费数据查询的AI助手。在回答用户问题时,请首先判断这个问题是否可以通过查询数据库来回答。
|
|
|
|
-
|
|
|
|
-数据库包含以下类型的数据:
|
|
|
|
-- 服务区信息(名称、位置、档口数量等)
|
|
|
|
-- 收费站数据
|
|
|
|
-- 车流量统计
|
|
|
|
-- 业务数据分析
|
|
|
|
-
|
|
|
|
-如果用户的问题与这些数据相关,请使用工具生成SQL查询。
|
|
|
|
-如果问题与数据库内容无关(如常识性问题、天气、新闻等),请直接用你的知识回答,不要尝试生成SQL。"""
|
|
|
|
-
|
|
|
|
- def chat(self, message: str, user_id: str, thread_id: Optional[str] = None) -> Dict[str, Any]:
|
|
|
|
- """
|
|
|
|
- 同步聊天方法 - 关键:使用 graph.invoke() 而不是 ainvoke()
|
|
|
|
- """
|
|
|
|
- if thread_id is None:
|
|
|
|
- import uuid
|
|
|
|
- thread_id = str(uuid.uuid4())
|
|
|
|
-
|
|
|
|
- # 构建输入
|
|
|
|
- inputs = {
|
|
|
|
- "messages": [HumanMessage(content=message)],
|
|
|
|
- "user_id": user_id,
|
|
|
|
- "thread_id": thread_id,
|
|
|
|
- "suggested_next_step": None
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- # 构建运行配置(不使用checkpointer)
|
|
|
|
- run_config = {
|
|
|
|
- "recursion_limit": config.RECURSION_LIMIT,
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- try:
|
|
|
|
- logger.info(f"🚀 开始同步处理用户消息: {message[:50]}...")
|
|
|
|
-
|
|
|
|
- # 关键:使用同步的 invoke() 方法
|
|
|
|
- final_state = self.agent_executor.invoke(inputs, run_config)
|
|
|
|
-
|
|
|
|
- logger.info(f"🔍 Final state keys: {list(final_state.keys())}")
|
|
|
|
-
|
|
|
|
- # 提取答案
|
|
|
|
- if final_state["messages"]:
|
|
|
|
- answer = final_state["messages"][-1].content
|
|
|
|
- else:
|
|
|
|
- answer = "抱歉,无法处理您的请求。"
|
|
|
|
-
|
|
|
|
- # 提取SQL数据(如果有)
|
|
|
|
- sql_data = self._extract_latest_sql_data(final_state["messages"])
|
|
|
|
-
|
|
|
|
- logger.info(f"✅ 同步处理完成 - Final Answer: '{answer[:100]}...'")
|
|
|
|
-
|
|
|
|
- # 构建返回结果
|
|
|
|
- result = {
|
|
|
|
- "success": True,
|
|
|
|
- "answer": answer,
|
|
|
|
- "thread_id": thread_id
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- # 只有当存在SQL数据时才添加到返回结果中
|
|
|
|
- if sql_data:
|
|
|
|
- try:
|
|
|
|
- # 尝试解析SQL数据
|
|
|
|
- sql_parsed = json.loads(sql_data)
|
|
|
|
-
|
|
|
|
- # 检查数据格式:run_sql工具返回的是数组格式 [{"col1":"val1"}]
|
|
|
|
- if isinstance(sql_parsed, list):
|
|
|
|
- # 数组格式:直接作为records使用
|
|
|
|
- result["api_data"] = {
|
|
|
|
- "response": answer,
|
|
|
|
- "records": sql_parsed,
|
|
|
|
- "react_agent_meta": {
|
|
|
|
- "thread_id": thread_id,
|
|
|
|
- "agent_version": "sync_react_v1"
|
|
|
|
- }
|
|
|
|
- }
|
|
|
|
- elif isinstance(sql_parsed, dict):
|
|
|
|
- # 字典格式:按原逻辑处理
|
|
|
|
- result["api_data"] = {
|
|
|
|
- "response": answer,
|
|
|
|
- "sql": sql_parsed.get("sql", ""),
|
|
|
|
- "records": sql_parsed.get("records", []),
|
|
|
|
- "react_agent_meta": {
|
|
|
|
- "thread_id": thread_id,
|
|
|
|
- "agent_version": "sync_react_v1"
|
|
|
|
- }
|
|
|
|
- }
|
|
|
|
- else:
|
|
|
|
- logger.warning(f"SQL数据格式未知: {type(sql_parsed)}")
|
|
|
|
- raise ValueError("Unknown SQL data format")
|
|
|
|
-
|
|
|
|
- except (json.JSONDecodeError, AttributeError, ValueError) as e:
|
|
|
|
- logger.warning(f"SQL数据格式处理失败: {str(e)}, 跳过API数据构建")
|
|
|
|
- else:
|
|
|
|
- result["api_data"] = {
|
|
|
|
- "response": answer,
|
|
|
|
- "react_agent_meta": {
|
|
|
|
- "thread_id": thread_id,
|
|
|
|
- "agent_version": "sync_react_v1"
|
|
|
|
- }
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- return result
|
|
|
|
-
|
|
|
|
- except Exception as e:
|
|
|
|
- logger.error(f"❌ 同步处理失败: {str(e)}", exc_info=True)
|
|
|
|
- return {
|
|
|
|
- "success": False,
|
|
|
|
- "error": f"同步处理失败: {str(e)}",
|
|
|
|
- "thread_id": thread_id,
|
|
|
|
- "retry_suggested": True
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- def _extract_latest_sql_data(self, messages: List[BaseMessage]) -> Optional[str]:
|
|
|
|
- """从消息历史中提取最近的run_sql执行结果(同步版本)"""
|
|
|
|
- logger.info("🔍 提取最新的SQL执行结果...")
|
|
|
|
-
|
|
|
|
- # 查找最后一个HumanMessage之后的SQL执行结果
|
|
|
|
- last_human_index = -1
|
|
|
|
- for i in range(len(messages) - 1, -1, -1):
|
|
|
|
- if isinstance(messages[i], HumanMessage):
|
|
|
|
- last_human_index = i
|
|
|
|
- break
|
|
|
|
-
|
|
|
|
- if last_human_index == -1:
|
|
|
|
- logger.info(" 未找到用户消息,跳过SQL数据提取")
|
|
|
|
- return None
|
|
|
|
-
|
|
|
|
- # 只在当前对话轮次中查找SQL结果
|
|
|
|
- current_conversation = messages[last_human_index:]
|
|
|
|
- logger.info(f" 当前对话轮次包含 {len(current_conversation)} 条消息")
|
|
|
|
-
|
|
|
|
- for msg in reversed(current_conversation):
|
|
|
|
- if isinstance(msg, ToolMessage) and msg.name == 'run_sql':
|
|
|
|
- logger.info(f" 找到当前对话轮次的run_sql结果: {msg.content[:100]}...")
|
|
|
|
-
|
|
|
|
- try:
|
|
|
|
- # 尝试解析JSON以验证格式
|
|
|
|
- parsed_data = json.loads(msg.content)
|
|
|
|
- # 重新序列化,确保中文字符正常显示
|
|
|
|
- formatted_content = json.dumps(parsed_data, ensure_ascii=False, separators=(',', ':'))
|
|
|
|
- logger.info(f" 已转换Unicode转义序列为中文字符")
|
|
|
|
- return formatted_content
|
|
|
|
- except json.JSONDecodeError:
|
|
|
|
- # 如果不是有效JSON,直接返回原内容
|
|
|
|
- logger.warning(f" SQL结果不是有效JSON格式,返回原始内容")
|
|
|
|
- return msg.content
|
|
|
|
-
|
|
|
|
- logger.info(" 当前对话轮次中未找到run_sql执行结果")
|
|
|
|
- return None
|
|
|