让我仔细分析你的代码,回答这两个关键问题: ## 1. CustomReactAgent类异步化分析 ### ✅ 已经异步化的方法: ```python @classmethod async def create(cls): # ✅ 异步工厂方法 async def _async_init(self): # ✅ 异步初始化 async def close(self): # ✅ 异步清理 async def chat(self, message, user_id, thread_id): # ✅ 异步聊天 async def get_conversation_history(self, thread_id): # ✅ 异步获取历史 async def get_user_recent_conversations(self, user_id, limit): # ✅ 异步获取对话列表 ``` ### ⚠️ **仍然是同步的关键方法(问题所在):** ```python def _should_continue(self, state: AgentState) -> str: # ❌ 同步 def _agent_node(self, state: AgentState) -> Dict[str, Any]: # ❌ 同步 def _prepare_tool_input_node(self, state: AgentState): # ❌ 同步 def _update_state_after_tool_node(self, state: AgentState): # ❌ 同步 def _format_final_response_node(self, state: AgentState): # ❌ 同步 ``` **问题分析**:你的StateGraph节点是同步的,但是内部调用了异步LLM: ```python def _agent_node(self, state: AgentState) -> Dict[str, Any]: # 这是同步函数,但调用异步LLM response = self.llm_with_tools.invoke(messages_for_llm) # 🔥 这里有问题 ``` ## 2. Redis客户端使用分析 ### 你的代码中混用了两种Redis客户端: **异步Redis客户端(在agent.py中):** ```python from redis.asyncio import Redis redis_client = Redis.from_url(config.REDIS_URL, decode_responses=True) ``` **同步Redis客户端(在api.py和enhanced_redis_api.py中):** ```python import redis # 同步版本 redis_client = redis.Redis(host='localhost', port=6379, decode_responses=True) ``` ### 🎯 **关键问题:混用导致的冲突** ## 完整的修复方案## 回答你的两个关键问题: ### 1. CustomReactAgent类异步化分析 **❌ 你的CustomReactAgent并没有完全异步化!** 虽然主要方法是async的,但**StateGraph的节点方法都是同步的**,这正是问题的根源: ```python # ❌ 这些是同步方法,但内部调用异步LLM def _agent_node(self, state: AgentState) -> Dict[str, Any]: response = self.llm_with_tools.invoke(messages_for_llm) # 🔥问题在这里 def _prepare_tool_input_node(self, state: AgentState) -> Dict[str, Any]: def _update_state_after_tool_node(self, state: AgentState) -> Dict[str, Any]: def _format_final_response_node(self, state: AgentState) -> Dict[str, Any]: def _should_continue(self, state: AgentState) -> str: ``` **核心问题**:LangGraph的StateGraph可以支持异步节点,但你没有使用! ### 2. Redis客户端使用分析 **❌ 你混用了同步和异步Redis客户端!** **异步Redis客户端(在agent.py中):** ```python from redis.asyncio import Redis # ✅ 异步版本 redis_client = Redis.from_url(config.REDIS_URL, decode_responses=True) ``` **同步Redis客户端(在api.py和enhanced_redis_api.py中):** ```python import redis # ❌ 同步版本 redis_client = redis.Redis(host='localhost', port=6379, decode_responses=True) ``` ## 🎯 **在你的场景中是否一定要用异步Redis?** **答案:强烈建议使用,原因如下:** 1. **架构一致性**:既然Flask路由是异步的,Redis操作也应该异步 2. **避免阻塞**:同步Redis操作会阻塞事件循环,影响并发性能 3. **消除错误源**:混用是导致"Event loop is closed"的重要原因 4. **对话机器人特性**:用户期望快速响应,异步Redis能提供更好的并发能力 ## 🔧 **最关键的修改点** 我在上面的artifact中提供了完整的修复方案,关键修改包括: ### **1. StateGraph节点全部改为异步:** ```python # ✅ 修改后:异步节点 async def _async_agent_node(self, state: AgentState) -> Dict[str, Any]: response = await self.llm_with_tools.ainvoke(messages_for_llm) # 使用ainvoke async def _async_prepare_tool_input_node(self, state: AgentState) -> Dict[str, Any]: async def _async_update_state_after_tool_node(self, state: AgentState) -> Dict[str, Any]: async def _async_format_final_response_node(self, state: AgentState) -> Dict[str, Any]: async def _async_should_continue(self, state: AgentState) -> str: ``` ### **2. 统一Redis客户端:** ```python # ✅ 统一使用异步Redis import redis.asyncio as redis self.redis_client = redis.from_url(config.REDIS_URL, decode_responses=True) await self.redis_client.ping() ``` ### **3. 移除复杂的事件循环管理:** ```python # ❌ 删除这些函数 def run_async_safely(): # 删除 def ensure_agent_ready_sync(): # 删除 # ✅ 直接使用async/await @app.route("/api/chat", methods=["POST"]) async def chat_endpoint(): agent_result = await _agent_instance.chat(...) # 直接调用 ``` ## 📋 **修改建议总结** 1. **立即修改**:将StateGraph所有节点改为async版本 2. **统一Redis**:全部使用`redis.asyncio` 3. **删除包装**:移除所有`run_async_safely`类的函数 4. **Flask路由**:核心API路由改为async 5. **测试验证**:修改后测试Event loop错误是否消失 **按照我提供的修复方案修改后,你就能直接用`python api.py`启动,不会再有任何事件循环问题!** """ CustomReactAgent完全异步化修复方案 解决StateGraph节点同步/异步混用问题,统一Redis客户端使用 """ # ============= 修复后的 agent.py ============= import logging import json import pandas as pd from typing import List, Optional, Dict, Any, Tuple from contextlib import AsyncExitStack from langchain_openai import ChatOpenAI from langchain_core.messages import HumanMessage, ToolMessage, BaseMessage, SystemMessage, AIMessage from langgraph.graph import StateGraph, END from langgraph.prebuilt import ToolNode import redis.asyncio as redis # 🔥 统一使用异步Redis try: from langgraph.checkpoint.redis import AsyncRedisSaver except ImportError: AsyncRedisSaver = None # 从新模块导入配置、状态和工具 try: from . import config from .state import AgentState from .sql_tools import sql_tools except ImportError: import config from state import AgentState from sql_tools import sql_tools logger = logging.getLogger(__name__) class CustomReactAgent: """ 完全异步化的 CustomReactAgent 所有节点方法都是异步的,统一使用异步Redis客户端 """ def __init__(self): """私有构造函数,请使用 create() 类方法来创建实例。""" self.llm = None self.tools = None self.agent_executor = None self.checkpointer = None self._exit_stack = None self.redis_client = None # 🔥 添加Redis客户端引用 @classmethod async def create(cls): """异步工厂方法,创建并初始化 CustomReactAgent 实例。""" instance = cls() await instance._async_init() return instance async def _async_init(self): """异步初始化所有组件。""" logger.info("🚀 开始初始化 CustomReactAgent...") # 1. 初始化异步Redis客户端 self.redis_client = redis.from_url(config.REDIS_URL, decode_responses=True) try: await self.redis_client.ping() logger.info(f" ✅ Redis连接成功: {config.REDIS_URL}") except Exception as e: logger.error(f" ❌ Redis连接失败: {e}") raise # 2. 初始化 LLM self.llm = ChatOpenAI( api_key=config.QWEN_API_KEY, base_url=config.QWEN_BASE_URL, model=config.QWEN_MODEL, temperature=0.1, timeout=config.NETWORK_TIMEOUT, max_retries=config.MAX_RETRIES, extra_body={ "enable_thinking": False, "misc": { "ensure_ascii": False } } ) logger.info(f" LLM 已初始化,模型: {config.QWEN_MODEL}") # 3. 绑定工具 self.tools = sql_tools self.llm_with_tools = self.llm.bind_tools(self.tools) logger.info(f" 已绑定 {len(self.tools)} 个工具。") # 4. 初始化 Redis Checkpointer if config.REDIS_ENABLED and AsyncRedisSaver is not None: try: self._exit_stack = AsyncExitStack() checkpointer_manager = AsyncRedisSaver.from_conn_string(config.REDIS_URL) self.checkpointer = await self._exit_stack.enter_async_context(checkpointer_manager) await self.checkpointer.asetup() logger.info(f" AsyncRedisSaver 持久化已启用: {config.REDIS_URL}") except Exception as e: logger.error(f" ❌ RedisSaver 初始化失败: {e}", exc_info=True) if self._exit_stack: await self._exit_stack.aclose() self.checkpointer = None else: logger.warning(" Redis 持久化功能已禁用。") # 5. 构建 StateGraph self.agent_executor = self._create_graph() logger.info(" StateGraph 已构建并编译。") logger.info("✅ CustomReactAgent 初始化完成。") async def close(self): """清理资源,关闭所有连接。""" if self._exit_stack: await self._exit_stack.aclose() self._exit_stack = None self.checkpointer = None logger.info("✅ RedisSaver 资源已通过 AsyncExitStack 释放。") if self.redis_client: await self.redis_client.aclose() logger.info("✅ Redis客户端已关闭。") def _create_graph(self): """定义并编译最终的、正确的 StateGraph 结构。""" builder = StateGraph(AgentState) # 🔥 关键修改:所有节点都是异步的 builder.add_node("agent", self._async_agent_node) builder.add_node("prepare_tool_input", self._async_prepare_tool_input_node) builder.add_node("tools", ToolNode(self.tools)) builder.add_node("update_state_after_tool", self._async_update_state_after_tool_node) builder.add_node("format_final_response", self._async_format_final_response_node) # 建立正确的边连接 builder.set_entry_point("agent") builder.add_conditional_edges( "agent", self._async_should_continue, # 🔥 异步条件判断 { "continue": "prepare_tool_input", "end": "format_final_response" } ) builder.add_edge("prepare_tool_input", "tools") builder.add_edge("tools", "update_state_after_tool") builder.add_edge("update_state_after_tool", "agent") builder.add_edge("format_final_response", END) return builder.compile(checkpointer=self.checkpointer) async def _async_should_continue(self, state: AgentState) -> str: """🔥 异步版本:判断是继续调用工具还是结束。""" last_message = state["messages"][-1] if hasattr(last_message, "tool_calls") and last_message.tool_calls: return "continue" return "end" async def _async_agent_node(self, state: AgentState) -> Dict[str, Any]: """🔥 异步版本:Agent 节点,使用异步LLM调用。""" logger.info(f"🧠 [Async Node] agent - Thread: {state['thread_id']}") messages_for_llm = list(state["messages"]) if state.get("suggested_next_step"): instruction = f"提示:建议下一步使用工具 '{state['suggested_next_step']}'。" messages_for_llm.append(SystemMessage(content=instruction)) # 🔥 关键修改:使用异步LLM调用 import time max_retries = config.MAX_RETRIES for attempt in range(max_retries): try: # 使用异步调用 response = await self.llm_with_tools.ainvoke(messages_for_llm) logger.info(f" ✅ 异步LLM调用成功") return {"messages": [response]} except Exception as e: error_msg = str(e) logger.warning(f" ⚠️ 异步LLM调用失败 (尝试 {attempt + 1}/{max_retries}): {error_msg}") if any(keyword in error_msg for keyword in [ "Connection error", "APIConnectionError", "ConnectError", "timeout", "远程主机强迫关闭", "网络连接" ]): if attempt < max_retries - 1: wait_time = config.RETRY_BASE_DELAY ** attempt logger.info(f" 🔄 网络错误,{wait_time}秒后重试...") await asyncio.sleep(wait_time) # 🔥 使用async sleep continue else: logger.error(f" ❌ 网络连接持续失败,返回降级回答") sql_data = await self._async_extract_latest_sql_data(state["messages"]) if sql_data: fallback_content = "抱歉,由于网络连接问题,无法生成完整的文字总结。不过查询已成功执行,结果如下:\n\n" + sql_data else: fallback_content = "抱歉,由于网络连接问题,无法完成此次请求。请稍后重试或检查网络连接。" fallback_response = AIMessage(content=fallback_content) return {"messages": [fallback_response]} else: logger.error(f" ❌ LLM调用出现非网络错误: {error_msg}") raise e async def _async_prepare_tool_input_node(self, state: AgentState) -> Dict[str, Any]: """🔥 异步版本:信息组装节点。""" logger.info(f"🛠️ [Async Node] prepare_tool_input - Thread: {state['thread_id']}") last_message = state["messages"][-1] if not hasattr(last_message, "tool_calls") or not last_message.tool_calls: return {"messages": [last_message]} new_tool_calls = [] for tool_call in last_message.tool_calls: if tool_call["name"] == "generate_sql": logger.info(" 检测到 generate_sql 调用,注入历史消息。") modified_args = tool_call["args"].copy() clean_history = [] messages_except_current = state["messages"][:-1] for msg in messages_except_current: if isinstance(msg, HumanMessage): clean_history.append({ "type": "human", "content": msg.content }) elif isinstance(msg, AIMessage): if msg.content and "[Formatted Output]" in msg.content: clean_content = msg.content.replace("[Formatted Output]\n", "") clean_history.append({ "type": "ai", "content": clean_content }) modified_args["history_messages"] = clean_history logger.info(f" 注入了 {len(clean_history)} 条过滤后的历史消息") new_tool_calls.append({ "name": tool_call["name"], "args": modified_args, "id": tool_call["id"], }) else: new_tool_calls.append(tool_call) last_message.tool_calls = new_tool_calls return {"messages": [last_message]} async def _async_update_state_after_tool_node(self, state: AgentState) -> Dict[str, Any]: """🔥 异步版本:在工具执行后,更新 suggested_next_step。""" logger.info(f"📝 [Async Node] update_state_after_tool - Thread: {state['thread_id']}") last_tool_message = state['messages'][-1] tool_name = last_tool_message.name tool_output = last_tool_message.content next_step = None if tool_name == 'generate_sql': if "失败" in tool_output or "无法生成" in tool_output: next_step = 'answer_with_common_sense' else: next_step = 'valid_sql' elif tool_name == 'valid_sql': if "失败" in tool_output: next_step = 'analyze_validation_error' else: next_step = 'run_sql' elif tool_name == 'run_sql': next_step = 'summarize_final_answer' logger.info(f" Tool '{tool_name}' executed. Suggested next step: {next_step}") return {"suggested_next_step": next_step} async def _async_format_final_response_node(self, state: AgentState) -> Dict[str, Any]: """🔥 异步版本:最终输出格式化节点。""" logger.info(f"🎨 [Async Node] format_final_response - Thread: {state['thread_id']}") last_message = state['messages'][-1] last_message.content = f"[Formatted Output]\n{last_message.content}" # 生成API格式的数据 api_data = await self._async_generate_api_data(state) return { "messages": [last_message], "api_data": api_data } async def _async_generate_api_data(self, state: AgentState) -> Dict[str, Any]: """🔥 异步版本:生成API格式的数据结构""" logger.info("📊 异步生成API格式数据...") last_message = state['messages'][-1] response_content = last_message.content if response_content.startswith("[Formatted Output]\n"): response_content = response_content.replace("[Formatted Output]\n", "") api_data = { "response": response_content } sql_info = await self._async_extract_sql_and_data(state['messages']) if sql_info['sql']: api_data["sql"] = sql_info['sql'] if sql_info['records']: api_data["records"] = sql_info['records'] api_data["react_agent_meta"] = await self._async_collect_agent_metadata(state) logger.info(f" API数据生成完成,包含字段: {list(api_data.keys())}") return api_data async def _async_extract_sql_and_data(self, messages: List[BaseMessage]) -> Dict[str, Any]: """🔥 异步版本:从消息历史中提取SQL和数据记录""" result = {"sql": None, "records": None} last_human_index = -1 for i in range(len(messages) - 1, -1, -1): if isinstance(messages[i], HumanMessage): last_human_index = i break if last_human_index == -1: return result current_conversation = messages[last_human_index:] sql_query = None sql_data = None for msg in current_conversation: if isinstance(msg, ToolMessage): if msg.name == 'generate_sql': content = msg.content if content and not any(keyword in content for keyword in ["失败", "无法生成", "Database query failed"]): sql_query = content.strip() elif msg.name == 'run_sql': try: import json parsed_data = json.loads(msg.content) if isinstance(parsed_data, list) and len(parsed_data) > 0: columns = list(parsed_data[0].keys()) if parsed_data else [] sql_data = { "columns": columns, "rows": parsed_data, "total_row_count": len(parsed_data), "is_limited": False } except (json.JSONDecodeError, Exception) as e: logger.warning(f" 解析SQL结果失败: {e}") if sql_query: result["sql"] = sql_query if sql_data: result["records"] = sql_data return result async def _async_collect_agent_metadata(self, state: AgentState) -> Dict[str, Any]: """🔥 异步版本:收集Agent元数据""" messages = state['messages'] tools_used = [] sql_execution_count = 0 context_injected = False conversation_rounds = sum(1 for msg in messages if isinstance(msg, HumanMessage)) for msg in messages: if isinstance(msg, ToolMessage): if msg.name not in tools_used: tools_used.append(msg.name) if msg.name == 'run_sql': sql_execution_count += 1 elif isinstance(msg, AIMessage) and hasattr(msg, 'tool_calls') and msg.tool_calls: for tool_call in msg.tool_calls: tool_name = tool_call.get('name') if tool_name and tool_name not in tools_used: tools_used.append(tool_name) if (tool_name == 'generate_sql' and tool_call.get('args', {}).get('history_messages')): context_injected = True execution_path = ["agent"] if tools_used: execution_path.extend(["prepare_tool_input", "tools"]) execution_path.append("format_final_response") return { "thread_id": state['thread_id'], "conversation_rounds": conversation_rounds, "tools_used": tools_used, "execution_path": execution_path, "total_messages": len(messages), "sql_execution_count": sql_execution_count, "context_injected": context_injected, "agent_version": "custom_react_v1_async" } async def _async_extract_latest_sql_data(self, messages: List[BaseMessage]) -> Optional[str]: """🔥 异步版本:提取最新的SQL执行结果""" logger.info("🔍 异步提取最新的SQL执行结果...") last_human_index = -1 for i in range(len(messages) - 1, -1, -1): if isinstance(messages[i], HumanMessage): last_human_index = i break if last_human_index == -1: logger.info(" 未找到用户消息,跳过SQL数据提取") return None current_conversation = messages[last_human_index:] logger.info(f" 当前对话轮次包含 {len(current_conversation)} 条消息") for msg in reversed(current_conversation): if isinstance(msg, ToolMessage) and msg.name == 'run_sql': logger.info(f" 找到当前对话轮次的run_sql结果: {msg.content[:100]}...") try: parsed_data = json.loads(msg.content) formatted_content = json.dumps(parsed_data, ensure_ascii=False, separators=(',', ':')) logger.info(f" 已转换Unicode转义序列为中文字符") return formatted_content except json.JSONDecodeError: logger.warning(f" SQL结果不是有效JSON格式,返回原始内容") return msg.content logger.info(" 当前对话轮次中未找到run_sql执行结果") return None async def chat(self, message: str, user_id: str, thread_id: Optional[str] = None) -> Dict[str, Any]: """🔥 完全异步的聊天处理方法""" if not thread_id: now = pd.Timestamp.now() milliseconds = int(now.microsecond / 1000) thread_id = f"{user_id}:{now.strftime('%Y%m%d%H%M%S')}{milliseconds:03d}" logger.info(f"🆕 新建会话,Thread ID: {thread_id}") config = { "configurable": { "thread_id": thread_id, } } inputs = { "messages": [HumanMessage(content=message)], "user_id": user_id, "thread_id": thread_id, "suggested_next_step": None, } try: # 🔥 使用异步调用 final_state = await self.agent_executor.ainvoke(inputs, config) answer = final_state["messages"][-1].content sql_data = await self._async_extract_latest_sql_data(final_state["messages"]) logger.info(f"✅ 异步处理完成 - Final Answer: '{answer}'") result = { "success": True, "answer": answer, "thread_id": thread_id } if sql_data: result["sql_data"] = sql_data logger.info(" 📊 已包含SQL原始数据") if "api_data" in final_state: result["api_data"] = final_state["api_data"] logger.info(" 🔌 已包含API格式数据") return result except Exception as e: logger.error(f"❌ 异步处理过程中发生严重错误 - Thread: {thread_id}: {e}", exc_info=True) return {"success": False, "error": str(e), "thread_id": thread_id} async def get_conversation_history(self, thread_id: str) -> List[Dict[str, Any]]: """🔥 完全异步的对话历史获取""" if not self.checkpointer: return [] config = {"configurable": {"thread_id": thread_id}} try: conversation_state = await self.checkpointer.aget(config) except RuntimeError as e: if "Event loop is closed" in str(e): logger.warning(f"⚠️ Event loop已关闭,返回空结果: {thread_id}") return [] else: raise if not conversation_state: return [] history = [] messages = conversation_state.get('channel_values', {}).get('messages', []) for msg in messages: if isinstance(msg, HumanMessage): role = "human" elif isinstance(msg, ToolMessage): role = "tool" else: role = "ai" history.append({ "type": role, "content": msg.content, "tool_calls": getattr(msg, 'tool_calls', None) }) return history async def get_user_recent_conversations(self, user_id: str, limit: int = 10) -> List[Dict[str, Any]]: """🔥 完全异步的用户对话列表获取""" if not self.checkpointer: return [] try: # 🔥 使用统一的异步Redis客户端 pattern = f"checkpoint:{user_id}:*" logger.info(f"🔍 异步扫描模式: {pattern}") user_threads = {} cursor = 0 while True: cursor, keys = await self.redis_client.scan( cursor=cursor, match=pattern, count=1000 ) for key in keys: try: key_str = key.decode() if isinstance(key, bytes) else key parts = key_str.split(':') if len(parts) >= 4: thread_id = f"{parts[1]}:{parts[2]}" timestamp = parts[2] if thread_id not in user_threads: user_threads[thread_id] = { "thread_id": thread_id, "timestamp": timestamp, "latest_key": key_str } else: if len(parts) > 4 and parts[4] > user_threads[thread_id]["latest_key"].split(':')[4]: user_threads[thread_id]["latest_key"] = key_str except Exception as e: logger.warning(f"解析key {key} 失败: {e}") continue if cursor == 0: break # 按时间戳排序 sorted_threads = sorted( user_threads.values(), key=lambda x: x["timestamp"], reverse=True )[:limit] # 获取每个thread的详细信息 conversations = [] for thread_info in sorted_threads: try: thread_id = thread_info["thread_id"] thread_config = {"configurable": {"thread_id": thread_id}} try: state = await self.checkpointer.aget(thread_config) except RuntimeError as e: if "Event loop is closed" in str(e): logger.warning(f"⚠️ Event loop已关闭,跳过thread: {thread_id}") continue else: raise if state and state.get('channel_values', {}).get('messages'): messages = state['channel_values']['messages'] preview = self._generate_conversation_preview(messages) conversations.append({ "thread_id": thread_id, "user_id": user_id, "timestamp": thread_info["timestamp"], "message_count": len(messages), "last_message": messages[-1].content if messages else None, "last_updated": state.get('created_at'), "conversation_preview": preview, "formatted_time": self._format_timestamp(thread_info["timestamp"]) }) except Exception as e: logger.error(f"获取thread {thread_info['thread_id']} 详情失败: {e}") continue logger.info(f"✅ 异步找到用户 {user_id} 的 {len(conversations)} 个对话") return conversations except Exception as e: logger.error(f"❌ 异步获取用户 {user_id} 对话列表失败: {e}") return [] def _generate_conversation_preview(self, messages: List[BaseMessage]) -> str: """生成对话预览(保持同步,因为是纯计算)""" if not messages: return "空对话" for msg in messages: if isinstance(msg, HumanMessage): content = str(msg.content) return content[:50] + "..." if len(content) > 50 else content return "系统消息" def _format_timestamp(self, timestamp: str) -> str: """格式化时间戳为可读格式(保持同步,因为是纯计算)""" try: if len(timestamp) >= 14: year = timestamp[:4] month = timestamp[4:6] day = timestamp[6:8] hour = timestamp[8:10] minute = timestamp[10:12] second = timestamp[12:14] return f"{year}-{month}-{day} {hour}:{minute}:{second}" except Exception: pass return timestamp # ============= 修复后的 api.py 关键部分 ============= """ 修复后的 api.py - 统一使用异步Redis客户端,移除复杂的事件循环管理 """ import asyncio import logging import os from datetime import datetime from typing import Optional, Dict, Any from flask import Flask, request, jsonify import redis.asyncio as redis # 🔥 统一使用异步Redis try: from .agent import CustomReactAgent except ImportError: from agent import CustomReactAgent logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # 全局Agent实例 _agent_instance: Optional[CustomReactAgent] = None _redis_client: Optional[redis.Redis] = None def validate_request_data(data: Dict[str, Any]) -> Dict[str, Any]: """验证请求数据(保持不变)""" errors = [] question = data.get('question', '') if not question or not question.strip(): errors.append('问题不能为空') elif len(question) > 2000: errors.append('问题长度不能超过2000字符') user_id = data.get('user_id', 'guest') if user_id and len(user_id) > 50: errors.append('用户ID长度不能超过50字符') if errors: raise ValueError('; '.join(errors)) return { 'question': question.strip(), 'user_id': user_id or 'guest', 'thread_id': data.get('thread_id') } async def initialize_agent(): """🔥 异步初始化Agent""" global _agent_instance, _redis_client if _agent_instance is None: logger.info("🚀 正在异步初始化 Custom React Agent...") try: os.environ['REDIS_URL'] = 'redis://localhost:6379' # 初始化共享的Redis客户端 _redis_client = redis.from_url('redis://localhost:6379', decode_responses=True) await _redis_client.ping() _agent_instance = await CustomReactAgent.create() logger.info("✅ Agent 异步初始化完成") except Exception as e: logger.error(f"❌ Agent 异步初始化失败: {e}") raise async def ensure_agent_ready(): """🔥 异步确保Agent实例可用""" global _agent_instance if _agent_instance is None: await initialize_agent() try: test_result = await _agent_instance.get_user_recent_conversations("__test__", 1) return True except Exception as e: logger.warning(f"⚠️ Agent实例不可用: {e}") _agent_instance = None await initialize_agent() return True async def cleanup_agent(): """🔥 异步清理Agent资源""" global _agent_instance, _redis_client if _agent_instance: await _agent_instance.close() logger.info("✅ Agent 资源已异步清理") _agent_instance = None if _redis_client: await _redis_client.aclose() logger.info("✅ Redis客户端已异步关闭") _redis_client = None # 创建Flask应用 app = Flask(__name__) # 🔥 移除所有同步包装函数:run_async_safely, ensure_agent_ready_sync @app.route("/") def root(): """健康检查端点(保持同步)""" return jsonify({"message": "Custom React Agent API 服务正在运行"}) @app.route('/health', methods=['GET']) def health_check(): """健康检查端点(保持同步)""" try: health_status = { "status": "healthy", "agent_initialized": _agent_instance is not None, "timestamp": datetime.now().isoformat() } return jsonify(health_status), 200 except Exception as e: logger.error(f"健康检查失败: {e}") return jsonify({"status": "unhealthy", "error": str(e)}), 500 @app.route("/api/chat", methods=["POST"]) async def chat_endpoint(): """🔥 异步智能问答接口""" global _agent_instance # 确保Agent已初始化 if not await ensure_agent_ready(): return jsonify({ "code": 503, "message": "服务未就绪", "success": False, "error": "Agent 初始化失败" }), 503 try: data = request.get_json() if not data: return jsonify({ "code": 400, "message": "请求参数错误", "success": False, "error": "请求体不能为空" }), 400 validated_data = validate_request_data(data) logger.info(f"📨 收到请求 - User: {validated_data['user_id']}, Question: {validated_data['question'][:50]}...") # 🔥 直接调用异步方法,不需要事件循环包装 agent_result = await _agent_instance.chat( message=validated_data['question'], user_id=validated_data['user_id'], thread_id=validated_data['thread_id'] ) if not agent_result.get("success", False): error_msg = agent_result.get("error", "Agent处理失败") logger.error(f"❌ Agent处理失败: {error_msg}") return jsonify({ "code": 500, "message": "处理失败", "success": False, "error": error_msg, "data": { "react_agent_meta": { "thread_id": agent_result.get("thread_id"), "agent_version": "custom_react_v1_async", "execution_path": ["error"] }, "timestamp": datetime.now().isoformat() } }), 500 api_data = agent_result.get("api_data", {}) response_data = { **api_data, "timestamp": datetime.now().isoformat() } logger.info(f"✅ 异步请求处理成功 - Thread: {api_data.get('react_agent_meta', {}).get('thread_id')}") return jsonify({ "code": 200, "message": "操作成功", "success": True, "data": response_data }) except ValueError as e: logger.warning(f"⚠️ 参数验证失败: {e}") return jsonify({ "code": 400, "message": "请求参数错误", "success": False, "error": str(e) }), 400 except Exception as e: logger.error(f"❌ 未预期的错误: {e}", exc_info=True) return jsonify({ "code": 500, "message": "服务器内部错误", "success": False, "error": "系统异常,请稍后重试" }), 500 @app.route('/api/v0/react/users//conversations', methods=['GET']) async def get_user_conversations(user_id: str): """🔥 异步获取用户的聊天记录列表""" global _agent_instance try: limit = request.args.get('limit', 10, type=int) limit = max(1, min(limit, 50)) logger.info(f"📋 异步获取用户 {user_id} 的对话列表,限制 {limit} 条") if not await ensure_agent_ready(): return jsonify({ "success": False, "error": "Agent 未就绪", "timestamp": datetime.now().isoformat() }), 503 # 🔥 直接调用异步方法 conversations = await _agent_instance.get_user_recent_conversations(user_id, limit) return jsonify({ "success": True, "data": { "user_id": user_id, "conversations": conversations, "total_count": len(conversations), "limit": limit }, "timestamp": datetime.now().isoformat() }), 200 except Exception as e: logger.error(f"❌ 异步获取用户 {user_id} 对话列表失败: {e}") return jsonify({ "success": False, "error": str(e), "timestamp": datetime.now().isoformat() }), 500 @app.route('/api/v0/react/users//conversations/', methods=['GET']) async def get_user_conversation_detail(user_id: str, thread_id: str): """🔥 异步获取特定对话的详细历史""" global _agent_instance try: if not thread_id.startswith(f"{user_id}:"): return jsonify({ "success": False, "error": f"Thread ID {thread_id} 不属于用户 {user_id}", "timestamp": datetime.now().isoformat() }), 400 logger.info(f"📖 异步获取用户 {user_id} 的对话 {thread_id} 详情") if not await ensure_agent_ready(): return jsonify({ "success": False, "error": "Agent 未就绪", "timestamp": datetime.now().isoformat() }), 503 # 🔥 直接调用异步方法 history = await _agent_instance.get_conversation_history(thread_id) logger.info(f"✅ 异步成功获取对话历史,消息数量: {len(history)}") if not history: return jsonify({ "success": False, "error": f"未找到对话 {thread_id}", "timestamp": datetime.now().isoformat() }), 404 return jsonify({ "success": True, "data": { "user_id": user_id, "thread_id": thread_id, "message_count": len(history), "messages": history }, "timestamp": datetime.now().isoformat() }), 200 except Exception as e: import traceback logger.error(f"❌ 异步获取对话 {thread_id} 详情失败: {e}") logger.error(f"❌ 详细错误信息: {traceback.format_exc()}") return jsonify({ "success": False, "error": str(e), "timestamp": datetime.now().isoformat() }), 500 # 🔥 异步Redis API(如果还需要直接Redis访问) async def get_user_conversations_async(user_id: str, limit: int = 10): """🔥 完全异步的Redis查询函数""" global _redis_client try: if not _redis_client: _redis_client = redis.from_url('redis://localhost:6379', decode_responses=True) await _redis_client.ping() pattern = f"checkpoint:{user_id}:*" logger.info(f"🔍 异步扫描模式: {pattern}") keys = [] cursor = 0 while True: cursor, batch = await _redis_client.scan(cursor=cursor, match=pattern, count=1000) keys.extend(batch) if cursor == 0: break logger.info(f"📋 异步找到 {len(keys)} 个keys") # 解析和处理逻辑(与原来相同,但使用异步Redis操作) thread_data = {} for key in keys: try: parts = key.split(':') if len(parts) >= 4: thread_id = f"{parts[1]}:{parts[2]}" timestamp = parts[2] if thread_id not in thread_data: thread_data[thread_id] = { "thread_id": thread_id, "timestamp": timestamp, "keys": [] } thread_data[thread_id]["keys"].append(key) except Exception as e: logger.warning(f"解析key失败 {key}: {e}") continue sorted_threads = sorted( thread_data.values(), key=lambda x: x["timestamp"], reverse=True )[:limit] conversations = [] for thread_info in sorted_threads: try: thread_id = thread_info["thread_id"] latest_key = max(thread_info["keys"]) # 🔥 使用异步Redis获取 key_type = await _redis_client.type(latest_key) data = None if key_type == 'string': data = await _redis_client.get(latest_key) elif key_type == 'ReJSON-RL': try: data = await _redis_client.execute_command('JSON.GET', latest_key) except Exception as json_error: logger.error(f"❌ 异步JSON.GET 失败: {json_error}") continue if data: try: import json checkpoint_data = json.loads(data) messages = [] if 'checkpoint' in checkpoint_data: checkpoint = checkpoint_data['checkpoint'] if isinstance(checkpoint, dict) and 'channel_values' in checkpoint: channel_values = checkpoint['channel_values'] if isinstance(channel_values, dict) and 'messages' in channel_values: messages = channel_values['messages'] preview = "空对话" if messages: for msg in messages: if isinstance(msg, dict): if (msg.get('lc') == 1 and msg.get('type') == 'constructor' and 'id' in msg and isinstance(msg['id'], list) and len(msg['id']) >= 4 and msg['id'][3] == 'HumanMessage' and 'kwargs' in msg): kwargs = msg['kwargs'] if kwargs.get('type') == 'human' and 'content' in kwargs: content = str(kwargs['content']) preview = content[:50] + "..." if len(content) > 50 else content break conversations.append({ "thread_id": thread_id, "user_id": user_id, "timestamp": thread_info["timestamp"], "message_count": len(messages), "conversation_preview": preview }) except json.JSONDecodeError: logger.error(f"❌ 异步JSON解析失败") continue except Exception as e: logger.error(f"异步处理thread {thread_info['thread_id']} 失败: {e}") continue logger.info(f"✅ 异步返回 {len(conversations)} 个对话") return conversations except Exception as e: logger.error(f"❌ 异步Redis查询失败: {e}") return [] # 🔥 异步启动和清理 async def startup(): """应用启动时的异步初始化""" logger.info("🚀 启动异步Flask应用...") try: await initialize_agent() logger.info("✅ Agent 预初始化完成") except Exception as e: logger.error(f"❌ 启动时Agent初始化失败: {e}") async def shutdown(): """应用关闭时的异步清理""" logger.info("🔄 关闭异步Flask应用...") try: await cleanup_agent() logger.info("✅ 资源清理完成") except Exception as e: logger.error(f"❌ 关闭时清理失败: {e}") if __name__ == "__main__": # 🔥 简化的启动方式 - Flask 3.x 原生支持异步 logger.info("🚀 使用Flask内置异步支持启动...") # 信号处理 import signal def signal_handler(signum, frame): logger.info("🛑 收到关闭信号,开始清理...") # 在信号处理中,我们只能打印消息,实际清理在程序正常退出时进行 print("正在关闭服务...") exit(0) signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) # 启动Flask应用 app.run(host="0.0.0.0", port=8000, debug=False)