classifier.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. # agent/classifier.py
  2. import re
  3. from typing import Dict, Any, List
  4. from dataclasses import dataclass
  5. @dataclass
  6. class ClassificationResult:
  7. question_type: str
  8. confidence: float
  9. reason: str
  10. method: str
  11. class QuestionClassifier:
  12. """
  13. 多策略融合的问题分类器
  14. 策略:规则优先 + LLM fallback
  15. """
  16. def __init__(self):
  17. # 从配置文件加载阈值参数
  18. try:
  19. from agent.config import get_current_config, get_nested_config
  20. config = get_current_config()
  21. self.high_confidence_threshold = get_nested_config(config, "classification.high_confidence_threshold", 0.8)
  22. self.low_confidence_threshold = get_nested_config(config, "classification.low_confidence_threshold", 0.4)
  23. self.max_confidence = get_nested_config(config, "classification.max_confidence", 0.9)
  24. self.base_confidence = get_nested_config(config, "classification.base_confidence", 0.5)
  25. self.confidence_increment = get_nested_config(config, "classification.confidence_increment", 0.1)
  26. self.llm_fallback_confidence = get_nested_config(config, "classification.llm_fallback_confidence", 0.5)
  27. self.uncertain_confidence = get_nested_config(config, "classification.uncertain_confidence", 0.2)
  28. print("[CLASSIFIER] 从配置文件加载分类器参数完成")
  29. except ImportError:
  30. # 配置文件不可用时的默认值
  31. self.high_confidence_threshold = 0.8
  32. self.low_confidence_threshold = 0.4
  33. self.max_confidence = 0.9
  34. self.base_confidence = 0.5
  35. self.confidence_increment = 0.1
  36. self.llm_fallback_confidence = 0.5
  37. self.uncertain_confidence = 0.2
  38. print("[CLASSIFIER] 配置文件不可用,使用默认分类器参数")
  39. # 移除了 LLM 实例存储,现在使用 Vanna 实例
  40. self.db_keywords = {
  41. "数据类": [
  42. "收入", "销量", "数量", "平均", "总计", "统计", "合计", "累计",
  43. "营业额", "利润", "成本", "费用", "金额", "价格", "单价", "服务区", "多少个"
  44. ],
  45. "分析类": [
  46. "分组", "排行", "排名", "增长率", "趋势", "对比", "比较", "占比",
  47. "百分比", "比例", "环比", "同比", "最大", "最小", "最高", "最低"
  48. ],
  49. "时间类": [
  50. "今天", "昨天", "本月", "上月", "去年", "季度", "年度", "月份",
  51. "本年", "上年", "本周", "上周", "近期", "最近"
  52. ],
  53. "业务类": [
  54. "客户", "订单", "产品", "商品", "用户", "会员", "供应商", "库存",
  55. "部门", "员工", "项目", "合同", "发票", "账单"
  56. ]
  57. }
  58. # SQL关键词
  59. self.sql_patterns = [
  60. r"\b(select|from|where|group by|order by|having|join)\b",
  61. r"\b(查询|统计|汇总|计算|分析|有多少)\b",
  62. r"\b(表|字段|数据库)\b"
  63. ]
  64. # 聊天关键词
  65. self.chat_keywords = [
  66. "你好", "谢谢", "再见", "怎么样", "如何", "为什么", "什么是",
  67. "介绍", "解释", "说明", "帮助", "操作", "使用方法", "功能"
  68. ]
  69. def classify(self, question: str) -> ClassificationResult:
  70. """
  71. 主分类方法:规则优先 + LLM fallback
  72. """
  73. # 第一步:规则分类
  74. rule_result = self._rule_based_classify(question)
  75. if rule_result.confidence >= self.high_confidence_threshold:
  76. return rule_result
  77. # 第二步:LLM分类(针对不确定的情况)
  78. if rule_result.confidence <= self.low_confidence_threshold:
  79. llm_result = self._llm_classify(question)
  80. # 如果LLM也不确定,返回不确定状态
  81. if llm_result.confidence <= self.low_confidence_threshold:
  82. return ClassificationResult(
  83. question_type="UNCERTAIN",
  84. confidence=max(rule_result.confidence, llm_result.confidence),
  85. reason=f"规则和LLM都不确定: {rule_result.reason} | {llm_result.reason}",
  86. method="hybrid_uncertain"
  87. )
  88. return llm_result
  89. return rule_result
  90. def _rule_based_classify(self, question: str) -> ClassificationResult:
  91. """基于规则的分类"""
  92. question_lower = question.lower()
  93. # 检查数据库相关关键词
  94. db_score = 0
  95. matched_keywords = []
  96. for category, keywords in self.db_keywords.items():
  97. for keyword in keywords:
  98. if keyword in question_lower:
  99. db_score += 1
  100. matched_keywords.append(f"{category}:{keyword}")
  101. # 检查SQL模式
  102. sql_patterns_matched = []
  103. for pattern in self.sql_patterns:
  104. if re.search(pattern, question_lower, re.IGNORECASE):
  105. db_score += 2 # SQL模式权重更高
  106. sql_patterns_matched.append(pattern)
  107. # 检查聊天关键词
  108. chat_score = 0
  109. chat_keywords_matched = []
  110. for keyword in self.chat_keywords:
  111. if keyword in question_lower:
  112. chat_score += 1
  113. chat_keywords_matched.append(keyword)
  114. # 计算置信度和分类
  115. total_score = db_score + chat_score
  116. if db_score > chat_score and db_score >= 1:
  117. confidence = min(self.max_confidence, self.base_confidence + (db_score * self.confidence_increment))
  118. return ClassificationResult(
  119. question_type="DATABASE",
  120. confidence=confidence,
  121. reason=f"匹配数据库关键词: {matched_keywords}, SQL模式: {sql_patterns_matched}",
  122. method="rule_based"
  123. )
  124. elif chat_score > db_score and chat_score >= 1:
  125. confidence = min(self.max_confidence, self.base_confidence + (chat_score * self.confidence_increment))
  126. return ClassificationResult(
  127. question_type="CHAT",
  128. confidence=confidence,
  129. reason=f"匹配聊天关键词: {chat_keywords_matched}",
  130. method="rule_based"
  131. )
  132. else:
  133. # 没有明确匹配
  134. return ClassificationResult(
  135. question_type="UNCERTAIN",
  136. confidence=self.uncertain_confidence,
  137. reason="没有匹配到明确的关键词模式",
  138. method="rule_based"
  139. )
  140. def _llm_classify(self, question: str) -> ClassificationResult:
  141. """基于LLM的分类"""
  142. try:
  143. # 使用 Vanna 实例进行分类
  144. from common.vanna_instance import get_vanna_instance
  145. vn = get_vanna_instance()
  146. # 分类提示词
  147. classification_prompt = f"""
  148. 请判断以下问题是否需要查询数据库。
  149. 问题: {question}
  150. 判断标准:
  151. 1. 如果问题涉及数据查询、统计、分析、报表等,返回 "DATABASE"
  152. 2. 如果问题是一般性咨询、概念解释、操作指导、闲聊等,返回 "CHAT"
  153. 请只返回 "DATABASE" 或 "CHAT",并在下一行简要说明理由。
  154. 格式:
  155. 分类: [DATABASE/CHAT]
  156. 理由: [简要说明]
  157. 置信度: [0.0-1.0之间的数字]
  158. """
  159. # 分类专用的系统提示词
  160. system_prompt = "你是一个专业的问题分类助手,能准确判断问题类型。请严格按照要求的格式返回分类结果。"
  161. # 使用 Vanna 实例的 chat_with_llm 方法
  162. response = vn.chat_with_llm(
  163. question=classification_prompt,
  164. system_prompt=system_prompt
  165. )
  166. # 解析响应
  167. return self._parse_llm_response(response)
  168. except Exception as e:
  169. print(f"[WARNING] LLM分类失败: {str(e)}")
  170. return ClassificationResult(
  171. question_type="UNCERTAIN",
  172. confidence=self.llm_fallback_confidence,
  173. reason=f"LLM分类异常: {str(e)}",
  174. method="llm_error"
  175. )
  176. def _parse_llm_response(self, response: str) -> ClassificationResult:
  177. """解析LLM响应"""
  178. try:
  179. lines = response.strip().split('\n')
  180. question_type = "UNCERTAIN"
  181. reason = "LLM响应解析失败"
  182. confidence = self.llm_fallback_confidence
  183. for line in lines:
  184. line = line.strip()
  185. if line.startswith("分类:") or line.startswith("Classification:"):
  186. type_part = line.split(":", 1)[1].strip().upper()
  187. if "DATABASE" in type_part:
  188. question_type = "DATABASE"
  189. elif "CHAT" in type_part:
  190. question_type = "CHAT"
  191. elif line.startswith("理由:") or line.startswith("Reason:"):
  192. reason = line.split(":", 1)[1].strip()
  193. elif line.startswith("置信度:") or line.startswith("Confidence:"):
  194. try:
  195. conf_str = line.split(":", 1)[1].strip()
  196. confidence = float(conf_str)
  197. except:
  198. confidence = self.llm_fallback_confidence
  199. return ClassificationResult(
  200. question_type=question_type,
  201. confidence=confidence,
  202. reason=reason,
  203. method="llm_based"
  204. )
  205. except Exception as e:
  206. return ClassificationResult(
  207. question_type="UNCERTAIN",
  208. confidence=self.llm_fallback_confidence,
  209. reason=f"响应解析失败: {str(e)}",
  210. method="llm_parse_error"
  211. )