theme_extractor.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. import asyncio
  2. import json
  3. import logging
  4. from typing import List, Dict, Any
  5. from data_pipeline.config import SCHEMA_TOOLS_CONFIG
  6. class ThemeExtractor:
  7. """主题提取器"""
  8. def __init__(self, vn, business_context: str):
  9. """
  10. 初始化主题提取器
  11. Args:
  12. vn: vanna实例
  13. business_context: 业务上下文
  14. """
  15. self.vn = vn
  16. self.business_context = business_context
  17. self.logger = logging.getLogger("schema_tools.ThemeExtractor")
  18. self.config = SCHEMA_TOOLS_CONFIG
  19. async def extract_themes(self, md_contents: str) -> List[Dict[str, Any]]:
  20. """
  21. 从MD内容中提取分析主题
  22. Args:
  23. md_contents: 所有MD文件的组合内容
  24. Returns:
  25. 主题列表
  26. """
  27. theme_count = self.config['qs_generation']['theme_count']
  28. prompt = self._build_theme_extraction_prompt(md_contents, theme_count)
  29. try:
  30. # 调用LLM提取主题
  31. response = await self._call_llm(prompt)
  32. # 解析响应
  33. themes = self._parse_theme_response(response)
  34. self.logger.info(f"成功提取 {len(themes)} 个分析主题")
  35. return themes
  36. except Exception as e:
  37. self.logger.error(f"主题提取失败: {e}")
  38. raise
  39. def _build_theme_extraction_prompt(self, md_contents: str, theme_count: int) -> str:
  40. """构建主题提取的prompt"""
  41. prompt = f"""你是一位经验丰富的业务数据分析师,正在分析{self.business_context}的数据库。
  42. 以下是数据库中所有表的详细结构说明:
  43. {md_contents}
  44. 基于对这些表结构的理解,请从业务分析的角度提出 {theme_count} 个数据查询分析主题。
  45. 要求:
  46. 1. 每个主题应该有明确的业务价值和分析目标
  47. 2. 主题之间应该有所区别,覆盖不同的业务领域
  48. 3. 你需要自行决定每个主题应该涉及哪些表(使用实际存在的表名)
  49. 4. 主题应该体现实际业务场景的数据分析需求
  50. 5. 考虑时间维度、对比分析、排名统计等多种分析角度
  51. 6. 在选择业务实体时,请忽略以下技术性字段:
  52. - id、主键ID等标识字段
  53. - create_time、created_at、create_ts等创建时间字段
  54. - update_time、updated_at、update_ts等更新时间字段
  55. - delete_time、deleted_at、delete_ts等删除时间字段
  56. - version、版本号等版本控制字段
  57. - created_by、updated_by、deleted_by等操作人字段
  58. 7. 重点关注具有业务含义的实体字段和指标
  59. 请以JSON格式输出:
  60. ```json
  61. {{
  62. "themes": [
  63. {{
  64. "topic_name": "日营业数据分析",
  65. "description": "基于 bss_business_day_data 表,分析每个服务区和档口每天的营业收入、订单数量、支付方式等",
  66. "related_tables": ["bss_business_day_data", "bss_branch", "bss_service_area"],
  67. "biz_entities": ["服务区", "档口", "支付方式", "营收"],
  68. "biz_metrics": ["收入趋势", "服务区对比", "支付方式分布"]
  69. }}
  70. ]
  71. }}
  72. ```
  73. 请确保:
  74. - topic_name 简洁明了(10字以内)
  75. - description 详细说明分析目标和价值(50字左右)
  76. - related_tables 列出该主题需要用到的表名(数组格式)
  77. - biz_entities 列出3-5个主要业务实体(表的维度字段或非数值型字段,如服务区、公司、车辆等)
  78. - biz_metrics 列出3-5个主要业务指标名称(统计指标,如收入趋势、对比分析等)"""
  79. return prompt
  80. async def _call_llm(self, prompt: str) -> str:
  81. """调用LLM"""
  82. try:
  83. # 使用vanna的chat_with_llm方法
  84. response = await asyncio.to_thread(
  85. self.vn.chat_with_llm,
  86. question=prompt,
  87. system_prompt="你是一个专业的数据分析师,擅长从业务角度设计数据分析主题和查询方案。请严格按照要求的JSON格式输出。"
  88. )
  89. if not response or not response.strip():
  90. raise ValueError("LLM返回空响应")
  91. return response.strip()
  92. except Exception as e:
  93. self.logger.error(f"LLM调用失败: {e}")
  94. raise
  95. def _parse_theme_response(self, response: str) -> List[Dict[str, Any]]:
  96. """解析LLM的主题响应"""
  97. try:
  98. # 提取JSON部分
  99. import re
  100. json_match = re.search(r'```json\s*(.*?)\s*```', response, re.DOTALL)
  101. if json_match:
  102. json_str = json_match.group(1)
  103. else:
  104. # 尝试直接解析
  105. json_str = response
  106. # 解析JSON
  107. data = json.loads(json_str)
  108. themes = data.get('themes', [])
  109. # 验证和标准化主题格式
  110. validated_themes = []
  111. for theme in themes:
  112. # 兼容旧格式(name -> topic_name)
  113. if 'name' in theme and 'topic_name' not in theme:
  114. theme['topic_name'] = theme['name']
  115. # 验证必需字段
  116. required_fields = ['topic_name', 'description', 'related_tables']
  117. if all(key in theme for key in required_fields):
  118. # 确保related_tables是数组
  119. if isinstance(theme['related_tables'], str):
  120. theme['related_tables'] = [theme['related_tables']]
  121. # 确保biz_entities存在且是数组
  122. if 'biz_entities' not in theme:
  123. # 从description中提取业务实体
  124. theme['biz_entities'] = self._extract_biz_entities_from_description(theme['description'])
  125. elif isinstance(theme['biz_entities'], str):
  126. theme['biz_entities'] = [theme['biz_entities']]
  127. # 确保biz_metrics存在且是数组
  128. if 'biz_metrics' not in theme:
  129. # 从description中提取业务指标
  130. theme['biz_metrics'] = self._extract_biz_metrics_from_description(theme['description'])
  131. elif isinstance(theme['biz_metrics'], str):
  132. theme['biz_metrics'] = [theme['biz_metrics']]
  133. validated_themes.append(theme)
  134. else:
  135. self.logger.warning(f"主题格式不完整,跳过: {theme.get('topic_name', 'Unknown')}")
  136. return validated_themes
  137. except json.JSONDecodeError as e:
  138. self.logger.error(f"JSON解析失败: {e}")
  139. self.logger.debug(f"原始响应: {response}")
  140. raise ValueError(f"无法解析LLM响应为JSON格式: {e}")
  141. except Exception as e:
  142. self.logger.error(f"解析主题响应失败: {e}")
  143. raise
  144. def _extract_biz_entities_from_description(self, description: str) -> List[str]:
  145. """从描述中提取业务实体(简单实现)"""
  146. # 定义常见的业务实体关键词
  147. entity_keywords = [
  148. "服务区", "档口", "商品", "公司", "分公司", "车辆", "支付方式",
  149. "订单", "客户", "营收", "路段", "区域", "品牌", "品类"
  150. ]
  151. # 从描述中查找出现的实体关键词
  152. found_entities = []
  153. for entity in entity_keywords:
  154. if entity in description:
  155. found_entities.append(entity)
  156. # 如果找到的太少,返回默认值
  157. if len(found_entities) < 3:
  158. found_entities = ["业务实体", "数据对象", "分析主体"]
  159. return found_entities[:5] # 最多返回5个
  160. def _extract_biz_metrics_from_description(self, description: str) -> List[str]:
  161. """从描述中提取业务指标(简单实现)"""
  162. # 定义常见的业务指标关键词
  163. metrics_keywords = [
  164. "收入趋势", "营业额对比", "支付方式分布", "服务区对比", "增长率",
  165. "占比分析", "排名统计", "效率评估", "流量分析", "转化率"
  166. ]
  167. # 从描述中查找出现的指标关键词
  168. found_metrics = []
  169. for metric in metrics_keywords:
  170. if any(word in description for word in metric.split()):
  171. found_metrics.append(metric)
  172. # 如果找到的太少,返回默认值
  173. if len(found_metrics) < 3:
  174. found_metrics = ["数据统计", "趋势分析", "对比分析"]
  175. return found_metrics[:5] # 最多返回5个