classifier.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. # agent/classifier.py
  2. import re
  3. from typing import Dict, Any, List, Optional
  4. from dataclasses import dataclass
  5. from core.logging import get_agent_logger
  6. @dataclass
  7. class ClassificationResult:
  8. question_type: str
  9. confidence: float
  10. reason: str
  11. method: str
  12. class QuestionClassifier:
  13. """
  14. 增强版问题分类器:基于高速公路服务区业务上下文的智能分类
  15. """
  16. def __init__(self):
  17. # 初始化日志
  18. self.logger = get_agent_logger("Classifier")
  19. # 从配置文件加载阈值参数
  20. try:
  21. from agent.config import get_current_config, get_nested_config
  22. config = get_current_config()
  23. self.high_confidence_threshold = get_nested_config(config, "classification.high_confidence_threshold", 0.7)
  24. self.low_confidence_threshold = get_nested_config(config, "classification.low_confidence_threshold", 0.4)
  25. self.max_confidence = get_nested_config(config, "classification.max_confidence", 0.9)
  26. self.base_confidence = get_nested_config(config, "classification.base_confidence", 0.4)
  27. self.confidence_increment = get_nested_config(config, "classification.confidence_increment", 0.08)
  28. self.llm_fallback_confidence = get_nested_config(config, "classification.llm_fallback_confidence", 0.5)
  29. self.uncertain_confidence = get_nested_config(config, "classification.uncertain_confidence", 0.2)
  30. self.medium_confidence_threshold = get_nested_config(config, "classification.medium_confidence_threshold", 0.6)
  31. self.logger.info("从配置文件加载分类器参数完成")
  32. except ImportError:
  33. self.high_confidence_threshold = 0.7
  34. self.low_confidence_threshold = 0.4
  35. self.max_confidence = 0.9
  36. self.base_confidence = 0.4
  37. self.confidence_increment = 0.08
  38. self.llm_fallback_confidence = 0.5
  39. self.uncertain_confidence = 0.2
  40. self.medium_confidence_threshold = 0.6
  41. self.logger.warning("配置文件不可用,使用默认分类器参数")
  42. # 加载词典配置(新增逻辑)
  43. self._load_dict_config()
  44. def _load_dict_config(self):
  45. """加载分类器词典配置"""
  46. try:
  47. from agent.config import get_classifier_dict_config
  48. dict_config = get_classifier_dict_config()
  49. # 加载关键词列表
  50. self.strong_business_keywords = dict_config.strong_business_keywords
  51. self.query_intent_keywords = dict_config.query_intent_keywords
  52. self.non_business_keywords = dict_config.non_business_keywords
  53. self.sql_patterns = dict_config.sql_patterns
  54. self.chat_keywords = dict_config.chat_keywords
  55. # 加载权重配置
  56. self.weights = dict_config.weights
  57. # 加载其他配置
  58. self.metadata = dict_config.metadata
  59. total_keywords = (
  60. sum(len(keywords) for keywords in self.strong_business_keywords.values()) +
  61. len(self.query_intent_keywords) +
  62. len(self.non_business_keywords) +
  63. len(self.sql_patterns) +
  64. len(self.chat_keywords)
  65. )
  66. self.logger.info(f"从YAML配置文件加载词典完成,共加载 {total_keywords} 个关键词")
  67. except Exception as e:
  68. self.logger.warning(f"加载YAML词典配置失败: {str(e)},使用代码中的备用配置")
  69. self._load_default_dict()
  70. def _load_default_dict(self):
  71. """YAML配置加载失败时的处理"""
  72. error_msg = "YAML词典配置文件加载失败,无法初始化分类器"
  73. self.logger.error(error_msg)
  74. # 初始化空的weights字典,使用代码中的默认值
  75. self.weights = {}
  76. raise RuntimeError(error_msg)
  77. def classify(self, question: str, context_type: Optional[str] = None, routing_mode: Optional[str] = None) -> ClassificationResult:
  78. """
  79. 主分类方法:简化为混合分类策略
  80. Args:
  81. question: 当前问题
  82. context_type: 上下文类型(保留参数兼容性,但不使用)
  83. routing_mode: 路由模式,可选,用于覆盖配置文件设置
  84. """
  85. # 确定使用的路由模式
  86. if routing_mode:
  87. QUESTION_ROUTING_MODE = routing_mode
  88. self.logger.info(f"使用传入的路由模式: {QUESTION_ROUTING_MODE}")
  89. else:
  90. try:
  91. from app_config import QUESTION_ROUTING_MODE
  92. self.logger.info(f"使用配置文件路由模式: {QUESTION_ROUTING_MODE}")
  93. except ImportError:
  94. QUESTION_ROUTING_MODE = "hybrid"
  95. self.logger.info(f"配置导入失败,使用默认路由模式: {QUESTION_ROUTING_MODE}")
  96. # 根据路由模式选择分类策略
  97. if QUESTION_ROUTING_MODE == "database_direct":
  98. return ClassificationResult(
  99. question_type="DATABASE",
  100. confidence=1.0,
  101. reason="配置为直接数据库查询模式",
  102. method="direct_database"
  103. )
  104. elif QUESTION_ROUTING_MODE == "chat_direct":
  105. return ClassificationResult(
  106. question_type="CHAT",
  107. confidence=1.0,
  108. reason="配置为直接聊天模式",
  109. method="direct_chat"
  110. )
  111. elif QUESTION_ROUTING_MODE == "llm_only":
  112. return self._enhanced_llm_classify(question)
  113. else:
  114. # hybrid模式:直接使用混合分类策略(规则+LLM)
  115. return self._hybrid_classify(question)
  116. def _hybrid_classify(self, question: str) -> ClassificationResult:
  117. """
  118. 混合分类模式:规则预筛选 + 增强LLM分类
  119. 这是原来的 classify 方法逻辑
  120. """
  121. # 第一步:规则预筛选
  122. rule_result = self._rule_based_classify(question)
  123. # 如果规则分类有高置信度,直接使用
  124. if rule_result.confidence >= self.high_confidence_threshold:
  125. return rule_result
  126. # 否则:使用增强的LLM分类
  127. llm_result = self._enhanced_llm_classify(question)
  128. # 选择置信度更高的结果
  129. if llm_result.confidence > rule_result.confidence:
  130. return llm_result
  131. else:
  132. return rule_result
  133. def _extract_current_question_for_rule_classification(self, question: str) -> str:
  134. """
  135. 从enhanced_question中提取[CURRENT]部分用于规则分类
  136. 如果没有[CURRENT]标签,返回原问题
  137. Args:
  138. question: 可能包含上下文的完整问题
  139. Returns:
  140. str: 用于规则分类的当前问题
  141. """
  142. try:
  143. # 处理None或非字符串输入
  144. if question is None:
  145. self.logger.warning("输入问题为None,返回空字符串")
  146. return ""
  147. if not isinstance(question, str):
  148. self.logger.warning(f"输入问题类型错误: {type(question)},转换为字符串")
  149. question = str(question)
  150. # 检查是否为enhanced_question格式
  151. if "\n[CURRENT]\n" in question:
  152. current_start = question.find("\n[CURRENT]\n")
  153. if current_start != -1:
  154. current_question = question[current_start + len("\n[CURRENT]\n"):].strip()
  155. self.logger.info(f"规则分类从[CURRENT]标签提取到问题: {current_question}")
  156. return current_question
  157. # 如果不是enhanced_question格式,直接返回原问题
  158. stripped_question = question.strip()
  159. self.logger.info(f"规则分类未找到[CURRENT]标签,使用完整问题: {stripped_question}")
  160. return stripped_question
  161. except Exception as e:
  162. self.logger.warning(f"提取当前问题失败: {str(e)},返回空字符串")
  163. return ""
  164. def _rule_based_classify(self, question: str) -> ClassificationResult:
  165. """基于规则的预分类"""
  166. # 提取当前问题用于规则判断,避免上下文干扰
  167. current_question = self._extract_current_question_for_rule_classification(question)
  168. question_lower = current_question.lower()
  169. # 检查非业务实体词
  170. non_business_matched = []
  171. for keyword in self.non_business_keywords:
  172. if keyword in question_lower:
  173. non_business_matched.append(keyword)
  174. # 如果包含非业务实体词,直接分类为CHAT
  175. if non_business_matched:
  176. return ClassificationResult(
  177. question_type="CHAT",
  178. confidence=self.weights.get('non_business_confidence', 0.85), # 使用YAML配置的置信度
  179. reason=f"包含非业务实体词: {non_business_matched}",
  180. method="rule_based_non_business"
  181. )
  182. # 检查强业务关键词
  183. business_score = 0
  184. business_matched = []
  185. for category, keywords in self.strong_business_keywords.items():
  186. if category == "系统查询指示词": # 系统指示词单独处理
  187. continue
  188. for keyword in keywords:
  189. if keyword in question_lower:
  190. business_score += self.weights.get('business_entity', 2) # 使用YAML配置的权重
  191. business_matched.append(f"{category}:{keyword}")
  192. # 检查系统查询指示词
  193. system_indicator_score = 0
  194. system_matched = []
  195. for keyword in self.strong_business_keywords.get("系统查询指示词", []):
  196. if keyword in question_lower:
  197. system_indicator_score += self.weights.get('system_indicator', 1) # 使用YAML配置的权重
  198. system_matched.append(f"系统查询指示词:{keyword}")
  199. # 检查查询意图词
  200. intent_score = 0
  201. intent_matched = []
  202. for keyword in self.query_intent_keywords:
  203. if keyword in question_lower:
  204. intent_score += self.weights.get('query_intent', 1) # 使用YAML配置的权重
  205. intent_matched.append(keyword)
  206. # 检查SQL模式
  207. sql_patterns_matched = []
  208. for pattern in self.sql_patterns:
  209. if re.search(pattern, question_lower, re.IGNORECASE):
  210. business_score += self.weights.get('sql_pattern', 3) # 使用YAML配置的权重
  211. sql_patterns_matched.append(pattern)
  212. # 检查聊天关键词
  213. chat_score = 0
  214. chat_matched = []
  215. for keyword in self.chat_keywords:
  216. if keyword in question_lower:
  217. chat_score += self.weights.get('chat_keyword', 1) # 使用YAML配置的权重
  218. chat_matched.append(keyword)
  219. # 系统指示词组合评分逻辑
  220. if system_indicator_score > 0 and business_score > 0:
  221. # 系统指示词 + 业务实体 = 强组合效应
  222. business_score += self.weights.get('combination_bonus', 3) # 使用YAML配置的组合加分权重
  223. business_matched.extend(system_matched)
  224. elif system_indicator_score > 0:
  225. # 仅有系统指示词 = 中等业务倾向
  226. business_score += self.weights.get('system_indicator', 1) # 使用YAML配置的权重
  227. business_matched.extend(system_matched)
  228. # 分类决策逻辑
  229. total_business_score = business_score + intent_score
  230. # 强业务特征:包含业务实体 + 查询意图
  231. min_business_score = self.weights.get('strong_business_min_score', 2)
  232. min_intent_score = self.weights.get('strong_business_min_intent', 1)
  233. if business_score >= min_business_score and intent_score >= min_intent_score:
  234. base_conf = self.weights.get('strong_business_base', 0.8)
  235. increment = self.weights.get('strong_business_increment', 0.05)
  236. confidence = min(self.max_confidence, base_conf + (total_business_score * increment))
  237. return ClassificationResult(
  238. question_type="DATABASE",
  239. confidence=confidence,
  240. reason=f"强业务特征 - 业务实体: {business_matched}, 查询意图: {intent_matched}, SQL: {sql_patterns_matched}",
  241. method="rule_based_strong_business"
  242. )
  243. # 中等业务特征:包含多个业务实体词
  244. elif business_score >= self.weights.get('medium_business_min_score', 4):
  245. base_conf = self.weights.get('medium_business_base', 0.7)
  246. increment = self.weights.get('medium_business_increment', 0.03)
  247. confidence = min(self.max_confidence, base_conf + (business_score * increment))
  248. return ClassificationResult(
  249. question_type="DATABASE",
  250. confidence=confidence,
  251. reason=f"中等业务特征 - 业务实体: {business_matched}",
  252. method="rule_based_medium_business"
  253. )
  254. # 聊天特征
  255. elif chat_score >= self.weights.get('chat_min_score', 1) and business_score == 0:
  256. base_conf = self.weights.get('chat_base_confidence', 0.4)
  257. increment = self.weights.get('chat_confidence_increment', 0.08)
  258. confidence = min(self.max_confidence, base_conf + (chat_score * increment))
  259. return ClassificationResult(
  260. question_type="CHAT",
  261. confidence=confidence,
  262. reason=f"聊天特征: {chat_matched}",
  263. method="rule_based_chat"
  264. )
  265. # 不确定情况
  266. else:
  267. return ClassificationResult(
  268. question_type="UNCERTAIN",
  269. confidence=self.uncertain_confidence,
  270. reason=f"规则分类不确定 - 业务分:{business_score}, 意图分:{intent_score}, 聊天分:{chat_score}",
  271. method="rule_based_uncertain"
  272. )
  273. def _load_business_context(self) -> str:
  274. """从文件中加载数据库业务范围描述"""
  275. try:
  276. import os
  277. current_dir = os.path.dirname(os.path.abspath(__file__))
  278. prompt_file = os.path.join(current_dir, "tools", "db_query_decision_prompt.txt")
  279. with open(prompt_file, 'r', encoding='utf-8') as f:
  280. content = f.read().strip()
  281. if not content:
  282. raise ValueError("业务上下文文件为空")
  283. return content
  284. except FileNotFoundError:
  285. error_msg = f"无法找到业务上下文文件: {prompt_file}"
  286. self.logger.error(error_msg)
  287. raise FileNotFoundError(error_msg)
  288. except Exception as e:
  289. error_msg = f"读取业务上下文文件失败: {str(e)}"
  290. self.logger.error(error_msg)
  291. raise RuntimeError(error_msg)
  292. def _enhanced_llm_classify(self, question: str) -> ClassificationResult:
  293. """增强的LLM分类:包含详细的业务上下文"""
  294. try:
  295. from common.vanna_instance import get_vanna_instance
  296. vn = get_vanna_instance()
  297. # 动态加载业务上下文(如果失败会抛出异常)
  298. business_context = self._load_business_context()
  299. # 构建包含业务上下文的分类提示词
  300. classification_prompt = f"""
  301. 请判断以下用户问题是否需要查询我们的数据库。
  302. 用户问题:{question}
  303. {business_context}
  304. === 判断标准 ===
  305. 1. **DATABASE类型** - 需要查询数据库:
  306. - 涉及上述业务实体和指标的查询、统计、分析、报表
  307. - 包含业务相关的时间查询
  308. - 例如:业务数据统计、收入排行、流量分析、占比分析等
  309. 2. **CHAT类型** - 不需要查询数据库:
  310. - 生活常识:水果蔬菜上市时间、动植物知识、天气等
  311. - 身份询问:你是谁、什么模型、AI助手等
  312. - 技术概念:人工智能、编程、算法等
  313. - 平台使用:功能介绍、操作帮助、使用教程等
  314. - 旅游出行:旅游景点、酒店、机票、高铁、的士等
  315. - 情绪:开心、伤心、无聊、生气、孤独、累了、烦恼、心情、难过、抑郁
  316. - 商业:股票、基金、理财、投资、经济、通货膨胀、上市
  317. - 哲学:人生意义、价值观、道德、信仰、宗教、爱情
  318. - 政策:政策、法规、法律、条例、指南、手册、规章制度、实施细则
  319. - 地理:全球、中国、亚洲、发展中、欧洲、美洲、东亚、东南亚、南美、非洲、大洋
  320. - 体育:足球、NBA、篮球、乒乓球、冠军、夺冠
  321. - 文学:小说、新闻、政治、战争、足球、NBA、篮球、乒乓球、冠军、夺冠
  322. - 娱乐:游戏、小说、新闻、政治、战争、足球、NBA、篮球、乒乓球、冠军、夺冠、电影、电视剧、音乐、舞蹈、绘画、书法、摄影、雕塑、建筑、设计、
  323. - 健康:健康、医疗、病症、健康、饮食、睡眠、心理、养生、减肥、美容、护肤
  324. - 其他:高考、人生意义、价值观、道德、信仰、宗教、爱情、全球、全国、亚洲、发展中、欧洲、美洲、东亚、东南亚、南美、非洲、大洋
  325. - 例如:"荔枝几月份上市"、"今天天气如何"、"你是什么AI"、"怎么使用平台"
  326. **重要提示:**
  327. - 只有涉及高速公路服务区业务数据的问题才分类为DATABASE
  328. - 只要不是涉及高速公路服务区业务数据的问题都应分类为CHAT
  329. 请基于问题与我们高速公路服务区业务数据的相关性来分类。
  330. 格式:
  331. 分类: [DATABASE/CHAT]
  332. 理由: [详细说明问题与业务数据的相关性,具体分析涉及哪些业务实体或为什么不相关]
  333. 置信度: [0.0-1.0之间的数字]
  334. """
  335. # 专业的系统提示词
  336. system_prompt = """你是一个专业的业务问题分类助手。你具有以下特长:
  337. 1. 深度理解业务领域和数据范围
  338. 2. 准确区分业务数据查询需求和一般性问题
  339. 3. 基于具体业务上下文进行精准分类,而不仅仅依赖关键词匹配
  340. 4. 对边界情况能够给出合理的置信度评估
  341. 请严格按照业务相关性进行分类,并提供详细的分类理由。"""
  342. # 使用 Vanna 实例的 chat_with_llm 方法
  343. response = vn.chat_with_llm(
  344. question=classification_prompt,
  345. system_prompt=system_prompt
  346. )
  347. # 解析响应
  348. return self._parse_llm_response(response)
  349. except (FileNotFoundError, RuntimeError) as e:
  350. # 业务上下文加载失败,返回错误状态
  351. self.logger.error(f"LLM分类失败,业务上下文不可用: {str(e)}")
  352. return ClassificationResult(
  353. question_type="CHAT", # 失败时默认为CHAT,更安全
  354. confidence=self.weights.get('llm_error_confidence', 0.1), # 使用YAML配置的低置信度
  355. reason=f"业务上下文加载失败,无法进行准确分类: {str(e)}",
  356. method="llm_context_error"
  357. )
  358. except Exception as e:
  359. self.logger.warning(f"增强LLM分类失败: {str(e)}")
  360. return ClassificationResult(
  361. question_type="CHAT", # 失败时默认为CHAT,更安全
  362. confidence=self.llm_fallback_confidence,
  363. reason=f"LLM分类异常,默认为聊天: {str(e)}",
  364. method="llm_error"
  365. )
  366. def _parse_llm_response(self, response: str) -> ClassificationResult:
  367. """解析LLM响应"""
  368. try:
  369. lines = response.strip().split('\n')
  370. question_type = "CHAT" # 默认为CHAT
  371. reason = "LLM响应解析失败"
  372. confidence = self.llm_fallback_confidence
  373. for line in lines:
  374. line = line.strip()
  375. if line.startswith("分类:") or line.startswith("Classification:"):
  376. type_part = line.split(":", 1)[1].strip().upper()
  377. if "DATABASE" in type_part:
  378. question_type = "DATABASE"
  379. elif "CHAT" in type_part:
  380. question_type = "CHAT"
  381. elif line.startswith("理由:") or line.startswith("Reason:"):
  382. reason = line.split(":", 1)[1].strip()
  383. elif line.startswith("置信度:") or line.startswith("Confidence:"):
  384. try:
  385. conf_str = line.split(":", 1)[1].strip()
  386. confidence = float(conf_str)
  387. # 确保置信度在合理范围内
  388. confidence = max(0.0, min(1.0, confidence))
  389. except:
  390. confidence = self.llm_fallback_confidence
  391. return ClassificationResult(
  392. question_type=question_type,
  393. confidence=confidence,
  394. reason=reason,
  395. method="enhanced_llm"
  396. )
  397. except Exception as e:
  398. return ClassificationResult(
  399. question_type="CHAT", # 解析失败时默认为CHAT
  400. confidence=self.llm_fallback_confidence,
  401. reason=f"响应解析失败: {str(e)}",
  402. method="llm_parse_error"
  403. )