瀏覽代碼

修复了多个bug,现在准备修改动态修改workflow.

wangxq 1 周之前
父節點
當前提交
a71cd78844
共有 4 個文件被更改,包括 177 次插入21 次删除
  1. 18 6
      agent/citu_agent.py
  2. 137 13
      agent/classifier.py
  3. 3 0
      agent/state.py
  4. 19 2
      citu_app.py

+ 18 - 6
agent/citu_agent.py

@@ -144,13 +144,19 @@ class CituLangGraphAgent:
             return state
     
     def _classify_question_node(self, state: AgentState) -> AgentState:
-        """问题分类节点 - 支持路由模式"""
+        """问题分类节点 - 支持渐进式分类策略"""
         try:
             from app_config import QUESTION_ROUTING_MODE
             
             print(f"[CLASSIFY_NODE] 开始分类问题: {state['question']}")
             
-            classification_result = self.classifier.classify(state["question"])
+            # 获取上下文类型(如果有的话)
+            context_type = state.get("context_type")
+            if context_type:
+                print(f"[CLASSIFY_NODE] 检测到上下文类型: {context_type}")
+            
+            # 使用渐进式分类策略
+            classification_result = self.classifier.classify(state["question"], context_type)
             
             # 更新状态
             state["question_type"] = classification_result.question_type
@@ -456,22 +462,25 @@ class CituLangGraphAgent:
             # 聊天Agent可以处理不确定的情况,并在必要时引导用户提供更多信息
             return "CHAT"
     
-    def process_question(self, question: str, session_id: str = None) -> Dict[str, Any]:
+    def process_question(self, question: str, session_id: str = None, context_type: str = None) -> Dict[str, Any]:
         """
         统一的问题处理入口
         
         Args:
             question: 用户问题
             session_id: 会话ID
+            context_type: 上下文类型 ("DATABASE" 或 "CHAT"),用于渐进式分类
             
         Returns:
             Dict包含完整的处理结果
         """
         try:
             print(f"[CITU_AGENT] 开始处理问题: {question}")
+            if context_type:
+                print(f"[CITU_AGENT] 上下文类型: {context_type}")
             
             # 初始化状态
-            initial_state = self._create_initial_state(question, session_id)
+            initial_state = self._create_initial_state(question, session_id, context_type)
             
             # 执行工作流
             final_state = self.workflow.invoke(
@@ -497,8 +506,8 @@ class CituLangGraphAgent:
                 "execution_path": ["error"]
             }
     
-    def _create_initial_state(self, question: str, session_id: str = None) -> AgentState:
-        """创建初始状态 - 支持路由模式"""
+    def _create_initial_state(self, question: str, session_id: str = None, context_type: str = None) -> AgentState:
+        """创建初始状态 - 支持渐进式分类"""
         try:
             from app_config import QUESTION_ROUTING_MODE
         except ImportError:
@@ -509,6 +518,9 @@ class CituLangGraphAgent:
             question=question,
             session_id=session_id,
             
+            # 上下文信息
+            context_type=context_type,
+            
             # 分类结果 (初始值,会在分类节点或直接模式初始化节点中更新)
             question_type="UNCERTAIN",
             classification_confidence=0.0,

+ 137 - 13
agent/classifier.py

@@ -1,6 +1,6 @@
 # agent/classifier.py
 import re
-from typing import Dict, Any, List
+from typing import Dict, Any, List, Optional
 from dataclasses import dataclass
 
 @dataclass
@@ -44,7 +44,7 @@ class QuestionClassifier:
                 "服务区", "档口", "商铺", "收费站", "高速公路",
                 "驿美", "驿购",  # 业务系统名称
                 "北区", "南区", "西区", "东区", "两区",  # 物理分区
-                "停车区"
+                "停车区", "公司", "管理公司", "运营公司", "驿美运营公司"  # 公司相关
             ],
             "支付业务": [
                 "微信支付", "支付宝支付", "现金支付", "行吧支付", "金豆支付",
@@ -64,6 +64,12 @@ class QuestionClassifier:
             "地理路线": [
                 "大广", "昌金", "昌栗", "线路", "路段", "路线",
                 "高速线路", "公路线路"
+            ],
+            "系统查询指示词": [
+                "当前系统", "当前数据库", "当前数据",
+                "本系统", "系统中", "数据库中", "数据中",
+                "现有数据", "已有数据", "存储的数据",
+                "平台数据", "我们的数据库", "这个系统"
             ]
         }
         
@@ -113,12 +119,27 @@ class QuestionClassifier:
             "介绍", "解释", "说明", "帮助", "操作", "使用方法", "功能",
             "教程", "指南", "手册"
         ]
