Jelajahi Sumber

完成了react方式的步骤轮询查询功能,准备对State进行裁剪.

wangxq 3 minggu lalu
induk
melakukan
1672d4ab47
1 mengubah file dengan 409 tambahan dan 25 penghapusan
  1. 409 25
      unified_api.py

+ 409 - 25
unified_api.py

@@ -684,6 +684,373 @@ async def ask_react_agent():
             "error": "服务暂时不可用,请稍后重试"
         }), 500
 
+@app.route('/api/v0/react/status/<thread_id>', methods=['GET'])
+async def get_react_agent_status(thread_id: str):
+    """获取React Agent执行状态,使用LangGraph API"""
+    
+    try:
+        global _react_agent_instance
+        
+        if not _react_agent_instance:
+            from common.result import failed
+            return jsonify(failed(message="Agent实例未初始化", code=500)), 500
+        
+        # 工具状态映射
+        TOOL_STATUS_MAPPING = {
+            "generate_sql": {"name": "生成SQL中", "icon": "🔍"},
+            "valid_sql": {"name": "验证SQL中", "icon": "✅"}, 
+            "run_sql": {"name": "执行查询中", "icon": "⚡"},
+        }
+        
+        # 使用LangGraph API获取checkpoint
+        read_config = {"configurable": {"thread_id": thread_id}}
+        checkpoint_tuple = await _react_agent_instance.checkpointer.aget_tuple(read_config)
+        
+        if not checkpoint_tuple or not checkpoint_tuple.checkpoint:
+            from common.result import failed
+            return jsonify(failed(message="未找到执行线程", code=404)), 404
+        
+        # 获取checkpoint数据
+        checkpoint = checkpoint_tuple.checkpoint
+        channel_values = checkpoint.get("channel_values", {})
+        messages = channel_values.get("messages", [])
+        
+        if not messages:
+            from common.result import success
+            return jsonify(success(data={
+                "status": "running",
+                "name": "初始化中",
+                "icon": "🚀",
+                "timestamp": datetime.now().isoformat()
+            }, message="获取状态成功"))
+        
+        # 分析最后一条消息确定状态
+        last_message = messages[-1]
+        last_msg_type = last_message.get("type", "") if hasattr(last_message, 'get') else getattr(last_message, 'type', "")
+        
+        # 如果last_message是对象,需要转换为字典格式
+        if hasattr(last_message, '__dict__'):
+            last_message_dict = {
+                'type': getattr(last_message, 'type', ''),
+                'content': getattr(last_message, 'content', ''),
+                'tool_calls': getattr(last_message, 'tool_calls', []) if hasattr(last_message, 'tool_calls') else [],
+            }
+            # 如果有additional_kwargs,也包含进来
+            if hasattr(last_message, 'additional_kwargs'):
+                last_message_dict.update(last_message.additional_kwargs)
+        else:
+            last_message_dict = last_message
+        
+        # 判断执行状态
+        if (last_msg_type == "ai" and 
+            not last_message_dict.get("tool_calls", []) and
+            last_message_dict.get("content", "").strip()):
+            
+            from common.result import success
+            return jsonify(success(data={
+                "status": "completed",
+                "name": "完成",
+                "icon": "✅",
+                "timestamp": datetime.now().isoformat()
+            }, message="获取状态成功"))
+            
+        elif (last_msg_type == "ai" and 
+              last_message_dict.get("tool_calls", [])):
+            
+            tool_calls = last_message_dict.get("tool_calls", [])
+            tool_name = tool_calls[0].get("name", "") if tool_calls else ""
+            
+            tool_info = TOOL_STATUS_MAPPING.get(tool_name, {
+                "name": f"调用{tool_name}中" if tool_name else "调用工具中",
+                "icon": "🔧"
+            })
+            
+            from common.result import success
+            return jsonify(success(data={
+                "status": "running",
+                "name": tool_info["name"],
+                "icon": tool_info["icon"],
+                "timestamp": datetime.now().isoformat()
+            }, message="获取状态成功"))
+            
+        elif last_msg_type == "tool":
+            tool_name = last_message_dict.get("name", "")
+            tool_status = last_message_dict.get("status", "")
+            
+            if tool_status == "success":
+                tool_info = TOOL_STATUS_MAPPING.get(tool_name, {"name": "处理中", "icon": "🔄"})
+                from common.result import success
+                return jsonify(success(data={
+                    "status": "running", 
+                    "name": f"{tool_info['name'].replace('中', '')}完成,AI处理中",
+                    "icon": "🤖",
+                    "timestamp": datetime.now().isoformat()
+                }, message="获取状态成功"))
+            else:
+                tool_info = TOOL_STATUS_MAPPING.get(tool_name, {
+                    "name": f"执行{tool_name}中",
+                    "icon": "⚙️"
+                })
+                from common.result import success
+                return jsonify(success(data={
+                    "status": "running",
+                    "name": tool_info["name"], 
+                    "icon": tool_info["icon"],
+                    "timestamp": datetime.now().isoformat()
+                }, message="获取状态成功"))
+                
+        else:
+            from common.result import success
+            return jsonify(success(data={
+                "status": "running",
+                "name": "执行中",
+                "icon": "⚙️", 
+                "timestamp": datetime.now().isoformat()
+            }, message="获取状态成功"))
+        
+    except Exception as e:
+        from common.result import failed
+        logger.error(f"获取React Agent状态失败: {e}")
+        return jsonify(failed(message=f"获取状态失败: {str(e)}", code=500)), 500
+
+@app.route('/api/v0/react/direct/status/<thread_id>', methods=['GET'])
+async def get_react_agent_status_direct(thread_id: str):
+    """直接访问Redis获取React Agent执行状态,绕过Agent实例资源竞争"""
+    
+    try:
+        # 工具状态映射
+        TOOL_STATUS_MAPPING = {
+            "generate_sql": {"name": "生成SQL中", "icon": "🔍"},
+            "valid_sql": {"name": "验证SQL中", "icon": "✅"}, 
+            "run_sql": {"name": "执行查询中", "icon": "⚡"},
+        }
+        
+        # 创建独立的Redis连接,不使用Agent的连接
+        redis_client = redis.from_url("redis://localhost:6379", decode_responses=True)
+        
+        try:
+            # 1. 查找该thread_id的所有checkpoint键
+            pattern = f"checkpoint:{thread_id}:*"
+            keys = await redis_client.keys(pattern)
+            
+            if not keys:
+                from common.result import failed
+                return jsonify(failed(message="未找到执行线程", code=404)), 404
+            
+            # 2. 获取最新的checkpoint键
+            latest_key = sorted(keys)[-1]
+            
+            # 3. 检查Redis key的数据类型
+            key_type = await redis_client.type(latest_key)
+            logger.info(f"🔍 Redis key类型: {key_type}, key: {latest_key}")
+            
+            # 4. 根据数据类型获取checkpoint数据
+            if key_type == "string":
+                # 字符串类型,直接使用GET
+                raw_checkpoint_data = await redis_client.get(latest_key)
+                if raw_checkpoint_data:
+                    checkpoint = json.loads(raw_checkpoint_data)
+                else:
+                    from common.result import failed
+                    return jsonify(failed(message="无法读取checkpoint数据", code=500)), 500
+                    
+            elif key_type == "ReJSON-RL":
+                # RedisJSON类型,使用JSON.GET命令
+                try:
+                    # 使用execute_command执行JSON.GET
+                    checkpoint = await redis_client.execute_command("JSON.GET", latest_key)
+                    if checkpoint:
+                        # JSON.GET返回的是JSON字符串,需要解析
+                        if isinstance(checkpoint, str):
+                            checkpoint = json.loads(checkpoint)
+                        logger.info(f"✅ 成功从RedisJSON获取checkpoint数据")
+                    else:
+                        from common.result import failed
+                        return jsonify(failed(message="无法读取RedisJSON数据", code=500)), 500
+                except Exception as json_error:
+                    logger.error(f"❌ RedisJSON操作失败: {json_error}")
+                    from common.result import failed
+                    return jsonify(failed(message=f"RedisJSON操作失败: {str(json_error)}", code=500)), 500
+                    
+            elif key_type == "hash":
+                # Hash类型,使用HGETALL
+                hash_data = await redis_client.hgetall(latest_key)
+                logger.info(f"🔍 Hash数据字段: {list(hash_data.keys())}")
+                
+                # 尝试不同的字段名获取checkpoint
+                checkpoint_fields = ['checkpoint', 'data', 'value']
+                checkpoint_data = None
+                
+                for field in checkpoint_fields:
+                    if field in hash_data:
+                        checkpoint_data = hash_data[field]
+                        break
+                
+                if not checkpoint_data:
+                    # 如果没有找到标准字段,返回整个hash结构
+                    checkpoint = {"hash_data": hash_data}
+                else:
+                    try:
+                        checkpoint = json.loads(checkpoint_data)
+                    except json.JSONDecodeError:
+                        # 如果不是JSON,可能是其他格式
+                        checkpoint = {"raw_data": checkpoint_data}
+                        
+            elif key_type == "list":
+                # List类型,获取所有元素
+                list_data = await redis_client.lrange(latest_key, 0, -1)
+                logger.info(f"🔍 List数据长度: {len(list_data)}")
+                checkpoint = {"list_data": list_data}
+                
+            else:
+                from common.result import failed
+                return jsonify(failed(message=f"不支持的Redis数据类型: {key_type}", code=500)), 500
+            
+            # 5. 提取messages
+            messages = []
+            
+            # 根据不同的checkpoint结构提取messages
+            if "checkpoint" in checkpoint and "channel_values" in checkpoint["checkpoint"]:
+                # 标准checkpoint结构(与您的数据匹配)
+                messages = checkpoint["checkpoint"]["channel_values"].get("messages", [])
+                logger.info(f"✅ 从标准checkpoint结构提取到 {len(messages)} 条messages")
+            elif "channel_values" in checkpoint:
+                # 直接的channel_values结构
+                messages = checkpoint["channel_values"].get("messages", [])
+                logger.info(f"✅ 从直接channel_values结构提取到 {len(messages)} 条messages")
+            elif "hash_data" in checkpoint:
+                # Hash数据结构,尝试从不同字段提取
+                hash_data = checkpoint["hash_data"]
+                logger.info(f"🔍 Hash字段详情: {list(hash_data.keys())}")
+                
+                # 尝试解析可能包含messages的字段
+                for key, value in hash_data.items():
+                    try:
+                        parsed_data = json.loads(value)
+                        if isinstance(parsed_data, dict):
+                            if "channel_values" in parsed_data and "messages" in parsed_data["channel_values"]:
+                                messages = parsed_data["channel_values"]["messages"]
+                                logger.info(f"✅ 从Hash字段 {key} 提取到 {len(messages)} 条messages")
+                                break
+                            elif "messages" in parsed_data:
+                                messages = parsed_data["messages"]
+                                logger.info(f"✅ 从Hash字段 {key} 直接提取到 {len(messages)} 条messages")
+                                break
+                    except (json.JSONDecodeError, TypeError):
+                        continue
+                        
+            elif "list_data" in checkpoint:
+                # List数据结构
+                logger.info(f"🔍 List数据: {len(checkpoint['list_data'])} 个元素")
+                
+            # 如果无法提取messages或为空,返回初始化状态
+            if not messages:
+                logger.warning(f"⚠️ 无法从checkpoint中提取messages,checkpoint结构: {list(checkpoint.keys())}")
+                status_data = {
+                    "status": "running", 
+                    "name": "初始化中",
+                    "icon": "🚀",
+                    "timestamp": datetime.now().isoformat(),
+                    "debug_info": {
+                        "key_type": key_type,
+                        "checkpoint_keys": list(checkpoint.keys()),
+                        "has_checkpoint": "checkpoint" in checkpoint,
+                        "has_channel_values": "channel_values" in checkpoint.get("checkpoint", {})
+                    }
+                }
+                from common.result import success
+                return jsonify(success(data=status_data, message="获取状态成功")), 200
+            
+            # 6. 分析最后一条消息
+            last_message = messages[-1]
+            last_msg_type = last_message.get("kwargs", {}).get("type", "")
+            
+            # 7. 判断执行状态
+            if (last_msg_type == "ai" and 
+                not last_message.get("kwargs", {}).get("tool_calls", []) and
+                last_message.get("kwargs", {}).get("content", "").strip()):
+                
+                # 完成状态:AIMessage有完整回答且无tool_calls
+                status_data = {
+                    "status": "completed",
+                    "name": "完成",
+                    "icon": "✅",
+                    "timestamp": datetime.now().isoformat()
+                }
+                
+            elif (last_msg_type == "ai" and 
+                  last_message.get("kwargs", {}).get("tool_calls", [])):
+                
+                # AI正在调用工具
+                tool_calls = last_message.get("kwargs", {}).get("tool_calls", [])
+                tool_name = tool_calls[0].get("name", "") if tool_calls else ""
+                
+                tool_info = TOOL_STATUS_MAPPING.get(tool_name, {
+                    "name": f"调用{tool_name}中" if tool_name else "调用工具中",
+                    "icon": "🔧"
+                })
+                
+                status_data = {
+                    "status": "running",
+                    "name": tool_info["name"],
+                    "icon": tool_info["icon"],
+                    "timestamp": datetime.now().isoformat()
+                }
+                
+            elif last_msg_type == "tool":
+                # 工具执行完成,等待AI处理
+                tool_name = last_message.get("kwargs", {}).get("name", "")
+                tool_status = last_message.get("kwargs", {}).get("status", "")
+                
+                if tool_status == "success":
+                    tool_info = TOOL_STATUS_MAPPING.get(tool_name, {"name": "处理中", "icon": "🔄"})
+                    status_data = {
+                        "status": "running", 
+                        "name": f"{tool_info['name'].replace('中', '')}完成,AI处理中",
+                        "icon": "🤖",
+                        "timestamp": datetime.now().isoformat()
+                    }
+                else:
+                    tool_info = TOOL_STATUS_MAPPING.get(tool_name, {
+                        "name": f"执行{tool_name}中",
+                        "icon": "⚙️"
+                    })
+                    status_data = {
+                        "status": "running",
+                        "name": tool_info["name"], 
+                        "icon": tool_info["icon"],
+                        "timestamp": datetime.now().isoformat()
+                    }
+                    
+            elif last_msg_type == "human":
+                # 用户刚提问,AI开始思考
+                status_data = {
+                    "status": "running",
+                    "name": "AI思考中",
+                    "icon": "🤖",
+                    "timestamp": datetime.now().isoformat()
+                }
+                
+            else:
+                # 默认执行中状态
+                status_data = {
+                    "status": "running",
+                    "name": "执行中",
+                    "icon": "⚙️", 
+                    "timestamp": datetime.now().isoformat()
+                }
+            
+            from common.result import success
+            return jsonify(success(data=status_data, message="获取状态成功")), 200
+            
+        finally:
+            await redis_client.aclose()
+            
+    except Exception as e:
+        logger.error(f"获取React Agent状态失败: {e}")
+        from common.result import failed
+        return jsonify(failed(message=f"获取状态失败: {str(e)}", code=500)), 500
+
 # ==================== LangGraph Agent API ====================
 
 # 全局Agent实例(单例模式)
