Kaynağa Gözat

增加QUESTION_ROUTING_MODE的配置.

wangxq 2 hafta önce
ebeveyn
işleme
82725324c5
4 değiştirilmiş dosya ile 175 ekleme ve 39 silme
  1. 124 34
      agent/citu_agent.py
  2. 38 1
      agent/classifier.py
  3. 7 3
      agent/state.py
  4. 6 1
      app_config.py

+ 124 - 34
agent/citu_agent.py

@@ -35,39 +35,119 @@ class CituLangGraphAgent:
         print("[CITU_AGENT] LangGraph Agent with Direct Tools初始化完成")
     
     def _create_workflow(self) -> StateGraph:
-        """创建LangGraph工作流"""
-        workflow = StateGraph(AgentState)
-        
-        # 添加节点
-        workflow.add_node("classify_question", self._classify_question_node)
-        workflow.add_node("agent_chat", self._agent_chat_node)
-        workflow.add_node("agent_database", self._agent_database_node)
-        workflow.add_node("format_response", self._format_response_node)
-        
-        # 设置入口点
-        workflow.set_entry_point("classify_question")
+        """根据路由模式创建不同的工作流"""
+        try:
+            from app_config import QUESTION_ROUTING_MODE
+            print(f"[CITU_AGENT] 创建工作流,路由模式: {QUESTION_ROUTING_MODE}")
+        except ImportError:
+            QUESTION_ROUTING_MODE = "hybrid"
+            print(f"[CITU_AGENT] 配置导入失败,使用默认路由模式: {QUESTION_ROUTING_MODE}")
         
-        # 添加条件边:分类后的路由
-        # 完全信任QuestionClassifier的决策,不再进行二次判断
-        workflow.add_conditional_edges(
-            "classify_question",
-            self._route_after_classification,
-            {
-                "DATABASE": "agent_database",
-                "CHAT": "agent_chat"  # CHAT分支处理所有非DATABASE的情况(包括UNCERTAIN)
-            }
-        )
+        workflow = StateGraph(AgentState)
         
-        # 添加边
-        workflow.add_edge("agent_chat", "format_response")
-        workflow.add_edge("agent_database", "format_response")
-        workflow.add_edge("format_response", END)
+        # 根据路由模式创建不同的工作流
+        if QUESTION_ROUTING_MODE == "database_direct":
+            # 直接数据库模式:跳过分类,直接进入数据库处理
+            workflow.add_node("init_direct_database", self._init_direct_database_node)
+            workflow.add_node("agent_database", self._agent_database_node)
+            workflow.add_node("format_response", self._format_response_node)
+            
+            workflow.set_entry_point("init_direct_database")
+            workflow.add_edge("init_direct_database", "agent_database")
+            workflow.add_edge("agent_database", "format_response")
+            workflow.add_edge("format_response", END)
+            
+        elif QUESTION_ROUTING_MODE == "chat_direct":
+            # 直接聊天模式:跳过分类,直接进入聊天处理
+            workflow.add_node("init_direct_chat", self._init_direct_chat_node)
+            workflow.add_node("agent_chat", self._agent_chat_node)
+            workflow.add_node("format_response", self._format_response_node)
+            
+            workflow.set_entry_point("init_direct_chat")
+            workflow.add_edge("init_direct_chat", "agent_chat")
+            workflow.add_edge("agent_chat", "format_response")
+            workflow.add_edge("format_response", END)
+            
+        else:
+            # 其他模式(hybrid, llm_only):使用原有的分类工作流
+            workflow.add_node("classify_question", self._classify_question_node)
+            workflow.add_node("agent_chat", self._agent_chat_node)
+            workflow.add_node("agent_database", self._agent_database_node)
+            workflow.add_node("format_response", self._format_response_node)
+            
+            workflow.set_entry_point("classify_question")
+            
+            # 添加条件边:分类后的路由
+            workflow.add_conditional_edges(
+                "classify_question",
+                self._route_after_classification,
+                {
+                    "DATABASE": "agent_database",
+                    "CHAT": "agent_chat"
+                }
+            )
+            
+            workflow.add_edge("agent_chat", "format_response")
+            workflow.add_edge("agent_database", "format_response")
+            workflow.add_edge("format_response", END)
         
         return workflow.compile()
     
