Procházet zdrojové kódy

增加通过查询参数动态修改访问模式的功能,在返回结果中将summary合并到response.

wangxq před 1 týdnem
rodič
revize
b8118592dc
3 změnil soubory, kde provedl 97 přidání a 41 odebrání
  1. 52 31
      agent/citu_agent.py
  2. 13 7
      agent/classifier.py
  3. 32 3
      citu_app.py

+ 52 - 31
agent/citu_agent.py

@@ -31,17 +31,23 @@ class CituLangGraphAgent:
         # 注意:现在使用直接工具调用模式,不再需要预创建Agent执行器
         print("[CITU_AGENT] 使用直接工具调用模式")
         
-        self.workflow = self._create_workflow()
+        # 不在构造时创建workflow,改为动态创建以支持路由模式参数
+        # self.workflow = self._create_workflow()
         print("[CITU_AGENT] LangGraph Agent with Direct Tools初始化完成")
     
-    def _create_workflow(self) -> StateGraph:
+    def _create_workflow(self, routing_mode: str = None) -> StateGraph:
         """根据路由模式创建不同的工作流"""
-        try:
-            from app_config import QUESTION_ROUTING_MODE
-            print(f"[CITU_AGENT] 创建工作流,路由模式: {QUESTION_ROUTING_MODE}")
-        except ImportError:
-            QUESTION_ROUTING_MODE = "hybrid"
-            print(f"[CITU_AGENT] 配置导入失败,使用默认路由模式: {QUESTION_ROUTING_MODE}")
+        # 确定使用的路由模式
+        if routing_mode:
+            QUESTION_ROUTING_MODE = routing_mode
+            print(f"[CITU_AGENT] 创建工作流,使用传入的路由模式: {QUESTION_ROUTING_MODE}")
+        else:
+            try:
+                from app_config import QUESTION_ROUTING_MODE
+                print(f"[CITU_AGENT] 创建工作流,使用配置文件路由模式: {QUESTION_ROUTING_MODE}")
+            except ImportError:
+                QUESTION_ROUTING_MODE = "hybrid"
+                print(f"[CITU_AGENT] 配置导入失败,使用默认路由模式: {QUESTION_ROUTING_MODE}")
         
         workflow = StateGraph(AgentState)
         
@@ -96,14 +102,15 @@ class CituLangGraphAgent:
     def _init_direct_database_node(self, state: AgentState) -> AgentState:
         """初始化直接数据库模式的状态"""
         try:
-            from app_config import QUESTION_ROUTING_MODE
+            # 从state中获取路由模式,而不是从配置文件读取
+            routing_mode = state.get("routing_mode", "database_direct")
             
             # 设置直接数据库模式的分类状态
             state["question_type"] = "DATABASE"
             state["classification_confidence"] = 1.0
             state["classification_reason"] = "配置为直接数据库查询模式"
             state["classification_method"] = "direct_database"
-            state["routing_mode"] = QUESTION_ROUTING_MODE
+            state["routing_mode"] = routing_mode
             state["current_step"] = "direct_database_init"
             state["execution_path"].append("init_direct_database")
             
@@ -121,14 +128,15 @@ class CituLangGraphAgent:
     def _init_direct_chat_node(self, state: AgentState) -> AgentState:
         """初始化直接聊天模式的状态"""
         try:
-            from app_config import QUESTION_ROUTING_MODE
+            # 从state中获取路由模式,而不是从配置文件读取
+            routing_mode = state.get("routing_mode", "chat_direct")
             
             # 设置直接聊天模式的分类状态
             state["question_type"] = "CHAT"
             state["classification_confidence"] = 1.0
             state["classification_reason"] = "配置为直接聊天模式"
             state["classification_method"] = "direct_chat"
-            state["routing_mode"] = QUESTION_ROUTING_MODE
+            state["routing_mode"] = routing_mode
             state["current_step"] = "direct_chat_init"
             state["execution_path"].append("init_direct_chat")
             
