Ver Fonte

简化unified_api.py的main函数,使用uvicorn启动使用asgi_app.py.

wangxq há 2 semanas atrás
pai
commit
b39d23bd36
3 ficheiros alterados com 12 adições e 429 exclusões
  1. 5 3
      asgi_app.py
  2. 0 382
      react_agent/sync_agent.py
  3. 7 44
      unified_api.py

+ 5 - 3
asgi_app.py

@@ -13,6 +13,8 @@ from unified_api import app
 asgi_app = WsgiToAsgi(app)
 
 # 启动方式示例:
-# uvicorn asgi_app:asgi_app --host 0.0.0.0 --port 8084
-# 或者带重载:
-# uvicorn asgi_app:asgi_app --host 0.0.0.0 --port 8084 --reload 
+# 开发环境(单进程 + 重载):
+# uvicorn asgi_app:asgi_app --host 127.0.0.1 --port 8084 --reload
+
+# 生产环境(多进程 + 性能优化):
+# uvicorn asgi_app:asgi_app --host 0.0.0.0 --port 8084 --workers 4 --limit-concurrency 100 --limit-max-requests 1000 --access-log

+ 0 - 382
react_agent/sync_agent.py

@@ -1,382 +0,0 @@
-"""
-同步版本的React Agent - 解决Vector搜索异步冲突问题
-基于原有CustomReactAgent,但使用完全同步的实现
-"""
-import json
-import sys
-import os
-from pathlib import Path
-from typing import List, Optional, Dict, Any
-import redis
-
-# 添加项目根目录到sys.path
-try:
-    project_root = Path(__file__).parent.parent
-    if str(project_root) not in sys.path:
-        sys.path.insert(0, str(project_root))
-except Exception as e:
-    pass
-
-from core.logging import get_react_agent_logger
-from langchain_openai import ChatOpenAI
-from langchain_core.messages import HumanMessage, ToolMessage, BaseMessage, SystemMessage, AIMessage
-from langgraph.graph import StateGraph, END
-from langgraph.prebuilt import ToolNode
-
-# 导入同步版本的依赖
-try:
-    from . import config
-    from .state import AgentState
-    from .sql_tools import sql_tools
-except ImportError:
-    import config
-    from state import AgentState
-    from sql_tools import sql_tools
-
-logger = get_react_agent_logger("SyncCustomReactAgent")
-
-class SyncCustomReactAgent:
-    """
-    同步版本的React Agent
-    专门解决Vector搜索的异步事件循环冲突问题
-    """
-    
-    def __init__(self):
-        """私有构造函数,请使用 create() 类方法来创建实例。"""
-        self.llm = None
-        self.tools = None
-        self.agent_executor = None
-        self.checkpointer = None
-        self.redis_client = None
-
-    @classmethod
-    def create(cls):
-        """同步工厂方法,创建并初始化 SyncCustomReactAgent 实例。"""
-        instance = cls()
-        instance._sync_init()
-        return instance
-
-    def _sync_init(self):
-        """同步初始化所有组件。"""
-        logger.info("🚀 开始初始化 SyncCustomReactAgent...")
-
-        # 1. 初始化同步Redis客户端(如果需要)
-        try:
-            self.redis_client = redis.from_url(config.REDIS_URL, decode_responses=True)
-            self.redis_client.ping()
-            logger.info(f"   ✅ Redis连接成功: {config.REDIS_URL}")
-        except Exception as e:
-            logger.warning(f"   ⚠️ Redis连接失败,将不使用checkpointer: {e}")
-            self.redis_client = None
-
-        # 2. 初始化 LLM(同步版本)
-        self.llm = ChatOpenAI(
-            api_key=config.QWEN_API_KEY,
-            base_url=config.QWEN_BASE_URL,
-            model=config.QWEN_MODEL,
-            temperature=0.1,
-            timeout=config.NETWORK_TIMEOUT,
-            max_retries=0,
-            streaming=False,  # 关键:禁用流式处理
-            extra_body={
-                "enable_thinking": False,  # 明确设置为False:非流式调用必须设为false
-                "misc": {
-                    "ensure_ascii": False
-                }
-            }
-        )
-        logger.info(f"   ✅ 同步LLM已初始化,模型: {config.QWEN_MODEL}")
-
-        # 3. 绑定工具
-        self.tools = sql_tools
-        self.llm_with_tools = self.llm.bind_tools(self.tools)
-        logger.info(f"   ✅ 已绑定 {len(self.tools)} 个工具")
-
-        # 4. 创建StateGraph(不使用checkpointer避免异步依赖)
-        self.agent_executor = self._create_sync_graph()
-        logger.info("   ✅ 同步StateGraph已创建")
-
-        logger.info("✅ SyncCustomReactAgent 初始化完成")
-
-    def _create_sync_graph(self):
-        """创建同步的StateGraph"""
-        graph = StateGraph(AgentState)
-        
-        # 添加同步节点
-        graph.add_node("agent", self._sync_agent_node)
-        graph.add_node("tools", ToolNode(self.tools))
-        graph.add_node("prepare_tool_input", self._sync_prepare_tool_input_node)
-        graph.add_node("update_state_after_tool", self._sync_update_state_after_tool_node)
-        graph.add_node("format_final_response", self._sync_format_final_response_node)
-
-        # 设置入口点
-        graph.set_entry_point("agent")
-
-        # 添加条件边
-        graph.add_conditional_edges(
-            "agent",
-            self._sync_should_continue,
-            {
-                "tools": "prepare_tool_input",
-                "end": "format_final_response"
-            }
-        )
-
-        # 添加普通边
-        graph.add_edge("prepare_tool_input", "tools")
-        graph.add_edge("tools", "update_state_after_tool")
-        graph.add_edge("update_state_after_tool", "agent")
-        graph.add_edge("format_final_response", END)
-
-        # 关键:使用同步编译,不传入checkpointer
-        return graph.compile()
-
-    def _sync_agent_node(self, state: AgentState) -> Dict[str, Any]:
-        """同步Agent节点"""
-        logger.info(f"🧠 [Sync Node] agent - Thread: {state.get('thread_id', 'unknown')}")
-        
-        messages_for_llm = state["messages"].copy()
-        
-        # 添加数据库范围提示词
-        if isinstance(state["messages"][-1], HumanMessage):
-            db_scope_prompt = self._get_database_scope_prompt()
-            if db_scope_prompt:
-                messages_for_llm.insert(0, SystemMessage(content=db_scope_prompt))
-                logger.info("   ✅ 已添加数据库范围判断提示词")
-
-        # 同步LLM调用
-        response = self.llm_with_tools.invoke(messages_for_llm)
-        
-        return {"messages": [response]}
-
-    def _sync_should_continue(self, state: AgentState):
-        """同步条件判断"""
-        messages = state["messages"]
-        last_message = messages[-1]
-        
-        if not last_message.tool_calls:
-            return "end"
-        else:
-            return "tools"
-
-    def _sync_prepare_tool_input_node(self, state: AgentState) -> Dict[str, Any]:
-        """同步准备工具输入节点"""
-        logger.info(f"🔧 [Sync Node] prepare_tool_input - Thread: {state.get('thread_id', 'unknown')}")
-        
-        last_message = state["messages"][-1]
-        
-        if hasattr(last_message, 'tool_calls') and last_message.tool_calls:
-            for tool_call in last_message.tool_calls:
-                if tool_call.get('name') == 'generate_sql':
-                    # 注入历史消息
-                    history_messages = self._filter_and_format_history(state["messages"])
-                    if 'args' not in tool_call:
-                        tool_call['args'] = {}
-                    tool_call['args']['history_messages'] = history_messages
-                    logger.info(f"   ✅ 为generate_sql注入了 {len(history_messages)} 条历史消息")
-
-        return {"messages": [last_message]}
-
-    def _sync_update_state_after_tool_node(self, state: AgentState) -> Dict[str, Any]:
-        """同步更新工具执行后的状态"""
-        logger.info(f"📝 [Sync Node] update_state_after_tool - Thread: {state.get('thread_id', 'unknown')}")
-        
-        last_message = state["messages"][-1]
-        tool_name = last_message.name
-        tool_output = last_message.content
-        next_step = None
-
-        if tool_name == 'generate_sql':
-            tool_output_lower = tool_output.lower()
-            if "failed" in tool_output_lower or "无法生成" in tool_output_lower or "失败" in tool_output_lower:
-                next_step = 'answer_with_common_sense'
-            else:
-                next_step = 'valid_sql'
-        elif tool_name == 'valid_sql':
-            if "失败" in tool_output:
-                next_step = 'analyze_validation_error'
-            else:
-                next_step = 'run_sql'
-        elif tool_name == 'run_sql':
-            next_step = 'summarize_final_answer'
-            
-        logger.info(f"   Tool '{tool_name}' executed. Suggested next step: {next_step}")
-        return {"suggested_next_step": next_step}
-
-    def _sync_format_final_response_node(self, state: AgentState) -> Dict[str, Any]:
-        """同步格式化最终响应节点"""
-        logger.info(f"📄 [Sync Node] format_final_response - Thread: {state.get('thread_id', 'unknown')}")
-        
-        messages = state["messages"]
-        last_message = messages[-1]
-        
-        # 构建最终响应
-        final_response = last_message.content
-        
-        logger.info(f"   ✅ 最终响应已准备完成")
-        return {"final_answer": final_response}
-
-    def _filter_and_format_history(self, messages: list) -> list:
-        """过滤和格式化历史消息"""
-        clean_history = []
-        for msg in messages[:-1]:  # 排除最后一条消息
-            if isinstance(msg, HumanMessage):
-                clean_history.append({"type": "human", "content": msg.content})
-            elif isinstance(msg, AIMessage):
-                clean_content = msg.content if not hasattr(msg, 'tool_calls') or not msg.tool_calls else ""
-                if clean_content.strip():
-                    clean_history.append({"type": "ai", "content": clean_content})
-        
-        return clean_history
-
-    def _get_database_scope_prompt(self) -> str:
-        """获取数据库范围判断提示词"""
-        return """你是一个专门处理高速公路收费数据查询的AI助手。在回答用户问题时,请首先判断这个问题是否可以通过查询数据库来回答。
-
-数据库包含以下类型的数据:
-- 服务区信息(名称、位置、档口数量等)
-- 收费站数据
-- 车流量统计
-- 业务数据分析
-
-如果用户的问题与这些数据相关,请使用工具生成SQL查询。
-如果问题与数据库内容无关(如常识性问题、天气、新闻等),请直接用你的知识回答,不要尝试生成SQL。"""
-
-    def chat(self, message: str, user_id: str, thread_id: Optional[str] = None) -> Dict[str, Any]:
-        """
-        同步聊天方法 - 关键:使用 graph.invoke() 而不是 ainvoke()
-        """
-        if thread_id is None:
-            import uuid
-            thread_id = str(uuid.uuid4())
-
-        # 构建输入
-        inputs = {
-            "messages": [HumanMessage(content=message)],
-            "user_id": user_id,
-            "thread_id": thread_id,
-            "suggested_next_step": None
-        }
-
-        # 构建运行配置(不使用checkpointer)
-        run_config = {
-            "recursion_limit": config.RECURSION_LIMIT,
-        }
-
-        try:
-            logger.info(f"🚀 开始同步处理用户消息: {message[:50]}...")
-            
-            # 关键:使用同步的 invoke() 方法
-            final_state = self.agent_executor.invoke(inputs, run_config)
-            
-            logger.info(f"🔍 Final state keys: {list(final_state.keys())}")
-            
-            # 提取答案
-            if final_state["messages"]:
-                answer = final_state["messages"][-1].content
-            else:
-                answer = "抱歉,无法处理您的请求。"
-            
-            # 提取SQL数据(如果有)
-            sql_data = self._extract_latest_sql_data(final_state["messages"])
-            
-            logger.info(f"✅ 同步处理完成 - Final Answer: '{answer[:100]}...'")
-            
-            # 构建返回结果
-            result = {
-                "success": True, 
-                "answer": answer, 
-                "thread_id": thread_id
-            }
-            
-            # 只有当存在SQL数据时才添加到返回结果中
-            if sql_data:
-                try:
-                    # 尝试解析SQL数据
-                    sql_parsed = json.loads(sql_data)
-                    
-                    # 检查数据格式:run_sql工具返回的是数组格式 [{"col1":"val1"}]
-                    if isinstance(sql_parsed, list):
-                        # 数组格式:直接作为records使用
-                        result["api_data"] = {
-                            "response": answer,
-                            "records": sql_parsed,
-                            "react_agent_meta": {
-                                "thread_id": thread_id,
-                                "agent_version": "sync_react_v1"
-                            }
-                        }
-                    elif isinstance(sql_parsed, dict):
-                        # 字典格式:按原逻辑处理
-                        result["api_data"] = {
-                            "response": answer,
-                            "sql": sql_parsed.get("sql", ""),
-                            "records": sql_parsed.get("records", []),
-                            "react_agent_meta": {
-                                "thread_id": thread_id,
-                                "agent_version": "sync_react_v1"
-                            }
-                        }
-                    else:
-                        logger.warning(f"SQL数据格式未知: {type(sql_parsed)}")
-                        raise ValueError("Unknown SQL data format")
-                        
-                except (json.JSONDecodeError, AttributeError, ValueError) as e:
-                    logger.warning(f"SQL数据格式处理失败: {str(e)}, 跳过API数据构建")
-            else:
-                result["api_data"] = {
-                    "response": answer,
-                    "react_agent_meta": {
-                        "thread_id": thread_id,
-                        "agent_version": "sync_react_v1"
-                    }
-                }
-            
-            return result
-            
-        except Exception as e:
-            logger.error(f"❌ 同步处理失败: {str(e)}", exc_info=True)
-            return {
-                "success": False,
-                "error": f"同步处理失败: {str(e)}",
-                "thread_id": thread_id,
-                "retry_suggested": True
-            }
-
-    def _extract_latest_sql_data(self, messages: List[BaseMessage]) -> Optional[str]:
-        """从消息历史中提取最近的run_sql执行结果(同步版本)"""
-        logger.info("🔍 提取最新的SQL执行结果...")
-        
-        # 查找最后一个HumanMessage之后的SQL执行结果
-        last_human_index = -1
-        for i in range(len(messages) - 1, -1, -1):
-            if isinstance(messages[i], HumanMessage):
-                last_human_index = i
-                break
-        
-        if last_human_index == -1:
-            logger.info("   未找到用户消息,跳过SQL数据提取")
-            return None
-        
-        # 只在当前对话轮次中查找SQL结果
-        current_conversation = messages[last_human_index:]
-        logger.info(f"   当前对话轮次包含 {len(current_conversation)} 条消息")
-        
-        for msg in reversed(current_conversation):
-            if isinstance(msg, ToolMessage) and msg.name == 'run_sql':
-                logger.info(f"   找到当前对话轮次的run_sql结果: {msg.content[:100]}...")
-                
-                try:
-                    # 尝试解析JSON以验证格式
-                    parsed_data = json.loads(msg.content)
-                    # 重新序列化,确保中文字符正常显示
-                    formatted_content = json.dumps(parsed_data, ensure_ascii=False, separators=(',', ':'))
-                    logger.info(f"   已转换Unicode转义序列为中文字符")
-                    return formatted_content
-                except json.JSONDecodeError:
-                    # 如果不是有效JSON,直接返回原内容
-                    logger.warning(f"   SQL结果不是有效JSON格式,返回原始内容")
-                    return msg.content
-        
-        logger.info("   当前对话轮次中未找到run_sql执行结果")
-        return None