-    
-# 修改 agent/classifier.py 中的 classify 方法
+        
+        # 追问关键词(用于检测追问型问题)
+        self.follow_up_keywords = [
+            "还有", "详细", "具体", "更多", "继续", "再", "也",
+            "那么", "另外", "其他", "以及", "还", "进一步",
+            "深入", "补充", "额外", "此外", "同时", "并且"
+        ]
+        
+        # 话题切换关键词(明显的话题转换)
+        self.topic_switch_keywords = [
+            "你好", "你是", "介绍", "功能", "帮助", "使用方法",
+            "平台", "系统", "AI", "助手", "谢谢", "再见"
+        ]
 
-    def classify(self, question: str) -> ClassificationResult:
+    def classify(self, question: str, context_type: Optional[str] = None) -> ClassificationResult:
         """
-        主分类方法:根据配置的路由模式进行分类
+        主分类方法:支持渐进式分类策略
+        
+        Args:
+            question: 当前问题
+            context_type: 上下文类型 ("DATABASE" 或 "CHAT"),可选
         """
         try:
             from app_config import QUESTION_ROUTING_MODE
@@ -128,11 +149,7 @@ class QuestionClassifier:
             print(f"[CLASSIFIER] 配置导入失败,使用默认路由模式: {QUESTION_ROUTING_MODE}")
         
         # 根据路由模式选择分类策略
-        if QUESTION_ROUTING_MODE == "hybrid":
-            return self._hybrid_classify(question)
-        elif QUESTION_ROUTING_MODE == "llm_only":
-            return self._enhanced_llm_classify(question)
-        elif QUESTION_ROUTING_MODE == "database_direct":
+        if QUESTION_ROUTING_MODE == "database_direct":
             return ClassificationResult(
                 question_type="DATABASE",
                 confidence=1.0,
@@ -146,9 +163,96 @@ class QuestionClassifier:
                 reason="配置为直接聊天模式",
                 method="direct_chat"
             )
+        elif QUESTION_ROUTING_MODE == "llm_only":
+            return self._enhanced_llm_classify(question)
         else:
