|
@@ -1,6 +1,6 @@
|
|
|
# agent/classifier.py
|
|
|
import re
|
|
|
-from typing import Dict, Any, List
|
|
|
+from typing import Dict, Any, List, Optional
|
|
|
from dataclasses import dataclass
|
|
|
|
|
|
@dataclass
|
|
@@ -44,7 +44,7 @@ class QuestionClassifier:
|
|
|
"服务区", "档口", "商铺", "收费站", "高速公路",
|
|
|
"驿美", "驿购", # 业务系统名称
|
|
|
"北区", "南区", "西区", "东区", "两区", # 物理分区
|
|
|
- "停车区"
|
|
|
+ "停车区", "公司", "管理公司", "运营公司", "驿美运营公司" # 公司相关
|
|
|
],
|
|
|
"支付业务": [
|
|
|
"微信支付", "支付宝支付", "现金支付", "行吧支付", "金豆支付",
|
|
@@ -64,6 +64,12 @@ class QuestionClassifier:
|
|
|
"地理路线": [
|
|
|
"大广", "昌金", "昌栗", "线路", "路段", "路线",
|
|
|
"高速线路", "公路线路"
|
|
|
+ ],
|
|
|
+ "系统查询指示词": [
|
|
|
+ "当前系统", "当前数据库", "当前数据",
|
|
|
+ "本系统", "系统中", "数据库中", "数据中",
|
|
|
+ "现有数据", "已有数据", "存储的数据",
|
|
|
+ "平台数据", "我们的数据库", "这个系统"
|
|
|
]
|
|
|
}
|
|
|
|
|
@@ -113,12 +119,27 @@ class QuestionClassifier:
|
|
|
"介绍", "解释", "说明", "帮助", "操作", "使用方法", "功能",
|
|
|
"教程", "指南", "手册"
|
|
|
]
|
|
|
-
|
|
|
-# 修改 agent/classifier.py 中的 classify 方法
|
|
|
+
|
|
|
+ # 追问关键词(用于检测追问型问题)
|
|
|
+ self.follow_up_keywords = [
|
|
|
+ "还有", "详细", "具体", "更多", "继续", "再", "也",
|
|
|
+ "那么", "另外", "其他", "以及", "还", "进一步",
|
|
|
+ "深入", "补充", "额外", "此外", "同时", "并且"
|
|
|
+ ]
|
|
|
+
|
|
|
+ # 话题切换关键词(明显的话题转换)
|
|
|
+ self.topic_switch_keywords = [
|
|
|
+ "你好", "你是", "介绍", "功能", "帮助", "使用方法",
|
|
|
+ "平台", "系统", "AI", "助手", "谢谢", "再见"
|
|
|
+ ]
|
|
|
|
|
|
- def classify(self, question: str) -> ClassificationResult:
|
|
|
+ def classify(self, question: str, context_type: Optional[str] = None) -> ClassificationResult:
|
|
|
"""
|
|
|
- 主分类方法:根据配置的路由模式进行分类
|
|
|
+ 主分类方法:支持渐进式分类策略
|
|
|
+
|
|
|
+ Args:
|
|
|
+ question: 当前问题
|
|
|
+ context_type: 上下文类型 ("DATABASE" 或 "CHAT"),可选
|
|
|
"""
|
|
|
try:
|
|
|
from app_config import QUESTION_ROUTING_MODE
|
|
@@ -128,11 +149,7 @@ class QuestionClassifier:
|
|
|
print(f"[CLASSIFIER] 配置导入失败,使用默认路由模式: {QUESTION_ROUTING_MODE}")
|
|
|
|
|
|
# 根据路由模式选择分类策略
|
|
|
- if QUESTION_ROUTING_MODE == "hybrid":
|
|
|
- return self._hybrid_classify(question)
|
|
|
- elif QUESTION_ROUTING_MODE == "llm_only":
|
|
|
- return self._enhanced_llm_classify(question)
|
|
|
- elif QUESTION_ROUTING_MODE == "database_direct":
|
|
|
+ if QUESTION_ROUTING_MODE == "database_direct":
|
|
|
return ClassificationResult(
|
|
|
question_type="DATABASE",
|
|
|
confidence=1.0,
|
|
@@ -146,9 +163,96 @@ class QuestionClassifier:
|
|
|
reason="配置为直接聊天模式",
|
|
|
method="direct_chat"
|
|
|
)
|
|
|
+ elif QUESTION_ROUTING_MODE == "llm_only":
|
|
|
+ return self._enhanced_llm_classify(question)
|
|
|
else:
|
|
|
- print(f"[WARNING] 未知的路由模式: {QUESTION_ROUTING_MODE},使用默认hybrid模式")
|
|
|
- return self._hybrid_classify(question)
|
|
|
+ # hybrid模式:使用渐进式分类策略
|
|
|
+ return self._progressive_classify(question, context_type)
|
|
|
+
|
|
|
+ def _progressive_classify(self, question: str, context_type: Optional[str] = None) -> ClassificationResult:
|
|
|
+ """
|
|
|
+ 渐进式分类策略:
|
|
|
+ 1. 首先只基于问题本身分类
|
|
|
+ 2. 如果置信度不够且有上下文,考虑上下文辅助
|
|
|
+ 3. 检测话题切换,避免错误继承
|
|
|
+ """
|
|
|
+ print(f"[CLASSIFIER] 渐进式分类 - 问题: {question}")
|
|
|
+ if context_type:
|
|
|
+ print(f"[CLASSIFIER] 上下文类型: {context_type}")
|
|
|
+
|
|
|
+ # 第一步:只基于问题本身分类
|
|
|
+ primary_result = self._hybrid_classify(question)
|
|
|
+ print(f"[CLASSIFIER] 主分类结果: {primary_result.question_type}, 置信度: {primary_result.confidence}")
|
|
|
+
|
|
|
+ # 如果没有上下文,直接返回主分类结果
|
|
|
+ if not context_type:
|
|
|
+ print(f"[CLASSIFIER] 无上下文,使用主分类结果")
|
|
|
+ return primary_result
|
|
|
+
|
|
|
+ # 如果置信度足够高,直接使用主分类结果
|
|
|
+ if primary_result.confidence >= self.high_confidence_threshold:
|
|
|
+ print(f"[CLASSIFIER] 高置信度({primary_result.confidence}≥{self.high_confidence_threshold}),使用主分类结果")
|
|
|
+ return primary_result
|
|
|
+
|
|
|
+ # 检测明显的话题切换
|
|
|
+ if self._is_topic_switch(question):
|
|
|
+ print(f"[CLASSIFIER] 检测到话题切换,忽略上下文")
|
|
|
+ return primary_result
|
|
|
+
|
|
|
+ # 如果置信度较低,考虑上下文辅助
|
|
|
+ if primary_result.confidence < self.medium_confidence_threshold:
|
|
|
+ print(f"[CLASSIFIER] 低置信度({primary_result.confidence}<{self.medium_confidence_threshold}),考虑上下文辅助")
|
|
|
+
|
|
|
+ # 检测是否为追问型问题
|
|
|
+ if self._is_follow_up_question(question):
|
|
|
+ print(f"[CLASSIFIER] 检测到追问型问题,继承上下文类型: {context_type}")
|
|
|
+ return ClassificationResult(
|
|
|
+ question_type=context_type,
|
|
|
+ confidence=0.75, # 给予中等置信度
|
|
|
+ reason=f"追问型问题,继承上下文类型。原分类: {primary_result.reason}",
|
|
|
+ method="progressive_context_inherit"
|
|
|
+ )
|
|
|
+
|
|
|
+ # 中等置信度或其他情况,保持主分类结果
|
|
|
+ print(f"[CLASSIFIER] 保持主分类结果")
|
|
|
+ return primary_result
|
|
|
+
|
|
|
+ def _is_follow_up_question(self, question: str) -> bool:
|
|
|
+ """检测是否为追问型问题"""
|
|
|
+ question_lower = question.lower()
|
|
|
+
|
|
|
+ # 检查追问关键词
|
|
|
+ for keyword in self.follow_up_keywords:
|
|
|
+ if keyword in question_lower:
|
|
|
+ return True
|
|
|
+
|
|
|
+ # 检查问号开头的短问题(通常是追问)
|
|
|
+ if question.strip().startswith(('还', '再', '那', '这', '有')) and len(question.strip()) < 15:
|
|
|
+ return True
|
|
|
+
|
|
|
+ return False
|
|
|
+
|
|
|
+ def _is_topic_switch(self, question: str) -> bool:
|
|
|
+ """检测是否为明显的话题切换"""
|
|
|
+ question_lower = question.lower()
|
|
|
+
|
|
|
+ # 检查话题切换关键词
|
|
|
+ for keyword in self.topic_switch_keywords:
|
|
|
+ if keyword in question_lower:
|
|
|
+ return True
|
|
|
+
|
|
|
+ # 检查问候语模式
|
|
|
+ greeting_patterns = [
|
|
|
+ r"^(你好|您好|hi|hello)",
|
|
|
+ r"(你是|您是).*(什么|谁|哪)",
|
|
|
+ r"(介绍|说明).*(功能|平台|系统)"
|
|
|
+ ]
|
|
|
+
|
|
|
+ for pattern in greeting_patterns:
|
|
|
+ if re.search(pattern, question_lower):
|
|
|
+ return True
|
|
|
+
|
|
|
+ return False
|
|
|
|
|
|
def _hybrid_classify(self, question: str) -> ClassificationResult:
|
|
|
"""
|
|
@@ -195,11 +299,21 @@ class QuestionClassifier:
|
|
|
business_matched = []
|
|
|
|
|
|
for category, keywords in self.strong_business_keywords.items():
|
|
|
+ if category == "系统查询指示词": # 系统指示词单独处理
|
|
|
+ continue
|
|
|
for keyword in keywords:
|
|
|
if keyword in question_lower:
|
|
|
business_score += 2 # 业务实体词权重更高
|
|
|
business_matched.append(f"{category}:{keyword}")
|
|
|
|
|
|
+ # 检查系统查询指示词
|
|
|
+ system_indicator_score = 0
|
|
|
+ system_matched = []
|
|
|
+ for keyword in self.strong_business_keywords.get("系统查询指示词", []):
|
|
|
+ if keyword in question_lower:
|
|
|
+ system_indicator_score += 1
|
|
|
+ system_matched.append(f"系统查询指示词:{keyword}")
|
|
|
+
|
|
|
# 检查查询意图词
|
|
|
intent_score = 0
|
|
|
intent_matched = []
|
|
@@ -223,6 +337,16 @@ class QuestionClassifier:
|
|
|
chat_score += 1
|
|
|
chat_matched.append(keyword)
|
|
|
|
|
|
+ # 系统指示词组合评分逻辑
|
|
|
+ if system_indicator_score > 0 and business_score > 0:
|
|
|
+ # 系统指示词 + 业务实体 = 强组合效应
|
|
|
+ business_score += 3 # 组合加分
|
|
|
+ business_matched.extend(system_matched)
|
|
|
+ elif system_indicator_score > 0:
|
|
|
+ # 仅有系统指示词 = 中等业务倾向
|
|
|
+ business_score += 1
|
|
|
+ business_matched.extend(system_matched)
|
|
|
+
|
|
|
# 分类决策逻辑
|
|
|
total_business_score = business_score + intent_score
|
|
|
|