123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382 |
- """
- 同步版本的React Agent - 解决Vector搜索异步冲突问题
- 基于原有CustomReactAgent,但使用完全同步的实现
- """
- import json
- import sys
- import os
- from pathlib import Path
- from typing import List, Optional, Dict, Any
- import redis
- # 添加项目根目录到sys.path
- try:
- project_root = Path(__file__).parent.parent
- if str(project_root) not in sys.path:
- sys.path.insert(0, str(project_root))
- except Exception as e:
- pass
- from core.logging import get_react_agent_logger
- from langchain_openai import ChatOpenAI
- from langchain_core.messages import HumanMessage, ToolMessage, BaseMessage, SystemMessage, AIMessage
- from langgraph.graph import StateGraph, END
- from langgraph.prebuilt import ToolNode
- # 导入同步版本的依赖
- try:
- from . import config
- from .state import AgentState
- from .sql_tools import sql_tools
- except ImportError:
- import config
- from state import AgentState
- from sql_tools import sql_tools
- logger = get_react_agent_logger("SyncCustomReactAgent")
- class SyncCustomReactAgent:
- """
- 同步版本的React Agent
- 专门解决Vector搜索的异步事件循环冲突问题
- """
-
- def __init__(self):
- """私有构造函数,请使用 create() 类方法来创建实例。"""
- self.llm = None
- self.tools = None
- self.agent_executor = None
- self.checkpointer = None
- self.redis_client = None
- @classmethod
- def create(cls):
- """同步工厂方法,创建并初始化 SyncCustomReactAgent 实例。"""
- instance = cls()
- instance._sync_init()
- return instance
- def _sync_init(self):
- """同步初始化所有组件。"""
- logger.info("🚀 开始初始化 SyncCustomReactAgent...")
- # 1. 初始化同步Redis客户端(如果需要)
- try:
- self.redis_client = redis.from_url(config.REDIS_URL, decode_responses=True)
- self.redis_client.ping()
- logger.info(f" ✅ Redis连接成功: {config.REDIS_URL}")
- except Exception as e:
- logger.warning(f" ⚠️ Redis连接失败,将不使用checkpointer: {e}")
- self.redis_client = None
- # 2. 初始化 LLM(同步版本)
- self.llm = ChatOpenAI(
- api_key=config.QWEN_API_KEY,
- base_url=config.QWEN_BASE_URL,
- model=config.QWEN_MODEL,
- temperature=0.1,
- timeout=config.NETWORK_TIMEOUT,
- max_retries=0,
- streaming=False, # 关键:禁用流式处理
- extra_body={
- "enable_thinking": False, # 明确设置为False:非流式调用必须设为false
- "misc": {
- "ensure_ascii": False
- }
- }
- )
- logger.info(f" ✅ 同步LLM已初始化,模型: {config.QWEN_MODEL}")
- # 3. 绑定工具
- self.tools = sql_tools
- self.llm_with_tools = self.llm.bind_tools(self.tools)
- logger.info(f" ✅ 已绑定 {len(self.tools)} 个工具")
- # 4. 创建StateGraph(不使用checkpointer避免异步依赖)
- self.agent_executor = self._create_sync_graph()
- logger.info(" ✅ 同步StateGraph已创建")
- logger.info("✅ SyncCustomReactAgent 初始化完成")
- def _create_sync_graph(self):
- """创建同步的StateGraph"""
- graph = StateGraph(AgentState)
-
- # 添加同步节点
- graph.add_node("agent", self._sync_agent_node)
- graph.add_node("tools", ToolNode(self.tools))
- graph.add_node("prepare_tool_input", self._sync_prepare_tool_input_node)
- graph.add_node("update_state_after_tool", self._sync_update_state_after_tool_node)
- graph.add_node("format_final_response", self._sync_format_final_response_node)
- # 设置入口点
- graph.set_entry_point("agent")
- # 添加条件边
- graph.add_conditional_edges(
- "agent",
- self._sync_should_continue,
- {
- "tools": "prepare_tool_input",
- "end": "format_final_response"
- }
- )
- # 添加普通边
- graph.add_edge("prepare_tool_input", "tools")
- graph.add_edge("tools", "update_state_after_tool")
- graph.add_edge("update_state_after_tool", "agent")
- graph.add_edge("format_final_response", END)
- # 关键:使用同步编译,不传入checkpointer
- return graph.compile()
- def _sync_agent_node(self, state: AgentState) -> Dict[str, Any]:
- """同步Agent节点"""
- logger.info(f"🧠 [Sync Node] agent - Thread: {state.get('thread_id', 'unknown')}")
-
- messages_for_llm = state["messages"].copy()
-
- # 添加数据库范围提示词
- if isinstance(state["messages"][-1], HumanMessage):
- db_scope_prompt = self._get_database_scope_prompt()
- if db_scope_prompt:
- messages_for_llm.insert(0, SystemMessage(content=db_scope_prompt))
- logger.info(" ✅ 已添加数据库范围判断提示词")
- # 同步LLM调用
- response = self.llm_with_tools.invoke(messages_for_llm)
-
- return {"messages": [response]}
- def _sync_should_continue(self, state: AgentState):
- """同步条件判断"""
- messages = state["messages"]
- last_message = messages[-1]
-
- if not last_message.tool_calls:
- return "end"
- else:
- return "tools"
- def _sync_prepare_tool_input_node(self, state: AgentState) -> Dict[str, Any]:
- """同步准备工具输入节点"""
- logger.info(f"🔧 [Sync Node] prepare_tool_input - Thread: {state.get('thread_id', 'unknown')}")
-
- last_message = state["messages"][-1]
-
- if hasattr(last_message, 'tool_calls') and last_message.tool_calls:
- for tool_call in last_message.tool_calls:
- if tool_call.get('name') == 'generate_sql':
- # 注入历史消息
- history_messages = self._filter_and_format_history(state["messages"])
- if 'args' not in tool_call:
- tool_call['args'] = {}
- tool_call['args']['history_messages'] = history_messages
- logger.info(f" ✅ 为generate_sql注入了 {len(history_messages)} 条历史消息")
- return {"messages": [last_message]}
- def _sync_update_state_after_tool_node(self, state: AgentState) -> Dict[str, Any]:
- """同步更新工具执行后的状态"""
- logger.info(f"📝 [Sync Node] update_state_after_tool - Thread: {state.get('thread_id', 'unknown')}")
-
- last_message = state["messages"][-1]
- tool_name = last_message.name
- tool_output = last_message.content
- next_step = None
- if tool_name == 'generate_sql':
- tool_output_lower = tool_output.lower()
- if "failed" in tool_output_lower or "无法生成" in tool_output_lower or "失败" in tool_output_lower:
- next_step = 'answer_with_common_sense'
- else:
- next_step = 'valid_sql'
- elif tool_name == 'valid_sql':
- if "失败" in tool_output:
- next_step = 'analyze_validation_error'
- else:
- next_step = 'run_sql'
- elif tool_name == 'run_sql':
- next_step = 'summarize_final_answer'
-
- logger.info(f" Tool '{tool_name}' executed. Suggested next step: {next_step}")
- return {"suggested_next_step": next_step}
- def _sync_format_final_response_node(self, state: AgentState) -> Dict[str, Any]:
- """同步格式化最终响应节点"""
- logger.info(f"📄 [Sync Node] format_final_response - Thread: {state.get('thread_id', 'unknown')}")
-
- messages = state["messages"]
- last_message = messages[-1]
-
- # 构建最终响应
- final_response = last_message.content
-
- logger.info(f" ✅ 最终响应已准备完成")
- return {"final_answer": final_response}
- def _filter_and_format_history(self, messages: list) -> list:
- """过滤和格式化历史消息"""
- clean_history = []
- for msg in messages[:-1]: # 排除最后一条消息
- if isinstance(msg, HumanMessage):
- clean_history.append({"type": "human", "content": msg.content})
- elif isinstance(msg, AIMessage):
- clean_content = msg.content if not hasattr(msg, 'tool_calls') or not msg.tool_calls else ""
- if clean_content.strip():
- clean_history.append({"type": "ai", "content": clean_content})
-
- return clean_history
- def _get_database_scope_prompt(self) -> str:
- """获取数据库范围判断提示词"""
- return """你是一个专门处理高速公路收费数据查询的AI助手。在回答用户问题时,请首先判断这个问题是否可以通过查询数据库来回答。
- 数据库包含以下类型的数据:
- - 服务区信息(名称、位置、档口数量等)
- - 收费站数据
- - 车流量统计
- - 业务数据分析
- 如果用户的问题与这些数据相关,请使用工具生成SQL查询。
- 如果问题与数据库内容无关(如常识性问题、天气、新闻等),请直接用你的知识回答,不要尝试生成SQL。"""
- def chat(self, message: str, user_id: str, thread_id: Optional[str] = None) -> Dict[str, Any]:
- """
- 同步聊天方法 - 关键:使用 graph.invoke() 而不是 ainvoke()
- """
- if thread_id is None:
- import uuid
- thread_id = str(uuid.uuid4())
- # 构建输入
- inputs = {
- "messages": [HumanMessage(content=message)],
- "user_id": user_id,
- "thread_id": thread_id,
- "suggested_next_step": None
- }
- # 构建运行配置(不使用checkpointer)
- run_config = {
- "recursion_limit": config.RECURSION_LIMIT,
- }
- try:
- logger.info(f"🚀 开始同步处理用户消息: {message[:50]}...")
-
- # 关键:使用同步的 invoke() 方法
- final_state = self.agent_executor.invoke(inputs, run_config)
-
- logger.info(f"🔍 Final state keys: {list(final_state.keys())}")
-
- # 提取答案
- if final_state["messages"]:
- answer = final_state["messages"][-1].content
- else:
- answer = "抱歉,无法处理您的请求。"
-
- # 提取SQL数据(如果有)
- sql_data = self._extract_latest_sql_data(final_state["messages"])
-
- logger.info(f"✅ 同步处理完成 - Final Answer: '{answer[:100]}...'")
-
- # 构建返回结果
- result = {
- "success": True,
- "answer": answer,
- "thread_id": thread_id
- }
-
- # 只有当存在SQL数据时才添加到返回结果中
- if sql_data:
- try:
- # 尝试解析SQL数据
- sql_parsed = json.loads(sql_data)
-
- # 检查数据格式:run_sql工具返回的是数组格式 [{"col1":"val1"}]
- if isinstance(sql_parsed, list):
- # 数组格式:直接作为records使用
- result["api_data"] = {
- "response": answer,
- "records": sql_parsed,
- "react_agent_meta": {
- "thread_id": thread_id,
- "agent_version": "sync_react_v1"
- }
- }
- elif isinstance(sql_parsed, dict):
- # 字典格式:按原逻辑处理
- result["api_data"] = {
- "response": answer,
- "sql": sql_parsed.get("sql", ""),
- "records": sql_parsed.get("records", []),
- "react_agent_meta": {
- "thread_id": thread_id,
- "agent_version": "sync_react_v1"
- }
- }
- else:
- logger.warning(f"SQL数据格式未知: {type(sql_parsed)}")
- raise ValueError("Unknown SQL data format")
-
- except (json.JSONDecodeError, AttributeError, ValueError) as e:
- logger.warning(f"SQL数据格式处理失败: {str(e)}, 跳过API数据构建")
- else:
- result["api_data"] = {
- "response": answer,
- "react_agent_meta": {
- "thread_id": thread_id,
- "agent_version": "sync_react_v1"
- }
- }
-
- return result
-
- except Exception as e:
- logger.error(f"❌ 同步处理失败: {str(e)}", exc_info=True)
- return {
- "success": False,
- "error": f"同步处理失败: {str(e)}",
- "thread_id": thread_id,
- "retry_suggested": True
- }
- def _extract_latest_sql_data(self, messages: List[BaseMessage]) -> Optional[str]:
- """从消息历史中提取最近的run_sql执行结果(同步版本)"""
- logger.info("🔍 提取最新的SQL执行结果...")
-
- # 查找最后一个HumanMessage之后的SQL执行结果
- last_human_index = -1
- for i in range(len(messages) - 1, -1, -1):
- if isinstance(messages[i], HumanMessage):
- last_human_index = i
- break
-
- if last_human_index == -1:
- logger.info(" 未找到用户消息,跳过SQL数据提取")
- return None
-
- # 只在当前对话轮次中查找SQL结果
- current_conversation = messages[last_human_index:]
- logger.info(f" 当前对话轮次包含 {len(current_conversation)} 条消息")
-
- for msg in reversed(current_conversation):
- if isinstance(msg, ToolMessage) and msg.name == 'run_sql':
- logger.info(f" 找到当前对话轮次的run_sql结果: {msg.content[:100]}...")
-
- try:
- # 尝试解析JSON以验证格式
- parsed_data = json.loads(msg.content)
- # 重新序列化,确保中文字符正常显示
- formatted_content = json.dumps(parsed_data, ensure_ascii=False, separators=(',', ':'))
- logger.info(f" 已转换Unicode转义序列为中文字符")
- return formatted_content
- except json.JSONDecodeError:
- # 如果不是有效JSON,直接返回原内容
- logger.warning(f" SQL结果不是有效JSON格式,返回原始内容")
- return msg.content
-
- logger.info(" 当前对话轮次中未找到run_sql执行结果")
- return None
|