classifier.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. # agent/classifier.py
  2. import re
  3. from typing import Dict, Any, List
  4. from dataclasses import dataclass
  5. @dataclass
  6. class ClassificationResult:
  7. question_type: str
  8. confidence: float
  9. reason: str
  10. method: str
  11. class QuestionClassifier:
  12. """
  13. 增强版问题分类器:基于高速公路服务区业务上下文的智能分类
  14. """
  15. def __init__(self):
  16. # 从配置文件加载阈值参数
  17. try:
  18. from agent.config import get_current_config, get_nested_config
  19. config = get_current_config()
  20. self.high_confidence_threshold = get_nested_config(config, "classification.high_confidence_threshold", 0.7)
  21. self.low_confidence_threshold = get_nested_config(config, "classification.low_confidence_threshold", 0.4)
  22. self.max_confidence = get_nested_config(config, "classification.max_confidence", 0.9)
  23. self.base_confidence = get_nested_config(config, "classification.base_confidence", 0.4)
  24. self.confidence_increment = get_nested_config(config, "classification.confidence_increment", 0.08)
  25. self.llm_fallback_confidence = get_nested_config(config, "classification.llm_fallback_confidence", 0.5)
  26. self.uncertain_confidence = get_nested_config(config, "classification.uncertain_confidence", 0.2)
  27. print("[CLASSIFIER] 从配置文件加载分类器参数完成")
  28. except ImportError:
  29. self.high_confidence_threshold = 0.7
  30. self.low_confidence_threshold = 0.4
  31. self.max_confidence = 0.9
  32. self.base_confidence = 0.4
  33. self.confidence_increment = 0.08
  34. self.llm_fallback_confidence = 0.5
  35. self.uncertain_confidence = 0.2
  36. print("[CLASSIFIER] 配置文件不可用,使用默认分类器参数")
  37. # 基于高速公路服务区业务的精准关键词
  38. self.strong_business_keywords = {
  39. "核心业务实体": [
  40. "服务区", "档口", "商铺", "收费站", "高速公路",
  41. "驿美", "驿购", # 业务系统名称
  42. "北区", "南区", "西区", "东区", "两区", # 物理分区
  43. "停车区"
  44. ],
  45. "支付业务": [
  46. "微信支付", "支付宝支付", "现金支付", "行吧支付", "金豆支付",
  47. "支付金额", "订单数量", "营业额", "收入", "营业收入",
  48. "微信", "支付宝", "现金", "行吧", "金豆", # 简化形式
  49. "wx", "zfb", "rmb", "xs", "jd" # 系统字段名
  50. ],
  51. "经营品类": [
  52. "餐饮", "小吃", "便利店", "整体租赁",
  53. "驿美餐饮", "品牌", "经营品类", "商业品类"
  54. ],
  55. "车流业务": [
  56. "车流量", "车辆数量", "客车", "货车",
  57. "过境", "危化品", "城际", "车辆统计",
  58. "流量统计", "车型分布"
  59. ],
  60. "地理路线": [
  61. "大广", "昌金", "昌栗", "线路", "路段", "路线",
  62. "高速线路", "公路线路"
  63. ]
  64. }
  65. # 查询意图词(辅助判断)
  66. self.query_intent_keywords = [
  67. "统计", "查询", "分析", "排行", "排名",
  68. "报表", "报告", "汇总", "计算", "对比",
  69. "趋势", "占比", "百分比", "比例",
  70. "最大", "最小", "最高", "最低", "平均",
  71. "总计", "合计", "累计", "求和", "求平均",
  72. "生成", "导出", "显示", "列出"
  73. ]
  74. # 非业务实体词(包含则倾向CHAT)
  75. self.non_business_keywords = [
  76. # 农产品/食物
  77. "荔枝", "苹果", "西瓜", "水果", "蔬菜", "大米", "小麦",
  78. "橙子", "香蕉", "葡萄", "草莓", "樱桃", "桃子", "梨",
  79. # 技术概念
  80. "人工智能", "机器学习", "编程", "算法", "深度学习",
  81. "AI", "神经网络", "模型训练", "数据挖掘",
  82. # 身份询问
  83. "你是谁", "你是什么", "你叫什么", "你的名字",
  84. "什么模型", "大模型", "AI助手", "助手", "机器人",
  85. # 天气相关
  86. "天气", "气温", "下雨", "晴天", "阴天", "温度",
  87. "天气预报", "气候", "降雨", "雪天",
  88. # 其他生活常识
  89. "怎么做饭", "如何减肥", "健康", "医疗", "病症",
  90. "历史", "地理", "文学", "电影", "音乐", "体育",
  91. "娱乐", "游戏", "小说", "新闻", "政治"
  92. ]
  93. # SQL关键词(技术层面的数据库操作)
  94. self.sql_patterns = [
  95. r"\b(select|from|where|group by|order by|having|join)\b",
  96. r"\b(数据库|表名|字段名|SQL|sql)\b"
  97. ]
  98. # 聊天关键词(平台功能和帮助)
  99. self.chat_keywords = [
  100. "你好", "谢谢", "再见", "怎么样", "如何", "为什么", "什么是",
  101. "介绍", "解释", "说明", "帮助", "操作", "使用方法", "功能",
  102. "教程", "指南", "手册"
  103. ]
  104. # 修改 agent/classifier.py 中的 classify 方法
  105. def classify(self, question: str) -> ClassificationResult:
  106. """
  107. 主分类方法:根据配置的路由模式进行分类
  108. """
  109. try:
  110. from app_config import QUESTION_ROUTING_MODE
  111. print(f"[CLASSIFIER] 使用路由模式: {QUESTION_ROUTING_MODE}")
  112. except ImportError:
  113. QUESTION_ROUTING_MODE = "hybrid"
  114. print(f"[CLASSIFIER] 配置导入失败,使用默认路由模式: {QUESTION_ROUTING_MODE}")
  115. # 根据路由模式选择分类策略
  116. if QUESTION_ROUTING_MODE == "hybrid":
  117. return self._hybrid_classify(question)
  118. elif QUESTION_ROUTING_MODE == "llm_only":
  119. return self._enhanced_llm_classify(question)
  120. elif QUESTION_ROUTING_MODE == "database_direct":
  121. return ClassificationResult(
  122. question_type="DATABASE",
  123. confidence=1.0,
  124. reason="配置为直接数据库查询模式",
  125. method="direct_database"
  126. )
  127. elif QUESTION_ROUTING_MODE == "chat_direct":
  128. return ClassificationResult(
  129. question_type="CHAT",
  130. confidence=1.0,
  131. reason="配置为直接聊天模式",
  132. method="direct_chat"
  133. )
  134. else:
  135. print(f"[WARNING] 未知的路由模式: {QUESTION_ROUTING_MODE},使用默认hybrid模式")
  136. return self._hybrid_classify(question)
  137. def _hybrid_classify(self, question: str) -> ClassificationResult:
  138. """
  139. 混合分类模式:规则预筛选 + 增强LLM分类
  140. 这是原来的 classify 方法逻辑
  141. """
  142. # 第一步:规则预筛选
  143. rule_result = self._rule_based_classify(question)
  144. # 如果规则分类有高置信度,直接使用
  145. if rule_result.confidence >= self.high_confidence_threshold:
  146. return rule_result
  147. # 第二步:使用增强的LLM分类
  148. llm_result = self._enhanced_llm_classify(question)
  149. # 选择置信度更高的结果
  150. if llm_result.confidence > rule_result.confidence:
  151. return llm_result
  152. else:
  153. return rule_result
  154. def _rule_based_classify(self, question: str) -> ClassificationResult:
  155. """基于规则的预分类"""
  156. question_lower = question.lower()
  157. # 检查非业务实体词
  158. non_business_matched = []
  159. for keyword in self.non_business_keywords:
  160. if keyword in question_lower:
  161. non_business_matched.append(keyword)
  162. # 如果包含非业务实体词,直接分类为CHAT
  163. if non_business_matched:
  164. return ClassificationResult(
  165. question_type="CHAT",
  166. confidence=0.85,
  167. reason=f"包含非业务实体词: {non_business_matched}",
  168. method="rule_based_non_business"
  169. )
  170. # 检查强业务关键词
  171. business_score = 0
  172. business_matched = []
  173. for category, keywords in self.strong_business_keywords.items():
  174. for keyword in keywords:
  175. if keyword in question_lower:
  176. business_score += 2 # 业务实体词权重更高
  177. business_matched.append(f"{category}:{keyword}")
  178. # 检查查询意图词
  179. intent_score = 0
  180. intent_matched = []
  181. for keyword in self.query_intent_keywords:
  182. if keyword in question_lower:
  183. intent_score += 1
  184. intent_matched.append(keyword)
  185. # 检查SQL模式
  186. sql_patterns_matched = []
  187. for pattern in self.sql_patterns:
  188. if re.search(pattern, question_lower, re.IGNORECASE):
  189. business_score += 3 # SQL模式权重最高
  190. sql_patterns_matched.append(pattern)
  191. # 检查聊天关键词
  192. chat_score = 0
  193. chat_matched = []
  194. for keyword in self.chat_keywords:
  195. if keyword in question_lower:
  196. chat_score += 1
  197. chat_matched.append(keyword)
  198. # 分类决策逻辑
  199. total_business_score = business_score + intent_score
  200. # 强业务特征:包含业务实体 + 查询意图
  201. if business_score >= 2 and intent_score >= 1:
  202. confidence = min(self.max_confidence, 0.8 + (total_business_score * 0.05))
  203. return ClassificationResult(
  204. question_type="DATABASE",
  205. confidence=confidence,
  206. reason=f"强业务特征 - 业务实体: {business_matched}, 查询意图: {intent_matched}, SQL: {sql_patterns_matched}",
  207. method="rule_based_strong_business"
  208. )
  209. # 中等业务特征:包含多个业务实体词
  210. elif business_score >= 4:
  211. confidence = min(self.max_confidence, 0.7 + (business_score * 0.03))
  212. return ClassificationResult(
  213. question_type="DATABASE",
  214. confidence=confidence,
  215. reason=f"中等业务特征 - 业务实体: {business_matched}",
  216. method="rule_based_medium_business"
  217. )
  218. # 聊天特征
  219. elif chat_score >= 1 and business_score == 0:
  220. confidence = min(self.max_confidence, self.base_confidence + (chat_score * self.confidence_increment))
  221. return ClassificationResult(
  222. question_type="CHAT",
  223. confidence=confidence,
  224. reason=f"聊天特征: {chat_matched}",
  225. method="rule_based_chat"
  226. )
  227. # 不确定情况
  228. else:
  229. return ClassificationResult(
  230. question_type="UNCERTAIN",
  231. confidence=self.uncertain_confidence,
  232. reason=f"规则分类不确定 - 业务分:{business_score}, 意图分:{intent_score}, 聊天分:{chat_score}",
  233. method="rule_based_uncertain"
  234. )
  235. def _enhanced_llm_classify(self, question: str) -> ClassificationResult:
  236. """增强的LLM分类:包含详细的业务上下文"""
  237. try:
  238. from common.vanna_instance import get_vanna_instance
  239. vn = get_vanna_instance()
  240. # 构建包含业务上下文的分类提示词
  241. classification_prompt = f"""
  242. 请判断以下用户问题是否需要查询我们的高速公路服务区管理数据库。
  243. 用户问题:{question}
  244. === 数据库业务范围 ===
  245. 本系统是高速公路服务区商业管理系统,包含以下业务数据:
  246. 核心业务实体:
  247. - 服务区(bss_service_area):服务区基础信息、位置、状态,如"鄱阳湖服务区"、"信丰西服务区"
  248. - 档口/商铺(bss_branch):档口信息、品类(餐饮/小吃/便利店)、品牌,如"驿美餐饮"、"加水机"
  249. - 营业数据(bss_business_day_data):每日支付金额、订单数量,包含微信、支付宝、现金等支付方式
  250. - 车流量(bss_car_day_count):按车型统计的日流量数据,包含客车、货车、过境、危化品等
  251. - 公司信息(bss_company):服务区管理公司,如"驿美运营公司"
  252. 关键业务指标:
  253. - 支付方式:微信支付(wx)、支付宝支付(zfb)、现金支付(rmb)、行吧支付(xs)、金豆支付(jd)
  254. - 营业数据:支付金额、订单数量、营业额、收入统计
  255. - 车流统计:按车型(客车/货车/过境/危化品/城际)的流量分析
  256. - 经营分析:餐饮、小吃、便利店、整体租赁等品类收入
  257. - 地理分区:北区、南区、西区、东区、两区
  258. 高速线路:
  259. - 线路信息:大广、昌金、昌栗等高速线路
  260. - 路段管理:按线路统计服务区分布
  261. === 判断标准 ===
  262. 1. **DATABASE类型** - 需要查询数据库:
  263. - 涉及上述业务实体和指标的查询、统计、分析、报表
  264. - 包含业务相关的时间查询,如"本月服务区营业额"、"上月档口收入"
  265. - 例如:"本月营业额统计"、"档口收入排行"、"车流量分析"、"支付方式占比"
  266. 2. **CHAT类型** - 不需要查询数据库:
  267. - 生活常识:水果蔬菜上市时间、动植物知识、天气等
  268. - 身份询问:你是谁、什么模型、AI助手等
  269. - 技术概念:人工智能、编程、算法等
  270. - 平台使用:功能介绍、操作帮助、使用教程等
  271. - 例如:"荔枝几月份上市"、"今天天气如何"、"你是什么AI"、"怎么使用平台"
  272. **重要提示:**
  273. - 只有涉及高速公路服务区业务数据的问题才分类为DATABASE
  274. - 即使包含时间词汇(如"月份"、"时间"),也要看是否与我们的业务数据相关
  275. - 农产品上市时间、生活常识等都应分类为CHAT
  276. 请基于问题与我们高速公路服务区业务数据的相关性来分类。
  277. 格式:
  278. 分类: [DATABASE/CHAT]
  279. 理由: [详细说明问题与业务数据的相关性,具体分析涉及哪些业务实体或为什么不相关]
  280. 置信度: [0.0-1.0之间的数字]
  281. """
  282. # 专业的系统提示词
  283. system_prompt = """你是一个专业的业务问题分类助手,专门负责高速公路服务区管理系统的问题分类。你具有以下特长:
  284. 1. 深度理解高速公路服务区业务领域和数据范围
  285. 2. 准确区分业务数据查询需求和一般性问题
  286. 3. 基于具体业务上下文进行精准分类,而不仅仅依赖关键词匹配
  287. 4. 对边界情况能够给出合理的置信度评估
  288. 请严格按照业务相关性进行分类,并提供详细的分类理由。"""
  289. # 使用 Vanna 实例的 chat_with_llm 方法
  290. response = vn.chat_with_llm(
  291. question=classification_prompt,
  292. system_prompt=system_prompt
  293. )
  294. # 解析响应
  295. return self._parse_llm_response(response)
  296. except Exception as e:
  297. print(f"[WARNING] 增强LLM分类失败: {str(e)}")
  298. return ClassificationResult(
  299. question_type="CHAT", # 失败时默认为CHAT,更安全
  300. confidence=self.llm_fallback_confidence,
  301. reason=f"LLM分类异常,默认为聊天: {str(e)}",
  302. method="llm_error"
  303. )
  304. def _parse_llm_response(self, response: str) -> ClassificationResult:
  305. """解析LLM响应"""
  306. try:
  307. lines = response.strip().split('\n')
  308. question_type = "CHAT" # 默认为CHAT
  309. reason = "LLM响应解析失败"
  310. confidence = self.llm_fallback_confidence
  311. for line in lines:
  312. line = line.strip()
  313. if line.startswith("分类:") or line.startswith("Classification:"):
  314. type_part = line.split(":", 1)[1].strip().upper()
  315. if "DATABASE" in type_part:
  316. question_type = "DATABASE"
  317. elif "CHAT" in type_part:
  318. question_type = "CHAT"
  319. elif line.startswith("理由:") or line.startswith("Reason:"):
  320. reason = line.split(":", 1)[1].strip()
  321. elif line.startswith("置信度:") or line.startswith("Confidence:"):
  322. try:
  323. conf_str = line.split(":", 1)[1].strip()
  324. confidence = float(conf_str)
  325. # 确保置信度在合理范围内
  326. confidence = max(0.0, min(1.0, confidence))
  327. except:
  328. confidence = self.llm_fallback_confidence
  329. return ClassificationResult(
  330. question_type=question_type,
  331. confidence=confidence,
  332. reason=reason,
  333. method="enhanced_llm"
  334. )
  335. except Exception as e:
  336. return ClassificationResult(
  337. question_type="CHAT", # 解析失败时默认为CHAT
  338. confidence=self.llm_fallback_confidence,
  339. reason=f"响应解析失败: {str(e)}",
  340. method="llm_parse_error"
  341. )