classifier.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527
  1. # agent/classifier.py
  2. import re
  3. from typing import Dict, Any, List, Optional
  4. from dataclasses import dataclass
  5. @dataclass
  6. class ClassificationResult:
  7. question_type: str
  8. confidence: float
  9. reason: str
  10. method: str
  11. class QuestionClassifier:
  12. """
  13. 增强版问题分类器:基于高速公路服务区业务上下文的智能分类
  14. """
  15. def __init__(self):
  16. # 从配置文件加载阈值参数
  17. try:
  18. from agent.config import get_current_config, get_nested_config
  19. config = get_current_config()
  20. self.high_confidence_threshold = get_nested_config(config, "classification.high_confidence_threshold", 0.7)
  21. self.low_confidence_threshold = get_nested_config(config, "classification.low_confidence_threshold", 0.4)
  22. self.max_confidence = get_nested_config(config, "classification.max_confidence", 0.9)
  23. self.base_confidence = get_nested_config(config, "classification.base_confidence", 0.4)
  24. self.confidence_increment = get_nested_config(config, "classification.confidence_increment", 0.08)
  25. self.llm_fallback_confidence = get_nested_config(config, "classification.llm_fallback_confidence", 0.5)
  26. self.uncertain_confidence = get_nested_config(config, "classification.uncertain_confidence", 0.2)
  27. print("[CLASSIFIER] 从配置文件加载分类器参数完成")
  28. except ImportError:
  29. self.high_confidence_threshold = 0.7
  30. self.low_confidence_threshold = 0.4
  31. self.max_confidence = 0.9
  32. self.base_confidence = 0.4
  33. self.confidence_increment = 0.08
  34. self.llm_fallback_confidence = 0.5
  35. self.uncertain_confidence = 0.2
  36. print("[CLASSIFIER] 配置文件不可用,使用默认分类器参数")
  37. # 基于高速公路服务区业务的精准关键词
  38. self.strong_business_keywords = {
  39. "核心业务实体": [
  40. "服务区", "档口", "商铺", "收费站", "高速公路",
  41. "驿美", "驿购", # 业务系统名称
  42. "北区", "南区", "西区", "东区", "两区", # 物理分区
  43. "停车区", "公司", "管理公司", "运营公司", "驿美运营公司" # 公司相关
  44. ],
  45. "支付业务": [
  46. "微信支付", "支付宝支付", "现金支付", "行吧支付", "金豆支付",
  47. "支付金额", "订单数量", "营业额", "收入", "营业收入",
  48. "微信", "支付宝", "现金", "行吧", "金豆", # 简化形式
  49. "wx", "zfb", "rmb", "xs", "jd" # 系统字段名
  50. ],
  51. "经营品类": [
  52. "餐饮", "小吃", "便利店", "整体租赁",
  53. "驿美餐饮", "品牌", "经营品类", "商业品类"
  54. ],
  55. "车流业务": [
  56. "车流量", "车辆数量", "客车", "货车",
  57. "过境", "危化品", "城际", "车辆统计",
  58. "流量统计", "车型分布"
  59. ],
  60. "地理路线": [
  61. "大广", "昌金", "昌栗", "线路", "路段", "路线",
  62. "高速线路", "公路线路"
  63. ],
  64. "系统查询指示词": [
  65. "当前系统", "当前数据库", "当前数据",
  66. "本系统", "系统中", "数据库中", "数据中",
  67. "现有数据", "已有数据", "存储的数据",
  68. "平台数据", "我们的数据库", "这个系统"
  69. ]
  70. }
  71. # 查询意图词(辅助判断)
  72. self.query_intent_keywords = [
  73. "统计", "查询", "分析", "排行", "排名",
  74. "报表", "报告", "汇总", "计算", "对比",
  75. "趋势", "占比", "百分比", "比例",
  76. "最大", "最小", "最高", "最低", "平均",
  77. "总计", "合计", "累计", "求和", "求平均",
  78. "生成", "导出", "显示", "列出"
  79. ]
  80. # 非业务实体词(包含则倾向CHAT)
  81. self.non_business_keywords = [
  82. # 农产品/食物
  83. "荔枝", "苹果", "西瓜", "水果", "蔬菜", "大米", "小麦",
  84. "橙子", "香蕉", "葡萄", "草莓", "樱桃", "桃子", "梨",
  85. # 技术概念
  86. "人工智能", "机器学习", "编程", "算法", "深度学习",
  87. "AI", "神经网络", "模型训练", "数据挖掘",
  88. # 身份询问
  89. "你是谁", "你是什么", "你叫什么", "你的名字",
  90. "什么模型", "大模型", "AI助手", "助手", "机器人",
  91. # 天气相关
  92. "天气", "气温", "下雨", "晴天", "阴天", "温度",
  93. "天气预报", "气候", "降雨", "雪天",
  94. # 其他生活常识
  95. "怎么做饭", "如何减肥", "健康", "医疗", "病症",
  96. "历史", "地理", "文学", "电影", "音乐", "体育",
  97. "娱乐", "游戏", "小说", "新闻", "政治"
  98. ]
  99. # SQL关键词(技术层面的数据库操作)
  100. self.sql_patterns = [
  101. r"\b(select|from|where|group by|order by|having|join)\b",
  102. r"\b(数据库|表名|字段名|SQL|sql)\b"
  103. ]
  104. # 聊天关键词(平台功能和帮助)
  105. self.chat_keywords = [
  106. "你好", "谢谢", "再见", "怎么样", "如何", "为什么", "什么是",
  107. "介绍", "解释", "说明", "帮助", "操作", "使用方法", "功能",
  108. "教程", "指南", "手册"
  109. ]
  110. # 追问关键词(用于检测追问型问题)
  111. self.follow_up_keywords = [
  112. "还有", "详细", "具体", "更多", "继续", "再", "也",
  113. "那么", "另外", "其他", "以及", "还", "进一步",
  114. "深入", "补充", "额外", "此外", "同时", "并且"
  115. ]
  116. # 话题切换关键词(明显的话题转换)
  117. self.topic_switch_keywords = [
  118. "你好", "你是", "介绍", "功能", "帮助", "使用方法",
  119. "平台", "系统", "AI", "助手", "谢谢", "再见"
  120. ]
  121. def classify(self, question: str, context_type: Optional[str] = None, routing_mode: Optional[str] = None) -> ClassificationResult:
  122. """
  123. 主分类方法:支持渐进式分类策略
  124. Args:
  125. question: 当前问题
  126. context_type: 上下文类型 ("DATABASE" 或 "CHAT"),可选
  127. routing_mode: 路由模式,可选,用于覆盖配置文件设置
  128. """
  129. # 确定使用的路由模式
  130. if routing_mode:
  131. QUESTION_ROUTING_MODE = routing_mode
  132. print(f"[CLASSIFIER] 使用传入的路由模式: {QUESTION_ROUTING_MODE}")
  133. else:
  134. try:
  135. from app_config import QUESTION_ROUTING_MODE
  136. print(f"[CLASSIFIER] 使用配置文件路由模式: {QUESTION_ROUTING_MODE}")
  137. except ImportError:
  138. QUESTION_ROUTING_MODE = "hybrid"
  139. print(f"[CLASSIFIER] 配置导入失败,使用默认路由模式: {QUESTION_ROUTING_MODE}")
  140. # 根据路由模式选择分类策略
  141. if QUESTION_ROUTING_MODE == "database_direct":
  142. return ClassificationResult(
  143. question_type="DATABASE",
  144. confidence=1.0,
  145. reason="配置为直接数据库查询模式",
  146. method="direct_database"
  147. )
  148. elif QUESTION_ROUTING_MODE == "chat_direct":
  149. return ClassificationResult(
  150. question_type="CHAT",
  151. confidence=1.0,
  152. reason="配置为直接聊天模式",
  153. method="direct_chat"
  154. )
  155. elif QUESTION_ROUTING_MODE == "llm_only":
  156. return self._enhanced_llm_classify(question)
  157. else:
  158. # hybrid模式:使用渐进式分类策略
  159. return self._progressive_classify(question, context_type)
  160. def _progressive_classify(self, question: str, context_type: Optional[str] = None) -> ClassificationResult:
  161. """
  162. 渐进式分类策略:
  163. 1. 首先只基于问题本身分类
  164. 2. 如果置信度不够且有上下文,考虑上下文辅助
  165. 3. 检测话题切换,避免错误继承
  166. """
  167. print(f"[CLASSIFIER] 渐进式分类 - 问题: {question}")
  168. if context_type:
  169. print(f"[CLASSIFIER] 上下文类型: {context_type}")
  170. # 第一步:只基于问题本身分类
  171. primary_result = self._hybrid_classify(question)
  172. print(f"[CLASSIFIER] 主分类结果: {primary_result.question_type}, 置信度: {primary_result.confidence}")
  173. # 如果没有上下文,直接返回主分类结果
  174. if not context_type:
  175. print(f"[CLASSIFIER] 无上下文,使用主分类结果")
  176. return primary_result
  177. # 如果置信度足够高,直接使用主分类结果
  178. if primary_result.confidence >= self.high_confidence_threshold:
  179. print(f"[CLASSIFIER] 高置信度({primary_result.confidence}≥{self.high_confidence_threshold}),使用主分类结果")
  180. return primary_result
  181. # 检测明显的话题切换
  182. if self._is_topic_switch(question):
  183. print(f"[CLASSIFIER] 检测到话题切换,忽略上下文")
  184. return primary_result
  185. # 如果置信度较低,考虑上下文辅助
  186. if primary_result.confidence < self.medium_confidence_threshold:
  187. print(f"[CLASSIFIER] 低置信度({primary_result.confidence}<{self.medium_confidence_threshold}),考虑上下文辅助")
  188. # 检测是否为追问型问题
  189. if self._is_follow_up_question(question):
  190. print(f"[CLASSIFIER] 检测到追问型问题,继承上下文类型: {context_type}")
  191. return ClassificationResult(
  192. question_type=context_type,
  193. confidence=0.75, # 给予中等置信度
  194. reason=f"追问型问题,继承上下文类型。原分类: {primary_result.reason}",
  195. method="progressive_context_inherit"
  196. )
  197. # 中等置信度或其他情况,保持主分类结果
  198. print(f"[CLASSIFIER] 保持主分类结果")
  199. return primary_result
  200. def _is_follow_up_question(self, question: str) -> bool:
  201. """检测是否为追问型问题"""
  202. question_lower = question.lower()
  203. # 检查追问关键词
  204. for keyword in self.follow_up_keywords:
  205. if keyword in question_lower:
  206. return True
  207. # 检查问号开头的短问题(通常是追问)
  208. if question.strip().startswith(('还', '再', '那', '这', '有')) and len(question.strip()) < 15:
  209. return True
  210. return False
  211. def _is_topic_switch(self, question: str) -> bool:
  212. """检测是否为明显的话题切换"""
  213. question_lower = question.lower()
  214. # 检查话题切换关键词
  215. for keyword in self.topic_switch_keywords:
  216. if keyword in question_lower:
  217. return True
  218. # 检查问候语模式
  219. greeting_patterns = [
  220. r"^(你好|您好|hi|hello)",
  221. r"(你是|您是).*(什么|谁|哪)",
  222. r"(介绍|说明).*(功能|平台|系统)"
  223. ]
  224. for pattern in greeting_patterns:
  225. if re.search(pattern, question_lower):
  226. return True
  227. return False
  228. def _hybrid_classify(self, question: str) -> ClassificationResult:
  229. """
  230. 混合分类模式:规则预筛选 + 增强LLM分类
  231. 这是原来的 classify 方法逻辑
  232. """
  233. # 第一步:规则预筛选
  234. rule_result = self._rule_based_classify(question)
  235. # 如果规则分类有高置信度,直接使用
  236. if rule_result.confidence >= self.high_confidence_threshold:
  237. return rule_result
  238. # 第二步:使用增强的LLM分类
  239. llm_result = self._enhanced_llm_classify(question)
  240. # 选择置信度更高的结果
  241. if llm_result.confidence > rule_result.confidence:
  242. return llm_result
  243. else:
  244. return rule_result
  245. def _rule_based_classify(self, question: str) -> ClassificationResult:
  246. """基于规则的预分类"""
  247. question_lower = question.lower()
  248. # 检查非业务实体词
  249. non_business_matched = []
  250. for keyword in self.non_business_keywords:
  251. if keyword in question_lower:
  252. non_business_matched.append(keyword)
  253. # 如果包含非业务实体词,直接分类为CHAT
  254. if non_business_matched:
  255. return ClassificationResult(
  256. question_type="CHAT",
  257. confidence=0.85,
  258. reason=f"包含非业务实体词: {non_business_matched}",
  259. method="rule_based_non_business"
  260. )
  261. # 检查强业务关键词
  262. business_score = 0
  263. business_matched = []
  264. for category, keywords in self.strong_business_keywords.items():
  265. if category == "系统查询指示词": # 系统指示词单独处理
  266. continue
  267. for keyword in keywords:
  268. if keyword in question_lower:
  269. business_score += 2 # 业务实体词权重更高
  270. business_matched.append(f"{category}:{keyword}")
  271. # 检查系统查询指示词
  272. system_indicator_score = 0
  273. system_matched = []
  274. for keyword in self.strong_business_keywords.get("系统查询指示词", []):
  275. if keyword in question_lower:
  276. system_indicator_score += 1
  277. system_matched.append(f"系统查询指示词:{keyword}")
  278. # 检查查询意图词
  279. intent_score = 0
  280. intent_matched = []
  281. for keyword in self.query_intent_keywords:
  282. if keyword in question_lower:
  283. intent_score += 1
  284. intent_matched.append(keyword)
  285. # 检查SQL模式
  286. sql_patterns_matched = []
  287. for pattern in self.sql_patterns:
  288. if re.search(pattern, question_lower, re.IGNORECASE):
  289. business_score += 3 # SQL模式权重最高
  290. sql_patterns_matched.append(pattern)
  291. # 检查聊天关键词
  292. chat_score = 0
  293. chat_matched = []
  294. for keyword in self.chat_keywords:
  295. if keyword in question_lower:
  296. chat_score += 1
  297. chat_matched.append(keyword)
  298. # 系统指示词组合评分逻辑
  299. if system_indicator_score > 0 and business_score > 0:
  300. # 系统指示词 + 业务实体 = 强组合效应
  301. business_score += 3 # 组合加分
  302. business_matched.extend(system_matched)
  303. elif system_indicator_score > 0:
  304. # 仅有系统指示词 = 中等业务倾向
  305. business_score += 1
  306. business_matched.extend(system_matched)
  307. # 分类决策逻辑
  308. total_business_score = business_score + intent_score
  309. # 强业务特征:包含业务实体 + 查询意图
  310. if business_score >= 2 and intent_score >= 1:
  311. confidence = min(self.max_confidence, 0.8 + (total_business_score * 0.05))
  312. return ClassificationResult(
  313. question_type="DATABASE",
  314. confidence=confidence,
  315. reason=f"强业务特征 - 业务实体: {business_matched}, 查询意图: {intent_matched}, SQL: {sql_patterns_matched}",
  316. method="rule_based_strong_business"
  317. )
  318. # 中等业务特征:包含多个业务实体词
  319. elif business_score >= 4:
  320. confidence = min(self.max_confidence, 0.7 + (business_score * 0.03))
  321. return ClassificationResult(
  322. question_type="DATABASE",
  323. confidence=confidence,
  324. reason=f"中等业务特征 - 业务实体: {business_matched}",
  325. method="rule_based_medium_business"
  326. )
  327. # 聊天特征
  328. elif chat_score >= 1 and business_score == 0:
  329. confidence = min(self.max_confidence, self.base_confidence + (chat_score * self.confidence_increment))
  330. return ClassificationResult(
  331. question_type="CHAT",
  332. confidence=confidence,
  333. reason=f"聊天特征: {chat_matched}",
  334. method="rule_based_chat"
  335. )
  336. # 不确定情况
  337. else:
  338. return ClassificationResult(
  339. question_type="UNCERTAIN",
  340. confidence=self.uncertain_confidence,
  341. reason=f"规则分类不确定 - 业务分:{business_score}, 意图分:{intent_score}, 聊天分:{chat_score}",
  342. method="rule_based_uncertain"
  343. )
  344. def _enhanced_llm_classify(self, question: str) -> ClassificationResult:
  345. """增强的LLM分类:包含详细的业务上下文"""
  346. try:
  347. from common.vanna_instance import get_vanna_instance
  348. vn = get_vanna_instance()
  349. # 构建包含业务上下文的分类提示词
  350. classification_prompt = f"""
  351. 请判断以下用户问题是否需要查询我们的高速公路服务区管理数据库。
  352. 用户问题:{question}
  353. === 数据库业务范围 ===
  354. 本系统是高速公路服务区商业管理系统,包含以下业务数据:
  355. 核心业务实体:
  356. - 服务区(bss_service_area):服务区基础信息、位置、状态,如"鄱阳湖服务区"、"信丰西服务区"
  357. - 档口/商铺(bss_branch):档口信息、品类(餐饮/小吃/便利店)、品牌,如"驿美餐饮"、"加水机"
  358. - 营业数据(bss_business_day_data):每日支付金额、订单数量,包含微信、支付宝、现金等支付方式
  359. - 车流量(bss_car_day_count):按车型统计的日流量数据,包含客车、货车、过境、危化品等
  360. - 公司信息(bss_company):服务区管理公司,如"驿美运营公司"
  361. 关键业务指标:
  362. - 支付方式:微信支付(wx)、支付宝支付(zfb)、现金支付(rmb)、行吧支付(xs)、金豆支付(jd)
  363. - 营业数据:支付金额、订单数量、营业额、收入统计
  364. - 车流统计:按车型(客车/货车/过境/危化品/城际)的流量分析
  365. - 经营分析:餐饮、小吃、便利店、整体租赁等品类收入
  366. - 地理分区:北区、南区、西区、东区、两区
  367. 高速线路:
  368. - 线路信息:大广、昌金、昌栗等高速线路
  369. - 路段管理:按线路统计服务区分布
  370. === 判断标准 ===
  371. 1. **DATABASE类型** - 需要查询数据库:
  372. - 涉及上述业务实体和指标的查询、统计、分析、报表
  373. - 包含业务相关的时间查询,如"本月服务区营业额"、"上月档口收入"
  374. - 例如:"本月营业额统计"、"档口收入排行"、"车流量分析"、"支付方式占比"
  375. 2. **CHAT类型** - 不需要查询数据库:
  376. - 生活常识:水果蔬菜上市时间、动植物知识、天气等
  377. - 身份询问:你是谁、什么模型、AI助手等
  378. - 技术概念:人工智能、编程、算法等
  379. - 平台使用:功能介绍、操作帮助、使用教程等
  380. - 例如:"荔枝几月份上市"、"今天天气如何"、"你是什么AI"、"怎么使用平台"
  381. **重要提示:**
  382. - 只有涉及高速公路服务区业务数据的问题才分类为DATABASE
  383. - 即使包含时间词汇(如"月份"、"时间"),也要看是否与我们的业务数据相关
  384. - 农产品上市时间、生活常识等都应分类为CHAT
  385. 请基于问题与我们高速公路服务区业务数据的相关性来分类。
  386. 格式:
  387. 分类: [DATABASE/CHAT]
  388. 理由: [详细说明问题与业务数据的相关性,具体分析涉及哪些业务实体或为什么不相关]
  389. 置信度: [0.0-1.0之间的数字]
  390. """
  391. # 专业的系统提示词
  392. system_prompt = """你是一个专业的业务问题分类助手,专门负责高速公路服务区管理系统的问题分类。你具有以下特长:
  393. 1. 深度理解高速公路服务区业务领域和数据范围
  394. 2. 准确区分业务数据查询需求和一般性问题
  395. 3. 基于具体业务上下文进行精准分类,而不仅仅依赖关键词匹配
  396. 4. 对边界情况能够给出合理的置信度评估
  397. 请严格按照业务相关性进行分类,并提供详细的分类理由。"""
  398. # 使用 Vanna 实例的 chat_with_llm 方法
  399. response = vn.chat_with_llm(
  400. question=classification_prompt,
  401. system_prompt=system_prompt
  402. )
  403. # 解析响应
  404. return self._parse_llm_response(response)
  405. except Exception as e:
  406. print(f"[WARNING] 增强LLM分类失败: {str(e)}")
  407. return ClassificationResult(
  408. question_type="CHAT", # 失败时默认为CHAT,更安全
  409. confidence=self.llm_fallback_confidence,
  410. reason=f"LLM分类异常,默认为聊天: {str(e)}",
  411. method="llm_error"
  412. )
  413. def _parse_llm_response(self, response: str) -> ClassificationResult:
  414. """解析LLM响应"""
  415. try:
  416. lines = response.strip().split('\n')
  417. question_type = "CHAT" # 默认为CHAT
  418. reason = "LLM响应解析失败"
  419. confidence = self.llm_fallback_confidence
  420. for line in lines:
  421. line = line.strip()
  422. if line.startswith("分类:") or line.startswith("Classification:"):
  423. type_part = line.split(":", 1)[1].strip().upper()
  424. if "DATABASE" in type_part:
  425. question_type = "DATABASE"
  426. elif "CHAT" in type_part:
  427. question_type = "CHAT"
  428. elif line.startswith("理由:") or line.startswith("Reason:"):
  429. reason = line.split(":", 1)[1].strip()
  430. elif line.startswith("置信度:") or line.startswith("Confidence:"):
  431. try:
  432. conf_str = line.split(":", 1)[1].strip()
  433. confidence = float(conf_str)
  434. # 确保置信度在合理范围内
  435. confidence = max(0.0, min(1.0, confidence))
  436. except:
  437. confidence = self.llm_fallback_confidence
  438. return ClassificationResult(
  439. question_type=question_type,
  440. confidence=confidence,
  441. reason=reason,
  442. method="enhanced_llm"
  443. )
  444. except Exception as e:
  445. return ClassificationResult(
  446. question_type="CHAT", # 解析失败时默认为CHAT
  447. confidence=self.llm_fallback_confidence,
  448. reason=f"响应解析失败: {str(e)}",
  449. method="llm_parse_error"
  450. )