-            print(f"[WARNING] 未知的路由模式: {QUESTION_ROUTING_MODE},使用默认hybrid模式")
-            return self._hybrid_classify(question)
+            # hybrid模式:使用渐进式分类策略
+            return self._progressive_classify(question, context_type)
+
+    def _progressive_classify(self, question: str, context_type: Optional[str] = None) -> ClassificationResult:
+        """
+        渐进式分类策略:
+        1. 首先只基于问题本身分类
+        2. 如果置信度不够且有上下文,考虑上下文辅助
+        3. 检测话题切换,避免错误继承
+        """
+        print(f"[CLASSIFIER] 渐进式分类 - 问题: {question}")
+        if context_type:
+            print(f"[CLASSIFIER] 上下文类型: {context_type}")
+        
+        # 第一步:只基于问题本身分类
+        primary_result = self._hybrid_classify(question)
+        print(f"[CLASSIFIER] 主分类结果: {primary_result.question_type}, 置信度: {primary_result.confidence}")
+        
+        # 如果没有上下文,直接返回主分类结果
+        if not context_type:
+            print(f"[CLASSIFIER] 无上下文,使用主分类结果")
+            return primary_result
+        
+        # 如果置信度足够高,直接使用主分类结果
+        if primary_result.confidence >= self.high_confidence_threshold:
+            print(f"[CLASSIFIER] 高置信度({primary_result.confidence}≥{self.high_confidence_threshold}),使用主分类结果")
+            return primary_result
+        
+        # 检测明显的话题切换
+        if self._is_topic_switch(question):
+            print(f"[CLASSIFIER] 检测到话题切换,忽略上下文")
+            return primary_result
+        
+        # 如果置信度较低,考虑上下文辅助
+        if primary_result.confidence < self.medium_confidence_threshold:
+            print(f"[CLASSIFIER] 低置信度({primary_result.confidence}<{self.medium_confidence_threshold}),考虑上下文辅助")
+            
+            # 检测是否为追问型问题
+            if self._is_follow_up_question(question):
+                print(f"[CLASSIFIER] 检测到追问型问题,继承上下文类型: {context_type}")
+                return ClassificationResult(
+                    question_type=context_type,
+                    confidence=0.75,  # 给予中等置信度
+                    reason=f"追问型问题,继承上下文类型。原分类: {primary_result.reason}",
+                    method="progressive_context_inherit"
+                )
+        
+        # 中等置信度或其他情况,保持主分类结果
+        print(f"[CLASSIFIER] 保持主分类结果")
+        return primary_result
+
+    def _is_follow_up_question(self, question: str) -> bool:
+        """检测是否为追问型问题"""
+        question_lower = question.lower()
+        
+        # 检查追问关键词
+        for keyword in self.follow_up_keywords:
+            if keyword in question_lower:
+                return True
+        
+        # 检查问号开头的短问题(通常是追问)
+        if question.strip().startswith(('还', '再', '那', '这', '有')) and len(question.strip()) < 15:
+            return True
+        
+        return False
+
+    def _is_topic_switch(self, question: str) -> bool:
+        """检测是否为明显的话题切换"""
+        question_lower = question.lower()
+        
+        # 检查话题切换关键词
+        for keyword in self.topic_switch_keywords:
+            if keyword in question_lower:
+                return True
+        
+        # 检查问候语模式
+        greeting_patterns = [
+            r"^(你好|您好|hi|hello)",
+            r"(你是|您是).*(什么|谁|哪)",
+            r"(介绍|说明).*(功能|平台|系统)"
+        ]
+        
+        for pattern in greeting_patterns:
+            if re.search(pattern, question_lower):
+                return True
+        
+        return False
 
     def _hybrid_classify(self, question: str) -> ClassificationResult:
         """
@@ -195,11 +299,21 @@ class QuestionClassifier:
         business_matched = []
         
         for category, keywords in self.strong_business_keywords.items():
+            if category == "系统查询指示词":  # 系统指示词单独处理
+                continue
             for keyword in keywords:
                 if keyword in question_lower:
                     business_score += 2  # 业务实体词权重更高
                     business_matched.append(f"{category}:{keyword}")
         
+        # 检查系统查询指示词
+        system_indicator_score = 0
+        system_matched = []
+        for keyword in self.strong_business_keywords.get("系统查询指示词", []):
+            if keyword in question_lower:
+                system_indicator_score += 1
+                system_matched.append(f"系统查询指示词:{keyword}")
+        
         # 检查查询意图词
         intent_score = 0
         intent_matched = []
@@ -223,6 +337,16 @@ class QuestionClassifier:
                 chat_score += 1
                 chat_matched.append(keyword)
         
+        # 系统指示词组合评分逻辑
+        if system_indicator_score > 0 and business_score > 0:
+            # 系统指示词 + 业务实体 = 强组合效应
+            business_score += 3  # 组合加分
+            business_matched.extend(system_matched)
+        elif system_indicator_score > 0:
+            # 仅有系统指示词 = 中等业务倾向
+            business_score += 1
+            business_matched.extend(system_matched)
+        
         # 分类决策逻辑
         total_business_score = business_score + intent_score
         

+ 3 - 0
agent/state.py

@@ -9,6 +9,9 @@ class AgentState(TypedDict):
     question: str
     session_id: Optional[str]
     
+    # 上下文信息
+    context_type: Optional[str]  # 上下文类型 ("DATABASE" 或 "CHAT")
+    
     # 分类结果
     question_type: Literal["DATABASE", "CHAT", "UNCERTAIN"]
     classification_confidence: float

+ 19 - 2
citu_app.py

@@ -463,9 +463,25 @@ def ask_agent():
             user_id, conversation_id_input, continue_conversation
         )
         
-        # 3. 获取上下文(提前到缓存检查之前)
+        # 3. 获取上下文和上下文类型(提前到缓存检查之前)
         context = redis_conversation_manager.get_context(conversation_id)
         
+        # 获取上下文类型:从最后一条助手消息的metadata中获取类型
+        context_type = None
+        if context:
+            try:
+                # 获取最后一条助手消息的metadata
+                messages = redis_conversation_manager.get_messages(conversation_id, limit=10)
+                for message in reversed(messages):  # 从最新的开始找
+                    if message.get("role") == "assistant":
+                        metadata = message.get("metadata", {})
+                        context_type = metadata.get("type")
+                        if context_type:
+                            print(f"[AGENT_API] 检测到上下文类型: {context_type}")
+                            break
+            except Exception as e:
+                print(f"[WARNING] 获取上下文类型失败: {str(e)}")
+        
         # 4. 检查缓存(新逻辑:放宽使用条件,严控存储条件)
         cached_answer = redis_conversation_manager.get_cached_answer(question, context)
         if cached_answer:
@@ -550,7 +566,8 @@ def ask_agent():
         
         agent_result = agent.process_question(
             question=enhanced_question,  # 使用增强后的问题
-            session_id=browser_session_id
+            session_id=browser_session_id,
+            context_type=context_type  # 传递上下文类型
         )
         
         # 8. 处理Agent结果