@@ -146,7 +154,8 @@ class CituLangGraphAgent:
     def _classify_question_node(self, state: AgentState) -> AgentState:
         """问题分类节点 - 支持渐进式分类策略"""
         try:
-            from app_config import QUESTION_ROUTING_MODE
+            # 从state中获取路由模式,而不是从配置文件读取
+            routing_mode = state.get("routing_mode", "hybrid")
             
             print(f"[CLASSIFY_NODE] 开始分类问题: {state['question']}")
             
@@ -155,20 +164,20 @@ class CituLangGraphAgent:
             if context_type:
                 print(f"[CLASSIFY_NODE] 检测到上下文类型: {context_type}")
             
-            # 使用渐进式分类策略
-            classification_result = self.classifier.classify(state["question"], context_type)
+            # 使用渐进式分类策略,传递路由模式
+            classification_result = self.classifier.classify(state["question"], context_type, routing_mode)
             
             # 更新状态
             state["question_type"] = classification_result.question_type
             state["classification_confidence"] = classification_result.confidence
             state["classification_reason"] = classification_result.reason
             state["classification_method"] = classification_result.method
-            state["routing_mode"] = QUESTION_ROUTING_MODE
+            state["routing_mode"] = routing_mode
             state["current_step"] = "classified"
             state["execution_path"].append("classify")
             
             print(f"[CLASSIFY_NODE] 分类结果: {classification_result.question_type}, 置信度: {classification_result.confidence}")
-            print(f"[CLASSIFY_NODE] 路由模式: {QUESTION_ROUTING_MODE}, 分类方法: {classification_result.method}")
+            print(f"[CLASSIFY_NODE] 路由模式: {routing_mode}, 分类方法: {classification_result.method}")
             
             return state
             
@@ -370,13 +379,14 @@ class CituLangGraphAgent:
                     }
                 elif state.get("summary"):
                     # 正常的数据库查询结果,有摘要的情况
