两个关键问题.md 48 KB

让我仔细分析你的代码,回答这两个关键问题:

1. CustomReactAgent类异步化分析

✅ 已经异步化的方法:

@classmethod
async def create(cls):                    # ✅ 异步工厂方法
async def _async_init(self):              # ✅ 异步初始化
async def close(self):                    # ✅ 异步清理
async def chat(self, message, user_id, thread_id):  # ✅ 异步聊天
async def get_conversation_history(self, thread_id): # ✅ 异步获取历史
async def get_user_recent_conversations(self, user_id, limit): # ✅ 异步获取对话列表

⚠️ 仍然是同步的关键方法(问题所在):

def _should_continue(self, state: AgentState) -> str:           # ❌ 同步
def _agent_node(self, state: AgentState) -> Dict[str, Any]:     # ❌ 同步
def _prepare_tool_input_node(self, state: AgentState):         # ❌ 同步  
def _update_state_after_tool_node(self, state: AgentState):    # ❌ 同步
def _format_final_response_node(self, state: AgentState):      # ❌ 同步

问题分析:你的StateGraph节点是同步的,但是内部调用了异步LLM:

def _agent_node(self, state: AgentState) -> Dict[str, Any]:
    # 这是同步函数,但调用异步LLM
    response = self.llm_with_tools.invoke(messages_for_llm)  # 🔥 这里有问题

2. Redis客户端使用分析

你的代码中混用了两种Redis客户端:

异步Redis客户端(在agent.py中):

from redis.asyncio import Redis
redis_client = Redis.from_url(config.REDIS_URL, decode_responses=True)

同步Redis客户端(在api.py和enhanced_redis_api.py中):

import redis  # 同步版本
redis_client = redis.Redis(host='localhost', port=6379, decode_responses=True)

🎯 关键问题:混用导致的冲突

完整的修复方案## 回答你的两个关键问题:

1. CustomReactAgent类异步化分析

❌ 你的CustomReactAgent并没有完全异步化!

虽然主要方法是async的,但StateGraph的节点方法都是同步的,这正是问题的根源:

# ❌ 这些是同步方法,但内部调用异步LLM
def _agent_node(self, state: AgentState) -> Dict[str, Any]:
    response = self.llm_with_tools.invoke(messages_for_llm)  # 🔥问题在这里

def _prepare_tool_input_node(self, state: AgentState) -> Dict[str, Any]:
def _update_state_after_tool_node(self, state: AgentState) -> Dict[str, Any]:
def _format_final_response_node(self, state: AgentState) -> Dict[str, Any]:
def _should_continue(self, state: AgentState) -> str:

核心问题:LangGraph的StateGraph可以支持异步节点,但你没有使用!

2. Redis客户端使用分析

❌ 你混用了同步和异步Redis客户端!

异步Redis客户端(在agent.py中):

from redis.asyncio import Redis  # ✅ 异步版本
redis_client = Redis.from_url(config.REDIS_URL, decode_responses=True)

同步Redis客户端(在api.py和enhanced_redis_api.py中):

import redis  # ❌ 同步版本
redis_client = redis.Redis(host='localhost', port=6379, decode_responses=True)

🎯 在你的场景中是否一定要用异步Redis?

答案:强烈建议使用,原因如下:

  1. 架构一致性:既然Flask路由是异步的,Redis操作也应该异步
  2. 避免阻塞:同步Redis操作会阻塞事件循环,影响并发性能
  3. 消除错误源:混用是导致"Event loop is closed"的重要原因
  4. 对话机器人特性:用户期望快速响应,异步Redis能提供更好的并发能力

🔧 最关键的修改点

我在上面的artifact中提供了完整的修复方案,关键修改包括:

1. StateGraph节点全部改为异步:

# ✅ 修改后:异步节点
async def _async_agent_node(self, state: AgentState) -> Dict[str, Any]:
    response = await self.llm_with_tools.ainvoke(messages_for_llm)  # 使用ainvoke

async def _async_prepare_tool_input_node(self, state: AgentState) -> Dict[str, Any]:
async def _async_update_state_after_tool_node(self, state: AgentState) -> Dict[str, Any]:
async def _async_format_final_response_node(self, state: AgentState) -> Dict[str, Any]:
async def _async_should_continue(self, state: AgentState) -> str:

2. 统一Redis客户端:

# ✅ 统一使用异步Redis
import redis.asyncio as redis
self.redis_client = redis.from_url(config.REDIS_URL, decode_responses=True)
await self.redis_client.ping()

3. 移除复杂的事件循环管理:

# ❌ 删除这些函数
def run_async_safely():  # 删除
def ensure_agent_ready_sync():  # 删除

# ✅ 直接使用async/await
@app.route("/api/chat", methods=["POST"])
async def chat_endpoint():
    agent_result = await _agent_instance.chat(...)  # 直接调用

📋 修改建议总结

  1. 立即修改:将StateGraph所有节点改为async版本
  2. 统一Redis:全部使用redis.asyncio
  3. 删除包装:移除所有run_async_safely类的函数
  4. Flask路由:核心API路由改为async
  5. 测试验证:修改后测试Event loop错误是否消失

