classifier.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. # agent/classifier.py
  2. import re
  3. from typing import Dict, Any, List
  4. from dataclasses import dataclass
  5. @dataclass
  6. class ClassificationResult:
  7. question_type: str
  8. confidence: float
  9. reason: str
  10. method: str
  11. class QuestionClassifier:
  12. """
  13. 多策略融合的问题分类器
  14. 策略:规则优先 + LLM fallback
  15. """
  16. def __init__(self):
  17. # 从配置文件加载阈值参数
  18. try:
  19. from agent.config import get_current_config, get_nested_config
  20. config = get_current_config()
  21. self.high_confidence_threshold = get_nested_config(config, "classification.high_confidence_threshold", 0.8)
  22. self.low_confidence_threshold = get_nested_config(config, "classification.low_confidence_threshold", 0.4)
  23. self.max_confidence = get_nested_config(config, "classification.max_confidence", 0.9)
  24. self.base_confidence = get_nested_config(config, "classification.base_confidence", 0.5)
  25. self.confidence_increment = get_nested_config(config, "classification.confidence_increment", 0.1)
  26. self.llm_fallback_confidence = get_nested_config(config, "classification.llm_fallback_confidence", 0.5)
  27. self.uncertain_confidence = get_nested_config(config, "classification.uncertain_confidence", 0.2)
  28. print("[CLASSIFIER] 从配置文件加载分类器参数完成")
  29. except ImportError:
  30. # 配置文件不可用时的默认值
  31. self.high_confidence_threshold = 0.8
  32. self.low_confidence_threshold = 0.4
  33. self.max_confidence = 0.9
  34. self.base_confidence = 0.5
  35. self.confidence_increment = 0.1
  36. self.llm_fallback_confidence = 0.5
  37. self.uncertain_confidence = 0.2
  38. print("[CLASSIFIER] 配置文件不可用,使用默认分类器参数")
  39. self.db_keywords = {
  40. "数据类": [
  41. "收入", "销量", "数量", "平均", "总计", "统计", "合计", "累计",
  42. "营业额", "利润", "成本", "费用", "金额", "价格", "单价"
  43. ],
  44. "分析类": [
  45. "分组", "排行", "排名", "增长率", "趋势", "对比", "比较", "占比",
  46. "百分比", "比例", "环比", "同比", "最大", "最小", "最高", "最低"
  47. ],
  48. "时间类": [
  49. "今天", "昨天", "本月", "上月", "去年", "季度", "年度", "月份",
  50. "本年", "上年", "本周", "上周", "近期", "最近"
  51. ],
  52. "业务类": [
  53. "客户", "订单", "产品", "商品", "用户", "会员", "供应商", "库存",
  54. "部门", "员工", "项目", "合同", "发票", "账单"
  55. ]
  56. }
  57. # SQL关键词
  58. self.sql_patterns = [
  59. r"\b(select|from|where|group by|order by|having|join)\b",
  60. r"\b(查询|统计|汇总|计算|分析)\b",
  61. r"\b(表|字段|数据库)\b"
  62. ]
  63. # 聊天关键词
  64. self.chat_keywords = [
  65. "你好", "谢谢", "再见", "怎么样", "如何", "为什么", "什么是",
  66. "介绍", "解释", "说明", "帮助", "操作", "使用方法", "功能"
  67. ]
  68. def classify(self, question: str) -> ClassificationResult:
  69. """
  70. 主分类方法:规则优先 + LLM fallback
  71. """
  72. # 第一步:规则分类
  73. rule_result = self._rule_based_classify(question)
  74. if rule_result.confidence >= self.high_confidence_threshold:
  75. return rule_result
  76. # 第二步:LLM分类(针对不确定的情况)
  77. if rule_result.confidence <= self.low_confidence_threshold:
  78. llm_result = self._llm_classify(question)
  79. # 如果LLM也不确定,返回不确定状态
  80. if llm_result.confidence <= self.low_confidence_threshold:
  81. return ClassificationResult(
  82. question_type="UNCERTAIN",
  83. confidence=max(rule_result.confidence, llm_result.confidence),
  84. reason=f"规则和LLM都不确定: {rule_result.reason} | {llm_result.reason}",
  85. method="hybrid_uncertain"
  86. )
  87. return llm_result
  88. return rule_result
  89. def _rule_based_classify(self, question: str) -> ClassificationResult:
  90. """基于规则的分类"""
  91. question_lower = question.lower()
  92. # 检查数据库相关关键词
  93. db_score = 0
  94. matched_keywords = []
  95. for category, keywords in self.db_keywords.items():
  96. for keyword in keywords:
  97. if keyword in question_lower:
  98. db_score += 1
  99. matched_keywords.append(f"{category}:{keyword}")
  100. # 检查SQL模式
  101. sql_patterns_matched = []
  102. for pattern in self.sql_patterns:
  103. if re.search(pattern, question_lower, re.IGNORECASE):
  104. db_score += 2 # SQL模式权重更高
  105. sql_patterns_matched.append(pattern)
  106. # 检查聊天关键词
  107. chat_score = 0
  108. chat_keywords_matched = []
  109. for keyword in self.chat_keywords:
  110. if keyword in question_lower:
  111. chat_score += 1
  112. chat_keywords_matched.append(keyword)
  113. # 计算置信度和分类
  114. total_score = db_score + chat_score
  115. if db_score > chat_score and db_score >= 1:
  116. confidence = min(self.max_confidence, self.base_confidence + (db_score * self.confidence_increment))
  117. return ClassificationResult(
  118. question_type="DATABASE",
  119. confidence=confidence,
  120. reason=f"匹配数据库关键词: {matched_keywords}, SQL模式: {sql_patterns_matched}",
  121. method="rule_based"
  122. )
  123. elif chat_score > db_score and chat_score >= 1:
  124. confidence = min(self.max_confidence, self.base_confidence + (chat_score * self.confidence_increment))
  125. return ClassificationResult(
  126. question_type="CHAT",
  127. confidence=confidence,
  128. reason=f"匹配聊天关键词: {chat_keywords_matched}",
  129. method="rule_based"
  130. )
  131. else:
  132. # 没有明确匹配
  133. return ClassificationResult(
  134. question_type="UNCERTAIN",
  135. confidence=self.uncertain_confidence,
  136. reason="没有匹配到明确的关键词模式",
  137. method="rule_based"
  138. )
  139. def _llm_classify(self, question: str) -> ClassificationResult:
  140. """基于LLM的分类"""
  141. try:
  142. from common.utils import get_current_llm_config
  143. from customllm.qianwen_chat import QianWenChat
  144. llm_config = get_current_llm_config()
  145. llm = QianWenChat(config=llm_config)
  146. # 分类提示词
  147. classification_prompt = f"""
  148. 请判断以下问题是否需要查询数据库。
  149. 问题: {question}
  150. 判断标准:
  151. 1. 如果问题涉及数据查询、统计、分析、报表等,返回 "DATABASE"
  152. 2. 如果问题是一般性咨询、概念解释、操作指导、闲聊等,返回 "CHAT"
  153. 请只返回 "DATABASE" 或 "CHAT",并在下一行简要说明理由。
  154. 格式:
  155. 分类: [DATABASE/CHAT]
  156. 理由: [简要说明]
  157. 置信度: [0.0-1.0之间的数字]
  158. """
  159. prompt = [
  160. llm.system_message("你是一个专业的问题分类助手,能准确判断问题类型。"),
  161. llm.user_message(classification_prompt)
  162. ]
  163. response = llm.submit_prompt(prompt)
  164. # 解析响应
  165. return self._parse_llm_response(response)
  166. except Exception as e:
  167. print(f"[WARNING] LLM分类失败: {str(e)}")
  168. return ClassificationResult(
  169. question_type="UNCERTAIN",
  170. confidence=self.llm_fallback_confidence,
  171. reason=f"LLM分类异常: {str(e)}",
  172. method="llm_error"
  173. )
  174. def _parse_llm_response(self, response: str) -> ClassificationResult:
  175. """解析LLM响应"""
  176. try:
  177. lines = response.strip().split('\n')
  178. question_type = "UNCERTAIN"
  179. reason = "LLM响应解析失败"
  180. confidence = self.llm_fallback_confidence
  181. for line in lines:
  182. line = line.strip()
  183. if line.startswith("分类:") or line.startswith("Classification:"):
  184. type_part = line.split(":", 1)[1].strip().upper()
  185. if "DATABASE" in type_part:
  186. question_type = "DATABASE"
  187. elif "CHAT" in type_part:
  188. question_type = "CHAT"
  189. elif line.startswith("理由:") or line.startswith("Reason:"):
  190. reason = line.split(":", 1)[1].strip()
  191. elif line.startswith("置信度:") or line.startswith("Confidence:"):
  192. try:
  193. conf_str = line.split(":", 1)[1].strip()
  194. confidence = float(conf_str)
  195. except:
  196. confidence = self.llm_fallback_confidence
  197. return ClassificationResult(
  198. question_type=question_type,
  199. confidence=confidence,
  200. reason=reason,
  201. method="llm_based"
  202. )
  203. except Exception as e:
  204. return ClassificationResult(
  205. question_type="UNCERTAIN",
  206. confidence=self.llm_fallback_confidence,
  207. reason=f"响应解析失败: {str(e)}",
  208. method="llm_parse_error"
  209. )