dict_loader.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. # agent/dict_loader.py
  2. """
  3. 分类器词典配置加载器
  4. 负责从YAML文件加载分类器词典配置,并提供数据转换和验证功能
  5. """
  6. import yaml
  7. import os
  8. import re
  9. from typing import Dict, Any, List, Optional
  10. from dataclasses import dataclass
  11. from core.logging import get_agent_logger
  12. # 初始化日志 [[memory:3840221]]
  13. logger = get_agent_logger("DictLoader")
  14. @dataclass
  15. class ClassifierDictConfig:
  16. """分类器词典配置数据类"""
  17. strong_business_keywords: Dict[str, List[str]]
  18. query_intent_keywords: List[str]
  19. non_business_keywords: List[str]
  20. sql_patterns: List[str]
  21. chat_keywords: List[str]
  22. weights: Dict[str, float]
  23. metadata: Dict[str, Any]
  24. class DictLoader:
  25. """分类器词典配置加载器"""
  26. def __init__(self, dict_file: str = None):
  27. """
  28. 初始化加载器
  29. Args:
  30. dict_file: 词典配置文件路径,默认为agent/classifier_dict.yaml
  31. """
  32. if dict_file is None:
  33. current_dir = os.path.dirname(os.path.abspath(__file__))
  34. dict_file = os.path.join(current_dir, "classifier_dict.yaml")
  35. self.dict_file = dict_file
  36. self.config_cache = None
  37. def load_config(self, force_reload: bool = False) -> ClassifierDictConfig:
  38. """
  39. 加载词典配置
  40. Args:
  41. force_reload: 是否强制重新加载,默认使用缓存
  42. Returns:
  43. ClassifierDictConfig: 词典配置对象
  44. Raises:
  45. FileNotFoundError: 配置文件不存在
  46. ValueError: 配置文件格式错误
  47. """
  48. if self.config_cache is not None and not force_reload:
  49. return self.config_cache
  50. try:
  51. logger.info(f"加载词典配置文件: {self.dict_file}")
  52. with open(self.dict_file, 'r', encoding='utf-8') as f:
  53. yaml_data = yaml.safe_load(f)
  54. # 验证配置文件
  55. self._validate_config(yaml_data)
  56. # 转换数据格式
  57. config = self._convert_config(yaml_data)
  58. # 缓存配置
  59. self.config_cache = config
  60. logger.info("词典配置加载成功")
  61. return config
  62. except FileNotFoundError:
  63. error_msg = f"词典配置文件不存在: {self.dict_file}"
  64. logger.error(error_msg)
  65. raise FileNotFoundError(error_msg)
  66. except yaml.YAMLError as e:
  67. error_msg = f"词典配置文件YAML格式错误: {str(e)}"
  68. logger.error(error_msg)
  69. raise ValueError(error_msg)
  70. except Exception as e:
  71. error_msg = f"词典配置加载失败: {str(e)}"
  72. logger.error(error_msg)
  73. raise ValueError(error_msg)
  74. def _validate_config(self, yaml_data: Dict[str, Any]) -> None:
  75. """验证配置文件格式和必要字段"""
  76. required_sections = [
  77. 'strong_business_keywords',
  78. 'query_intent_keywords',
  79. 'non_business_keywords',
  80. 'sql_patterns',
  81. 'chat_keywords',
  82. 'weights'
  83. ]
  84. for section in required_sections:
  85. if section not in yaml_data:
  86. raise ValueError(f"配置文件缺少必要部分: {section}")
  87. # 验证权重配置
  88. required_weights = [
  89. 'business_entity',
  90. 'system_indicator',
  91. 'query_intent',
  92. 'sql_pattern',
  93. 'chat_keyword',
  94. 'non_business_confidence',
  95. 'high_confidence_threshold',
  96. 'max_confidence',
  97. 'llm_fallback_confidence',
  98. 'uncertain_confidence',
  99. 'llm_error_confidence'
  100. ]
  101. for weight in required_weights:
  102. if weight not in yaml_data['weights']:
  103. raise ValueError(f"权重配置缺少: {weight}")
  104. logger.debug("配置文件验证通过")
  105. def _convert_config(self, yaml_data: Dict[str, Any]) -> ClassifierDictConfig:
  106. """将YAML数据转换为ClassifierDictConfig对象"""
  107. # 转换强业务关键词(保持字典结构)
  108. strong_business_keywords = {}
  109. for category, data in yaml_data['strong_business_keywords'].items():
  110. if isinstance(data, dict) and 'keywords' in data:
  111. strong_business_keywords[category] = data['keywords']
  112. else:
  113. # 兼容简单格式
  114. strong_business_keywords[category] = data
  115. # 转换查询意图关键词
  116. query_intent_data = yaml_data['query_intent_keywords']
  117. if isinstance(query_intent_data, dict) and 'keywords' in query_intent_data:
  118. query_intent_keywords = query_intent_data['keywords']
  119. else:
  120. query_intent_keywords = query_intent_data
  121. # 转换非业务实体词(展平为列表)
  122. non_business_keywords = self._flatten_non_business_keywords(
  123. yaml_data['non_business_keywords']
  124. )
  125. # 转换SQL模式
  126. sql_patterns = []
  127. patterns_data = yaml_data['sql_patterns']
  128. if isinstance(patterns_data, dict) and 'patterns' in patterns_data:
  129. for pattern_info in patterns_data['patterns']:
  130. if isinstance(pattern_info, dict):
  131. sql_patterns.append(pattern_info['pattern'])
  132. else:
  133. sql_patterns.append(pattern_info)
  134. else:
  135. sql_patterns = patterns_data
  136. # 转换其他关键词列表
  137. chat_keywords = self._extract_keywords_list(yaml_data['chat_keywords'])
  138. return ClassifierDictConfig(
  139. strong_business_keywords=strong_business_keywords,
  140. query_intent_keywords=query_intent_keywords,
  141. non_business_keywords=non_business_keywords,
  142. sql_patterns=sql_patterns,
  143. chat_keywords=chat_keywords,
  144. weights=yaml_data['weights'],
  145. metadata=yaml_data.get('metadata', {})
  146. )
  147. def _flatten_non_business_keywords(self, non_business_data: Dict[str, Any]) -> List[str]:
  148. """将分类的非业务词展平为列表"""
  149. flattened = []
  150. # 跳过description字段
  151. for category, keywords in non_business_data.items():
  152. if category == 'description':
  153. continue
  154. if isinstance(keywords, list):
  155. flattened.extend(keywords)
  156. return flattened
  157. def _extract_keywords_list(self, data: Any) -> List[str]:
  158. """从可能包含description的数据中提取关键词列表"""
  159. if isinstance(data, dict) and 'keywords' in data:
  160. return data['keywords']
  161. elif isinstance(data, list):
  162. return data
  163. else:
  164. return []
  165. # 全局加载器实例
  166. _dict_loader = None
  167. def get_dict_loader() -> DictLoader:
  168. """获取全局词典加载器实例"""
  169. global _dict_loader
  170. if _dict_loader is None:
  171. _dict_loader = DictLoader()
  172. return _dict_loader
  173. def load_classifier_dict_config(force_reload: bool = False) -> ClassifierDictConfig:
  174. """
  175. 加载分类器词典配置(便捷函数)
  176. Args:
  177. force_reload: 是否强制重新加载
  178. Returns:
  179. ClassifierDictConfig: 词典配置对象
  180. """
  181. loader = get_dict_loader()
  182. return loader.load_config(force_reload)