按照我提供的修复方案修改后,你就能直接用python api.py启动,不会再有任何事件循环问题!

""" CustomReactAgent完全异步化修复方案 解决StateGraph节点同步/异步混用问题,统一Redis客户端使用 """

============= 修复后的 agent.py =============

import logging import json import pandas as pd from typing import List, Optional, Dict, Any, Tuple from contextlib import AsyncExitStack

from langchain_openai import ChatOpenAI from langchain_core.messages import HumanMessage, ToolMessage, BaseMessage, SystemMessage, AIMessage from langgraph.graph import StateGraph, END from langgraph.prebuilt import ToolNode import redis.asyncio as redis # 🔥 统一使用异步Redis try:

from langgraph.checkpoint.redis import AsyncRedisSaver

except ImportError:

AsyncRedisSaver = None

从新模块导入配置、状态和工具

try:

from . import config
from .state import AgentState
from .sql_tools import sql_tools

except ImportError:

import config
from state import AgentState
from sql_tools import sql_tools

logger = logging.getLogger(name)

class CustomReactAgent:

"""
完全异步化的 CustomReactAgent
所有节点方法都是异步的,统一使用异步Redis客户端
"""
def __init__(self):
    """私有构造函数,请使用 create() 类方法来创建实例。"""
    self.llm = None
    self.tools = None
    self.agent_executor = None
    self.checkpointer = None
    self._exit_stack = None
    self.redis_client = None  # 🔥 添加Redis客户端引用

@classmethod
async def create(cls):
    """异步工厂方法,创建并初始化 CustomReactAgent 实例。"""
    instance = cls()
    await instance._async_init()
    return instance

async def _async_init(self):
    """异步初始化所有组件。"""
    logger.info("🚀 开始初始化 CustomReactAgent...")

    # 1. 初始化异步Redis客户端
    self.redis_client = redis.from_url(config.REDIS_URL, decode_responses=True)
    try:
        await self.redis_client.ping()
        logger.info(f"   ✅ Redis连接成功: {config.REDIS_URL}")
    except Exception as e:
        logger.error(f"   ❌ Redis连接失败: {e}")
        raise

    # 2. 初始化 LLM
    self.llm = ChatOpenAI(
        api_key=config.QWEN_API_KEY,
        base_url=config.QWEN_BASE_URL,
        model=config.QWEN_MODEL,
        temperature=0.1,
        timeout=config.NETWORK_TIMEOUT,
        max_retries=config.MAX_RETRIES,
        extra_body={
            "enable_thinking": False,
            "misc": {
                "ensure_ascii": False
            }
        }
    )
    logger.info(f"   LLM 已初始化,模型: {config.QWEN_MODEL}")

    # 3. 绑定工具
    self.tools = sql_tools
    self.llm_with_tools = self.llm.bind_tools(self.tools)
    logger.info(f"   已绑定 {len(self.tools)} 个工具。")

    # 4. 初始化 Redis Checkpointer
    if config.REDIS_ENABLED and AsyncRedisSaver is not None:
        try:
            self._exit_stack = AsyncExitStack()
            checkpointer_manager = AsyncRedisSaver.from_conn_string(config.REDIS_URL)
            self.checkpointer = await self._exit_stack.enter_async_context(checkpointer_manager)
            await self.checkpointer.asetup()
            logger.info(f"   AsyncRedisSaver 持久化已启用: {config.REDIS_URL}")
        except Exception as e:
            logger.error(f"   ❌ RedisSaver 初始化失败: {e}", exc_info=True)
            if self._exit_stack:
                await self._exit_stack.aclose()
            self.checkpointer = None
    else:
        logger.warning("   Redis 持久化功能已禁用。")

    # 5. 构建 StateGraph
    self.agent_executor = self._create_graph()
    logger.info("   StateGraph 已构建并编译。")
    logger.info("✅ CustomReactAgent 初始化完成。")

async def close(self):
    """清理资源,关闭所有连接。"""
    if self._exit_stack:
        await self._exit_stack.aclose()
        self._exit_stack = None
        self.checkpointer = None
        logger.info("✅ RedisSaver 资源已通过 AsyncExitStack 释放。")

    if self.redis_client:
        await self.redis_client.aclose()
        logger.info("✅ Redis客户端已关闭。")

def _create_graph(self):
    """定义并编译最终的、正确的 StateGraph 结构。"""
    builder = StateGraph(AgentState)

    # 🔥 关键修改:所有节点都是异步的
    builder.add_node("agent", self._async_agent_node)
    builder.add_node("prepare_tool_input", self._async_prepare_tool_input_node)
    builder.add_node("tools", ToolNode(self.tools))
    builder.add_node("update_state_after_tool", self._async_update_state_after_tool_node)
    builder.add_node("format_final_response", self._async_format_final_response_node)

    # 建立正确的边连接
    builder.set_entry_point("agent")
    builder.add_conditional_edges(
        "agent",
        self._async_should_continue,  # 🔥 异步条件判断
        {
            "continue": "prepare_tool_input",
            "end": "format_final_response"
        }
    )
    builder.add_edge("prepare_tool_input", "tools")
    builder.add_edge("tools", "update_state_after_tool")
    builder.add_edge("update_state_after_tool", "agent")
    builder.add_edge("format_final_response", END)

    return builder.compile(checkpointer=self.checkpointer)