+    def _init_direct_database_node(self, state: AgentState) -> AgentState:
+        """初始化直接数据库模式的状态"""
+        try:
+            from app_config import QUESTION_ROUTING_MODE
+            
+            # 设置直接数据库模式的分类状态
+            state["question_type"] = "DATABASE"
+            state["classification_confidence"] = 1.0
+            state["classification_reason"] = "配置为直接数据库查询模式"
+            state["classification_method"] = "direct_database"
+            state["routing_mode"] = QUESTION_ROUTING_MODE
+            state["current_step"] = "direct_database_init"
+            state["execution_path"].append("init_direct_database")
+            
+            print(f"[DIRECT_DATABASE] 直接数据库模式初始化完成")
+            
+            return state
+            
+        except Exception as e:
+            print(f"[ERROR] 直接数据库模式初始化异常: {str(e)}")
+            state["error"] = f"直接数据库模式初始化失败: {str(e)}"
+            state["error_code"] = 500
+            state["execution_path"].append("init_direct_database_error")
+            return state
+
+    def _init_direct_chat_node(self, state: AgentState) -> AgentState:
+        """初始化直接聊天模式的状态"""
+        try:
+            from app_config import QUESTION_ROUTING_MODE
+            
+            # 设置直接聊天模式的分类状态
+            state["question_type"] = "CHAT"
+            state["classification_confidence"] = 1.0
+            state["classification_reason"] = "配置为直接聊天模式"
+            state["classification_method"] = "direct_chat"
+            state["routing_mode"] = QUESTION_ROUTING_MODE
+            state["current_step"] = "direct_chat_init"
+            state["execution_path"].append("init_direct_chat")
+            
+            print(f"[DIRECT_CHAT] 直接聊天模式初始化完成")
+            
+            return state
+            
+        except Exception as e:
+            print(f"[ERROR] 直接聊天模式初始化异常: {str(e)}")
+            state["error"] = f"直接聊天模式初始化失败: {str(e)}"
+            state["error_code"] = 500
+            state["execution_path"].append("init_direct_chat_error")
+            return state
+    
     def _classify_question_node(self, state: AgentState) -> AgentState:
-        """问题分类节点"""
+        """问题分类节点 - 支持路由模式"""
         try:
+            from app_config import QUESTION_ROUTING_MODE
+            
             print(f"[CLASSIFY_NODE] 开始分类问题: {state['question']}")
             
             classification_result = self.classifier.classify(state["question"])
@@ -77,10 +157,12 @@ class CituLangGraphAgent:
             state["classification_confidence"] = classification_result.confidence
             state["classification_reason"] = classification_result.reason
             state["classification_method"] = classification_result.method
+            state["routing_mode"] = QUESTION_ROUTING_MODE
             state["current_step"] = "classified"
             state["execution_path"].append("classify")
             
             print(f"[CLASSIFY_NODE] 分类结果: {classification_result.question_type}, 置信度: {classification_result.confidence}")
+            print(f"[CLASSIFY_NODE] 路由模式: {QUESTION_ROUTING_MODE}, 分类方法: {classification_result.method}")
             
             return state
             
@@ -90,7 +172,7 @@ class CituLangGraphAgent:
             state["error_code"] = 500
             state["execution_path"].append("classify_error")
             return state
-    
+        
     def _agent_database_node(self, state: AgentState) -> AgentState:
         """数据库Agent节点 - 直接工具调用模式"""
         try:
@@ -407,14 +489,19 @@ class CituLangGraphAgent:
             }
     
     def _create_initial_state(self, question: str, session_id: str = None) -> AgentState:
-        """创建初始状态"""
+        """创建初始状态 - 支持路由模式"""
+        try:
+            from app_config import QUESTION_ROUTING_MODE
+        except ImportError:
+            QUESTION_ROUTING_MODE = "hybrid"
+        
         return AgentState(
             # 输入信息
             question=question,
             session_id=session_id,
             
-            # 分类结果
-            question_type="",
+            # 分类结果 (初始值,会在分类节点或直接模式初始化节点中更新)
+            question_type="UNCERTAIN",
             classification_confidence=0.0,
             classification_reason="",
             classification_method="",
@@ -436,13 +523,16 @@ class CituLangGraphAgent:
             error_code=None,
             
             # 流程控制
-            current_step="start",
-            execution_path=[],
+            current_step="initialized",
+            execution_path=["start"],
             retry_count=0,
-            max_retries=2,
+            max_retries=3,
             
             # 调试信息
-            debug_info={}
+            debug_info={},
+            
+            # 路由模式
+            routing_mode=QUESTION_ROUTING_MODE
         )
     
     def health_check(self) -> Dict[str, Any]:

+ 38 - 1
agent/classifier.py

@@ -114,9 +114,46 @@ class QuestionClassifier:
             "教程", "指南", "手册"
         ]
     
+# 修改 agent/classifier.py 中的 classify 方法
+
     def classify(self, question: str) -> ClassificationResult:
         """
-        主分类方法:规则预筛选 + 增强LLM分类
+        主分类方法:根据配置的路由模式进行分类
+        """
+        try:
+            from app_config import QUESTION_ROUTING_MODE
+            print(f"[CLASSIFIER] 使用路由模式: {QUESTION_ROUTING_MODE}")
+        except ImportError:
+            QUESTION_ROUTING_MODE = "hybrid"
+            print(f"[CLASSIFIER] 配置导入失败,使用默认路由模式: {QUESTION_ROUTING_MODE}")
+        
+        # 根据路由模式选择分类策略
+        if QUESTION_ROUTING_MODE == "hybrid":
+            return self._hybrid_classify(question)
+        elif QUESTION_ROUTING_MODE == "llm_only":
+            return self._enhanced_llm_classify(question)
+        elif QUESTION_ROUTING_MODE == "database_direct":
+            return ClassificationResult(
+                question_type="DATABASE",
+                confidence=1.0,
+                reason="配置为直接数据库查询模式",
+                method="direct_database"
+            )
+        elif QUESTION_ROUTING_MODE == "chat_direct":
+            return ClassificationResult(
+                question_type="CHAT",
+                confidence=1.0,
+                reason="配置为直接聊天模式",
+                method="direct_chat"
+            )
+        else:
+            print(f"[WARNING] 未知的路由模式: {QUESTION_ROUTING_MODE},使用默认hybrid模式")
+            return self._hybrid_classify(question)
+
+    def _hybrid_classify(self, question: str) -> ClassificationResult:
+        """
+        混合分类模式:规则预筛选 + 增强LLM分类
+        这是原来的 classify 方法逻辑
         """
         # 第一步:规则预筛选
         rule_result = self._rule_based_classify(question)

+ 7 - 3
agent/state.py

@@ -1,4 +1,5 @@
-# agent/state.py
+# 在 agent/state.py 中更新 AgentState 定义
+
 from typing import TypedDict, Literal, Optional, List, Dict, Any
 
 class AgentState(TypedDict):
@@ -12,7 +13,7 @@ class AgentState(TypedDict):
     question_type: Literal["DATABASE", "CHAT", "UNCERTAIN"]
     classification_confidence: float
     classification_reason: str
-    classification_method: str  # "rule", "llm", "hybrid"
+    classification_method: str  # "rule_based_*", "enhanced_llm", "direct_database", "direct_chat", etc.
     
     # 数据库查询流程状态
     sql: Optional[str]
@@ -37,4 +38,7 @@ class AgentState(TypedDict):
     max_retries: int
     
     # 调试信息
-    debug_info: Dict[str, Any]
+    debug_info: Dict[str, Any]
+    
+    # 路由模式相关
+    routing_mode: Optional[str]  # 记录使用的路由模式

+ 6 - 1
app_config.py

@@ -157,4 +157,9 @@ ENABLE_ERROR_SQL_PROMPT = True
 RESULT_VECTOR_ERROR_SQL_SCORE_THRESHOLD = 0.8
 
 # 接口返回查询记录的最大行数
-API_MAX_RETURN_ROWS = 1000
+API_MAX_RETURN_ROWS = 1000
+
+
+# 仅LLM分类:"llm_only", 直接数据库查询:"database_direct", 直接聊天对话: "chat_direct", 混合模式: "hybrid"
+# 混合模式 hybrid(推荐)
+QUESTION_ROUTING_MODE = "hybrid"