classifier.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466
  1. # agent/classifier.py
  2. import re
  3. from typing import Dict, Any, List, Optional
  4. from dataclasses import dataclass
  5. from core.logging import get_agent_logger
  6. @dataclass
  7. class ClassificationResult:
  8. question_type: str
  9. confidence: float
  10. reason: str
  11. method: str
  12. class QuestionClassifier:
  13. """
  14. 增强版问题分类器:基于高速公路服务区业务上下文的智能分类
  15. """
  16. def __init__(self):
  17. # 初始化日志
  18. self.logger = get_agent_logger("Classifier")
  19. # 初始化默认参数(作为后备)
  20. self.high_confidence_threshold = 0.7
  21. self.max_confidence = 0.9
  22. self.llm_fallback_confidence = 0.5
  23. self.uncertain_confidence = 0.2
  24. # 加载词典配置(新增逻辑)
  25. self._load_dict_config()
  26. def _load_dict_config(self):
  27. """加载分类器词典配置"""
  28. try:
  29. from agent.config import get_classifier_dict_config
  30. dict_config = get_classifier_dict_config()
  31. # 加载关键词列表
  32. self.strong_business_keywords = dict_config.strong_business_keywords
  33. self.query_intent_keywords = dict_config.query_intent_keywords
  34. self.non_business_keywords = dict_config.non_business_keywords
  35. self.sql_patterns = dict_config.sql_patterns
  36. self.chat_keywords = dict_config.chat_keywords
  37. # 加载权重配置
  38. self.weights = dict_config.weights
  39. # 从YAML权重配置中加载分类器参数(优先使用YAML配置)
  40. self.high_confidence_threshold = self.weights.get('high_confidence_threshold', self.high_confidence_threshold)
  41. self.max_confidence = self.weights.get('max_confidence', self.max_confidence)
  42. self.llm_fallback_confidence = self.weights.get('llm_fallback_confidence', self.llm_fallback_confidence)
  43. self.uncertain_confidence = self.weights.get('uncertain_confidence', self.uncertain_confidence)
  44. # 加载其他配置
  45. self.metadata = dict_config.metadata
  46. total_keywords = (
  47. sum(len(keywords) for keywords in self.strong_business_keywords.values()) +
  48. len(self.query_intent_keywords) +
  49. len(self.non_business_keywords) +
  50. len(self.sql_patterns) +
  51. len(self.chat_keywords)
  52. )
  53. self.logger.info(f"从YAML配置文件加载词典完成,共加载 {total_keywords} 个关键词")
  54. self.logger.info(f"从YAML配置文件加载分类器参数完成:高置信度阈值={self.high_confidence_threshold}")
  55. self.logger.debug(f"所有分类器参数:high_threshold={self.high_confidence_threshold}, max_conf={self.max_confidence}, llm_fallback={self.llm_fallback_confidence}")
  56. except Exception as e:
  57. self.logger.warning(f"加载YAML词典配置失败: {str(e)},使用代码中的备用配置")
  58. self._load_default_dict()
  59. def _load_default_dict(self):
  60. """YAML配置加载失败时的处理"""
  61. error_msg = "YAML词典配置文件加载失败,无法初始化分类器"
  62. self.logger.error(error_msg)
  63. # 初始化空的weights字典,使用代码中的默认值
  64. self.weights = {}
  65. raise RuntimeError(error_msg)
  66. def classify(self, question: str, context_type: Optional[str] = None, routing_mode: Optional[str] = None) -> ClassificationResult:
  67. """
  68. 主分类方法:简化为混合分类策略
  69. Args:
  70. question: 当前问题
  71. context_type: 上下文类型(保留参数兼容性,但不使用)
  72. routing_mode: 路由模式,可选,用于覆盖配置文件设置
  73. """
  74. # 确定使用的路由模式
  75. if routing_mode:
  76. QUESTION_ROUTING_MODE = routing_mode
  77. self.logger.info(f"使用传入的路由模式: {QUESTION_ROUTING_MODE}")
  78. else:
  79. try:
  80. from app_config import QUESTION_ROUTING_MODE
  81. self.logger.info(f"使用配置文件路由模式: {QUESTION_ROUTING_MODE}")
  82. except ImportError:
  83. QUESTION_ROUTING_MODE = "hybrid"
  84. self.logger.info(f"配置导入失败,使用默认路由模式: {QUESTION_ROUTING_MODE}")
  85. # 根据路由模式选择分类策略
  86. if QUESTION_ROUTING_MODE == "database_direct":
  87. return ClassificationResult(
  88. question_type="DATABASE",
  89. confidence=1.0,
  90. reason="配置为直接数据库查询模式",
  91. method="direct_database"
  92. )
  93. elif QUESTION_ROUTING_MODE == "chat_direct":
  94. return ClassificationResult(
  95. question_type="CHAT",
  96. confidence=1.0,
  97. reason="配置为直接聊天模式",
  98. method="direct_chat"
  99. )
  100. elif QUESTION_ROUTING_MODE == "llm_only":
  101. return self._enhanced_llm_classify(question)
  102. else:
  103. # hybrid模式:直接使用混合分类策略(规则+LLM)
  104. return self._hybrid_classify(question)
  105. def _hybrid_classify(self, question: str) -> ClassificationResult:
  106. """
  107. 混合分类模式:规则预筛选 + 增强LLM分类
  108. 这是原来的 classify 方法逻辑
  109. """
  110. # 第一步:规则预筛选
  111. rule_result = self._rule_based_classify(question)
  112. # 如果规则分类有高置信度,直接使用
  113. if rule_result.confidence >= self.high_confidence_threshold:
  114. return rule_result
  115. # 否则:使用增强的LLM分类
  116. llm_result = self._enhanced_llm_classify(question)
  117. # 选择置信度更高的结果
  118. if llm_result.confidence > rule_result.confidence:
  119. return llm_result
  120. else:
  121. return rule_result
  122. def _extract_current_question_for_rule_classification(self, question: str) -> str:
  123. """
  124. 从enhanced_question中提取[CURRENT]部分用于规则分类
  125. 如果没有[CURRENT]标签,返回原问题
  126. Args:
  127. question: 可能包含上下文的完整问题
  128. Returns:
  129. str: 用于规则分类的当前问题
  130. """
  131. try:
  132. # 处理None或非字符串输入
  133. if question is None:
  134. self.logger.warning("输入问题为None,返回空字符串")
  135. return ""
  136. if not isinstance(question, str):
  137. self.logger.warning(f"输入问题类型错误: {type(question)},转换为字符串")
  138. question = str(question)
  139. # 检查是否为enhanced_question格式
  140. if "\n[CURRENT]\n" in question:
  141. current_start = question.find("\n[CURRENT]\n")
  142. if current_start != -1:
  143. current_question = question[current_start + len("\n[CURRENT]\n"):].strip()
  144. self.logger.info(f"规则分类从[CURRENT]标签提取到问题: {current_question}")
  145. return current_question
  146. # 如果不是enhanced_question格式,直接返回原问题
  147. stripped_question = question.strip()
  148. self.logger.info(f"规则分类未找到[CURRENT]标签,使用完整问题: {stripped_question}")
  149. return stripped_question
  150. except Exception as e:
  151. self.logger.warning(f"提取当前问题失败: {str(e)},返回空字符串")
  152. return ""
  153. def _rule_based_classify(self, question: str) -> ClassificationResult:
  154. """基于规则的预分类"""
  155. # 提取当前问题用于规则判断,避免上下文干扰
  156. current_question = self._extract_current_question_for_rule_classification(question)
  157. question_lower = current_question.lower()
  158. # 检查非业务实体词
  159. non_business_matched = []
  160. for keyword in self.non_business_keywords:
  161. if keyword in question_lower:
  162. non_business_matched.append(keyword)
  163. # 如果包含非业务实体词,直接分类为CHAT
  164. if non_business_matched:
  165. return ClassificationResult(
  166. question_type="CHAT",
  167. confidence=self.weights.get('non_business_confidence', 0.85), # 使用YAML配置的置信度
  168. reason=f"包含非业务实体词: {non_business_matched}",
  169. method="rule_based_non_business"
  170. )
  171. # 检查强业务关键词
  172. business_score = 0
  173. business_matched = []
  174. for category, keywords in self.strong_business_keywords.items():
  175. if category == "系统查询指示词": # 系统指示词单独处理
  176. continue
  177. for keyword in keywords:
  178. if keyword in question_lower:
  179. business_score += self.weights.get('business_entity', 2) # 使用YAML配置的权重
  180. business_matched.append(f"{category}:{keyword}")
  181. # 检查系统查询指示词
  182. system_indicator_score = 0
  183. system_matched = []
  184. for keyword in self.strong_business_keywords.get("系统查询指示词", []):
  185. if keyword in question_lower:
  186. system_indicator_score += self.weights.get('system_indicator', 1) # 使用YAML配置的权重
  187. system_matched.append(f"系统查询指示词:{keyword}")
  188. # 检查查询意图词
  189. intent_score = 0
  190. intent_matched = []
  191. for keyword in self.query_intent_keywords:
  192. if keyword in question_lower:
  193. intent_score += self.weights.get('query_intent', 1) # 使用YAML配置的权重
  194. intent_matched.append(keyword)
  195. # 检查SQL模式
  196. sql_patterns_matched = []
  197. for pattern in self.sql_patterns:
  198. if re.search(pattern, question_lower, re.IGNORECASE):
  199. business_score += self.weights.get('sql_pattern', 3) # 使用YAML配置的权重
  200. sql_patterns_matched.append(pattern)
  201. # 检查聊天关键词
  202. chat_score = 0
  203. chat_matched = []
  204. for keyword in self.chat_keywords:
  205. if keyword in question_lower:
  206. chat_score += self.weights.get('chat_keyword', 1) # 使用YAML配置的权重
  207. chat_matched.append(keyword)
  208. # 系统指示词组合评分逻辑
  209. if system_indicator_score > 0 and business_score > 0:
  210. # 系统指示词 + 业务实体 = 强组合效应
  211. business_score += self.weights.get('combination_bonus', 3) # 使用YAML配置的组合加分权重
  212. business_matched.extend(system_matched)
  213. elif system_indicator_score > 0:
  214. # 仅有系统指示词 = 中等业务倾向
  215. business_score += self.weights.get('system_indicator', 1) # 使用YAML配置的权重
  216. business_matched.extend(system_matched)
  217. # 分类决策逻辑
  218. total_business_score = business_score + intent_score
  219. # 强业务特征:包含业务实体 + 查询意图
  220. min_business_score = self.weights.get('strong_business_min_score', 2)
  221. min_intent_score = self.weights.get('strong_business_min_intent', 1)
  222. if business_score >= min_business_score and intent_score >= min_intent_score:
  223. base_conf = self.weights.get('strong_business_base', 0.8)
  224. increment = self.weights.get('strong_business_increment', 0.05)
  225. confidence = min(self.max_confidence, base_conf + (total_business_score * increment))
  226. return ClassificationResult(
  227. question_type="DATABASE",
  228. confidence=confidence,
  229. reason=f"强业务特征 - 业务实体: {business_matched}, 查询意图: {intent_matched}, SQL: {sql_patterns_matched}",
  230. method="rule_based_strong_business"
  231. )
  232. # 中等业务特征:包含多个业务实体词
  233. elif business_score >= self.weights.get('medium_business_min_score', 4):
  234. base_conf = self.weights.get('medium_business_base', 0.7)
  235. increment = self.weights.get('medium_business_increment', 0.03)
  236. confidence = min(self.max_confidence, base_conf + (business_score * increment))
  237. return ClassificationResult(
  238. question_type="DATABASE",
  239. confidence=confidence,
  240. reason=f"中等业务特征 - 业务实体: {business_matched}",
  241. method="rule_based_medium_business"
  242. )
  243. # 聊天特征
  244. elif chat_score >= self.weights.get('chat_min_score', 1) and business_score == 0:
  245. base_conf = self.weights.get('chat_base_confidence', 0.4)
  246. increment = self.weights.get('chat_confidence_increment', 0.08)
  247. confidence = min(self.max_confidence, base_conf + (chat_score * increment))
  248. return ClassificationResult(
  249. question_type="CHAT",
  250. confidence=confidence,
  251. reason=f"聊天特征: {chat_matched}",
  252. method="rule_based_chat"
  253. )
  254. # 不确定情况
  255. else:
  256. return ClassificationResult(
  257. question_type="UNCERTAIN",
  258. confidence=self.uncertain_confidence,
  259. reason=f"规则分类不确定 - 业务分:{business_score}, 意图分:{intent_score}, 聊天分:{chat_score}",
  260. method="rule_based_uncertain"
  261. )
  262. def _load_business_context(self) -> str:
  263. """从文件中加载数据库业务范围描述"""
  264. try:
  265. import os
  266. current_dir = os.path.dirname(os.path.abspath(__file__))
  267. prompt_file = os.path.join(current_dir, "tools", "db_query_decision_prompt.txt")
  268. with open(prompt_file, 'r', encoding='utf-8') as f:
  269. content = f.read().strip()
  270. if not content:
  271. raise ValueError("业务上下文文件为空")
  272. return content
  273. except FileNotFoundError:
  274. error_msg = f"无法找到业务上下文文件: {prompt_file}"
  275. self.logger.error(error_msg)
  276. raise FileNotFoundError(error_msg)
  277. except Exception as e:
  278. error_msg = f"读取业务上下文文件失败: {str(e)}"
  279. self.logger.error(error_msg)
  280. raise RuntimeError(error_msg)
  281. def _enhanced_llm_classify(self, question: str) -> ClassificationResult:
  282. """增强的LLM分类:包含详细的业务上下文"""
  283. try:
  284. from common.vanna_instance import get_vanna_instance
  285. vn = get_vanna_instance()
  286. # 动态加载业务上下文(如果失败会抛出异常)
  287. business_context = self._load_business_context()
  288. # 构建包含业务上下文的分类提示词
  289. classification_prompt = f"""
  290. 请判断以下用户问题是否需要查询我们的数据库。
  291. 用户问题:{question}
  292. {business_context}
  293. === 判断标准 ===
  294. 1. **DATABASE类型** - 需要查询数据库:
  295. - 涉及上述业务实体和指标的查询、统计、分析、报表
  296. - 包含业务相关的时间查询
  297. - 例如:业务数据统计、收入排行、流量分析、占比分析等
  298. 2. **CHAT类型** - 不需要查询数据库:
  299. - 生活常识:水果蔬菜上市时间、动植物知识、天气等
  300. - 身份询问:你是谁、什么模型、AI助手等
  301. - 技术概念:人工智能、编程、算法等
  302. - 平台使用:功能介绍、操作帮助、使用教程等
  303. - 旅游出行:旅游景点、酒店、机票、高铁、的士等
  304. - 情绪:开心、伤心、无聊、生气、孤独、累了、烦恼、心情、难过、抑郁
  305. - 商业:股票、基金、理财、投资、经济、通货膨胀、上市
  306. - 哲学:人生意义、价值观、道德、信仰、宗教、爱情
  307. - 政策:政策、法规、法律、条例、指南、手册、规章制度、实施细则
  308. - 地理:全球、中国、亚洲、发展中、欧洲、美洲、东亚、东南亚、南美、非洲、大洋
  309. - 体育:足球、NBA、篮球、乒乓球、冠军、夺冠
  310. - 文学:小说、新闻、政治、战争、足球、NBA、篮球、乒乓球、冠军、夺冠
  311. - 娱乐:游戏、小说、新闻、政治、战争、足球、NBA、篮球、乒乓球、冠军、夺冠、电影、电视剧、音乐、舞蹈、绘画、书法、摄影、雕塑、建筑、设计、
  312. - 健康:健康、医疗、病症、健康、饮食、睡眠、心理、养生、减肥、美容、护肤
  313. - 其他:高考、人生意义、价值观、道德、信仰、宗教、爱情、全球、全国、亚洲、发展中、欧洲、美洲、东亚、东南亚、南美、非洲、大洋
  314. - 例如:"荔枝几月份上市"、"今天天气如何"、"你是什么AI"、"怎么使用平台"
  315. **重要提示:**
  316. - 只有涉及高速公路服务区业务数据的问题才分类为DATABASE
  317. - 只要不是涉及高速公路服务区业务数据的问题都应分类为CHAT
  318. 请基于问题与我们高速公路服务区业务数据的相关性来分类。
  319. 格式:
  320. 分类: [DATABASE/CHAT]
  321. 理由: [详细说明问题与业务数据的相关性,具体分析涉及哪些业务实体或为什么不相关]
  322. 置信度: [0.0-1.0之间的数字]
  323. """
  324. # 专业的系统提示词
  325. system_prompt = """你是一个专业的业务问题分类助手。你具有以下特长:
  326. 1. 深度理解业务领域和数据范围
  327. 2. 准确区分业务数据查询需求和一般性问题
  328. 3. 基于具体业务上下文进行精准分类,而不仅仅依赖关键词匹配
  329. 4. 对边界情况能够给出合理的置信度评估
  330. 请严格按照业务相关性进行分类,并提供详细的分类理由。"""
  331. # 使用 Vanna 实例的 chat_with_llm 方法
  332. response = vn.chat_with_llm(
  333. question=classification_prompt,
  334. system_prompt=system_prompt
  335. )
  336. self.logger.debug(f"LLM原始分类响应信息: {response}")
  337. # 解析响应
  338. return self._parse_llm_response(response)
  339. except (FileNotFoundError, RuntimeError) as e:
  340. # 业务上下文加载失败,返回错误状态
  341. self.logger.error(f"LLM分类失败,业务上下文不可用: {str(e)}")
  342. return ClassificationResult(
  343. question_type="CHAT", # 失败时默认为CHAT,更安全
  344. confidence=self.weights.get('llm_error_confidence', 0.1), # 使用YAML配置的低置信度
  345. reason=f"业务上下文加载失败,无法进行准确分类: {str(e)}",
  346. method="llm_context_error"
  347. )
  348. except Exception as e:
  349. self.logger.warning(f"增强LLM分类失败: {str(e)}")
  350. return ClassificationResult(
  351. question_type="CHAT", # 失败时默认为CHAT,更安全
  352. confidence=self.llm_fallback_confidence,
  353. reason=f"LLM分类异常,默认为聊天: {str(e)}",
  354. method="llm_error"
  355. )
  356. def _parse_llm_response(self, response: str) -> ClassificationResult:
  357. """解析LLM响应"""
  358. try:
  359. lines = response.strip().split('\n')
  360. question_type = "CHAT" # 默认为CHAT
  361. reason = "LLM响应解析失败"
  362. confidence = self.llm_fallback_confidence
  363. for line in lines:
  364. line = line.strip()
  365. if line.startswith("分类:") or line.startswith("Classification:"):
  366. type_part = line.split(":", 1)[1].strip().upper()
  367. if "DATABASE" in type_part:
  368. question_type = "DATABASE"
  369. elif "CHAT" in type_part:
  370. question_type = "CHAT"
  371. elif line.startswith("理由:") or line.startswith("Reason:"):
  372. reason = line.split(":", 1)[1].strip()
  373. elif line.startswith("置信度:") or line.startswith("Confidence:"):
  374. try:
  375. conf_str = line.split(":", 1)[1].strip()
  376. confidence = float(conf_str)
  377. # 确保置信度在合理范围内
  378. confidence = max(0.0, min(1.0, confidence))
  379. except:
  380. confidence = self.llm_fallback_confidence
  381. return ClassificationResult(
  382. question_type=question_type,
  383. confidence=confidence,
  384. reason=reason,
  385. method="enhanced_llm"
  386. )
  387. except Exception as e:
  388. return ClassificationResult(
  389. question_type="CHAT", # 解析失败时默认为CHAT
  390. confidence=self.llm_fallback_confidence,
  391. reason=f"响应解析失败: {str(e)}",
  392. method="llm_parse_error"
  393. )