async def _async_should_continue(self, state: AgentState) -> str:
    """🔥 异步版本:判断是继续调用工具还是结束。"""
    last_message = state["messages"][-1]
    if hasattr(last_message, "tool_calls") and last_message.tool_calls:
        return "continue"
    return "end"

async def _async_agent_node(self, state: AgentState) -> Dict[str, Any]:
    """🔥 异步版本:Agent 节点,使用异步LLM调用。"""
    logger.info(f"🧠 [Async Node] agent - Thread: {state['thread_id']}")

    messages_for_llm = list(state["messages"])
    if state.get("suggested_next_step"):
        instruction = f"提示:建议下一步使用工具 '{state['suggested_next_step']}'。"
        messages_for_llm.append(SystemMessage(content=instruction))

    # 🔥 关键修改:使用异步LLM调用
    import time
    max_retries = config.MAX_RETRIES
    for attempt in range(max_retries):
        try:
            # 使用异步调用
            response = await self.llm_with_tools.ainvoke(messages_for_llm)
            logger.info(f"   ✅ 异步LLM调用成功")
            return {"messages": [response]}

        except Exception as e:
            error_msg = str(e)
            logger.warning(f"   ⚠️ 异步LLM调用失败 (尝试 {attempt + 1}/{max_retries}): {error_msg}")

            if any(keyword in error_msg for keyword in [
                "Connection error", "APIConnectionError", "ConnectError", 
                "timeout", "远程主机强迫关闭", "网络连接"
            ]):
                if attempt < max_retries - 1:
                    wait_time = config.RETRY_BASE_DELAY ** attempt
                    logger.info(f"   🔄 网络错误,{wait_time}秒后重试...")
                    await asyncio.sleep(wait_time)  # 🔥 使用async sleep
                    continue
                else:
                    logger.error(f"   ❌ 网络连接持续失败,返回降级回答")
                    sql_data = await self._async_extract_latest_sql_data(state["messages"])
                    if sql_data:
                        fallback_content = "抱歉,由于网络连接问题,无法生成完整的文字总结。不过查询已成功执行,结果如下:\n\n" + sql_data
                    else:
                        fallback_content = "抱歉,由于网络连接问题,无法完成此次请求。请稍后重试或检查网络连接。"

                    fallback_response = AIMessage(content=fallback_content)
                    return {"messages": [fallback_response]}
            else:
                logger.error(f"   ❌ LLM调用出现非网络错误: {error_msg}")
                raise e

async def _async_prepare_tool_input_node(self, state: AgentState) -> Dict[str, Any]:
    """🔥 异步版本:信息组装节点。"""
    logger.info(f"🛠️ [Async Node] prepare_tool_input - Thread: {state['thread_id']}")

    last_message = state["messages"][-1]
    if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
        return {"messages": [last_message]}

    new_tool_calls = []
    for tool_call in last_message.tool_calls:
        if tool_call["name"] == "generate_sql":
            logger.info("   检测到 generate_sql 调用,注入历史消息。")
            modified_args = tool_call["args"].copy()

            clean_history = []
            messages_except_current = state["messages"][:-1]

            for msg in messages_except_current:
                if isinstance(msg, HumanMessage):
                    clean_history.append({
                        "type": "human",
                        "content": msg.content
                    })
                elif isinstance(msg, AIMessage):
                    if msg.content and "[Formatted Output]" in msg.content:
                        clean_content = msg.content.replace("[Formatted Output]\n", "")
                        clean_history.append({
                            "type": "ai",
                            "content": clean_content
                        })

            modified_args["history_messages"] = clean_history
            logger.info(f"   注入了 {len(clean_history)} 条过滤后的历史消息")

            new_tool_calls.append({
                "name": tool_call["name"],
                "args": modified_args,
                "id": tool_call["id"],
            })
        else:
            new_tool_calls.append(tool_call)

    last_message.tool_calls = new_tool_calls
    return {"messages": [last_message]}

async def _async_update_state_after_tool_node(self, state: AgentState) -> Dict[str, Any]:
    """🔥 异步版本:在工具执行后,更新 suggested_next_step。"""
    logger.info(f"📝 [Async Node] update_state_after_tool - Thread: {state['thread_id']}")

    last_tool_message = state['messages'][-1]
    tool_name = last_tool_message.name
    tool_output = last_tool_message.content
    next_step = None

    if tool_name == 'generate_sql':
        if "失败" in tool_output or "无法生成" in tool_output:
            next_step = 'answer_with_common_sense'
        else:
            next_step = 'valid_sql'
    elif tool_name == 'valid_sql':
        if "失败" in tool_output:
            next_step = 'analyze_validation_error'
        else:
            next_step = 'run_sql'
    elif tool_name == 'run_sql':
        next_step = 'summarize_final_answer'

    logger.info(f"   Tool '{tool_name}' executed. Suggested next step: {next_step}")
    return {"suggested_next_step": next_step}