-                    # 不将summary复制到response,让response保持为空
+                    # 将summary的值同时赋给response字段(为将来移除summary字段做准备)
                     state["final_response"] = {
                         "success": True,
                         "type": "DATABASE",
+                        "response": state["summary"],  # 新增:将summary的值赋给response
                         "sql": state.get("sql"),
                         "query_result": state.get("query_result"),  # 获取query_result字段
-                        "summary": state["summary"],
+                        "summary": state["summary"],  # 暂时保留summary字段
                         "execution_path": state["execution_path"],
                         "classification_info": {
                             "confidence": state["classification_confidence"],
@@ -462,7 +472,7 @@ class CituLangGraphAgent:
             # 聊天Agent可以处理不确定的情况,并在必要时引导用户提供更多信息
             return "CHAT"
     
-    def process_question(self, question: str, session_id: str = None, context_type: str = None) -> Dict[str, Any]:
+    def process_question(self, question: str, session_id: str = None, context_type: str = None, routing_mode: str = None) -> Dict[str, Any]:
         """
         统一的问题处理入口
         
@@ -470,6 +480,7 @@ class CituLangGraphAgent:
             question: 用户问题
             session_id: 会话ID
             context_type: 上下文类型 ("DATABASE" 或 "CHAT"),用于渐进式分类
+            routing_mode: 路由模式,可选,用于覆盖配置文件设置
             
         Returns:
             Dict包含完整的处理结果
@@ -478,12 +489,17 @@ class CituLangGraphAgent:
             print(f"[CITU_AGENT] 开始处理问题: {question}")
             if context_type:
                 print(f"[CITU_AGENT] 上下文类型: {context_type}")
+            if routing_mode:
+                print(f"[CITU_AGENT] 使用指定路由模式: {routing_mode}")
+            
+            # 动态创建workflow(基于路由模式)
+            workflow = self._create_workflow(routing_mode)
             
             # 初始化状态
-            initial_state = self._create_initial_state(question, session_id, context_type)
+            initial_state = self._create_initial_state(question, session_id, context_type, routing_mode)
             
             # 执行工作流
-            final_state = self.workflow.invoke(
+            final_state = workflow.invoke(
                 initial_state,
                 config={
                     "configurable": {"session_id": session_id}
@@ -506,12 +522,17 @@ class CituLangGraphAgent:
                 "execution_path": ["error"]
             }
     
-    def _create_initial_state(self, question: str, session_id: str = None, context_type: str = None) -> AgentState:
+    def _create_initial_state(self, question: str, session_id: str = None, context_type: str = None, routing_mode: str = None) -> AgentState:
         """创建初始状态 - 支持渐进式分类"""
-        try:
-            from app_config import QUESTION_ROUTING_MODE
-        except ImportError:
-            QUESTION_ROUTING_MODE = "hybrid"
+        # 确定使用的路由模式
+        if routing_mode:
+            effective_routing_mode = routing_mode
+        else:
+            try:
+                from app_config import QUESTION_ROUTING_MODE
+                effective_routing_mode = QUESTION_ROUTING_MODE
+            except ImportError:
+                effective_routing_mode = "hybrid"
         
         return AgentState(
             # 输入信息
@@ -553,7 +574,7 @@ class CituLangGraphAgent:
             debug_info={},
             
             # 路由模式
-            routing_mode=QUESTION_ROUTING_MODE
+            routing_mode=effective_routing_mode
         )
     
     def _extract_original_question(self, question: str) -> str:
@@ -597,7 +618,7 @@ class CituLangGraphAgent:
                 return {
                     "status": "healthy" if test_result.get("success") else "degraded",
                     "test_result": test_result.get("success", False),
-                    "workflow_compiled": self.workflow is not None,
+                    "workflow_compiled": True,  # 动态创建,始终可用
                     "tools_count": len(self.tools),
                     "agent_reuse_enabled": False,
                     "message": "Agent健康检查完成"
@@ -607,7 +628,7 @@ class CituLangGraphAgent:
                 return {
                     "status": "healthy",
                     "test_result": True,
-                    "workflow_compiled": self.workflow is not None,
+                    "workflow_compiled": True,  # 动态创建,始终可用
                     "tools_count": len(self.tools),
                     "agent_reuse_enabled": False,
                     "message": "Agent简单健康检查完成"
@@ -617,7 +638,7 @@ class CituLangGraphAgent:
             return {
                 "status": "unhealthy",
                 "error": str(e),
-                "workflow_compiled": self.workflow is not None,
+                "workflow_compiled": True,  # 动态创建,始终可用
                 "tools_count": len(self.tools) if hasattr(self, 'tools') else 0,
                 "agent_reuse_enabled": False,
                 "message": "Agent健康检查失败"

+ 13 - 7
agent/classifier.py

@@ -133,20 +133,26 @@ class QuestionClassifier:
             "平台", "系统", "AI", "助手", "谢谢", "再见"
         ]
 
-    def classify(self, question: str, context_type: Optional[str] = None) -> ClassificationResult:
+    def classify(self, question: str, context_type: Optional[str] = None, routing_mode: Optional[str] = None) -> ClassificationResult:
         """
         主分类方法:支持渐进式分类策略
         
         Args:
             question: 当前问题
             context_type: 上下文类型 ("DATABASE" 或 "CHAT"),可选
+            routing_mode: 路由模式,可选,用于覆盖配置文件设置
         """
-        try:
-            from app_config import QUESTION_ROUTING_MODE
-            print(f"[CLASSIFIER] 使用路由模式: {QUESTION_ROUTING_MODE}")
-        except ImportError:
-            QUESTION_ROUTING_MODE = "hybrid"
-            print(f"[CLASSIFIER] 配置导入失败,使用默认路由模式: {QUESTION_ROUTING_MODE}")
+        # 确定使用的路由模式
+        if routing_mode:
+            QUESTION_ROUTING_MODE = routing_mode
+            print(f"[CLASSIFIER] 使用传入的路由模式: {QUESTION_ROUTING_MODE}")
+        else:
+            try:
+                from app_config import QUESTION_ROUTING_MODE
+                print(f"[CLASSIFIER] 使用配置文件路由模式: {QUESTION_ROUTING_MODE}")
+            except ImportError:
+                QUESTION_ROUTING_MODE = "hybrid"
+                print(f"[CLASSIFIER] 配置导入失败,使用默认路由模式: {QUESTION_ROUTING_MODE}")
         
         # 根据路由模式选择分类策略
         if QUESTION_ROUTING_MODE == "database_direct":

+ 32 - 3
citu_app.py

@@ -445,11 +445,22 @@ def ask_agent():
     conversation_id_input = req.get("conversation_id", None)
     continue_conversation = req.get("continue_conversation", False)
     
+    # 新增:路由模式参数解析和验证
+    api_routing_mode = req.get("routing_mode", None)
+    VALID_ROUTING_MODES = ["database_direct", "chat_direct", "hybrid", "llm_only"]
+    
     if not question:
         return jsonify(bad_request_response(
             response_text="缺少必需参数:question",
             missing_params=["question"]
         )), 400
+    
+    # 验证routing_mode参数
+    if api_routing_mode and api_routing_mode not in VALID_ROUTING_MODES:
+        return jsonify(bad_request_response(
+            response_text=f"无效的routing_mode参数值: {api_routing_mode},支持的值: {VALID_ROUTING_MODES}",
+            invalid_params=["routing_mode"]
+        )), 400
 
     try:
         # 1. 获取登录用户ID(修正:在函数中获取session信息)
@@ -554,7 +565,22 @@ def ask_agent():
             enhanced_question = question
             print(f"[AGENT_API] 新对话,无上下文")
         
-        # 7. 现有Agent处理逻辑(保持不变)
+        # 7. 确定最终使用的路由模式(优先级逻辑)
+        if api_routing_mode:
+            # API传了参数,优先使用
+            effective_routing_mode = api_routing_mode
+            print(f"[AGENT_API] 使用API指定的路由模式: {effective_routing_mode}")
+        else:
+            # API没传参数,使用配置文件
+            try:
+                from app_config import QUESTION_ROUTING_MODE
+                effective_routing_mode = QUESTION_ROUTING_MODE
+                print(f"[AGENT_API] 使用配置文件路由模式: {effective_routing_mode}")
+            except ImportError:
+                effective_routing_mode = "hybrid"
+                print(f"[AGENT_API] 配置文件读取失败,使用默认路由模式: {effective_routing_mode}")
+        
+        # 8. 现有Agent处理逻辑(修改为传递路由模式)
         try:
             agent = get_citu_langraph_agent()
         except Exception as e:
@@ -567,7 +593,8 @@ def ask_agent():
         agent_result = agent.process_question(
             question=enhanced_question,  # 使用增强后的问题
             session_id=browser_session_id,
-            context_type=context_type  # 传递上下文类型
+            context_type=context_type,  # 传递上下文类型
+            routing_mode=effective_routing_mode  # 新增:传递路由模式
         )
         
         # 8. 处理Agent结果
@@ -632,7 +659,9 @@ def ask_agent():
                 from_cache=False,
                 conversation_status=conversation_status["status"],
                 conversation_message=conversation_status["message"],
-                requested_conversation_id=conversation_status.get("requested_id")
+                requested_conversation_id=conversation_status.get("requested_id"),
+                routing_mode_used=effective_routing_mode,  # 新增:实际使用的路由模式
+                routing_mode_source="api" if api_routing_mode else "config"  # 新增:路由模式来源
             ))
         else:
             # 错误处理(修正:确保使用现有的错误响应格式)