让我仔细分析你的代码,回答这两个关键问题:
@classmethod
async def create(cls): # ✅ 异步工厂方法
async def _async_init(self): # ✅ 异步初始化
async def close(self): # ✅ 异步清理
async def chat(self, message, user_id, thread_id): # ✅ 异步聊天
async def get_conversation_history(self, thread_id): # ✅ 异步获取历史
async def get_user_recent_conversations(self, user_id, limit): # ✅ 异步获取对话列表
def _should_continue(self, state: AgentState) -> str: # ❌ 同步
def _agent_node(self, state: AgentState) -> Dict[str, Any]: # ❌ 同步
def _prepare_tool_input_node(self, state: AgentState): # ❌ 同步
def _update_state_after_tool_node(self, state: AgentState): # ❌ 同步
def _format_final_response_node(self, state: AgentState): # ❌ 同步
问题分析:你的StateGraph节点是同步的,但是内部调用了异步LLM:
def _agent_node(self, state: AgentState) -> Dict[str, Any]:
# 这是同步函数,但调用异步LLM
response = self.llm_with_tools.invoke(messages_for_llm) # 🔥 这里有问题
异步Redis客户端(在agent.py中):
from redis.asyncio import Redis
redis_client = Redis.from_url(config.REDIS_URL, decode_responses=True)
同步Redis客户端(在api.py和enhanced_redis_api.py中):
import redis # 同步版本
redis_client = redis.Redis(host='localhost', port=6379, decode_responses=True)
❌ 你的CustomReactAgent并没有完全异步化!
虽然主要方法是async的,但StateGraph的节点方法都是同步的,这正是问题的根源:
# ❌ 这些是同步方法,但内部调用异步LLM
def _agent_node(self, state: AgentState) -> Dict[str, Any]:
response = self.llm_with_tools.invoke(messages_for_llm) # 🔥问题在这里
def _prepare_tool_input_node(self, state: AgentState) -> Dict[str, Any]:
def _update_state_after_tool_node(self, state: AgentState) -> Dict[str, Any]:
def _format_final_response_node(self, state: AgentState) -> Dict[str, Any]:
def _should_continue(self, state: AgentState) -> str:
核心问题:LangGraph的StateGraph可以支持异步节点,但你没有使用!
❌ 你混用了同步和异步Redis客户端!
异步Redis客户端(在agent.py中):
from redis.asyncio import Redis # ✅ 异步版本
redis_client = Redis.from_url(config.REDIS_URL, decode_responses=True)
同步Redis客户端(在api.py和enhanced_redis_api.py中):
import redis # ❌ 同步版本
redis_client = redis.Redis(host='localhost', port=6379, decode_responses=True)
答案:强烈建议使用,原因如下:
我在上面的artifact中提供了完整的修复方案,关键修改包括:
# ✅ 修改后:异步节点
async def _async_agent_node(self, state: AgentState) -> Dict[str, Any]:
response = await self.llm_with_tools.ainvoke(messages_for_llm) # 使用ainvoke
async def _async_prepare_tool_input_node(self, state: AgentState) -> Dict[str, Any]:
async def _async_update_state_after_tool_node(self, state: AgentState) -> Dict[str, Any]:
async def _async_format_final_response_node(self, state: AgentState) -> Dict[str, Any]:
async def _async_should_continue(self, state: AgentState) -> str:
# ✅ 统一使用异步Redis
import redis.asyncio as redis
self.redis_client = redis.from_url(config.REDIS_URL, decode_responses=True)
await self.redis_client.ping()
# ❌ 删除这些函数
def run_async_safely(): # 删除
def ensure_agent_ready_sync(): # 删除
# ✅ 直接使用async/await
@app.route("/api/chat", methods=["POST"])
async def chat_endpoint():
agent_result = await _agent_instance.chat(...) # 直接调用
redis.asyncio
run_async_safely
类的函数按照我提供的修复方案修改后,你就能直接用python api.py
启动,不会再有任何事件循环问题!
""" CustomReactAgent完全异步化修复方案 解决StateGraph节点同步/异步混用问题,统一Redis客户端使用 """
import logging import json import pandas as pd from typing import List, Optional, Dict, Any, Tuple from contextlib import AsyncExitStack
from langchain_openai import ChatOpenAI from langchain_core.messages import HumanMessage, ToolMessage, BaseMessage, SystemMessage, AIMessage from langgraph.graph import StateGraph, END from langgraph.prebuilt import ToolNode import redis.asyncio as redis # 🔥 统一使用异步Redis try:
from langgraph.checkpoint.redis import AsyncRedisSaver
except ImportError:
AsyncRedisSaver = None
try:
from . import config
from .state import AgentState
from .sql_tools import sql_tools
except ImportError:
import config
from state import AgentState
from sql_tools import sql_tools
logger = logging.getLogger(name)
class CustomReactAgent:
"""
完全异步化的 CustomReactAgent
所有节点方法都是异步的,统一使用异步Redis客户端
"""
def __init__(self):
"""私有构造函数,请使用 create() 类方法来创建实例。"""
self.llm = None
self.tools = None
self.agent_executor = None
self.checkpointer = None
self._exit_stack = None
self.redis_client = None # 🔥 添加Redis客户端引用
@classmethod
async def create(cls):
"""异步工厂方法,创建并初始化 CustomReactAgent 实例。"""
instance = cls()
await instance._async_init()
return instance
async def _async_init(self):
"""异步初始化所有组件。"""
logger.info("🚀 开始初始化 CustomReactAgent...")
# 1. 初始化异步Redis客户端
self.redis_client = redis.from_url(config.REDIS_URL, decode_responses=True)
try:
await self.redis_client.ping()
logger.info(f" ✅ Redis连接成功: {config.REDIS_URL}")
except Exception as e:
logger.error(f" ❌ Redis连接失败: {e}")
raise
# 2. 初始化 LLM
self.llm = ChatOpenAI(
api_key=config.QWEN_API_KEY,
base_url=config.QWEN_BASE_URL,
model=config.QWEN_MODEL,
temperature=0.1,
timeout=config.NETWORK_TIMEOUT,
max_retries=config.MAX_RETRIES,
extra_body={
"enable_thinking": False,
"misc": {
"ensure_ascii": False
}
}
)
logger.info(f" LLM 已初始化,模型: {config.QWEN_MODEL}")
# 3. 绑定工具
self.tools = sql_tools
self.llm_with_tools = self.llm.bind_tools(self.tools)
logger.info(f" 已绑定 {len(self.tools)} 个工具。")
# 4. 初始化 Redis Checkpointer
if config.REDIS_ENABLED and AsyncRedisSaver is not None:
try:
self._exit_stack = AsyncExitStack()
checkpointer_manager = AsyncRedisSaver.from_conn_string(config.REDIS_URL)
self.checkpointer = await self._exit_stack.enter_async_context(checkpointer_manager)
await self.checkpointer.asetup()
logger.info(f" AsyncRedisSaver 持久化已启用: {config.REDIS_URL}")
except Exception as e:
logger.error(f" ❌ RedisSaver 初始化失败: {e}", exc_info=True)
if self._exit_stack:
await self._exit_stack.aclose()
self.checkpointer = None
else:
logger.warning(" Redis 持久化功能已禁用。")
# 5. 构建 StateGraph
self.agent_executor = self._create_graph()
logger.info(" StateGraph 已构建并编译。")
logger.info("✅ CustomReactAgent 初始化完成。")
async def close(self):
"""清理资源,关闭所有连接。"""
if self._exit_stack:
await self._exit_stack.aclose()
self._exit_stack = None
self.checkpointer = None
logger.info("✅ RedisSaver 资源已通过 AsyncExitStack 释放。")
if self.redis_client:
await self.redis_client.aclose()
logger.info("✅ Redis客户端已关闭。")
def _create_graph(self):
"""定义并编译最终的、正确的 StateGraph 结构。"""
builder = StateGraph(AgentState)
# 🔥 关键修改:所有节点都是异步的
builder.add_node("agent", self._async_agent_node)
builder.add_node("prepare_tool_input", self._async_prepare_tool_input_node)
builder.add_node("tools", ToolNode(self.tools))
builder.add_node("update_state_after_tool", self._async_update_state_after_tool_node)
builder.add_node("format_final_response", self._async_format_final_response_node)
# 建立正确的边连接
builder.set_entry_point("agent")
builder.add_conditional_edges(
"agent",
self._async_should_continue, # 🔥 异步条件判断
{
"continue": "prepare_tool_input",
"end": "format_final_response"
}
)
builder.add_edge("prepare_tool_input", "tools")
builder.add_edge("tools", "update_state_after_tool")
builder.add_edge("update_state_after_tool", "agent")
builder.add_edge("format_final_response", END)
return builder.compile(checkpointer=self.checkpointer)
async def _async_should_continue(self, state: AgentState) -> str:
"""🔥 异步版本:判断是继续调用工具还是结束。"""
last_message = state["messages"][-1]
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
return "continue"
return "end"
async def _async_agent_node(self, state: AgentState) -> Dict[str, Any]:
"""🔥 异步版本:Agent 节点,使用异步LLM调用。"""
logger.info(f"🧠 [Async Node] agent - Thread: {state['thread_id']}")
messages_for_llm = list(state["messages"])
if state.get("suggested_next_step"):
instruction = f"提示:建议下一步使用工具 '{state['suggested_next_step']}'。"
messages_for_llm.append(SystemMessage(content=instruction))
# 🔥 关键修改:使用异步LLM调用
import time
max_retries = config.MAX_RETRIES
for attempt in range(max_retries):
try:
# 使用异步调用
response = await self.llm_with_tools.ainvoke(messages_for_llm)
logger.info(f" ✅ 异步LLM调用成功")
return {"messages": [response]}
except Exception as e:
error_msg = str(e)
logger.warning(f" ⚠️ 异步LLM调用失败 (尝试 {attempt + 1}/{max_retries}): {error_msg}")
if any(keyword in error_msg for keyword in [
"Connection error", "APIConnectionError", "ConnectError",
"timeout", "远程主机强迫关闭", "网络连接"
]):
if attempt < max_retries - 1:
wait_time = config.RETRY_BASE_DELAY ** attempt
logger.info(f" 🔄 网络错误,{wait_time}秒后重试...")
await asyncio.sleep(wait_time) # 🔥 使用async sleep
continue
else:
logger.error(f" ❌ 网络连接持续失败,返回降级回答")
sql_data = await self._async_extract_latest_sql_data(state["messages"])
if sql_data:
fallback_content = "抱歉,由于网络连接问题,无法生成完整的文字总结。不过查询已成功执行,结果如下:\n\n" + sql_data
else:
fallback_content = "抱歉,由于网络连接问题,无法完成此次请求。请稍后重试或检查网络连接。"
fallback_response = AIMessage(content=fallback_content)
return {"messages": [fallback_response]}
else:
logger.error(f" ❌ LLM调用出现非网络错误: {error_msg}")
raise e
async def _async_prepare_tool_input_node(self, state: AgentState) -> Dict[str, Any]:
"""🔥 异步版本:信息组装节点。"""
logger.info(f"🛠️ [Async Node] prepare_tool_input - Thread: {state['thread_id']}")
last_message = state["messages"][-1]
if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
return {"messages": [last_message]}
new_tool_calls = []
for tool_call in last_message.tool_calls:
if tool_call["name"] == "generate_sql":
logger.info(" 检测到 generate_sql 调用,注入历史消息。")
modified_args = tool_call["args"].copy()
clean_history = []
messages_except_current = state["messages"][:-1]
for msg in messages_except_current:
if isinstance(msg, HumanMessage):
clean_history.append({
"type": "human",
"content": msg.content
})
elif isinstance(msg, AIMessage):
if msg.content and "[Formatted Output]" in msg.content:
clean_content = msg.content.replace("[Formatted Output]\n", "")
clean_history.append({
"type": "ai",
"content": clean_content
})
modified_args["history_messages"] = clean_history
logger.info(f" 注入了 {len(clean_history)} 条过滤后的历史消息")
new_tool_calls.append({
"name": tool_call["name"],
"args": modified_args,
"id": tool_call["id"],
})
else:
new_tool_calls.append(tool_call)
last_message.tool_calls = new_tool_calls
return {"messages": [last_message]}
async def _async_update_state_after_tool_node(self, state: AgentState) -> Dict[str, Any]:
"""🔥 异步版本:在工具执行后,更新 suggested_next_step。"""
logger.info(f"📝 [Async Node] update_state_after_tool - Thread: {state['thread_id']}")
last_tool_message = state['messages'][-1]
tool_name = last_tool_message.name
tool_output = last_tool_message.content
next_step = None
if tool_name == 'generate_sql':
if "失败" in tool_output or "无法生成" in tool_output:
next_step = 'answer_with_common_sense'
else:
next_step = 'valid_sql'
elif tool_name == 'valid_sql':
if "失败" in tool_output:
next_step = 'analyze_validation_error'
else:
next_step = 'run_sql'
elif tool_name == 'run_sql':
next_step = 'summarize_final_answer'
logger.info(f" Tool '{tool_name}' executed. Suggested next step: {next_step}")
return {"suggested_next_step": next_step}
async def _async_format_final_response_node(self, state: AgentState) -> Dict[str, Any]:
"""🔥 异步版本:最终输出格式化节点。"""
logger.info(f"🎨 [Async Node] format_final_response - Thread: {state['thread_id']}")
last_message = state['messages'][-1]
last_message.content = f"[Formatted Output]\n{last_message.content}"
# 生成API格式的数据
api_data = await self._async_generate_api_data(state)
return {
"messages": [last_message],
"api_data": api_data
}
async def _async_generate_api_data(self, state: AgentState) -> Dict[str, Any]:
"""🔥 异步版本:生成API格式的数据结构"""
logger.info("📊 异步生成API格式数据...")
last_message = state['messages'][-1]
response_content = last_message.content
if response_content.startswith("[Formatted Output]\n"):
response_content = response_content.replace("[Formatted Output]\n", "")
api_data = {
"response": response_content
}
sql_info = await self._async_extract_sql_and_data(state['messages'])
if sql_info['sql']:
api_data["sql"] = sql_info['sql']
if sql_info['records']:
api_data["records"] = sql_info['records']
api_data["react_agent_meta"] = await self._async_collect_agent_metadata(state)
logger.info(f" API数据生成完成,包含字段: {list(api_data.keys())}")
return api_data
async def _async_extract_sql_and_data(self, messages: List[BaseMessage]) -> Dict[str, Any]:
"""🔥 异步版本:从消息历史中提取SQL和数据记录"""
result = {"sql": None, "records": None}
last_human_index = -1
for i in range(len(messages) - 1, -1, -1):
if isinstance(messages[i], HumanMessage):
last_human_index = i
break
if last_human_index == -1:
return result
current_conversation = messages[last_human_index:]
sql_query = None
sql_data = None
for msg in current_conversation:
if isinstance(msg, ToolMessage):
if msg.name == 'generate_sql':
content = msg.content
if content and not any(keyword in content for keyword in ["失败", "无法生成", "Database query failed"]):
sql_query = content.strip()
elif msg.name == 'run_sql':
try:
import json
parsed_data = json.loads(msg.content)
if isinstance(parsed_data, list) and len(parsed_data) > 0:
columns = list(parsed_data[0].keys()) if parsed_data else []
sql_data = {
"columns": columns,
"rows": parsed_data,
"total_row_count": len(parsed_data),
"is_limited": False
}
except (json.JSONDecodeError, Exception) as e:
logger.warning(f" 解析SQL结果失败: {e}")
if sql_query:
result["sql"] = sql_query
if sql_data:
result["records"] = sql_data
return result
async def _async_collect_agent_metadata(self, state: AgentState) -> Dict[str, Any]:
"""🔥 异步版本:收集Agent元数据"""
messages = state['messages']
tools_used = []
sql_execution_count = 0
context_injected = False
conversation_rounds = sum(1 for msg in messages if isinstance(msg, HumanMessage))
for msg in messages:
if isinstance(msg, ToolMessage):
if msg.name not in tools_used:
tools_used.append(msg.name)
if msg.name == 'run_sql':
sql_execution_count += 1
elif isinstance(msg, AIMessage) and hasattr(msg, 'tool_calls') and msg.tool_calls:
for tool_call in msg.tool_calls:
tool_name = tool_call.get('name')
if tool_name and tool_name not in tools_used:
tools_used.append(tool_name)
if (tool_name == 'generate_sql' and
tool_call.get('args', {}).get('history_messages')):
context_injected = True
execution_path = ["agent"]
if tools_used:
execution_path.extend(["prepare_tool_input", "tools"])
execution_path.append("format_final_response")
return {
"thread_id": state['thread_id'],
"conversation_rounds": conversation_rounds,
"tools_used": tools_used,
"execution_path": execution_path,
"total_messages": len(messages),
"sql_execution_count": sql_execution_count,
"context_injected": context_injected,
"agent_version": "custom_react_v1_async"
}
async def _async_extract_latest_sql_data(self, messages: List[BaseMessage]) -> Optional[str]:
"""🔥 异步版本:提取最新的SQL执行结果"""
logger.info("🔍 异步提取最新的SQL执行结果...")
last_human_index = -1
for i in range(len(messages) - 1, -1, -1):
if isinstance(messages[i], HumanMessage):
last_human_index = i
break
if last_human_index == -1:
logger.info(" 未找到用户消息,跳过SQL数据提取")
return None
current_conversation = messages[last_human_index:]
logger.info(f" 当前对话轮次包含 {len(current_conversation)} 条消息")
for msg in reversed(current_conversation):
if isinstance(msg, ToolMessage) and msg.name == 'run_sql':
logger.info(f" 找到当前对话轮次的run_sql结果: {msg.content[:100]}...")
try:
parsed_data = json.loads(msg.content)
formatted_content = json.dumps(parsed_data, ensure_ascii=False, separators=(',', ':'))
logger.info(f" 已转换Unicode转义序列为中文字符")
return formatted_content
except json.JSONDecodeError:
logger.warning(f" SQL结果不是有效JSON格式,返回原始内容")
return msg.content
logger.info(" 当前对话轮次中未找到run_sql执行结果")
return None
async def chat(self, message: str, user_id: str, thread_id: Optional[str] = None) -> Dict[str, Any]:
"""🔥 完全异步的聊天处理方法"""
if not thread_id:
now = pd.Timestamp.now()
milliseconds = int(now.microsecond / 1000)
thread_id = f"{user_id}:{now.strftime('%Y%m%d%H%M%S')}{milliseconds:03d}"
logger.info(f"🆕 新建会话,Thread ID: {thread_id}")
config = {
"configurable": {
"thread_id": thread_id,
}
}
inputs = {
"messages": [HumanMessage(content=message)],
"user_id": user_id,
"thread_id": thread_id,
"suggested_next_step": None,
}
try:
# 🔥 使用异步调用
final_state = await self.agent_executor.ainvoke(inputs, config)
answer = final_state["messages"][-1].content
sql_data = await self._async_extract_latest_sql_data(final_state["messages"])
logger.info(f"✅ 异步处理完成 - Final Answer: '{answer}'")
result = {
"success": True,
"answer": answer,
"thread_id": thread_id
}
if sql_data:
result["sql_data"] = sql_data
logger.info(" 📊 已包含SQL原始数据")
if "api_data" in final_state:
result["api_data"] = final_state["api_data"]
logger.info(" 🔌 已包含API格式数据")
return result
except Exception as e:
logger.error(f"❌ 异步处理过程中发生严重错误 - Thread: {thread_id}: {e}", exc_info=True)
return {"success": False, "error": str(e), "thread_id": thread_id}
async def get_conversation_history(self, thread_id: str) -> List[Dict[str, Any]]:
"""🔥 完全异步的对话历史获取"""
if not self.checkpointer:
return []
config = {"configurable": {"thread_id": thread_id}}
try:
conversation_state = await self.checkpointer.aget(config)
except RuntimeError as e:
if "Event loop is closed" in str(e):
logger.warning(f"⚠️ Event loop已关闭,返回空结果: {thread_id}")
return []
else:
raise
if not conversation_state:
return []
history = []
messages = conversation_state.get('channel_values', {}).get('messages', [])
for msg in messages:
if isinstance(msg, HumanMessage):
role = "human"
elif isinstance(msg, ToolMessage):
role = "tool"
else:
role = "ai"
history.append({
"type": role,
"content": msg.content,
"tool_calls": getattr(msg, 'tool_calls', None)
})
return history
async def get_user_recent_conversations(self, user_id: str, limit: int = 10) -> List[Dict[str, Any]]:
"""🔥 完全异步的用户对话列表获取"""
if not self.checkpointer:
return []
try:
# 🔥 使用统一的异步Redis客户端
pattern = f"checkpoint:{user_id}:*"
logger.info(f"🔍 异步扫描模式: {pattern}")
user_threads = {}
cursor = 0
while True:
cursor, keys = await self.redis_client.scan(
cursor=cursor,
match=pattern,
count=1000
)
for key in keys:
try:
key_str = key.decode() if isinstance(key, bytes) else key
parts = key_str.split(':')
if len(parts) >= 4:
thread_id = f"{parts[1]}:{parts[2]}"
timestamp = parts[2]
if thread_id not in user_threads:
user_threads[thread_id] = {
"thread_id": thread_id,
"timestamp": timestamp,
"latest_key": key_str
}
else:
if len(parts) > 4 and parts[4] > user_threads[thread_id]["latest_key"].split(':')[4]:
user_threads[thread_id]["latest_key"] = key_str
except Exception as e:
logger.warning(f"解析key {key} 失败: {e}")
continue
if cursor == 0:
break
# 按时间戳排序
sorted_threads = sorted(
user_threads.values(),
key=lambda x: x["timestamp"],
reverse=True
)[:limit]
# 获取每个thread的详细信息
conversations = []
for thread_info in sorted_threads:
try:
thread_id = thread_info["thread_id"]
thread_config = {"configurable": {"thread_id": thread_id}}
try:
state = await self.checkpointer.aget(thread_config)
except RuntimeError as e:
if "Event loop is closed" in str(e):
logger.warning(f"⚠️ Event loop已关闭,跳过thread: {thread_id}")
continue
else:
raise
if state and state.get('channel_values', {}).get('messages'):
messages = state['channel_values']['messages']
preview = self._generate_conversation_preview(messages)
conversations.append({
"thread_id": thread_id,
"user_id": user_id,
"timestamp": thread_info["timestamp"],
"message_count": len(messages),
"last_message": messages[-1].content if messages else None,
"last_updated": state.get('created_at'),
"conversation_preview": preview,
"formatted_time": self._format_timestamp(thread_info["timestamp"])
})
except Exception as e:
logger.error(f"获取thread {thread_info['thread_id']} 详情失败: {e}")
continue
logger.info(f"✅ 异步找到用户 {user_id} 的 {len(conversations)} 个对话")
return conversations
except Exception as e:
logger.error(f"❌ 异步获取用户 {user_id} 对话列表失败: {e}")
return []
def _generate_conversation_preview(self, messages: List[BaseMessage]) -> str:
"""生成对话预览(保持同步,因为是纯计算)"""
if not messages:
return "空对话"
for msg in messages:
if isinstance(msg, HumanMessage):
content = str(msg.content)
return content[:50] + "..." if len(content) > 50 else content
return "系统消息"
def _format_timestamp(self, timestamp: str) -> str:
"""格式化时间戳为可读格式(保持同步,因为是纯计算)"""
try:
if len(timestamp) >= 14:
year = timestamp[:4]
month = timestamp[4:6]
day = timestamp[6:8]
hour = timestamp[8:10]
minute = timestamp[10:12]
second = timestamp[12:14]
return f"{year}-{month}-{day} {hour}:{minute}:{second}"
except Exception:
pass
return timestamp
""" 修复后的 api.py - 统一使用异步Redis客户端,移除复杂的事件循环管理 """
import asyncio import logging import os from datetime import datetime from typing import Optional, Dict, Any
from flask import Flask, request, jsonify import redis.asyncio as redis # 🔥 统一使用异步Redis
try:
from .agent import CustomReactAgent
except ImportError:
from agent import CustomReactAgent
logging.basicConfig(level=logging.INFO) logger = logging.getLogger(name)
_agent_instance: Optional[CustomReactAgent] = None _redis_client: Optional[redis.Redis] = None
def validate_request_data(data: Dict[str, Any]) -> Dict[str, Any]:
"""验证请求数据(保持不变)"""
errors = []
question = data.get('question', '')
if not question or not question.strip():
errors.append('问题不能为空')
elif len(question) > 2000:
errors.append('问题长度不能超过2000字符')
user_id = data.get('user_id', 'guest')
if user_id and len(user_id) > 50:
errors.append('用户ID长度不能超过50字符')
if errors:
raise ValueError('; '.join(errors))
return {
'question': question.strip(),
'user_id': user_id or 'guest',
'thread_id': data.get('thread_id')
}
async def initialize_agent():
"""🔥 异步初始化Agent"""
global _agent_instance, _redis_client
if _agent_instance is None:
logger.info("🚀 正在异步初始化 Custom React Agent...")
try:
os.environ['REDIS_URL'] = 'redis://localhost:6379'
# 初始化共享的Redis客户端
_redis_client = redis.from_url('redis://localhost:6379', decode_responses=True)
await _redis_client.ping()
_agent_instance = await CustomReactAgent.create()
logger.info("✅ Agent 异步初始化完成")
except Exception as e:
logger.error(f"❌ Agent 异步初始化失败: {e}")
raise
async def ensure_agent_ready():
"""🔥 异步确保Agent实例可用"""
global _agent_instance
if _agent_instance is None:
await initialize_agent()
try:
test_result = await _agent_instance.get_user_recent_conversations("__test__", 1)
return True
except Exception as e:
logger.warning(f"⚠️ Agent实例不可用: {e}")
_agent_instance = None
await initialize_agent()
return True
async def cleanup_agent():
"""🔥 异步清理Agent资源"""
global _agent_instance, _redis_client
if _agent_instance:
await _agent_instance.close()
logger.info("✅ Agent 资源已异步清理")
_agent_instance = None
if _redis_client:
await _redis_client.aclose()
logger.info("✅ Redis客户端已异步关闭")
_redis_client = None
app = Flask(name)
@app.route("/") def root():
"""健康检查端点(保持同步)"""
return jsonify({"message": "Custom React Agent API 服务正在运行"})
@app.route('/health', methods=['GET']) def health_check():
"""健康检查端点(保持同步)"""
try:
health_status = {
"status": "healthy",
"agent_initialized": _agent_instance is not None,
"timestamp": datetime.now().isoformat()
}
return jsonify(health_status), 200
except Exception as e:
logger.error(f"健康检查失败: {e}")
return jsonify({"status": "unhealthy", "error": str(e)}), 500
@app.route("/api/chat", methods=["POST"]) async def chat_endpoint():
"""🔥 异步智能问答接口"""
global _agent_instance
# 确保Agent已初始化
if not await ensure_agent_ready():
return jsonify({
"code": 503,
"message": "服务未就绪",
"success": False,
"error": "Agent 初始化失败"
}), 503
try:
data = request.get_json()
if not data:
return jsonify({
"code": 400,
"message": "请求参数错误",
"success": False,
"error": "请求体不能为空"
}), 400
validated_data = validate_request_data(data)
logger.info(f"📨 收到请求 - User: {validated_data['user_id']}, Question: {validated_data['question'][:50]}...")
# 🔥 直接调用异步方法,不需要事件循环包装
agent_result = await _agent_instance.chat(
message=validated_data['question'],
user_id=validated_data['user_id'],
thread_id=validated_data['thread_id']
)
if not agent_result.get("success", False):
error_msg = agent_result.get("error", "Agent处理失败")
logger.error(f"❌ Agent处理失败: {error_msg}")
return jsonify({
"code": 500,
"message": "处理失败",
"success": False,
"error": error_msg,
"data": {
"react_agent_meta": {
"thread_id": agent_result.get("thread_id"),
"agent_version": "custom_react_v1_async",
"execution_path": ["error"]
},
"timestamp": datetime.now().isoformat()
}
}), 500
api_data = agent_result.get("api_data", {})
response_data = {
**api_data,
"timestamp": datetime.now().isoformat()
}
logger.info(f"✅ 异步请求处理成功 - Thread: {api_data.get('react_agent_meta', {}).get('thread_id')}")
return jsonify({
"code": 200,
"message": "操作成功",
"success": True,
"data": response_data
})
except ValueError as e:
logger.warning(f"⚠️ 参数验证失败: {e}")
return jsonify({
"code": 400,
"message": "请求参数错误",
"success": False,
"error": str(e)
}), 400
except Exception as e:
logger.error(f"❌ 未预期的错误: {e}", exc_info=True)
return jsonify({
"code": 500,
"message": "服务器内部错误",
"success": False,
"error": "系统异常,请稍后重试"
}), 500
@app.route('/api/v0/react/users//conversations', methods=['GET']) async def get_user_conversations(user_id: str):
"""🔥 异步获取用户的聊天记录列表"""
global _agent_instance
try:
limit = request.args.get('limit', 10, type=int)
limit = max(1, min(limit, 50))
logger.info(f"📋 异步获取用户 {user_id} 的对话列表,限制 {limit} 条")
if not await ensure_agent_ready():
return jsonify({
"success": False,
"error": "Agent 未就绪",
"timestamp": datetime.now().isoformat()
}), 503
# 🔥 直接调用异步方法
conversations = await _agent_instance.get_user_recent_conversations(user_id, limit)
return jsonify({
"success": True,
"data": {
"user_id": user_id,
"conversations": conversations,
"total_count": len(conversations),
"limit": limit
},
"timestamp": datetime.now().isoformat()
}), 200
except Exception as e:
logger.error(f"❌ 异步获取用户 {user_id} 对话列表失败: {e}")
return jsonify({
"success": False,
"error": str(e),
"timestamp": datetime.now().isoformat()
}), 500
@app.route('/api/v0/react/users//conversations/', methods=['GET']) async def get_user_conversation_detail(user_id: str, thread_id: str):
"""🔥 异步获取特定对话的详细历史"""
global _agent_instance
try:
if not thread_id.startswith(f"{user_id}:"):
return jsonify({
"success": False,
"error": f"Thread ID {thread_id} 不属于用户 {user_id}",
"timestamp": datetime.now().isoformat()
}), 400
logger.info(f"📖 异步获取用户 {user_id} 的对话 {thread_id} 详情")
if not await ensure_agent_ready():
return jsonify({
"success": False,
"error": "Agent 未就绪",
"timestamp": datetime.now().isoformat()
}), 503
# 🔥 直接调用异步方法
history = await _agent_instance.get_conversation_history(thread_id)
logger.info(f"✅ 异步成功获取对话历史,消息数量: {len(history)}")
if not history:
return jsonify({
"success": False,
"error": f"未找到对话 {thread_id}",
"timestamp": datetime.now().isoformat()
}), 404
return jsonify({
"success": True,
"data": {
"user_id": user_id,
"thread_id": thread_id,
"message_count": len(history),
"messages": history
},
"timestamp": datetime.now().isoformat()
}), 200
except Exception as e:
import traceback
logger.error(f"❌ 异步获取对话 {thread_id} 详情失败: {e}")
logger.error(f"❌ 详细错误信息: {traceback.format_exc()}")
return jsonify({
"success": False,
"error": str(e),
"timestamp": datetime.now().isoformat()
}), 500
async def get_user_conversations_async(user_id: str, limit: int = 10):
"""🔥 完全异步的Redis查询函数"""
global _redis_client
try:
if not _redis_client:
_redis_client = redis.from_url('redis://localhost:6379', decode_responses=True)
await _redis_client.ping()
pattern = f"checkpoint:{user_id}:*"
logger.info(f"🔍 异步扫描模式: {pattern}")
keys = []
cursor = 0
while True:
cursor, batch = await _redis_client.scan(cursor=cursor, match=pattern, count=1000)
keys.extend(batch)
if cursor == 0:
break
logger.info(f"📋 异步找到 {len(keys)} 个keys")
# 解析和处理逻辑(与原来相同,但使用异步Redis操作)
thread_data = {}
for key in keys:
try:
parts = key.split(':')
if len(parts) >= 4:
thread_id = f"{parts[1]}:{parts[2]}"
timestamp = parts[2]
if thread_id not in thread_data:
thread_data[thread_id] = {
"thread_id": thread_id,
"timestamp": timestamp,
"keys": []
}
thread_data[thread_id]["keys"].append(key)
except Exception as e:
logger.warning(f"解析key失败 {key}: {e}")
continue
sorted_threads = sorted(
thread_data.values(),
key=lambda x: x["timestamp"],
reverse=True
)[:limit]
conversations = []
for thread_info in sorted_threads:
try:
thread_id = thread_info["thread_id"]
latest_key = max(thread_info["keys"])
# 🔥 使用异步Redis获取
key_type = await _redis_client.type(latest_key)
data = None
if key_type == 'string':
data = await _redis_client.get(latest_key)
elif key_type == 'ReJSON-RL':
try:
data = await _redis_client.execute_command('JSON.GET', latest_key)
except Exception as json_error:
logger.error(f"❌ 异步JSON.GET 失败: {json_error}")
continue
if data:
try:
import json
checkpoint_data = json.loads(data)
messages = []
if 'checkpoint' in checkpoint_data:
checkpoint = checkpoint_data['checkpoint']
if isinstance(checkpoint, dict) and 'channel_values' in checkpoint:
channel_values = checkpoint['channel_values']
if isinstance(channel_values, dict) and 'messages' in channel_values:
messages = channel_values['messages']
preview = "空对话"
if messages:
for msg in messages:
if isinstance(msg, dict):
if (msg.get('lc') == 1 and
msg.get('type') == 'constructor' and
'id' in msg and
isinstance(msg['id'], list) and
len(msg['id']) >= 4 and
msg['id'][3] == 'HumanMessage' and
'kwargs' in msg):
kwargs = msg['kwargs']
if kwargs.get('type') == 'human' and 'content' in kwargs:
content = str(kwargs['content'])
preview = content[:50] + "..." if len(content) > 50 else content
break
conversations.append({
"thread_id": thread_id,
"user_id": user_id,
"timestamp": thread_info["timestamp"],
"message_count": len(messages),
"conversation_preview": preview
})
except json.JSONDecodeError:
logger.error(f"❌ 异步JSON解析失败")
continue
except Exception as e:
logger.error(f"异步处理thread {thread_info['thread_id']} 失败: {e}")
continue
logger.info(f"✅ 异步返回 {len(conversations)} 个对话")
return conversations
except Exception as e:
logger.error(f"❌ 异步Redis查询失败: {e}")
return []
async def startup():
"""应用启动时的异步初始化"""
logger.info("🚀 启动异步Flask应用...")
try:
await initialize_agent()
logger.info("✅ Agent 预初始化完成")
except Exception as e:
logger.error(f"❌ 启动时Agent初始化失败: {e}")
async def shutdown():
"""应用关闭时的异步清理"""
logger.info("🔄 关闭异步Flask应用...")
try:
await cleanup_agent()
logger.info("✅ 资源清理完成")
except Exception as e:
logger.error(f"❌ 关闭时清理失败: {e}")
if name == "main":
# 🔥 简化的启动方式 - Flask 3.x 原生支持异步
logger.info("🚀 使用Flask内置异步支持启动...")
# 信号处理
import signal
def signal_handler(signum, frame):
logger.info("🛑 收到关闭信号,开始清理...")
# 在信号处理中,我们只能打印消息,实际清理在程序正常退出时进行
print("正在关闭服务...")
exit(0)
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
# 启动Flask应用
app.run(host="0.0.0.0", port=8000, debug=False)