async def _async_format_final_response_node(self, state: AgentState) -> Dict[str, Any]:
    """🔥 异步版本:最终输出格式化节点。"""
    logger.info(f"🎨 [Async Node] format_final_response - Thread: {state['thread_id']}")

    last_message = state['messages'][-1]
    last_message.content = f"[Formatted Output]\n{last_message.content}"

    # 生成API格式的数据
    api_data = await self._async_generate_api_data(state)

    return {
        "messages": [last_message],
        "api_data": api_data
    }

async def _async_generate_api_data(self, state: AgentState) -> Dict[str, Any]:
    """🔥 异步版本:生成API格式的数据结构"""
    logger.info("📊 异步生成API格式数据...")

    last_message = state['messages'][-1]
    response_content = last_message.content

    if response_content.startswith("[Formatted Output]\n"):
        response_content = response_content.replace("[Formatted Output]\n", "")

    api_data = {
        "response": response_content
    }

    sql_info = await self._async_extract_sql_and_data(state['messages'])
    if sql_info['sql']:
        api_data["sql"] = sql_info['sql']
    if sql_info['records']:
        api_data["records"] = sql_info['records']

    api_data["react_agent_meta"] = await self._async_collect_agent_metadata(state)

    logger.info(f"   API数据生成完成,包含字段: {list(api_data.keys())}")
    return api_data

async def _async_extract_sql_and_data(self, messages: List[BaseMessage]) -> Dict[str, Any]:
    """🔥 异步版本:从消息历史中提取SQL和数据记录"""
    result = {"sql": None, "records": None}

    last_human_index = -1
    for i in range(len(messages) - 1, -1, -1):
        if isinstance(messages[i], HumanMessage):
            last_human_index = i
            break

    if last_human_index == -1:
        return result

    current_conversation = messages[last_human_index:]
    sql_query = None
    sql_data = None

    for msg in current_conversation:
        if isinstance(msg, ToolMessage):
            if msg.name == 'generate_sql':
                content = msg.content
                if content and not any(keyword in content for keyword in ["失败", "无法生成", "Database query failed"]):
                    sql_query = content.strip()
            elif msg.name == 'run_sql':
                try:
                    import json
                    parsed_data = json.loads(msg.content)
                    if isinstance(parsed_data, list) and len(parsed_data) > 0:
                        columns = list(parsed_data[0].keys()) if parsed_data else []
                        sql_data = {
                            "columns": columns,
                            "rows": parsed_data,
                            "total_row_count": len(parsed_data),
                            "is_limited": False
                        }
                except (json.JSONDecodeError, Exception) as e:
                    logger.warning(f"   解析SQL结果失败: {e}")

    if sql_query:
        result["sql"] = sql_query
    if sql_data:
        result["records"] = sql_data

    return result

async def _async_collect_agent_metadata(self, state: AgentState) -> Dict[str, Any]:
    """🔥 异步版本:收集Agent元数据"""
    messages = state['messages']

    tools_used = []
    sql_execution_count = 0
    context_injected = False
    conversation_rounds = sum(1 for msg in messages if isinstance(msg, HumanMessage))

    for msg in messages:
        if isinstance(msg, ToolMessage):
            if msg.name not in tools_used:
                tools_used.append(msg.name)
            if msg.name == 'run_sql':
                sql_execution_count += 1
        elif isinstance(msg, AIMessage) and hasattr(msg, 'tool_calls') and msg.tool_calls:
            for tool_call in msg.tool_calls:
                tool_name = tool_call.get('name')
                if tool_name and tool_name not in tools_used:
                    tools_used.append(tool_name)

                if (tool_name == 'generate_sql' and 
                    tool_call.get('args', {}).get('history_messages')):
                    context_injected = True

    execution_path = ["agent"]
    if tools_used:
        execution_path.extend(["prepare_tool_input", "tools"])
    execution_path.append("format_final_response")

    return {
        "thread_id": state['thread_id'],
        "conversation_rounds": conversation_rounds,
        "tools_used": tools_used,
        "execution_path": execution_path,
        "total_messages": len(messages),
        "sql_execution_count": sql_execution_count,
        "context_injected": context_injected,
        "agent_version": "custom_react_v1_async"
    }

async def _async_extract_latest_sql_data(self, messages: List[BaseMessage]) -> Optional[str]:
    """🔥 异步版本:提取最新的SQL执行结果"""
    logger.info("🔍 异步提取最新的SQL执行结果...")

    last_human_index = -1
    for i in range(len(messages) - 1, -1, -1):
        if isinstance(messages[i], HumanMessage):
            last_human_index = i
            break

    if last_human_index == -1:
        logger.info("   未找到用户消息,跳过SQL数据提取")
        return None

    current_conversation = messages[last_human_index:]
    logger.info(f"   当前对话轮次包含 {len(current_conversation)} 条消息")

    for msg in reversed(current_conversation):
        if isinstance(msg, ToolMessage) and msg.name == 'run_sql':
            logger.info(f"   找到当前对话轮次的run_sql结果: {msg.content[:100]}...")

            try:
                parsed_data = json.loads(msg.content)
                formatted_content = json.dumps(parsed_data, ensure_ascii=False, separators=(',', ':'))
                logger.info(f"   已转换Unicode转义序列为中文字符")
                return formatted_content
            except json.JSONDecodeError:
                logger.warning(f"   SQL结果不是有效JSON格式,返回原始内容")
                return msg.content

    logger.info("   当前对话轮次中未找到run_sql执行结果")
    return None

