瀏覽代碼

ask_react_agent_stream api 创建完成,现在准备优化step返回的结果.

wangxq 2 周之前
父節點
當前提交
586bff9728
共有 2 個文件被更改,包括 373 次插入2 次删除
  1. 156 0
      react_agent/agent.py
  2. 217 2
      unified_api.py

+ 156 - 0
react_agent/agent.py

@@ -1124,6 +1124,162 @@ class CustomReactAgent:
             else:
                 logger.error(f"❌ 处理过程中发生严重错误 - Thread: {thread_id}: {e}", exc_info=True)
                 return {"success": False, "error": str(e), "thread_id": thread_id}
+
+    async def chat_stream(self, message: str, user_id: str, thread_id: Optional[str] = None):
+        """
+        流式处理用户聊天请求 - 复用chat()方法的所有逻辑
+        
+        Args:
+            message: 用户消息
+            user_id: 用户ID
+            thread_id: 会话ID,可选,不传则自动生成
+            
+        Yields:
+            Dict: 包含进度信息或最终结果的字典
+                - type: "progress" | "completed" | "error"
+                - node: 节点名称 (仅progress)
+                - data: 节点数据 (仅progress) 
+                - thread_id: 会话ID
+                - result: 最终结果 (仅completed)
+                - error: 错误信息 (仅error)
+        """
+        # 1. 复用现有的初始化逻辑(thread_id生成、配置检查等)
+        if not thread_id:
+            now = pd.Timestamp.now()
+            milliseconds = int(now.microsecond / 1000)
+            thread_id = f"{user_id}:{now.strftime('%Y%m%d%H%M%S')}{milliseconds:03d}"
+            logger.info(f"🆕 流式处理 - 新建会话,Thread ID: {thread_id}")
+        
+        # 2. 复用现有的配置和错误处理
+        self._recursion_count = 0
+        
+        run_config = {
+            "configurable": {
+                "thread_id": thread_id,
+            },
+            "recursion_limit": config.RECURSION_LIMIT
+        }
+        
+        logger.info(f"🔢 流式处理 - 递归限制设置: {config.RECURSION_LIMIT}")
+        
+        inputs = {
+            "messages": [HumanMessage(content=message)],
+            "user_id": user_id,
+            "thread_id": thread_id,
+            "suggested_next_step": None,
+        }
+
+        try:
+            logger.info(f"🚀 流式处理开始 - 用户消息: {message[:50]}...")
+            
+            # 3. 复用checkpointer检查逻辑
+            if self.checkpointer:
+                try:
+                    # 简单的连接测试 - 不用aget_tuple因为可能没有数据
+                    # 直接测试Redis连接
+                    if hasattr(self.checkpointer, 'conn') and self.checkpointer.conn:
+                        await self.checkpointer.conn.ping()
+                except Exception as checkpoint_error:
+                    if "Event loop is closed" in str(checkpoint_error) or "closed" in str(checkpoint_error).lower():
+                        logger.warning(f"⚠️ 流式处理 - Checkpointer连接异常,尝试重新初始化: {checkpoint_error}")
+                        await self._reinitialize_checkpointer()
+                        # 重新构建graph使用新的checkpointer
+                        self.agent_executor = self._create_graph()
+                    else:
+                        logger.warning(f"⚠️ 流式处理 - Checkpointer测试失败,但继续执行: {checkpoint_error}")
+            
+            # 4. 使用astream流式执行
+            final_state = None
+            async for chunk in self.agent_executor.astream(inputs, run_config, stream_mode="updates"):
+                for node_name, node_data in chunk.items():
+                    logger.debug(f"🔄 流式进度 - 节点: {node_name}")
+                    yield {
+                        "type": "progress",
+                        "node": node_name,
+                        "data": node_data,
+                        "thread_id": thread_id
+                    }
+                    final_state = node_data
+            
+            # 获取完整的最终状态(包含所有字段)
+            if final_state:
+                # 确保 final_state 包含必需的字段
+                if "thread_id" not in final_state:
+                    final_state["thread_id"] = thread_id
+                if "user_id" not in final_state:
+                    final_state["user_id"] = user_id
+            
+            # 5. 复用现有的结果处理逻辑
+            if final_state and "messages" in final_state:
+                # 🔍 调试:打印 final_state 的所有 keys
+                logger.info(f"🔍 流式处理 - Final state keys: {list(final_state.keys())}")
+                
+                answer = final_state["messages"][-1].content
+                
+                # 🎯 提取最近的 run_sql 执行结果(不修改messages)
+                sql_data = await self._async_extract_latest_sql_data(final_state["messages"])
+                
+                logger.info(f"✅ 流式处理完成 - Final Answer: '{answer}'")
+                
+                # 构建返回结果(保持简化格式用于shell.py)
+                result = {
+                    "success": True, 
+                    "answer": answer, 
+                    "thread_id": thread_id
+                }
+                
+                # 只有当存在SQL数据时才添加到返回结果中
+                if sql_data:
+                    result["sql_data"] = sql_data
+                    logger.info("   📊 流式处理 - 已包含SQL原始数据")
+                
+                # 生成API格式数据
+                api_data = await self._async_generate_api_data(final_state)
+                result["api_data"] = api_data
+                logger.info("   🔌 流式处理 - 已生成API格式数据")
+                
+                yield {
+                    "type": "completed",
+                    "result": {"api_data": api_data, "thread_id": thread_id}
+                }
+            else:
+                # 如果没有获取到正确的final_state,返回错误
+                logger.error(f"❌ 流式处理 - 未获取到有效的final_state")
+                yield {
+                    "type": "error",
+                    "error": "处理失败:未获取到有效的执行结果",
+                    "thread_id": thread_id
+                }
+            
+        except Exception as e:
+            logger.error(f"❌ 流式处理异常 - Thread: {thread_id}: {e}", exc_info=True)
+            
+            # 特殊处理Redis相关的Event loop错误
+            if "Event loop is closed" in str(e):
+                # 尝试重新初始化checkpointer
+                try:
+                    await self._reinitialize_checkpointer()
+                    self.agent_executor = self._create_graph()
+                    logger.info("🔄 流式处理 - 已重新初始化checkpointer,请重试请求")
+                    yield {
+                        "type": "error",
+                        "error": "Redis连接问题,请重试",
+                        "thread_id": thread_id,
+                        "retry_suggested": True
+                    }
+                except Exception as reinit_error:
+                    logger.error(f"❌ 流式处理 - 重新初始化失败: {reinit_error}")
+                    yield {
+                        "type": "error", 
+                        "error": "服务暂时不可用,请稍后重试",
+                        "thread_id": thread_id
+                    }
+            else:
+                yield {
+                    "type": "error",
+                    "error": str(e),
+                    "thread_id": thread_id
+                }
     
     async def get_conversation_history(self, thread_id: str, include_tools: bool = False) -> Dict[str, Any]:
         """

+ 217 - 2
unified_api.py

@@ -10,6 +10,7 @@ import logging
 import atexit
 import os
 import sys
+import time
 from datetime import datetime, timedelta, timezone
 import pytz
 from typing import Optional, Dict, Any, TYPE_CHECKING, Union
@@ -88,7 +89,8 @@ redis_conversation_manager = RedisConversationManager()
 
 # ==================== React Agent 全局实例管理 ====================
 
-_react_agent_instance: Optional[Any] = None
+_react_agent_instance: Optional[Any] = None  # 同步工具,用于 ask_react_agent
+_react_agent_stream_instance: Optional[Any] = None  # 异步工具,用于 ask_react_agent_stream
 _redis_client: Optional[redis.Redis] = None
 
 def _format_timestamp_to_china_time(timestamp_str):
@@ -319,6 +321,29 @@ async def ensure_agent_ready() -> bool:
         await get_react_agent()
         return True
 
+async def create_stream_agent_instance():
+    """为每个流式请求创建新的Agent实例(使用异步工具)"""
+    if CustomReactAgent is None:
+        logger.error("❌ CustomReactAgent 未能导入,无法初始化流式Agent")
+        raise ImportError("CustomReactAgent 未能导入")
+        
+    logger.info("🚀 正在为流式请求创建新的 React Agent 实例...")
+    try:
+        # 创建流式专用 Agent 实例
+        stream_agent = await CustomReactAgent.create()
+        
+        # 配置使用异步 SQL 工具
+        from react_agent.async_sql_tools import async_sql_tools
+        stream_agent.tools = async_sql_tools
+        stream_agent.llm_with_tools = stream_agent.llm.bind_tools(async_sql_tools)
+        
+        logger.info("✅ 流式 React Agent 实例创建完成(配置异步工具)")
+        return stream_agent
+        
+    except Exception as e:
+        logger.error(f"❌ 流式 React Agent 实例创建失败: {e}")
+        raise
+
 def get_user_conversations_simple_sync(user_id: str, limit: int = 10):
     """直接从Redis获取用户对话,测试版本"""
     import redis
@@ -684,6 +709,119 @@ async def ask_react_agent():
             "error": "服务暂时不可用,请稍后重试"
         }), 500
 
+@app.route('/api/v0/ask_react_agent_stream', methods=['GET'])
+def ask_react_agent_stream():
+    """React Agent 流式API - 使用异步工具的专用 Agent 实例
+    功能与ask_react_agent完全相同,除了采用流式输出
+    """
+    def generate():
+        try:
+            # 1. 参数获取和验证(从URL参数,因为EventSource只支持GET)
+            question = request.args.get('question')
+            user_id_input = request.args.get('user_id')
+            thread_id_input = request.args.get('thread_id')
+            
+            # 参数验证(复用现有validate_request_data逻辑)
+            if not question:
+                yield format_sse_error("缺少必需参数:question")
+                return
+                
+            # 2. 数据预处理(与ask_react_agent相同)
+            try:
+                validated_data = validate_request_data({
+                    'question': question,
+                    'user_id': user_id_input,
+                    'thread_id': thread_id_input
+                })
+            except ValueError as ve:
+                yield format_sse_error(f"参数验证失败: {str(ve)}")
+                return
+            
+            logger.info(f"📨 收到React Agent流式请求 - User: {validated_data['user_id']}, Question: {validated_data['question'][:50]}...")
+            
+            # 3. 为当前请求创建新的事件循环和Agent实例
+            import asyncio
+            
+            # 创建新的事件循环
+            loop = asyncio.new_event_loop()
+            asyncio.set_event_loop(loop)
+            
+            stream_agent = None
+            try:
+                # 为当前请求创建新的Agent实例
+                stream_agent = loop.run_until_complete(create_stream_agent_instance())
+                
+                if not stream_agent:
+                    yield format_sse_error("流式 React Agent 初始化失败")
+                    return
+            except Exception as e:
+                logger.error(f"流式 Agent 初始化异常: {str(e)}")
+                yield format_sse_error(f"流式 Agent 初始化失败: {str(e)}")
+                return
+            
+            # 4. 在同一个事件循环中执行流式处理
+            try:
+                # 创建异步生成器
+                async def stream_worker():
+                    try:
+                        # 使用当前请求的 Agent 实例(已配置异步工具)
+                        async for chunk in stream_agent.chat_stream(
+                            message=validated_data['question'],
+                            user_id=validated_data['user_id'],
+                            thread_id=validated_data['thread_id']
+                        ):
+                            yield chunk
+                            if chunk.get("type") == "completed":
+                                break
+                    except Exception as e:
+                        logger.error(f"流式处理异常: {str(e)}", exc_info=True)
+                        yield {
+                            "type": "error", 
+                            "error": f"流式处理异常: {str(e)}"
+                        }
+                
+                # 在当前事件循环中运行异步生成器
+                async_gen = stream_worker()
+                
+                # 同步迭代异步生成器
+                while True:
+                    try:
+                        chunk = loop.run_until_complete(async_gen.__anext__())
+                        
+                        if chunk["type"] == "progress":
+                            yield format_sse_react_progress(chunk)
+                        elif chunk["type"] == "completed":
+                            yield format_sse_react_completed(chunk)
+                            break
+                        elif chunk["type"] == "error":
+                            yield format_sse_error(chunk.get("error", "未知错误"))
+                            break
+                            
+                    except StopAsyncIteration:
+                        break
+                    except Exception as e:
+                        logger.error(f"处理流式数据异常: {str(e)}")
+                        yield format_sse_error(f"处理异常: {str(e)}")
+                        break
+                        
+            except Exception as e:
+                logger.error(f"React Agent流式处理异常: {str(e)}")
+                yield format_sse_error(f"流式处理异常: {str(e)}")
+            finally:
+                # 清理:流式处理完成后关闭事件循环
+                try:
+                    loop.close()
+                except Exception as e:
+                    logger.warning(f"关闭事件循环时出错: {e}")
+                    
+        except Exception as e:
+            logger.error(f"React Agent流式API异常: {str(e)}")
+            yield format_sse_error(f"服务异常: {str(e)}")
+    
+    return Response(stream_with_context(generate()), mimetype='text/event-stream')
+
+
+
 @app.route('/api/v0/react/status/<thread_id>', methods=['GET'])
 async def get_react_agent_status(thread_id: str):
     """获取React Agent执行状态,使用LangGraph API"""
@@ -1688,6 +1826,78 @@ def format_sse_completed(chunk: dict) -> str:
     import json
     return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
 
+def format_sse_react_progress(chunk: dict) -> str:
+    """格式化React Agent进度事件为SSE格式"""
+    node = chunk.get("node")
+    thread_id = chunk.get("thread_id")
+    
+    # 节点显示名称映射
+    node_display_map = {
+        "__start__": "开始处理",
+        "trim_messages": "准备消息", 
+        "agent": "AI思考中",
+        "prepare_tool_input": "准备工具",
+        "tools": "执行查询",
+        "update_state_after_tool": "处理结果",
+        "format_final_response": "生成回答",
+        "__end__": "完成"
+    }
+    
+    display_name = node_display_map.get(node, "处理中")
+    
+    data = {
+        "code": 200,
+        "success": True,
+        "message": f"正在执行: {display_name}",
+        "data": {
+            "type": "progress",
+            "node": node,
+            "display_name": display_name,
+            "thread_id": thread_id,
+            "timestamp": datetime.now().isoformat()
+        }
+    }
+    
+    import json
+    return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
+
+def format_sse_react_completed(chunk: dict) -> str:
+    """格式化React Agent完成事件为SSE格式"""
+    result = chunk.get("result", {})
+    api_data = result.get("api_data", {})
+    thread_id = result.get("thread_id")
+    
+    # 构建与ask_react_agent相同的响应格式
+    response_data = {
+        "response": api_data.get("response", ""),
+        "conversation_id": thread_id,
+        "user_id": api_data.get("react_agent_meta", {}).get("user_id", ""),
+        "react_agent_meta": api_data.get("react_agent_meta", {
+            "thread_id": thread_id,
+            "agent_version": "custom_react_v1_async"
+        }),
+        "timestamp": datetime.now().isoformat()
+    }
+    
+    # 可选字段
+    if "sql" in api_data:
+        response_data["sql"] = api_data["sql"]
+    if "records" in api_data:
+        response_data["records"] = api_data["records"]
+    
+    data = {
+        "code": 200,
+        "success": True,
+        "message": "处理完成",
+        "data": {
+            "type": "completed",
+            **response_data
+        }
+    }
+    
+    import json
+    return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
+
 def format_sse_error(error_message: str) -> str:
     """格式化错误事件为SSE格式"""
     data = {
@@ -1704,6 +1914,11 @@ def format_sse_error(error_message: str) -> str:
     import json
     return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
 
+def format_sse_data(data: dict) -> str:
+    """格式化普通数据事件为SSE格式"""
+    import json
+    return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
+
 # ==================== QA反馈系统API ====================
 
 qa_feedback_manager = None
@@ -5836,7 +6051,7 @@ if __name__ == '__main__':
     logger.info("📋 备份列表API: http://localhost:8084/api/v0/data_pipeline/vector/restore/list")
     
     # 并发问题解决方案:已确认WsgiToAsgi导致阻塞,使用原生Flask解决
-    USE_WSGI_TO_ASGI = False  # 使用原生Flask并发,解决状态API阻塞问题
+    USE_WSGI_TO_ASGI = False  # 暂时回到原生Flask模式,解决ASGI兼容性问题
     
     if USE_WSGI_TO_ASGI:
         try: