|
@@ -0,0 +1,1246 @@
|
|
|
+让我仔细分析你的代码,回答这两个关键问题:
|
|
|
+
|
|
|
+## 1. CustomReactAgent类异步化分析
|
|
|
+
|
|
|
+### ✅ 已经异步化的方法:
|
|
|
+```python
|
|
|
+@classmethod
|
|
|
+async def create(cls): # ✅ 异步工厂方法
|
|
|
+async def _async_init(self): # ✅ 异步初始化
|
|
|
+async def close(self): # ✅ 异步清理
|
|
|
+async def chat(self, message, user_id, thread_id): # ✅ 异步聊天
|
|
|
+async def get_conversation_history(self, thread_id): # ✅ 异步获取历史
|
|
|
+async def get_user_recent_conversations(self, user_id, limit): # ✅ 异步获取对话列表
|
|
|
+```
|
|
|
+
|
|
|
+### ⚠️ **仍然是同步的关键方法(问题所在):**
|
|
|
+```python
|
|
|
+def _should_continue(self, state: AgentState) -> str: # ❌ 同步
|
|
|
+def _agent_node(self, state: AgentState) -> Dict[str, Any]: # ❌ 同步
|
|
|
+def _prepare_tool_input_node(self, state: AgentState): # ❌ 同步
|
|
|
+def _update_state_after_tool_node(self, state: AgentState): # ❌ 同步
|
|
|
+def _format_final_response_node(self, state: AgentState): # ❌ 同步
|
|
|
+```
|
|
|
+
|
|
|
+**问题分析**:你的StateGraph节点是同步的,但是内部调用了异步LLM:
|
|
|
+```python
|
|
|
+def _agent_node(self, state: AgentState) -> Dict[str, Any]:
|
|
|
+ # 这是同步函数,但调用异步LLM
|
|
|
+ response = self.llm_with_tools.invoke(messages_for_llm) # 🔥 这里有问题
|
|
|
+```
|
|
|
+
|
|
|
+## 2. Redis客户端使用分析
|
|
|
+
|
|
|
+### 你的代码中混用了两种Redis客户端:
|
|
|
+
|
|
|
+**异步Redis客户端(在agent.py中):**
|
|
|
+```python
|
|
|
+from redis.asyncio import Redis
|
|
|
+redis_client = Redis.from_url(config.REDIS_URL, decode_responses=True)
|
|
|
+```
|
|
|
+
|
|
|
+**同步Redis客户端(在api.py和enhanced_redis_api.py中):**
|
|
|
+```python
|
|
|
+import redis # 同步版本
|
|
|
+redis_client = redis.Redis(host='localhost', port=6379, decode_responses=True)
|
|
|
+```
|
|
|
+
|
|
|
+### 🎯 **关键问题:混用导致的冲突**
|
|
|
+
|
|
|
+## 完整的修复方案## 回答你的两个关键问题:
|
|
|
+
|
|
|
+### 1. CustomReactAgent类异步化分析
|
|
|
+
|
|
|
+**❌ 你的CustomReactAgent并没有完全异步化!**
|
|
|
+
|
|
|
+虽然主要方法是async的,但**StateGraph的节点方法都是同步的**,这正是问题的根源:
|
|
|
+
|
|
|
+```python
|
|
|
+# ❌ 这些是同步方法,但内部调用异步LLM
|
|
|
+def _agent_node(self, state: AgentState) -> Dict[str, Any]:
|
|
|
+ response = self.llm_with_tools.invoke(messages_for_llm) # 🔥问题在这里
|
|
|
+
|
|
|
+def _prepare_tool_input_node(self, state: AgentState) -> Dict[str, Any]:
|
|
|
+def _update_state_after_tool_node(self, state: AgentState) -> Dict[str, Any]:
|
|
|
+def _format_final_response_node(self, state: AgentState) -> Dict[str, Any]:
|
|
|
+def _should_continue(self, state: AgentState) -> str:
|
|
|
+```
|
|
|
+
|
|
|
+**核心问题**:LangGraph的StateGraph可以支持异步节点,但你没有使用!
|
|
|
+
|
|
|
+### 2. Redis客户端使用分析
|
|
|
+
|
|
|
+**❌ 你混用了同步和异步Redis客户端!**
|
|
|
+
|
|
|
+**异步Redis客户端(在agent.py中):**
|
|
|
+```python
|
|
|
+from redis.asyncio import Redis # ✅ 异步版本
|
|
|
+redis_client = Redis.from_url(config.REDIS_URL, decode_responses=True)
|
|
|
+```
|
|
|
+
|
|
|
+**同步Redis客户端(在api.py和enhanced_redis_api.py中):**
|
|
|
+```python
|
|
|
+import redis # ❌ 同步版本
|
|
|
+redis_client = redis.Redis(host='localhost', port=6379, decode_responses=True)
|
|
|
+```
|
|
|
+
|
|
|
+## 🎯 **在你的场景中是否一定要用异步Redis?**
|
|
|
+
|
|
|
+**答案:强烈建议使用,原因如下:**
|
|
|
+
|
|
|
+1. **架构一致性**:既然Flask路由是异步的,Redis操作也应该异步
|
|
|
+2. **避免阻塞**:同步Redis操作会阻塞事件循环,影响并发性能
|
|
|
+3. **消除错误源**:混用是导致"Event loop is closed"的重要原因
|
|
|
+4. **对话机器人特性**:用户期望快速响应,异步Redis能提供更好的并发能力
|
|
|
+
|
|
|
+## 🔧 **最关键的修改点**
|
|
|
+
|
|
|
+我在上面的artifact中提供了完整的修复方案,关键修改包括:
|
|
|
+
|
|
|
+### **1. StateGraph节点全部改为异步:**
|
|
|
+```python
|
|
|
+# ✅ 修改后:异步节点
|
|
|
+async def _async_agent_node(self, state: AgentState) -> Dict[str, Any]:
|
|
|
+ response = await self.llm_with_tools.ainvoke(messages_for_llm) # 使用ainvoke
|
|
|
+
|
|
|
+async def _async_prepare_tool_input_node(self, state: AgentState) -> Dict[str, Any]:
|
|
|
+async def _async_update_state_after_tool_node(self, state: AgentState) -> Dict[str, Any]:
|
|
|
+async def _async_format_final_response_node(self, state: AgentState) -> Dict[str, Any]:
|
|
|
+async def _async_should_continue(self, state: AgentState) -> str:
|
|
|
+```
|
|
|
+
|
|
|
+### **2. 统一Redis客户端:**
|
|
|
+```python
|
|
|
+# ✅ 统一使用异步Redis
|
|
|
+import redis.asyncio as redis
|
|
|
+self.redis_client = redis.from_url(config.REDIS_URL, decode_responses=True)
|
|
|
+await self.redis_client.ping()
|
|
|
+```
|
|
|
+
|
|
|
+### **3. 移除复杂的事件循环管理:**
|
|
|
+```python
|
|
|
+# ❌ 删除这些函数
|
|
|
+def run_async_safely(): # 删除
|
|
|
+def ensure_agent_ready_sync(): # 删除
|
|
|
+
|
|
|
+# ✅ 直接使用async/await
|
|
|
+@app.route("/api/chat", methods=["POST"])
|
|
|
+async def chat_endpoint():
|
|
|
+ agent_result = await _agent_instance.chat(...) # 直接调用
|
|
|
+```
|
|
|
+
|
|
|
+## 📋 **修改建议总结**
|
|
|
+
|
|
|
+1. **立即修改**:将StateGraph所有节点改为async版本
|
|
|
+2. **统一Redis**:全部使用`redis.asyncio`
|
|
|
+3. **删除包装**:移除所有`run_async_safely`类的函数
|
|
|
+4. **Flask路由**:核心API路由改为async
|
|
|
+5. **测试验证**:修改后测试Event loop错误是否消失
|
|
|
+
|
|
|
+**按照我提供的修复方案修改后,你就能直接用`python api.py`启动,不会再有任何事件循环问题!**
|
|
|
+
|
|
|
+
|
|
|
+"""
|
|
|
+CustomReactAgent完全异步化修复方案
|
|
|
+解决StateGraph节点同步/异步混用问题,统一Redis客户端使用
|
|
|
+"""
|
|
|
+
|
|
|
+# ============= 修复后的 agent.py =============
|
|
|
+import logging
|
|
|
+import json
|
|
|
+import pandas as pd
|
|
|
+from typing import List, Optional, Dict, Any, Tuple
|
|
|
+from contextlib import AsyncExitStack
|
|
|
+
|
|
|
+from langchain_openai import ChatOpenAI
|
|
|
+from langchain_core.messages import HumanMessage, ToolMessage, BaseMessage, SystemMessage, AIMessage
|
|
|
+from langgraph.graph import StateGraph, END
|
|
|
+from langgraph.prebuilt import ToolNode
|
|
|
+import redis.asyncio as redis # 🔥 统一使用异步Redis
|
|
|
+try:
|
|
|
+ from langgraph.checkpoint.redis import AsyncRedisSaver
|
|
|
+except ImportError:
|
|
|
+ AsyncRedisSaver = None
|
|
|
+
|
|
|
+# 从新模块导入配置、状态和工具
|
|
|
+try:
|
|
|
+ from . import config
|
|
|
+ from .state import AgentState
|
|
|
+ from .sql_tools import sql_tools
|
|
|
+except ImportError:
|
|
|
+ import config
|
|
|
+ from state import AgentState
|
|
|
+ from sql_tools import sql_tools
|
|
|
+
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
+
|
|
|
+class CustomReactAgent:
|
|
|
+ """
|
|
|
+ 完全异步化的 CustomReactAgent
|
|
|
+ 所有节点方法都是异步的,统一使用异步Redis客户端
|
|
|
+ """
|
|
|
+ def __init__(self):
|
|
|
+ """私有构造函数,请使用 create() 类方法来创建实例。"""
|
|
|
+ self.llm = None
|
|
|
+ self.tools = None
|
|
|
+ self.agent_executor = None
|
|
|
+ self.checkpointer = None
|
|
|
+ self._exit_stack = None
|
|
|
+ self.redis_client = None # 🔥 添加Redis客户端引用
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ async def create(cls):
|
|
|
+ """异步工厂方法,创建并初始化 CustomReactAgent 实例。"""
|
|
|
+ instance = cls()
|
|
|
+ await instance._async_init()
|
|
|
+ return instance
|
|
|
+
|
|
|
+ async def _async_init(self):
|
|
|
+ """异步初始化所有组件。"""
|
|
|
+ logger.info("🚀 开始初始化 CustomReactAgent...")
|
|
|
+
|
|
|
+ # 1. 初始化异步Redis客户端
|
|
|
+ self.redis_client = redis.from_url(config.REDIS_URL, decode_responses=True)
|
|
|
+ try:
|
|
|
+ await self.redis_client.ping()
|
|
|
+ logger.info(f" ✅ Redis连接成功: {config.REDIS_URL}")
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f" ❌ Redis连接失败: {e}")
|
|
|
+ raise
|
|
|
+
|
|
|
+ # 2. 初始化 LLM
|
|
|
+ self.llm = ChatOpenAI(
|
|
|
+ api_key=config.QWEN_API_KEY,
|
|
|
+ base_url=config.QWEN_BASE_URL,
|
|
|
+ model=config.QWEN_MODEL,
|
|
|
+ temperature=0.1,
|
|
|
+ timeout=config.NETWORK_TIMEOUT,
|
|
|
+ max_retries=config.MAX_RETRIES,
|
|
|
+ extra_body={
|
|
|
+ "enable_thinking": False,
|
|
|
+ "misc": {
|
|
|
+ "ensure_ascii": False
|
|
|
+ }
|
|
|
+ }
|
|
|
+ )
|
|
|
+ logger.info(f" LLM 已初始化,模型: {config.QWEN_MODEL}")
|
|
|
+
|
|
|
+ # 3. 绑定工具
|
|
|
+ self.tools = sql_tools
|
|
|
+ self.llm_with_tools = self.llm.bind_tools(self.tools)
|
|
|
+ logger.info(f" 已绑定 {len(self.tools)} 个工具。")
|
|
|
+
|
|
|
+ # 4. 初始化 Redis Checkpointer
|
|
|
+ if config.REDIS_ENABLED and AsyncRedisSaver is not None:
|
|
|
+ try:
|
|
|
+ self._exit_stack = AsyncExitStack()
|
|
|
+ checkpointer_manager = AsyncRedisSaver.from_conn_string(config.REDIS_URL)
|
|
|
+ self.checkpointer = await self._exit_stack.enter_async_context(checkpointer_manager)
|
|
|
+ await self.checkpointer.asetup()
|
|
|
+ logger.info(f" AsyncRedisSaver 持久化已启用: {config.REDIS_URL}")
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f" ❌ RedisSaver 初始化失败: {e}", exc_info=True)
|
|
|
+ if self._exit_stack:
|
|
|
+ await self._exit_stack.aclose()
|
|
|
+ self.checkpointer = None
|
|
|
+ else:
|
|
|
+ logger.warning(" Redis 持久化功能已禁用。")
|
|
|
+
|
|
|
+ # 5. 构建 StateGraph
|
|
|
+ self.agent_executor = self._create_graph()
|
|
|
+ logger.info(" StateGraph 已构建并编译。")
|
|
|
+ logger.info("✅ CustomReactAgent 初始化完成。")
|
|
|
+
|
|
|
+ async def close(self):
|
|
|
+ """清理资源,关闭所有连接。"""
|
|
|
+ if self._exit_stack:
|
|
|
+ await self._exit_stack.aclose()
|
|
|
+ self._exit_stack = None
|
|
|
+ self.checkpointer = None
|
|
|
+ logger.info("✅ RedisSaver 资源已通过 AsyncExitStack 释放。")
|
|
|
+
|
|
|
+ if self.redis_client:
|
|
|
+ await self.redis_client.aclose()
|
|
|
+ logger.info("✅ Redis客户端已关闭。")
|
|
|
+
|
|
|
+ def _create_graph(self):
|
|
|
+ """定义并编译最终的、正确的 StateGraph 结构。"""
|
|
|
+ builder = StateGraph(AgentState)
|
|
|
+
|
|
|
+ # 🔥 关键修改:所有节点都是异步的
|
|
|
+ builder.add_node("agent", self._async_agent_node)
|
|
|
+ builder.add_node("prepare_tool_input", self._async_prepare_tool_input_node)
|
|
|
+ builder.add_node("tools", ToolNode(self.tools))
|
|
|
+ builder.add_node("update_state_after_tool", self._async_update_state_after_tool_node)
|
|
|
+ builder.add_node("format_final_response", self._async_format_final_response_node)
|
|
|
+
|
|
|
+ # 建立正确的边连接
|
|
|
+ builder.set_entry_point("agent")
|
|
|
+ builder.add_conditional_edges(
|
|
|
+ "agent",
|
|
|
+ self._async_should_continue, # 🔥 异步条件判断
|
|
|
+ {
|
|
|
+ "continue": "prepare_tool_input",
|
|
|
+ "end": "format_final_response"
|
|
|
+ }
|
|
|
+ )
|
|
|
+ builder.add_edge("prepare_tool_input", "tools")
|
|
|
+ builder.add_edge("tools", "update_state_after_tool")
|
|
|
+ builder.add_edge("update_state_after_tool", "agent")
|
|
|
+ builder.add_edge("format_final_response", END)
|
|
|
+
|
|
|
+ return builder.compile(checkpointer=self.checkpointer)
|
|
|
+
|
|
|
+ async def _async_should_continue(self, state: AgentState) -> str:
|
|
|
+ """🔥 异步版本:判断是继续调用工具还是结束。"""
|
|
|
+ last_message = state["messages"][-1]
|
|
|
+ if hasattr(last_message, "tool_calls") and last_message.tool_calls:
|
|
|
+ return "continue"
|
|
|
+ return "end"
|
|
|
+
|
|
|
+ async def _async_agent_node(self, state: AgentState) -> Dict[str, Any]:
|
|
|
+ """🔥 异步版本:Agent 节点,使用异步LLM调用。"""
|
|
|
+ logger.info(f"🧠 [Async Node] agent - Thread: {state['thread_id']}")
|
|
|
+
|
|
|
+ messages_for_llm = list(state["messages"])
|
|
|
+ if state.get("suggested_next_step"):
|
|
|
+ instruction = f"提示:建议下一步使用工具 '{state['suggested_next_step']}'。"
|
|
|
+ messages_for_llm.append(SystemMessage(content=instruction))
|
|
|
+
|
|
|
+ # 🔥 关键修改:使用异步LLM调用
|
|
|
+ import time
|
|
|
+ max_retries = config.MAX_RETRIES
|
|
|
+ for attempt in range(max_retries):
|
|
|
+ try:
|
|
|
+ # 使用异步调用
|
|
|
+ response = await self.llm_with_tools.ainvoke(messages_for_llm)
|
|
|
+ logger.info(f" ✅ 异步LLM调用成功")
|
|
|
+ return {"messages": [response]}
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ error_msg = str(e)
|
|
|
+ logger.warning(f" ⚠️ 异步LLM调用失败 (尝试 {attempt + 1}/{max_retries}): {error_msg}")
|
|
|
+
|
|
|
+ if any(keyword in error_msg for keyword in [
|
|
|
+ "Connection error", "APIConnectionError", "ConnectError",
|
|
|
+ "timeout", "远程主机强迫关闭", "网络连接"
|
|
|
+ ]):
|
|
|
+ if attempt < max_retries - 1:
|
|
|
+ wait_time = config.RETRY_BASE_DELAY ** attempt
|
|
|
+ logger.info(f" 🔄 网络错误,{wait_time}秒后重试...")
|
|
|
+ await asyncio.sleep(wait_time) # 🔥 使用async sleep
|
|
|
+ continue
|
|
|
+ else:
|
|
|
+ logger.error(f" ❌ 网络连接持续失败,返回降级回答")
|
|
|
+ sql_data = await self._async_extract_latest_sql_data(state["messages"])
|
|
|
+ if sql_data:
|
|
|
+ fallback_content = "抱歉,由于网络连接问题,无法生成完整的文字总结。不过查询已成功执行,结果如下:\n\n" + sql_data
|
|
|
+ else:
|
|
|
+ fallback_content = "抱歉,由于网络连接问题,无法完成此次请求。请稍后重试或检查网络连接。"
|
|
|
+
|
|
|
+ fallback_response = AIMessage(content=fallback_content)
|
|
|
+ return {"messages": [fallback_response]}
|
|
|
+ else:
|
|
|
+ logger.error(f" ❌ LLM调用出现非网络错误: {error_msg}")
|
|
|
+ raise e
|
|
|
+
|
|
|
+ async def _async_prepare_tool_input_node(self, state: AgentState) -> Dict[str, Any]:
|
|
|
+ """🔥 异步版本:信息组装节点。"""
|
|
|
+ logger.info(f"🛠️ [Async Node] prepare_tool_input - Thread: {state['thread_id']}")
|
|
|
+
|
|
|
+ last_message = state["messages"][-1]
|
|
|
+ if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
|
|
|
+ return {"messages": [last_message]}
|
|
|
+
|
|
|
+ new_tool_calls = []
|
|
|
+ for tool_call in last_message.tool_calls:
|
|
|
+ if tool_call["name"] == "generate_sql":
|
|
|
+ logger.info(" 检测到 generate_sql 调用,注入历史消息。")
|
|
|
+ modified_args = tool_call["args"].copy()
|
|
|
+
|
|
|
+ clean_history = []
|
|
|
+ messages_except_current = state["messages"][:-1]
|
|
|
+
|
|
|
+ for msg in messages_except_current:
|
|
|
+ if isinstance(msg, HumanMessage):
|
|
|
+ clean_history.append({
|
|
|
+ "type": "human",
|
|
|
+ "content": msg.content
|
|
|
+ })
|
|
|
+ elif isinstance(msg, AIMessage):
|
|
|
+ if msg.content and "[Formatted Output]" in msg.content:
|
|
|
+ clean_content = msg.content.replace("[Formatted Output]\n", "")
|
|
|
+ clean_history.append({
|
|
|
+ "type": "ai",
|
|
|
+ "content": clean_content
|
|
|
+ })
|
|
|
+
|
|
|
+ modified_args["history_messages"] = clean_history
|
|
|
+ logger.info(f" 注入了 {len(clean_history)} 条过滤后的历史消息")
|
|
|
+
|
|
|
+ new_tool_calls.append({
|
|
|
+ "name": tool_call["name"],
|
|
|
+ "args": modified_args,
|
|
|
+ "id": tool_call["id"],
|
|
|
+ })
|
|
|
+ else:
|
|
|
+ new_tool_calls.append(tool_call)
|
|
|
+
|
|
|
+ last_message.tool_calls = new_tool_calls
|
|
|
+ return {"messages": [last_message]}
|
|
|
+
|
|
|
+ async def _async_update_state_after_tool_node(self, state: AgentState) -> Dict[str, Any]:
|
|
|
+ """🔥 异步版本:在工具执行后,更新 suggested_next_step。"""
|
|
|
+ logger.info(f"📝 [Async Node] update_state_after_tool - Thread: {state['thread_id']}")
|
|
|
+
|
|
|
+ last_tool_message = state['messages'][-1]
|
|
|
+ tool_name = last_tool_message.name
|
|
|
+ tool_output = last_tool_message.content
|
|
|
+ next_step = None
|
|
|
+
|
|
|
+ if tool_name == 'generate_sql':
|
|
|
+ if "失败" in tool_output or "无法生成" in tool_output:
|
|
|
+ next_step = 'answer_with_common_sense'
|
|
|
+ else:
|
|
|
+ next_step = 'valid_sql'
|
|
|
+ elif tool_name == 'valid_sql':
|
|
|
+ if "失败" in tool_output:
|
|
|
+ next_step = 'analyze_validation_error'
|
|
|
+ else:
|
|
|
+ next_step = 'run_sql'
|
|
|
+ elif tool_name == 'run_sql':
|
|
|
+ next_step = 'summarize_final_answer'
|
|
|
+
|
|
|
+ logger.info(f" Tool '{tool_name}' executed. Suggested next step: {next_step}")
|
|
|
+ return {"suggested_next_step": next_step}
|
|
|
+
|
|
|
+ async def _async_format_final_response_node(self, state: AgentState) -> Dict[str, Any]:
|
|
|
+ """🔥 异步版本:最终输出格式化节点。"""
|
|
|
+ logger.info(f"🎨 [Async Node] format_final_response - Thread: {state['thread_id']}")
|
|
|
+
|
|
|
+ last_message = state['messages'][-1]
|
|
|
+ last_message.content = f"[Formatted Output]\n{last_message.content}"
|
|
|
+
|
|
|
+ # 生成API格式的数据
|
|
|
+ api_data = await self._async_generate_api_data(state)
|
|
|
+
|
|
|
+ return {
|
|
|
+ "messages": [last_message],
|
|
|
+ "api_data": api_data
|
|
|
+ }
|
|
|
+
|
|
|
+ async def _async_generate_api_data(self, state: AgentState) -> Dict[str, Any]:
|
|
|
+ """🔥 异步版本:生成API格式的数据结构"""
|
|
|
+ logger.info("📊 异步生成API格式数据...")
|
|
|
+
|
|
|
+ last_message = state['messages'][-1]
|
|
|
+ response_content = last_message.content
|
|
|
+
|
|
|
+ if response_content.startswith("[Formatted Output]\n"):
|
|
|
+ response_content = response_content.replace("[Formatted Output]\n", "")
|
|
|
+
|
|
|
+ api_data = {
|
|
|
+ "response": response_content
|
|
|
+ }
|
|
|
+
|
|
|
+ sql_info = await self._async_extract_sql_and_data(state['messages'])
|
|
|
+ if sql_info['sql']:
|
|
|
+ api_data["sql"] = sql_info['sql']
|
|
|
+ if sql_info['records']:
|
|
|
+ api_data["records"] = sql_info['records']
|
|
|
+
|
|
|
+ api_data["react_agent_meta"] = await self._async_collect_agent_metadata(state)
|
|
|
+
|
|
|
+ logger.info(f" API数据生成完成,包含字段: {list(api_data.keys())}")
|
|
|
+ return api_data
|
|
|
+
|
|
|
+ async def _async_extract_sql_and_data(self, messages: List[BaseMessage]) -> Dict[str, Any]:
|
|
|
+ """🔥 异步版本:从消息历史中提取SQL和数据记录"""
|
|
|
+ result = {"sql": None, "records": None}
|
|
|
+
|
|
|
+ last_human_index = -1
|
|
|
+ for i in range(len(messages) - 1, -1, -1):
|
|
|
+ if isinstance(messages[i], HumanMessage):
|
|
|
+ last_human_index = i
|
|
|
+ break
|
|
|
+
|
|
|
+ if last_human_index == -1:
|
|
|
+ return result
|
|
|
+
|
|
|
+ current_conversation = messages[last_human_index:]
|
|
|
+ sql_query = None
|
|
|
+ sql_data = None
|
|
|
+
|
|
|
+ for msg in current_conversation:
|
|
|
+ if isinstance(msg, ToolMessage):
|
|
|
+ if msg.name == 'generate_sql':
|
|
|
+ content = msg.content
|
|
|
+ if content and not any(keyword in content for keyword in ["失败", "无法生成", "Database query failed"]):
|
|
|
+ sql_query = content.strip()
|
|
|
+ elif msg.name == 'run_sql':
|
|
|
+ try:
|
|
|
+ import json
|
|
|
+ parsed_data = json.loads(msg.content)
|
|
|
+ if isinstance(parsed_data, list) and len(parsed_data) > 0:
|
|
|
+ columns = list(parsed_data[0].keys()) if parsed_data else []
|
|
|
+ sql_data = {
|
|
|
+ "columns": columns,
|
|
|
+ "rows": parsed_data,
|
|
|
+ "total_row_count": len(parsed_data),
|
|
|
+ "is_limited": False
|
|
|
+ }
|
|
|
+ except (json.JSONDecodeError, Exception) as e:
|
|
|
+ logger.warning(f" 解析SQL结果失败: {e}")
|
|
|
+
|
|
|
+ if sql_query:
|
|
|
+ result["sql"] = sql_query
|
|
|
+ if sql_data:
|
|
|
+ result["records"] = sql_data
|
|
|
+
|
|
|
+ return result
|
|
|
+
|
|
|
+ async def _async_collect_agent_metadata(self, state: AgentState) -> Dict[str, Any]:
|
|
|
+ """🔥 异步版本:收集Agent元数据"""
|
|
|
+ messages = state['messages']
|
|
|
+
|
|
|
+ tools_used = []
|
|
|
+ sql_execution_count = 0
|
|
|
+ context_injected = False
|
|
|
+ conversation_rounds = sum(1 for msg in messages if isinstance(msg, HumanMessage))
|
|
|
+
|
|
|
+ for msg in messages:
|
|
|
+ if isinstance(msg, ToolMessage):
|
|
|
+ if msg.name not in tools_used:
|
|
|
+ tools_used.append(msg.name)
|
|
|
+ if msg.name == 'run_sql':
|
|
|
+ sql_execution_count += 1
|
|
|
+ elif isinstance(msg, AIMessage) and hasattr(msg, 'tool_calls') and msg.tool_calls:
|
|
|
+ for tool_call in msg.tool_calls:
|
|
|
+ tool_name = tool_call.get('name')
|
|
|
+ if tool_name and tool_name not in tools_used:
|
|
|
+ tools_used.append(tool_name)
|
|
|
+
|
|
|
+ if (tool_name == 'generate_sql' and
|
|
|
+ tool_call.get('args', {}).get('history_messages')):
|
|
|
+ context_injected = True
|
|
|
+
|
|
|
+ execution_path = ["agent"]
|
|
|
+ if tools_used:
|
|
|
+ execution_path.extend(["prepare_tool_input", "tools"])
|
|
|
+ execution_path.append("format_final_response")
|
|
|
+
|
|
|
+ return {
|
|
|
+ "thread_id": state['thread_id'],
|
|
|
+ "conversation_rounds": conversation_rounds,
|
|
|
+ "tools_used": tools_used,
|
|
|
+ "execution_path": execution_path,
|
|
|
+ "total_messages": len(messages),
|
|
|
+ "sql_execution_count": sql_execution_count,
|
|
|
+ "context_injected": context_injected,
|
|
|
+ "agent_version": "custom_react_v1_async"
|
|
|
+ }
|
|
|
+
|
|
|
+ async def _async_extract_latest_sql_data(self, messages: List[BaseMessage]) -> Optional[str]:
|
|
|
+ """🔥 异步版本:提取最新的SQL执行结果"""
|
|
|
+ logger.info("🔍 异步提取最新的SQL执行结果...")
|
|
|
+
|
|
|
+ last_human_index = -1
|
|
|
+ for i in range(len(messages) - 1, -1, -1):
|
|
|
+ if isinstance(messages[i], HumanMessage):
|
|
|
+ last_human_index = i
|
|
|
+ break
|
|
|
+
|
|
|
+ if last_human_index == -1:
|
|
|
+ logger.info(" 未找到用户消息,跳过SQL数据提取")
|
|
|
+ return None
|
|
|
+
|
|
|
+ current_conversation = messages[last_human_index:]
|
|
|
+ logger.info(f" 当前对话轮次包含 {len(current_conversation)} 条消息")
|
|
|
+
|
|
|
+ for msg in reversed(current_conversation):
|
|
|
+ if isinstance(msg, ToolMessage) and msg.name == 'run_sql':
|
|
|
+ logger.info(f" 找到当前对话轮次的run_sql结果: {msg.content[:100]}...")
|
|
|
+
|
|
|
+ try:
|
|
|
+ parsed_data = json.loads(msg.content)
|
|
|
+ formatted_content = json.dumps(parsed_data, ensure_ascii=False, separators=(',', ':'))
|
|
|
+ logger.info(f" 已转换Unicode转义序列为中文字符")
|
|
|
+ return formatted_content
|
|
|
+ except json.JSONDecodeError:
|
|
|
+ logger.warning(f" SQL结果不是有效JSON格式,返回原始内容")
|
|
|
+ return msg.content
|
|
|
+
|
|
|
+ logger.info(" 当前对话轮次中未找到run_sql执行结果")
|
|
|
+ return None
|
|
|
+
|
|
|
+ async def chat(self, message: str, user_id: str, thread_id: Optional[str] = None) -> Dict[str, Any]:
|
|
|
+ """🔥 完全异步的聊天处理方法"""
|
|
|
+ if not thread_id:
|
|
|
+ now = pd.Timestamp.now()
|
|
|
+ milliseconds = int(now.microsecond / 1000)
|
|
|
+ thread_id = f"{user_id}:{now.strftime('%Y%m%d%H%M%S')}{milliseconds:03d}"
|
|
|
+ logger.info(f"🆕 新建会话,Thread ID: {thread_id}")
|
|
|
+
|
|
|
+ config = {
|
|
|
+ "configurable": {
|
|
|
+ "thread_id": thread_id,
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ inputs = {
|
|
|
+ "messages": [HumanMessage(content=message)],
|
|
|
+ "user_id": user_id,
|
|
|
+ "thread_id": thread_id,
|
|
|
+ "suggested_next_step": None,
|
|
|
+ }
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 🔥 使用异步调用
|
|
|
+ final_state = await self.agent_executor.ainvoke(inputs, config)
|
|
|
+ answer = final_state["messages"][-1].content
|
|
|
+
|
|
|
+ sql_data = await self._async_extract_latest_sql_data(final_state["messages"])
|
|
|
+
|
|
|
+ logger.info(f"✅ 异步处理完成 - Final Answer: '{answer}'")
|
|
|
+
|
|
|
+ result = {
|
|
|
+ "success": True,
|
|
|
+ "answer": answer,
|
|
|
+ "thread_id": thread_id
|
|
|
+ }
|
|
|
+
|
|
|
+ if sql_data:
|
|
|
+ result["sql_data"] = sql_data
|
|
|
+ logger.info(" 📊 已包含SQL原始数据")
|
|
|
+
|
|
|
+ if "api_data" in final_state:
|
|
|
+ result["api_data"] = final_state["api_data"]
|
|
|
+ logger.info(" 🔌 已包含API格式数据")
|
|
|
+
|
|
|
+ return result
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"❌ 异步处理过程中发生严重错误 - Thread: {thread_id}: {e}", exc_info=True)
|
|
|
+ return {"success": False, "error": str(e), "thread_id": thread_id}
|
|
|
+
|
|
|
+ async def get_conversation_history(self, thread_id: str) -> List[Dict[str, Any]]:
|
|
|
+ """🔥 完全异步的对话历史获取"""
|
|
|
+ if not self.checkpointer:
|
|
|
+ return []
|
|
|
+
|
|
|
+ config = {"configurable": {"thread_id": thread_id}}
|
|
|
+ try:
|
|
|
+ conversation_state = await self.checkpointer.aget(config)
|
|
|
+ except RuntimeError as e:
|
|
|
+ if "Event loop is closed" in str(e):
|
|
|
+ logger.warning(f"⚠️ Event loop已关闭,返回空结果: {thread_id}")
|
|
|
+ return []
|
|
|
+ else:
|
|
|
+ raise
|
|
|
+
|
|
|
+ if not conversation_state:
|
|
|
+ return []
|
|
|
+
|
|
|
+ history = []
|
|
|
+ messages = conversation_state.get('channel_values', {}).get('messages', [])
|
|
|
+ for msg in messages:
|
|
|
+ if isinstance(msg, HumanMessage):
|
|
|
+ role = "human"
|
|
|
+ elif isinstance(msg, ToolMessage):
|
|
|
+ role = "tool"
|
|
|
+ else:
|
|
|
+ role = "ai"
|
|
|
+
|
|
|
+ history.append({
|
|
|
+ "type": role,
|
|
|
+ "content": msg.content,
|
|
|
+ "tool_calls": getattr(msg, 'tool_calls', None)
|
|
|
+ })
|
|
|
+ return history
|
|
|
+
|
|
|
+ async def get_user_recent_conversations(self, user_id: str, limit: int = 10) -> List[Dict[str, Any]]:
|
|
|
+ """🔥 完全异步的用户对话列表获取"""
|
|
|
+ if not self.checkpointer:
|
|
|
+ return []
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 🔥 使用统一的异步Redis客户端
|
|
|
+ pattern = f"checkpoint:{user_id}:*"
|
|
|
+ logger.info(f"🔍 异步扫描模式: {pattern}")
|
|
|
+
|
|
|
+ user_threads = {}
|
|
|
+ cursor = 0
|
|
|
+
|
|
|
+ while True:
|
|
|
+ cursor, keys = await self.redis_client.scan(
|
|
|
+ cursor=cursor,
|
|
|
+ match=pattern,
|
|
|
+ count=1000
|
|
|
+ )
|
|
|
+
|
|
|
+ for key in keys:
|
|
|
+ try:
|
|
|
+ key_str = key.decode() if isinstance(key, bytes) else key
|
|
|
+ parts = key_str.split(':')
|
|
|
+
|
|
|
+ if len(parts) >= 4:
|
|
|
+ thread_id = f"{parts[1]}:{parts[2]}"
|
|
|
+ timestamp = parts[2]
|
|
|
+
|
|
|
+ if thread_id not in user_threads:
|
|
|
+ user_threads[thread_id] = {
|
|
|
+ "thread_id": thread_id,
|
|
|
+ "timestamp": timestamp,
|
|
|
+ "latest_key": key_str
|
|
|
+ }
|
|
|
+ else:
|
|
|
+ if len(parts) > 4 and parts[4] > user_threads[thread_id]["latest_key"].split(':')[4]:
|
|
|
+ user_threads[thread_id]["latest_key"] = key_str
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning(f"解析key {key} 失败: {e}")
|
|
|
+ continue
|
|
|
+
|
|
|
+ if cursor == 0:
|
|
|
+ break
|
|
|
+
|
|
|
+ # 按时间戳排序
|
|
|
+ sorted_threads = sorted(
|
|
|
+ user_threads.values(),
|
|
|
+ key=lambda x: x["timestamp"],
|
|
|
+ reverse=True
|
|
|
+ )[:limit]
|
|
|
+
|
|
|
+ # 获取每个thread的详细信息
|
|
|
+ conversations = []
|
|
|
+ for thread_info in sorted_threads:
|
|
|
+ try:
|
|
|
+ thread_id = thread_info["thread_id"]
|
|
|
+ thread_config = {"configurable": {"thread_id": thread_id}}
|
|
|
+
|
|
|
+ try:
|
|
|
+ state = await self.checkpointer.aget(thread_config)
|
|
|
+ except RuntimeError as e:
|
|
|
+ if "Event loop is closed" in str(e):
|
|
|
+ logger.warning(f"⚠️ Event loop已关闭,跳过thread: {thread_id}")
|
|
|
+ continue
|
|
|
+ else:
|
|
|
+ raise
|
|
|
+
|
|
|
+ if state and state.get('channel_values', {}).get('messages'):
|
|
|
+ messages = state['channel_values']['messages']
|
|
|
+ preview = self._generate_conversation_preview(messages)
|
|
|
+
|
|
|
+ conversations.append({
|
|
|
+ "thread_id": thread_id,
|
|
|
+ "user_id": user_id,
|
|
|
+ "timestamp": thread_info["timestamp"],
|
|
|
+ "message_count": len(messages),
|
|
|
+ "last_message": messages[-1].content if messages else None,
|
|
|
+ "last_updated": state.get('created_at'),
|
|
|
+ "conversation_preview": preview,
|
|
|
+ "formatted_time": self._format_timestamp(thread_info["timestamp"])
|
|
|
+ })
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"获取thread {thread_info['thread_id']} 详情失败: {e}")
|
|
|
+ continue
|
|
|
+
|
|
|
+ logger.info(f"✅ 异步找到用户 {user_id} 的 {len(conversations)} 个对话")
|
|
|
+ return conversations
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"❌ 异步获取用户 {user_id} 对话列表失败: {e}")
|
|
|
+ return []
|
|
|
+
|
|
|
+ def _generate_conversation_preview(self, messages: List[BaseMessage]) -> str:
|
|
|
+ """生成对话预览(保持同步,因为是纯计算)"""
|
|
|
+ if not messages:
|
|
|
+ return "空对话"
|
|
|
+
|
|
|
+ for msg in messages:
|
|
|
+ if isinstance(msg, HumanMessage):
|
|
|
+ content = str(msg.content)
|
|
|
+ return content[:50] + "..." if len(content) > 50 else content
|
|
|
+
|
|
|
+ return "系统消息"
|
|
|
+
|
|
|
+ def _format_timestamp(self, timestamp: str) -> str:
|
|
|
+ """格式化时间戳为可读格式(保持同步,因为是纯计算)"""
|
|
|
+ try:
|
|
|
+ if len(timestamp) >= 14:
|
|
|
+ year = timestamp[:4]
|
|
|
+ month = timestamp[4:6]
|
|
|
+ day = timestamp[6:8]
|
|
|
+ hour = timestamp[8:10]
|
|
|
+ minute = timestamp[10:12]
|
|
|
+ second = timestamp[12:14]
|
|
|
+ return f"{year}-{month}-{day} {hour}:{minute}:{second}"
|
|
|
+ except Exception:
|
|
|
+ pass
|
|
|
+ return timestamp
|
|
|
+
|
|
|
+
|
|
|
+# ============= 修复后的 api.py 关键部分 =============
|
|
|
+
|
|
|
+"""
|
|
|
+修复后的 api.py - 统一使用异步Redis客户端,移除复杂的事件循环管理
|
|
|
+"""
|
|
|
+
|
|
|
+import asyncio
|
|
|
+import logging
|
|
|
+import os
|
|
|
+from datetime import datetime
|
|
|
+from typing import Optional, Dict, Any
|
|
|
+
|
|
|
+from flask import Flask, request, jsonify
|
|
|
+import redis.asyncio as redis # 🔥 统一使用异步Redis
|
|
|
+
|
|
|
+try:
|
|
|
+ from .agent import CustomReactAgent
|
|
|
+except ImportError:
|
|
|
+ from agent import CustomReactAgent
|
|
|
+
|
|
|
+logging.basicConfig(level=logging.INFO)
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
+
|
|
|
+# 全局Agent实例
|
|
|
+_agent_instance: Optional[CustomReactAgent] = None
|
|
|
+_redis_client: Optional[redis.Redis] = None
|
|
|
+
|
|
|
+def validate_request_data(data: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
+ """验证请求数据(保持不变)"""
|
|
|
+ errors = []
|
|
|
+
|
|
|
+ question = data.get('question', '')
|
|
|
+ if not question or not question.strip():
|
|
|
+ errors.append('问题不能为空')
|
|
|
+ elif len(question) > 2000:
|
|
|
+ errors.append('问题长度不能超过2000字符')
|
|
|
+
|
|
|
+ user_id = data.get('user_id', 'guest')
|
|
|
+ if user_id and len(user_id) > 50:
|
|
|
+ errors.append('用户ID长度不能超过50字符')
|
|
|
+
|
|
|
+ if errors:
|
|
|
+ raise ValueError('; '.join(errors))
|
|
|
+
|
|
|
+ return {
|
|
|
+ 'question': question.strip(),
|
|
|
+ 'user_id': user_id or 'guest',
|
|
|
+ 'thread_id': data.get('thread_id')
|
|
|
+ }
|
|
|
+
|
|
|
+async def initialize_agent():
|
|
|
+ """🔥 异步初始化Agent"""
|
|
|
+ global _agent_instance, _redis_client
|
|
|
+
|
|
|
+ if _agent_instance is None:
|
|
|
+ logger.info("🚀 正在异步初始化 Custom React Agent...")
|
|
|
+ try:
|
|
|
+ os.environ['REDIS_URL'] = 'redis://localhost:6379'
|
|
|
+
|
|
|
+ # 初始化共享的Redis客户端
|
|
|
+ _redis_client = redis.from_url('redis://localhost:6379', decode_responses=True)
|
|
|
+ await _redis_client.ping()
|
|
|
+
|
|
|
+ _agent_instance = await CustomReactAgent.create()
|
|
|
+ logger.info("✅ Agent 异步初始化完成")
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"❌ Agent 异步初始化失败: {e}")
|
|
|
+ raise
|
|
|
+
|
|
|
+async def ensure_agent_ready():
|
|
|
+ """🔥 异步确保Agent实例可用"""
|
|
|
+ global _agent_instance
|
|
|
+
|
|
|
+ if _agent_instance is None:
|
|
|
+ await initialize_agent()
|
|
|
+
|
|
|
+ try:
|
|
|
+ test_result = await _agent_instance.get_user_recent_conversations("__test__", 1)
|
|
|
+ return True
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning(f"⚠️ Agent实例不可用: {e}")
|
|
|
+ _agent_instance = None
|
|
|
+ await initialize_agent()
|
|
|
+ return True
|
|
|
+
|
|
|
+async def cleanup_agent():
|
|
|
+ """🔥 异步清理Agent资源"""
|
|
|
+ global _agent_instance, _redis_client
|
|
|
+
|
|
|
+ if _agent_instance:
|
|
|
+ await _agent_instance.close()
|
|
|
+ logger.info("✅ Agent 资源已异步清理")
|
|
|
+ _agent_instance = None
|
|
|
+
|
|
|
+ if _redis_client:
|
|
|
+ await _redis_client.aclose()
|
|
|
+ logger.info("✅ Redis客户端已异步关闭")
|
|
|
+ _redis_client = None
|
|
|
+
|
|
|
+# 创建Flask应用
|
|
|
+app = Flask(__name__)
|
|
|
+
|
|
|
+# 🔥 移除所有同步包装函数:run_async_safely, ensure_agent_ready_sync
|
|
|
+
|
|
|
+@app.route("/")
|
|
|
+def root():
|
|
|
+ """健康检查端点(保持同步)"""
|
|
|
+ return jsonify({"message": "Custom React Agent API 服务正在运行"})
|
|
|
+
|
|
|
+@app.route('/health', methods=['GET'])
|
|
|
+def health_check():
|
|
|
+ """健康检查端点(保持同步)"""
|
|
|
+ try:
|
|
|
+ health_status = {
|
|
|
+ "status": "healthy",
|
|
|
+ "agent_initialized": _agent_instance is not None,
|
|
|
+ "timestamp": datetime.now().isoformat()
|
|
|
+ }
|
|
|
+ return jsonify(health_status), 200
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"健康检查失败: {e}")
|
|
|
+ return jsonify({"status": "unhealthy", "error": str(e)}), 500
|
|
|
+
|
|
|
+@app.route("/api/chat", methods=["POST"])
|
|
|
+async def chat_endpoint():
|
|
|
+ """🔥 异步智能问答接口"""
|
|
|
+ global _agent_instance
|
|
|
+
|
|
|
+ # 确保Agent已初始化
|
|
|
+ if not await ensure_agent_ready():
|
|
|
+ return jsonify({
|
|
|
+ "code": 503,
|
|
|
+ "message": "服务未就绪",
|
|
|
+ "success": False,
|
|
|
+ "error": "Agent 初始化失败"
|
|
|
+ }), 503
|
|
|
+
|
|
|
+ try:
|
|
|
+ data = request.get_json()
|
|
|
+ if not data:
|
|
|
+ return jsonify({
|
|
|
+ "code": 400,
|
|
|
+ "message": "请求参数错误",
|
|
|
+ "success": False,
|
|
|
+ "error": "请求体不能为空"
|
|
|
+ }), 400
|
|
|
+
|
|
|
+ validated_data = validate_request_data(data)
|
|
|
+
|
|
|
+ logger.info(f"📨 收到请求 - User: {validated_data['user_id']}, Question: {validated_data['question'][:50]}...")
|
|
|
+
|
|
|
+ # 🔥 直接调用异步方法,不需要事件循环包装
|
|
|
+ agent_result = await _agent_instance.chat(
|
|
|
+ message=validated_data['question'],
|
|
|
+ user_id=validated_data['user_id'],
|
|
|
+ thread_id=validated_data['thread_id']
|
|
|
+ )
|
|
|
+
|
|
|
+ if not agent_result.get("success", False):
|
|
|
+ error_msg = agent_result.get("error", "Agent处理失败")
|
|
|
+ logger.error(f"❌ Agent处理失败: {error_msg}")
|
|
|
+
|
|
|
+ return jsonify({
|
|
|
+ "code": 500,
|
|
|
+ "message": "处理失败",
|
|
|
+ "success": False,
|
|
|
+ "error": error_msg,
|
|
|
+ "data": {
|
|
|
+ "react_agent_meta": {
|
|
|
+ "thread_id": agent_result.get("thread_id"),
|
|
|
+ "agent_version": "custom_react_v1_async",
|
|
|
+ "execution_path": ["error"]
|
|
|
+ },
|
|
|
+ "timestamp": datetime.now().isoformat()
|
|
|
+ }
|
|
|
+ }), 500
|
|
|
+
|
|
|
+ api_data = agent_result.get("api_data", {})
|
|
|
+ response_data = {
|
|
|
+ **api_data,
|
|
|
+ "timestamp": datetime.now().isoformat()
|
|
|
+ }
|
|
|
+
|
|
|
+ logger.info(f"✅ 异步请求处理成功 - Thread: {api_data.get('react_agent_meta', {}).get('thread_id')}")
|
|
|
+
|
|
|
+ return jsonify({
|
|
|
+ "code": 200,
|
|
|
+ "message": "操作成功",
|
|
|
+ "success": True,
|
|
|
+ "data": response_data
|
|
|
+ })
|
|
|
+
|
|
|
+ except ValueError as e:
|
|
|
+ logger.warning(f"⚠️ 参数验证失败: {e}")
|
|
|
+ return jsonify({
|
|
|
+ "code": 400,
|
|
|
+ "message": "请求参数错误",
|
|
|
+ "success": False,
|
|
|
+ "error": str(e)
|
|
|
+ }), 400
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"❌ 未预期的错误: {e}", exc_info=True)
|
|
|
+ return jsonify({
|
|
|
+ "code": 500,
|
|
|
+ "message": "服务器内部错误",
|
|
|
+ "success": False,
|
|
|
+ "error": "系统异常,请稍后重试"
|
|
|
+ }), 500
|
|
|
+
|
|
|
+@app.route('/api/v0/react/users/<user_id>/conversations', methods=['GET'])
|
|
|
+async def get_user_conversations(user_id: str):
|
|
|
+ """🔥 异步获取用户的聊天记录列表"""
|
|
|
+ global _agent_instance
|
|
|
+
|
|
|
+ try:
|
|
|
+ limit = request.args.get('limit', 10, type=int)
|
|
|
+ limit = max(1, min(limit, 50))
|
|
|
+
|
|
|
+ logger.info(f"📋 异步获取用户 {user_id} 的对话列表,限制 {limit} 条")
|
|
|
+
|
|
|
+ if not await ensure_agent_ready():
|
|
|
+ return jsonify({
|
|
|
+ "success": False,
|
|
|
+ "error": "Agent 未就绪",
|
|
|
+ "timestamp": datetime.now().isoformat()
|
|
|
+ }), 503
|
|
|
+
|
|
|
+ # 🔥 直接调用异步方法
|
|
|
+ conversations = await _agent_instance.get_user_recent_conversations(user_id, limit)
|
|
|
+
|
|
|
+ return jsonify({
|
|
|
+ "success": True,
|
|
|
+ "data": {
|
|
|
+ "user_id": user_id,
|
|
|
+ "conversations": conversations,
|
|
|
+ "total_count": len(conversations),
|
|
|
+ "limit": limit
|
|
|
+ },
|
|
|
+ "timestamp": datetime.now().isoformat()
|
|
|
+ }), 200
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"❌ 异步获取用户 {user_id} 对话列表失败: {e}")
|
|
|
+ return jsonify({
|
|
|
+ "success": False,
|
|
|
+ "error": str(e),
|
|
|
+ "timestamp": datetime.now().isoformat()
|
|
|
+ }), 500
|
|
|
+
|
|
|
+@app.route('/api/v0/react/users/<user_id>/conversations/<thread_id>', methods=['GET'])
|
|
|
+async def get_user_conversation_detail(user_id: str, thread_id: str):
|
|
|
+ """🔥 异步获取特定对话的详细历史"""
|
|
|
+ global _agent_instance
|
|
|
+
|
|
|
+ try:
|
|
|
+ if not thread_id.startswith(f"{user_id}:"):
|
|
|
+ return jsonify({
|
|
|
+ "success": False,
|
|
|
+ "error": f"Thread ID {thread_id} 不属于用户 {user_id}",
|
|
|
+ "timestamp": datetime.now().isoformat()
|
|
|
+ }), 400
|
|
|
+
|
|
|
+ logger.info(f"📖 异步获取用户 {user_id} 的对话 {thread_id} 详情")
|
|
|
+
|
|
|
+ if not await ensure_agent_ready():
|
|
|
+ return jsonify({
|
|
|
+ "success": False,
|
|
|
+ "error": "Agent 未就绪",
|
|
|
+ "timestamp": datetime.now().isoformat()
|
|
|
+ }), 503
|
|
|
+
|
|
|
+ # 🔥 直接调用异步方法
|
|
|
+ history = await _agent_instance.get_conversation_history(thread_id)
|
|
|
+ logger.info(f"✅ 异步成功获取对话历史,消息数量: {len(history)}")
|
|
|
+
|
|
|
+ if not history:
|
|
|
+ return jsonify({
|
|
|
+ "success": False,
|
|
|
+ "error": f"未找到对话 {thread_id}",
|
|
|
+ "timestamp": datetime.now().isoformat()
|
|
|
+ }), 404
|
|
|
+
|
|
|
+ return jsonify({
|
|
|
+ "success": True,
|
|
|
+ "data": {
|
|
|
+ "user_id": user_id,
|
|
|
+ "thread_id": thread_id,
|
|
|
+ "message_count": len(history),
|
|
|
+ "messages": history
|
|
|
+ },
|
|
|
+ "timestamp": datetime.now().isoformat()
|
|
|
+ }), 200
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ import traceback
|
|
|
+ logger.error(f"❌ 异步获取对话 {thread_id} 详情失败: {e}")
|
|
|
+ logger.error(f"❌ 详细错误信息: {traceback.format_exc()}")
|
|
|
+ return jsonify({
|
|
|
+ "success": False,
|
|
|
+ "error": str(e),
|
|
|
+ "timestamp": datetime.now().isoformat()
|
|
|
+ }), 500
|
|
|
+
|
|
|
+# 🔥 异步Redis API(如果还需要直接Redis访问)
|
|
|
+async def get_user_conversations_async(user_id: str, limit: int = 10):
|
|
|
+ """🔥 完全异步的Redis查询函数"""
|
|
|
+ global _redis_client
|
|
|
+
|
|
|
+ try:
|
|
|
+ if not _redis_client:
|
|
|
+ _redis_client = redis.from_url('redis://localhost:6379', decode_responses=True)
|
|
|
+ await _redis_client.ping()
|
|
|
+
|
|
|
+ pattern = f"checkpoint:{user_id}:*"
|
|
|
+ logger.info(f"🔍 异步扫描模式: {pattern}")
|
|
|
+
|
|
|
+ keys = []
|
|
|
+ cursor = 0
|
|
|
+ while True:
|
|
|
+ cursor, batch = await _redis_client.scan(cursor=cursor, match=pattern, count=1000)
|
|
|
+ keys.extend(batch)
|
|
|
+ if cursor == 0:
|
|
|
+ break
|
|
|
+
|
|
|
+ logger.info(f"📋 异步找到 {len(keys)} 个keys")
|
|
|
+
|
|
|
+ # 解析和处理逻辑(与原来相同,但使用异步Redis操作)
|
|
|
+ thread_data = {}
|
|
|
+ for key in keys:
|
|
|
+ try:
|
|
|
+ parts = key.split(':')
|
|
|
+ if len(parts) >= 4:
|
|
|
+ thread_id = f"{parts[1]}:{parts[2]}"
|
|
|
+ timestamp = parts[2]
|
|
|
+
|
|
|
+ if thread_id not in thread_data:
|
|
|
+ thread_data[thread_id] = {
|
|
|
+ "thread_id": thread_id,
|
|
|
+ "timestamp": timestamp,
|
|
|
+ "keys": []
|
|
|
+ }
|
|
|
+ thread_data[thread_id]["keys"].append(key)
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning(f"解析key失败 {key}: {e}")
|
|
|
+ continue
|
|
|
+
|
|
|
+ sorted_threads = sorted(
|
|
|
+ thread_data.values(),
|
|
|
+ key=lambda x: x["timestamp"],
|
|
|
+ reverse=True
|
|
|
+ )[:limit]
|
|
|
+
|
|
|
+ conversations = []
|
|
|
+ for thread_info in sorted_threads:
|
|
|
+ try:
|
|
|
+ thread_id = thread_info["thread_id"]
|
|
|
+ latest_key = max(thread_info["keys"])
|
|
|
+
|
|
|
+ # 🔥 使用异步Redis获取
|
|
|
+ key_type = await _redis_client.type(latest_key)
|
|
|
+
|
|
|
+ data = None
|
|
|
+ if key_type == 'string':
|
|
|
+ data = await _redis_client.get(latest_key)
|
|
|
+ elif key_type == 'ReJSON-RL':
|
|
|
+ try:
|
|
|
+ data = await _redis_client.execute_command('JSON.GET', latest_key)
|
|
|
+ except Exception as json_error:
|
|
|
+ logger.error(f"❌ 异步JSON.GET 失败: {json_error}")
|
|
|
+ continue
|
|
|
+
|
|
|
+ if data:
|
|
|
+ try:
|
|
|
+ import json
|
|
|
+ checkpoint_data = json.loads(data)
|
|
|
+
|
|
|
+ messages = []
|
|
|
+ if 'checkpoint' in checkpoint_data:
|
|
|
+ checkpoint = checkpoint_data['checkpoint']
|
|
|
+ if isinstance(checkpoint, dict) and 'channel_values' in checkpoint:
|
|
|
+ channel_values = checkpoint['channel_values']
|
|
|
+ if isinstance(channel_values, dict) and 'messages' in channel_values:
|
|
|
+ messages = channel_values['messages']
|
|
|
+
|
|
|
+ preview = "空对话"
|
|
|
+ if messages:
|
|
|
+ for msg in messages:
|
|
|
+ if isinstance(msg, dict):
|
|
|
+ if (msg.get('lc') == 1 and
|
|
|
+ msg.get('type') == 'constructor' and
|
|
|
+ 'id' in msg and
|
|
|
+ isinstance(msg['id'], list) and
|
|
|
+ len(msg['id']) >= 4 and
|
|
|
+ msg['id'][3] == 'HumanMessage' and
|
|
|
+ 'kwargs' in msg):
|
|
|
+
|
|
|
+ kwargs = msg['kwargs']
|
|
|
+ if kwargs.get('type') == 'human' and 'content' in kwargs:
|
|
|
+ content = str(kwargs['content'])
|
|
|
+ preview = content[:50] + "..." if len(content) > 50 else content
|
|
|
+ break
|
|
|
+
|
|
|
+ conversations.append({
|
|
|
+ "thread_id": thread_id,
|
|
|
+ "user_id": user_id,
|
|
|
+ "timestamp": thread_info["timestamp"],
|
|
|
+ "message_count": len(messages),
|
|
|
+ "conversation_preview": preview
|
|
|
+ })
|
|
|
+
|
|
|
+ except json.JSONDecodeError:
|
|
|
+ logger.error(f"❌ 异步JSON解析失败")
|
|
|
+ continue
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"异步处理thread {thread_info['thread_id']} 失败: {e}")
|
|
|
+ continue
|
|
|
+
|
|
|
+ logger.info(f"✅ 异步返回 {len(conversations)} 个对话")
|
|
|
+ return conversations
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"❌ 异步Redis查询失败: {e}")
|
|
|
+ return []
|
|
|
+
|
|
|
+# 🔥 异步启动和清理
|
|
|
+async def startup():
|
|
|
+ """应用启动时的异步初始化"""
|
|
|
+ logger.info("🚀 启动异步Flask应用...")
|
|
|
+ try:
|
|
|
+ await initialize_agent()
|
|
|
+ logger.info("✅ Agent 预初始化完成")
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"❌ 启动时Agent初始化失败: {e}")
|
|
|
+
|
|
|
+async def shutdown():
|
|
|
+ """应用关闭时的异步清理"""
|
|
|
+ logger.info("🔄 关闭异步Flask应用...")
|
|
|
+ try:
|
|
|
+ await cleanup_agent()
|
|
|
+ logger.info("✅ 资源清理完成")
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"❌ 关闭时清理失败: {e}")
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ # 🔥 简化的启动方式 - Flask 3.x 原生支持异步
|
|
|
+ logger.info("🚀 使用Flask内置异步支持启动...")
|
|
|
+
|
|
|
+ # 信号处理
|
|
|
+ import signal
|
|
|
+
|
|
|
+ def signal_handler(signum, frame):
|
|
|
+ logger.info("🛑 收到关闭信号,开始清理...")
|
|
|
+ # 在信号处理中,我们只能打印消息,实际清理在程序正常退出时进行
|
|
|
+ print("正在关闭服务...")
|
|
|
+ exit(0)
|
|
|
+
|
|
|
+ signal.signal(signal.SIGINT, signal_handler)
|
|
|
+ signal.signal(signal.SIGTERM, signal_handler)
|
|
|
+
|
|
|
+ # 启动Flask应用
|
|
|
+ app.run(host="0.0.0.0", port=8000, debug=False)
|