async def chat(self, message: str, user_id: str, thread_id: Optional[str] = None) -> Dict[str, Any]:
    """🔥 完全异步的聊天处理方法"""
    if not thread_id:
        now = pd.Timestamp.now()
        milliseconds = int(now.microsecond / 1000)
        thread_id = f"{user_id}:{now.strftime('%Y%m%d%H%M%S')}{milliseconds:03d}"
        logger.info(f"🆕 新建会话,Thread ID: {thread_id}")

    config = {
        "configurable": {
            "thread_id": thread_id,
        }
    }

    inputs = {
        "messages": [HumanMessage(content=message)],
        "user_id": user_id,
        "thread_id": thread_id,
        "suggested_next_step": None,
    }

    try:
        # 🔥 使用异步调用
        final_state = await self.agent_executor.ainvoke(inputs, config)
        answer = final_state["messages"][-1].content

        sql_data = await self._async_extract_latest_sql_data(final_state["messages"])

        logger.info(f"✅ 异步处理完成 - Final Answer: '{answer}'")

        result = {
            "success": True, 
            "answer": answer, 
            "thread_id": thread_id
        }

        if sql_data:
            result["sql_data"] = sql_data
            logger.info("   📊 已包含SQL原始数据")

        if "api_data" in final_state:
            result["api_data"] = final_state["api_data"]
            logger.info("   🔌 已包含API格式数据")

        return result

    except Exception as e:
        logger.error(f"❌ 异步处理过程中发生严重错误 - Thread: {thread_id}: {e}", exc_info=True)
        return {"success": False, "error": str(e), "thread_id": thread_id}

async def get_conversation_history(self, thread_id: str) -> List[Dict[str, Any]]:
    """🔥 完全异步的对话历史获取"""
    if not self.checkpointer:
        return []

    config = {"configurable": {"thread_id": thread_id}}
    try:
        conversation_state = await self.checkpointer.aget(config)
    except RuntimeError as e:
        if "Event loop is closed" in str(e):
            logger.warning(f"⚠️ Event loop已关闭,返回空结果: {thread_id}")
            return []
        else:
            raise

    if not conversation_state:
        return []

    history = []
    messages = conversation_state.get('channel_values', {}).get('messages', [])
    for msg in messages:
        if isinstance(msg, HumanMessage):
            role = "human"
        elif isinstance(msg, ToolMessage):
            role = "tool"
        else:
            role = "ai"

        history.append({
            "type": role,
            "content": msg.content,
            "tool_calls": getattr(msg, 'tool_calls', None)
        })
    return history 

async def get_user_recent_conversations(self, user_id: str, limit: int = 10) -> List[Dict[str, Any]]:
    """🔥 完全异步的用户对话列表获取"""
    if not self.checkpointer:
        return []

    try:
        # 🔥 使用统一的异步Redis客户端
        pattern = f"checkpoint:{user_id}:*"
        logger.info(f"🔍 异步扫描模式: {pattern}")

        user_threads = {}
        cursor = 0

        while True:
            cursor, keys = await self.redis_client.scan(
                cursor=cursor,
                match=pattern,
                count=1000
            )

            for key in keys:
                try:
                    key_str = key.decode() if isinstance(key, bytes) else key
                    parts = key_str.split(':')

                    if len(parts) >= 4:
                        thread_id = f"{parts[1]}:{parts[2]}"
                        timestamp = parts[2]

                        if thread_id not in user_threads:
                            user_threads[thread_id] = {
                                "thread_id": thread_id,
                                "timestamp": timestamp,
                                "latest_key": key_str
                            }
                        else:
                            if len(parts) > 4 and parts[4] > user_threads[thread_id]["latest_key"].split(':')[4]:
                                user_threads[thread_id]["latest_key"] = key_str

                except Exception as e:
                    logger.warning(f"解析key {key} 失败: {e}")
                    continue

            if cursor == 0:
                break

        # 按时间戳排序
        sorted_threads = sorted(
            user_threads.values(),
            key=lambda x: x["timestamp"],
            reverse=True
        )[:limit]

        # 获取每个thread的详细信息
        conversations = []
        for thread_info in sorted_threads:
            try:
                thread_id = thread_info["thread_id"]
                thread_config = {"configurable": {"thread_id": thread_id}}

                try:
                    state = await self.checkpointer.aget(thread_config)
                except RuntimeError as e:
                    if "Event loop is closed" in str(e):
                        logger.warning(f"⚠️ Event loop已关闭,跳过thread: {thread_id}")
                        continue
                    else:
                        raise

                if state and state.get('channel_values', {}).get('messages'):
                    messages = state['channel_values']['messages']
                    preview = self._generate_conversation_preview(messages)

                    conversations.append({
                        "thread_id": thread_id,
                        "user_id": user_id,
                        "timestamp": thread_info["timestamp"],
                        "message_count": len(messages),
                        "last_message": messages[-1].content if messages else None,
                        "last_updated": state.get('created_at'),
                        "conversation_preview": preview,
                        "formatted_time": self._format_timestamp(thread_info["timestamp"])
                    })

            except Exception as e:
                logger.error(f"获取thread {thread_info['thread_id']} 详情失败: {e}")
                continue

        logger.info(f"✅ 异步找到用户 {user_id} 的 {len(conversations)} 个对话")
        return conversations

    except Exception as e:
        logger.error(f"❌ 异步获取用户 {user_id} 对话列表失败: {e}")
        return []