+ 7 - 44
unified_api.py

@@ -6092,48 +6092,11 @@ if __name__ == '__main__':
     logger.info("📥 Vector恢复API: http://localhost:8084/api/v0/data_pipeline/vector/restore")
     logger.info("📋 备份列表API: http://localhost:8084/api/v0/data_pipeline/vector/restore/list")
     
-    # 并发问题解决方案:已确认WsgiToAsgi导致阻塞,使用原生Flask解决
-    USE_WSGI_TO_ASGI = False  # 暂时回到原生Flask模式,解决ASGI兼容性问题
+    # 原生Flask单进程模式启动
+    # 如需多进程ASGI模式,请使用:uvicorn asgi_app:asgi_app --workers 4
+    logger.info("🚀 使用原生Flask单进程模式启动...")
+    logger.info("   优点:避免WsgiToAsgi并发阻塞问题")
+    logger.info("   多进程模式请使用:uvicorn asgi_app:asgi_app --workers 4")
     
-    if USE_WSGI_TO_ASGI:
-        try:
-            # 方案A:使用ASGI模式启动(可能有并发限制)
-            import uvicorn
-            from asgiref.wsgi import WsgiToAsgi
-            
-            logger.info("🚀 使用ASGI模式启动异步Flask应用...")
-            logger.info("   这将解决事件循环冲突问题,支持LangGraph异步checkpoint保存")
-            
-            # 将Flask WSGI应用转换为ASGI应用
-            asgi_app = WsgiToAsgi(app)
-            
-            # 使用uvicorn启动ASGI应用,增加并发配置
-            uvicorn.run(
-                asgi_app,
-                host="0.0.0.0",
-                port=8084,
-                log_level="info",
-                access_log=True,
-                workers=1,  # 单进程多协程
-                loop="asyncio",  # 使用asyncio事件循环
-                limit_concurrency=100,  # 增加并发限制
-                limit_max_requests=1000  # 增加请求限制
-            )
-            
-        except ImportError as e:
-            # 如果缺少ASGI依赖,fallback到传统Flask模式
-            logger.warning("⚠️ ASGI依赖缺失,使用传统Flask模式启动")
-            logger.warning("   建议安装: pip install uvicorn asgiref")
-            logger.warning("   传统模式可能存在异步事件循环冲突问题")
-            
-            # 启动标准Flask应用(支持异步路由)
-            app.run(host="0.0.0.0", port=8084, debug=False, threaded=True)
-    
-    else:
-        # 方案B:使用原生Flask并发(可能解决WsgiToAsgi并发问题)
-        logger.info("🚀 使用原生Flask并发模式启动...")
-        logger.info("   绕过WsgiToAsgi,测试是否解决并发阻塞问题")
-        logger.info("   使用Flask内置多线程并发支持")
-        
-        # 启动标准Flask应用(支持异步路由)
-        app.run(host="0.0.0.0", port=8084, debug=False, threaded=True)
+    # 启动标准Flask应用(支持异步路由)
+    app.run(host="0.0.0.0", port=8084, debug=False, threaded=True)