@@ -4727,31 +5094,48 @@ if __name__ == '__main__':
     logger.info("📥 Vector恢复API: http://localhost:8084/api/v0/data_pipeline/vector/restore")
     logger.info("📋 备份列表API: http://localhost:8084/api/v0/data_pipeline/vector/restore/list")
     
-    try:
-        # 尝试使用ASGI模式启动(推荐)
-        import uvicorn
-        from asgiref.wsgi import WsgiToAsgi
-        
-        logger.info("🚀 使用ASGI模式启动异步Flask应用...")
-        logger.info("   这将解决事件循环冲突问题,支持LangGraph异步checkpoint保存")
-        
-        # 将Flask WSGI应用转换为ASGI应用
-        asgi_app = WsgiToAsgi(app)
-        
-        # 使用uvicorn启动ASGI应用
-        uvicorn.run(
-            asgi_app,
-            host="0.0.0.0",
-            port=8084,
-            log_level="info",
-            access_log=True
-        )
-        
-    except ImportError as e:
-        # 如果缺少ASGI依赖,fallback到传统Flask模式
-        logger.warning("⚠️ ASGI依赖缺失,使用传统Flask模式启动")
-        logger.warning("   建议安装: pip install uvicorn asgiref")
-        logger.warning("   传统模式可能存在异步事件循环冲突问题")
+    # 并发问题解决方案:已确认WsgiToAsgi导致阻塞,使用原生Flask解决
+    USE_WSGI_TO_ASGI = False  # 使用原生Flask并发,解决状态API阻塞问题
+    
+    if USE_WSGI_TO_ASGI:
+        try:
+            # 方案A:使用ASGI模式启动(可能有并发限制)
+            import uvicorn
+            from asgiref.wsgi import WsgiToAsgi
+            
+            logger.info("🚀 使用ASGI模式启动异步Flask应用...")
+            logger.info("   这将解决事件循环冲突问题,支持LangGraph异步checkpoint保存")
+            
+            # 将Flask WSGI应用转换为ASGI应用
+            asgi_app = WsgiToAsgi(app)
+            
+            # 使用uvicorn启动ASGI应用,增加并发配置
+            uvicorn.run(
+                asgi_app,
+                host="0.0.0.0",
+                port=8084,
+                log_level="info",
+                access_log=True,
+                workers=1,  # 单进程多协程
+                loop="asyncio",  # 使用asyncio事件循环
+                limit_concurrency=100,  # 增加并发限制
+                limit_max_requests=1000  # 增加请求限制
+            )
+            
+        except ImportError as e:
+            # 如果缺少ASGI依赖,fallback到传统Flask模式
+            logger.warning("⚠️ ASGI依赖缺失,使用传统Flask模式启动")
+            logger.warning("   建议安装: pip install uvicorn asgiref")
+            logger.warning("   传统模式可能存在异步事件循环冲突问题")
+            
+            # 启动标准Flask应用(支持异步路由)
+            app.run(host="0.0.0.0", port=8084, debug=False, threaded=True)
+    
+    else:
+        # 方案B:使用原生Flask并发(可能解决WsgiToAsgi并发问题)
+        logger.info("🚀 使用原生Flask并发模式启动...")
+        logger.info("   绕过WsgiToAsgi,测试是否解决并发阻塞问题")
+        logger.info("   使用Flask内置多线程并发支持")
         
         # 启动标准Flask应用(支持异步路由)
         app.run(host="0.0.0.0", port=8084, debug=False, threaded=True)