def _generate_conversation_preview(self, messages: List[BaseMessage]) -> str:
    """生成对话预览(保持同步,因为是纯计算)"""
    if not messages:
        return "空对话"

    for msg in messages:
        if isinstance(msg, HumanMessage):
            content = str(msg.content)
            return content[:50] + "..." if len(content) > 50 else content

    return "系统消息"

def _format_timestamp(self, timestamp: str) -> str:
    """格式化时间戳为可读格式(保持同步,因为是纯计算)"""
    try:
        if len(timestamp) >= 14:
            year = timestamp[:4]
            month = timestamp[4:6]
            day = timestamp[6:8]
            hour = timestamp[8:10]
            minute = timestamp[10:12]
            second = timestamp[12:14]
            return f"{year}-{month}-{day} {hour}:{minute}:{second}"
    except Exception:
        pass
    return timestamp

============= 修复后的 api.py 关键部分 =============

""" 修复后的 api.py - 统一使用异步Redis客户端,移除复杂的事件循环管理 """

import asyncio import logging import os from datetime import datetime from typing import Optional, Dict, Any

from flask import Flask, request, jsonify import redis.asyncio as redis # 🔥 统一使用异步Redis

try:

from .agent import CustomReactAgent

except ImportError:

from agent import CustomReactAgent

logging.basicConfig(level=logging.INFO) logger = logging.getLogger(name)

全局Agent实例

_agent_instance: Optional[CustomReactAgent] = None _redis_client: Optional[redis.Redis] = None

def validate_request_data(data: Dict[str, Any]) -> Dict[str, Any]:

"""验证请求数据(保持不变)"""
errors = []

question = data.get('question', '')
if not question or not question.strip():
    errors.append('问题不能为空')
elif len(question) > 2000:
    errors.append('问题长度不能超过2000字符')

user_id = data.get('user_id', 'guest')
if user_id and len(user_id) > 50:
    errors.append('用户ID长度不能超过50字符')

if errors:
    raise ValueError('; '.join(errors))

return {
    'question': question.strip(),
    'user_id': user_id or 'guest',
    'thread_id': data.get('thread_id')
}

async def initialize_agent():

"""🔥 异步初始化Agent"""
global _agent_instance, _redis_client

if _agent_instance is None:
    logger.info("🚀 正在异步初始化 Custom React Agent...")
    try:
        os.environ['REDIS_URL'] = 'redis://localhost:6379'

        # 初始化共享的Redis客户端
        _redis_client = redis.from_url('redis://localhost:6379', decode_responses=True)
        await _redis_client.ping()

        _agent_instance = await CustomReactAgent.create()
        logger.info("✅ Agent 异步初始化完成")
    except Exception as e:
        logger.error(f"❌ Agent 异步初始化失败: {e}")
        raise

async def ensure_agent_ready():

"""🔥 异步确保Agent实例可用"""
global _agent_instance

if _agent_instance is None:
    await initialize_agent()

try:
    test_result = await _agent_instance.get_user_recent_conversations("__test__", 1)
    return True
except Exception as e:
    logger.warning(f"⚠️ Agent实例不可用: {e}")
    _agent_instance = None
    await initialize_agent()
    return True

async def cleanup_agent():

"""🔥 异步清理Agent资源"""
global _agent_instance, _redis_client

if _agent_instance:
    await _agent_instance.close()
    logger.info("✅ Agent 资源已异步清理")
    _agent_instance = None

if _redis_client:
    await _redis_client.aclose()
    logger.info("✅ Redis客户端已异步关闭")
    _redis_client = None

创建Flask应用

app = Flask(name)

🔥 移除所有同步包装函数:run_async_safely, ensure_agent_ready_sync

@app.route("/") def root():

"""健康检查端点(保持同步)"""
return jsonify({"message": "Custom React Agent API 服务正在运行"})

@app.route('/health', methods=['GET']) def health_check():

"""健康检查端点(保持同步)"""
try:
    health_status = {
        "status": "healthy",
        "agent_initialized": _agent_instance is not None,
        "timestamp": datetime.now().isoformat()
    }
    return jsonify(health_status), 200
except Exception as e:
    logger.error(f"健康检查失败: {e}")
    return jsonify({"status": "unhealthy", "error": str(e)}), 500

@app.route("/api/chat", methods=["POST"]) async def chat_endpoint():

"""🔥 异步智能问答接口"""
global _agent_instance

# 确保Agent已初始化
if not await ensure_agent_ready():
    return jsonify({
        "code": 503,
        "message": "服务未就绪",
        "success": False,
        "error": "Agent 初始化失败"
    }), 503

try:
    data = request.get_json()
    if not data:
        return jsonify({
            "code": 400,
            "message": "请求参数错误",
            "success": False,
            "error": "请求体不能为空"
        }), 400

    validated_data = validate_request_data(data)

    logger.info(f"📨 收到请求 - User: {validated_data['user_id']}, Question: {validated_data['question'][:50]}...")

    # 🔥 直接调用异步方法,不需要事件循环包装
    agent_result = await _agent_instance.chat(
        message=validated_data['question'],
        user_id=validated_data['user_id'],
        thread_id=validated_data['thread_id']
    )

    if not agent_result.get("success", False):
        error_msg = agent_result.get("error", "Agent处理失败")
        logger.error(f"❌ Agent处理失败: {error_msg}")

        return jsonify({
            "code": 500,
            "message": "处理失败",
            "success": False,
            "error": error_msg,
            "data": {
                "react_agent_meta": {
                    "thread_id": agent_result.get("thread_id"),
                    "agent_version": "custom_react_v1_async",
                    "execution_path": ["error"]
                },
                "timestamp": datetime.now().isoformat()
            }
        }), 500

    api_data = agent_result.get("api_data", {})
    response_data = {
        **api_data,
        "timestamp": datetime.now().isoformat()
    }

    logger.info(f"✅ 异步请求处理成功 - Thread: {api_data.get('react_agent_meta', {}).get('thread_id')}")

    return jsonify({
        "code": 200,
        "message": "操作成功",
        "success": True,
        "data": response_data
    })

except ValueError as e:
    logger.warning(f"⚠️ 参数验证失败: {e}")
    return jsonify({
        "code": 400,
        "message": "请求参数错误",
        "success": False,
        "error": str(e)
    }), 400

except Exception as e:
    logger.error(f"❌ 未预期的错误: {e}", exc_info=True)
    return jsonify({
        "code": 500,
        "message": "服务器内部错误", 
        "success": False,
        "error": "系统异常,请稍后重试"
    }), 500

@app.route('/api/v0/react/users//conversations', methods=['GET']) async def get_user_conversations(user_id: str):

"""🔥 异步获取用户的聊天记录列表"""
global _agent_instance

try:
    limit = request.args.get('limit', 10, type=int)
    limit = max(1, min(limit, 50))

    logger.info(f"📋 异步获取用户 {user_id} 的对话列表,限制 {limit} 条")

    if not await ensure_agent_ready():
        return jsonify({
            "success": False,
            "error": "Agent 未就绪",
            "timestamp": datetime.now().isoformat()
        }), 503

    # 🔥 直接调用异步方法
    conversations = await _agent_instance.get_user_recent_conversations(user_id, limit)

    return jsonify({
        "success": True,
        "data": {
            "user_id": user_id,
            "conversations": conversations,
            "total_count": len(conversations),
            "limit": limit
        },
        "timestamp": datetime.now().isoformat()
    }), 200

except Exception as e:
    logger.error(f"❌ 异步获取用户 {user_id} 对话列表失败: {e}")
    return jsonify({
        "success": False,
        "error": str(e),
        "timestamp": datetime.now().isoformat()
    }), 500

@app.route('/api/v0/react/users//conversations/', methods=['GET']) async def get_user_conversation_detail(user_id: str, thread_id: str):

"""🔥 异步获取特定对话的详细历史"""
global _agent_instance

try:
    if not thread_id.startswith(f"{user_id}:"):
        return jsonify({
            "success": False,
            "error": f"Thread ID {thread_id} 不属于用户 {user_id}",
            "timestamp": datetime.now().isoformat()
        }), 400

    logger.info(f"📖 异步获取用户 {user_id} 的对话 {thread_id} 详情")

    if not await ensure_agent_ready():
        return jsonify({
            "success": False,
            "error": "Agent 未就绪",
            "timestamp": datetime.now().isoformat()
        }), 503

    # 🔥 直接调用异步方法
    history = await _agent_instance.get_conversation_history(thread_id)
    logger.info(f"✅ 异步成功获取对话历史,消息数量: {len(history)}")

    if not history:
        return jsonify({
            "success": False,
            "error": f"未找到对话 {thread_id}",
            "timestamp": datetime.now().isoformat()
        }), 404

    return jsonify({
        "success": True,
        "data": {
            "user_id": user_id,
            "thread_id": thread_id,
            "message_count": len(history),
            "messages": history
        },
        "timestamp": datetime.now().isoformat()
    }), 200

except Exception as e:
    import traceback
    logger.error(f"❌ 异步获取对话 {thread_id} 详情失败: {e}")
    logger.error(f"❌ 详细错误信息: {traceback.format_exc()}")
    return jsonify({
        "success": False,
        "error": str(e),
        "timestamp": datetime.now().isoformat()
    }), 500

🔥 异步Redis API(如果还需要直接Redis访问)

async def get_user_conversations_async(user_id: str, limit: int = 10):

"""🔥 完全异步的Redis查询函数"""
global _redis_client

try:
    if not _redis_client:
        _redis_client = redis.from_url('redis://localhost:6379', decode_responses=True)
        await _redis_client.ping()

    pattern = f"checkpoint:{user_id}:*"
    logger.info(f"🔍 异步扫描模式: {pattern}")

    keys = []
    cursor = 0
    while True:
        cursor, batch = await _redis_client.scan(cursor=cursor, match=pattern, count=1000)
        keys.extend(batch)
        if cursor == 0:
            break

    logger.info(f"📋 异步找到 {len(keys)} 个keys")

    # 解析和处理逻辑(与原来相同,但使用异步Redis操作)
    thread_data = {}
    for key in keys:
        try:
            parts = key.split(':')
            if len(parts) >= 4:
                thread_id = f"{parts[1]}:{parts[2]}"
                timestamp = parts[2]

                if thread_id not in thread_data:
                    thread_data[thread_id] = {
                        "thread_id": thread_id,
                        "timestamp": timestamp,
                        "keys": []
                    }
                thread_data[thread_id]["keys"].append(key)
        except Exception as e:
            logger.warning(f"解析key失败 {key}: {e}")
            continue

    sorted_threads = sorted(
        thread_data.values(),
        key=lambda x: x["timestamp"],
        reverse=True
    )[:limit]

    conversations = []
    for thread_info in sorted_threads:
        try:
            thread_id = thread_info["thread_id"]
            latest_key = max(thread_info["keys"])

            # 🔥 使用异步Redis获取
            key_type = await _redis_client.type(latest_key)

            data = None
            if key_type == 'string':
                data = await _redis_client.get(latest_key)
            elif key_type == 'ReJSON-RL':
                try:
                    data = await _redis_client.execute_command('JSON.GET', latest_key)
                except Exception as json_error:
                    logger.error(f"❌ 异步JSON.GET 失败: {json_error}")
                    continue

            if data:
                try:
                    import json
                    checkpoint_data = json.loads(data)

                    messages = []
                    if 'checkpoint' in checkpoint_data:
                        checkpoint = checkpoint_data['checkpoint']
                        if isinstance(checkpoint, dict) and 'channel_values' in checkpoint:
                            channel_values = checkpoint['channel_values']
                            if isinstance(channel_values, dict) and 'messages' in channel_values:
                                messages = channel_values['messages']

                    preview = "空对话"
                    if messages:
                        for msg in messages:
                            if isinstance(msg, dict):
                                if (msg.get('lc') == 1 and 
                                    msg.get('type') == 'constructor' and 
                                    'id' in msg and 
                                    isinstance(msg['id'], list) and 
                                    len(msg['id']) >= 4 and
                                    msg['id'][3] == 'HumanMessage' and
                                    'kwargs' in msg):

                                    kwargs = msg['kwargs']
                                    if kwargs.get('type') == 'human' and 'content' in kwargs:
                                        content = str(kwargs['content'])
                                        preview = content[:50] + "..." if len(content) > 50 else content
                                        break

                    conversations.append({
                        "thread_id": thread_id,
                        "user_id": user_id,
                        "timestamp": thread_info["timestamp"],
                        "message_count": len(messages),
                        "conversation_preview": preview
                    })

                except json.JSONDecodeError:
                    logger.error(f"❌ 异步JSON解析失败")
                    continue

        except Exception as e:
            logger.error(f"异步处理thread {thread_info['thread_id']} 失败: {e}")
            continue

    logger.info(f"✅ 异步返回 {len(conversations)} 个对话")
    return conversations

except Exception as e:
    logger.error(f"❌ 异步Redis查询失败: {e}")
    return []

🔥 异步启动和清理

async def startup():

"""应用启动时的异步初始化"""
logger.info("🚀 启动异步Flask应用...")
try:
    await initialize_agent()
    logger.info("✅ Agent 预初始化完成")
except Exception as e:
    logger.error(f"❌ 启动时Agent初始化失败: {e}")

async def shutdown():

"""应用关闭时的异步清理"""
logger.info("🔄 关闭异步Flask应用...")
try:
    await cleanup_agent()
    logger.info("✅ 资源清理完成")
except Exception as e:
    logger.error(f"❌ 关闭时清理失败: {e}")

if name == "main":

# 🔥 简化的启动方式 - Flask 3.x 原生支持异步
logger.info("🚀 使用Flask内置异步支持启动...")

# 信号处理
import signal

def signal_handler(signum, frame):
    logger.info("🛑 收到关闭信号,开始清理...")
    # 在信号处理中,我们只能打印消息,实际清理在程序正常退出时进行
    print("正在关闭服务...")
    exit(0)

signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)

# 启动Flask应用
app.run(host="0.0.0.0", port=8000, debug=False)