qs_agent.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525
  1. import asyncio
  2. import json
  3. import logging
  4. import time
  5. from datetime import datetime
  6. from pathlib import Path
  7. from typing import List, Dict, Any, Optional
  8. from schema_tools.config import SCHEMA_TOOLS_CONFIG
  9. from schema_tools.validators import FileCountValidator
  10. from schema_tools.analyzers import MDFileAnalyzer, ThemeExtractor
  11. from schema_tools.utils.logger import setup_logging
  12. from core.vanna_llm_factory import create_vanna_instance
  13. class QuestionSQLGenerationAgent:
  14. """Question-SQL生成Agent"""
  15. def __init__(self,
  16. output_dir: str,
  17. table_list_file: str,
  18. business_context: str,
  19. db_name: str = None):
  20. """
  21. 初始化Agent
  22. Args:
  23. output_dir: 输出目录(包含DDL和MD文件)
  24. table_list_file: 表清单文件路径
  25. business_context: 业务上下文
  26. db_name: 数据库名称(用于输出文件命名)
  27. """
  28. self.output_dir = Path(output_dir)
  29. self.table_list_file = table_list_file
  30. self.business_context = business_context
  31. self.db_name = db_name or "db"
  32. self.config = SCHEMA_TOOLS_CONFIG
  33. self.logger = logging.getLogger("schema_tools.QSAgent")
  34. # 初始化组件
  35. self.validator = FileCountValidator()
  36. self.md_analyzer = MDFileAnalyzer(output_dir)
  37. # vanna实例和主题提取器将在需要时初始化
  38. self.vn = None
  39. self.theme_extractor = None
  40. # 中间结果存储
  41. self.intermediate_results = []
  42. self.intermediate_file = None
  43. async def generate(self) -> Dict[str, Any]:
  44. """
  45. 生成Question-SQL对
  46. Returns:
  47. 生成结果报告
  48. """
  49. start_time = time.time()
  50. try:
  51. self.logger.info("🚀 开始生成Question-SQL训练数据")
  52. # 1. 验证文件数量
  53. self.logger.info("📋 验证文件数量...")
  54. validation_result = self.validator.validate(self.table_list_file, str(self.output_dir))
  55. if not validation_result.is_valid:
  56. self.logger.error(f"❌ 文件验证失败: {validation_result.error}")
  57. if validation_result.missing_ddl:
  58. self.logger.error(f"缺失DDL文件: {validation_result.missing_ddl}")
  59. if validation_result.missing_md:
  60. self.logger.error(f"缺失MD文件: {validation_result.missing_md}")
  61. raise ValueError(f"文件验证失败: {validation_result.error}")
  62. self.logger.info(f"✅ 文件验证通过: {validation_result.table_count}个表")
  63. # 2. 读取所有MD文件内容
  64. self.logger.info("📖 读取MD文件...")
  65. md_contents = await self.md_analyzer.read_all_md_files()
  66. # 3. 初始化LLM相关组件
  67. self._initialize_llm_components()
  68. # 4. 提取分析主题
  69. self.logger.info("🎯 提取分析主题...")
  70. themes = await self.theme_extractor.extract_themes(md_contents)
  71. self.logger.info(f"✅ 成功提取 {len(themes)} 个分析主题")
  72. for i, theme in enumerate(themes):
  73. topic_name = theme.get('topic_name', theme.get('name', ''))
  74. description = theme.get('description', '')
  75. self.logger.info(f" {i+1}. {topic_name}: {description}")
  76. # 5. 初始化中间结果文件
  77. self._init_intermediate_file()
  78. # 6. 处理每个主题
  79. all_qs_pairs = []
  80. failed_themes = []
  81. # 根据配置决定是并行还是串行处理
  82. max_concurrent = self.config['qs_generation'].get('max_concurrent_themes', 1)
  83. if max_concurrent > 1:
  84. results = await self._process_themes_parallel(themes, md_contents, max_concurrent)
  85. else:
  86. results = await self._process_themes_serial(themes, md_contents)
  87. # 7. 整理结果
  88. for result in results:
  89. if result['success']:
  90. all_qs_pairs.extend(result['qs_pairs'])
  91. else:
  92. failed_themes.append(result['theme_name'])
  93. # 8. 保存最终结果
  94. output_file = await self._save_final_results(all_qs_pairs)
  95. # 8.5 生成metadata.txt文件
  96. await self._generate_metadata_file(themes)
  97. # 9. 清理中间文件
  98. if not failed_themes: # 只有全部成功才清理
  99. self._cleanup_intermediate_file()
  100. # 10. 生成报告
  101. end_time = time.time()
  102. report = {
  103. 'success': True,
  104. 'total_themes': len(themes),
  105. 'successful_themes': len(themes) - len(failed_themes),
  106. 'failed_themes': failed_themes,
  107. 'total_questions': len(all_qs_pairs),
  108. 'output_file': str(output_file),
  109. 'execution_time': end_time - start_time
  110. }
  111. self._print_summary(report)
  112. return report
  113. except Exception as e:
  114. self.logger.exception("❌ Question-SQL生成失败")
  115. # 保存当前已生成的结果
  116. if self.intermediate_results:
  117. recovery_file = self._save_intermediate_results()
  118. self.logger.warning(f"⚠️ 中间结果已保存到: {recovery_file}")
  119. raise
  120. def _initialize_llm_components(self):
  121. """初始化LLM相关组件"""
  122. if not self.vn:
  123. self.logger.info("初始化LLM组件...")
  124. self.vn = create_vanna_instance()
  125. self.theme_extractor = ThemeExtractor(self.vn, self.business_context)
  126. async def _process_themes_serial(self, themes: List[Dict], md_contents: str) -> List[Dict]:
  127. """串行处理主题"""
  128. results = []
  129. for i, theme in enumerate(themes):
  130. self.logger.info(f"处理主题 {i+1}/{len(themes)}: {theme.get('topic_name', theme.get('name', ''))}")
  131. result = await self._process_single_theme(theme, md_contents)
  132. results.append(result)
  133. # 检查是否需要继续
  134. if not result['success'] and not self.config['qs_generation']['continue_on_theme_error']:
  135. self.logger.error(f"主题处理失败,停止处理")
  136. break
  137. return results
  138. async def _process_themes_parallel(self, themes: List[Dict], md_contents: str, max_concurrent: int) -> List[Dict]:
  139. """并行处理主题"""
  140. semaphore = asyncio.Semaphore(max_concurrent)
  141. async def process_with_semaphore(theme):
  142. async with semaphore:
  143. return await self._process_single_theme(theme, md_contents)
  144. tasks = [process_with_semaphore(theme) for theme in themes]
  145. results = await asyncio.gather(*tasks, return_exceptions=True)
  146. # 处理异常结果
  147. processed_results = []
  148. for i, result in enumerate(results):
  149. if isinstance(result, Exception):
  150. theme_name = themes[i].get('topic_name', themes[i].get('name', ''))
  151. self.logger.error(f"主题 '{theme_name}' 处理异常: {result}")
  152. processed_results.append({
  153. 'success': False,
  154. 'theme_name': theme_name,
  155. 'error': str(result)
  156. })
  157. else:
  158. processed_results.append(result)
  159. return processed_results
  160. async def _process_single_theme(self, theme: Dict, md_contents: str) -> Dict:
  161. """处理单个主题"""
  162. theme_name = theme.get('topic_name', theme.get('name', ''))
  163. try:
  164. self.logger.info(f"🔍 开始处理主题: {theme_name}")
  165. # 构建prompt
  166. prompt = self._build_qs_generation_prompt(theme, md_contents)
  167. # 调用LLM生成
  168. response = await self._call_llm(prompt)
  169. # 解析响应
  170. qs_pairs = self._parse_qs_response(response)
  171. # 验证和清理
  172. validated_pairs = self._validate_qs_pairs(qs_pairs, theme['name'])
  173. # 保存中间结果
  174. await self._save_theme_results(theme_name, validated_pairs)
  175. self.logger.info(f"✅ 主题 '{theme_name}' 处理成功,生成 {len(validated_pairs)} 个问题")
  176. return {
  177. 'success': True,
  178. 'theme_name': theme_name,
  179. 'qs_pairs': validated_pairs
  180. }
  181. except Exception as e:
  182. self.logger.error(f"❌ 处理主题 '{theme_name}' 失败: {e}")
  183. return {
  184. 'success': False,
  185. 'theme_name': theme_name,
  186. 'error': str(e),
  187. 'qs_pairs': []
  188. }
  189. def _build_qs_generation_prompt(self, theme: Dict, md_contents: str) -> str:
  190. """构建Question-SQL生成的prompt"""
  191. questions_count = self.config['qs_generation']['questions_per_theme']
  192. # 兼容新旧格式
  193. topic_name = theme.get('topic_name', theme.get('name', ''))
  194. description = theme.get('description', '')
  195. focus_areas = theme.get('focus_areas', theme.get('keywords', []))
  196. related_tables = theme.get('related_tables', [])
  197. prompt = f"""你是一位业务数据分析师,正在为{self.business_context}设计数据查询。
  198. 当前分析主题:{topic_name}
  199. 主题描述:{description}
  200. 关注领域:{', '.join(focus_areas)}
  201. 相关表:{', '.join(related_tables)}
  202. 数据库表结构信息:
  203. {md_contents}
  204. 请为这个主题生成 {questions_count} 个业务问题和对应的SQL查询。
  205. 要求:
  206. 1. 问题应该从业务角度出发,贴合主题要求,具有实际分析价值
  207. 2. SQL必须使用PostgreSQL语法
  208. 3. 考虑实际业务逻辑(如软删除使用 delete_ts IS NULL 条件)
  209. 4. 使用中文别名提高可读性(使用 AS 指定列别名)
  210. 5. 问题应该多样化,覆盖不同的分析角度
  211. 6. 包含时间筛选、分组统计、排序、限制等不同类型的查询
  212. 7. SQL语句末尾必须以分号结束
  213. 输出JSON格式(注意SQL中的双引号需要转义):
  214. ```json
  215. [
  216. {{
  217. "question": "具体的业务问题?",
  218. "sql": "SELECT column AS 中文名 FROM table WHERE condition;"
  219. }}
  220. ]
  221. ```
  222. 生成的问题应该包括但不限于:
  223. - 趋势分析(按时间维度)
  224. - 对比分析(不同维度对比)
  225. - 排名统计(TOP N)
  226. - 汇总统计(总量、平均值等)
  227. - 明细查询(特定条件的详细数据)"""
  228. return prompt
  229. async def _call_llm(self, prompt: str) -> str:
  230. """调用LLM"""
  231. try:
  232. response = await asyncio.to_thread(
  233. self.vn.chat_with_llm,
  234. question=prompt,
  235. system_prompt="你是一个专业的数据分析师,精通PostgreSQL语法,擅长设计有业务价值的数据查询。请严格按照JSON格式输出。"
  236. )
  237. if not response or not response.strip():
  238. raise ValueError("LLM返回空响应")
  239. return response.strip()
  240. except Exception as e:
  241. self.logger.error(f"LLM调用失败: {e}")
  242. raise
  243. def _parse_qs_response(self, response: str) -> List[Dict[str, str]]:
  244. """解析Question-SQL响应"""
  245. try:
  246. # 提取JSON部分
  247. import re
  248. json_match = re.search(r'```json\s*(.*?)\s*```', response, re.DOTALL)
  249. if json_match:
  250. json_str = json_match.group(1)
  251. else:
  252. json_str = response
  253. # 解析JSON
  254. qs_pairs = json.loads(json_str)
  255. if not isinstance(qs_pairs, list):
  256. raise ValueError("响应不是列表格式")
  257. return qs_pairs
  258. except json.JSONDecodeError as e:
  259. self.logger.error(f"JSON解析失败: {e}")
  260. self.logger.debug(f"原始响应: {response}")
  261. raise ValueError(f"无法解析LLM响应为JSON格式: {e}")
  262. def _validate_qs_pairs(self, qs_pairs: List[Dict], theme_name: str) -> List[Dict[str, str]]:
  263. """验证和清理Question-SQL对"""
  264. validated = []
  265. for i, pair in enumerate(qs_pairs):
  266. if not isinstance(pair, dict):
  267. self.logger.warning(f"跳过无效格式的项 {i+1}")
  268. continue
  269. question = pair.get('question', '').strip()
  270. sql = pair.get('sql', '').strip()
  271. if not question or not sql:
  272. self.logger.warning(f"跳过空问题或SQL的项 {i+1}")
  273. continue
  274. # 确保SQL以分号结束
  275. if not sql.endswith(';'):
  276. sql += ';'
  277. validated.append({
  278. 'question': question,
  279. 'sql': sql
  280. })
  281. self.logger.info(f"主题 '{theme_name}': 验证通过 {len(validated)}/{len(qs_pairs)} 个问题")
  282. return validated
  283. def _init_intermediate_file(self):
  284. """初始化中间结果文件"""
  285. timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
  286. self.intermediate_file = self.output_dir / f"qs_intermediate_{timestamp}.json"
  287. self.intermediate_results = []
  288. self.logger.info(f"中间结果文件: {self.intermediate_file}")
  289. async def _save_theme_results(self, theme_name: str, qs_pairs: List[Dict]):
  290. """保存单个主题的结果"""
  291. theme_result = {
  292. "theme": theme_name,
  293. "timestamp": datetime.now().isoformat(),
  294. "questions_count": len(qs_pairs),
  295. "questions": qs_pairs
  296. }
  297. self.intermediate_results.append(theme_result)
  298. # 立即保存到中间文件
  299. if self.config['qs_generation']['save_intermediate']:
  300. try:
  301. with open(self.intermediate_file, 'w', encoding='utf-8') as f:
  302. json.dump(self.intermediate_results, f, ensure_ascii=False, indent=2)
  303. self.logger.debug(f"中间结果已更新: {self.intermediate_file}")
  304. except Exception as e:
  305. self.logger.warning(f"保存中间结果失败: {e}")
  306. def _save_intermediate_results(self) -> Path:
  307. """异常时保存中间结果"""
  308. if not self.intermediate_results:
  309. return None
  310. recovery_file = self.output_dir / f"qs_recovery_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
  311. try:
  312. with open(recovery_file, 'w', encoding='utf-8') as f:
  313. json.dump({
  314. "status": "interrupted",
  315. "timestamp": datetime.now().isoformat(),
  316. "completed_themes": len(self.intermediate_results),
  317. "results": self.intermediate_results
  318. }, f, ensure_ascii=False, indent=2)
  319. return recovery_file
  320. except Exception as e:
  321. self.logger.error(f"保存恢复文件失败: {e}")
  322. return None
  323. async def _save_final_results(self, all_qs_pairs: List[Dict]) -> Path:
  324. """保存最终结果"""
  325. timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
  326. output_file = self.output_dir / f"{self.config['qs_generation']['output_file_prefix']}_{self.db_name}_{timestamp}_pair.json"
  327. try:
  328. with open(output_file, 'w', encoding='utf-8') as f:
  329. json.dump(all_qs_pairs, f, ensure_ascii=False, indent=2)
  330. self.logger.info(f"✅ 最终结果已保存到: {output_file}")
  331. return output_file
  332. except Exception as e:
  333. self.logger.error(f"保存最终结果失败: {e}")
  334. raise
  335. def _cleanup_intermediate_file(self):
  336. """清理中间文件"""
  337. if self.intermediate_file and self.intermediate_file.exists():
  338. try:
  339. self.intermediate_file.unlink()
  340. self.logger.info("已清理中间文件")
  341. except Exception as e:
  342. self.logger.warning(f"清理中间文件失败: {e}")
  343. def _print_summary(self, report: Dict):
  344. """打印总结信息"""
  345. self.logger.info("=" * 60)
  346. self.logger.info("📊 生成总结")
  347. self.logger.info(f" ✅ 总主题数: {report['total_themes']}")
  348. self.logger.info(f" ✅ 成功主题: {report['successful_themes']}")
  349. if report['failed_themes']:
  350. self.logger.info(f" ❌ 失败主题: {len(report['failed_themes'])}")
  351. for theme in report['failed_themes']:
  352. self.logger.info(f" - {theme}")
  353. self.logger.info(f" 📝 总问题数: {report['total_questions']}")
  354. self.logger.info(f" 📁 输出文件: {report['output_file']}")
  355. self.logger.info(f" ⏱️ 执行时间: {report['execution_time']:.2f}秒")
  356. self.logger.info("=" * 60)
  357. async def _generate_metadata_file(self, themes: List[Dict]):
  358. """生成metadata.txt文件,包含INSERT语句"""
  359. metadata_file = self.output_dir / "metadata.txt"
  360. try:
  361. with open(metadata_file, 'w', encoding='utf-8') as f:
  362. f.write("-- Schema Tools生成的主题元数据\n")
  363. f.write(f"-- 业务背景: {self.business_context}\n")
  364. f.write(f"-- 生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
  365. f.write(f"-- 数据库: {self.db_name}\n\n")
  366. f.write("-- 创建表(如果不存在)\n")
  367. f.write("CREATE TABLE IF NOT EXISTS metadata (\n")
  368. f.write(" id SERIAL PRIMARY KEY,\n")
  369. f.write(" topic_name VARCHAR(100) NOT NULL,\n")
  370. f.write(" description TEXT,\n")
  371. f.write(" related_tables TEXT[],\n")
  372. f.write(" keywords TEXT[],\n")
  373. f.write(" focus_areas TEXT[],\n")
  374. f.write(" created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP\n")
  375. f.write(");\n\n")
  376. f.write("-- 插入主题数据\n")
  377. for theme in themes:
  378. # 获取字段值,使用新格式
  379. topic_name = theme.get('topic_name', theme.get('name', ''))
  380. description = theme.get('description', '')
  381. # 处理related_tables
  382. related_tables = theme.get('related_tables', [])
  383. if isinstance(related_tables, list):
  384. tables_str = '{' + ','.join(related_tables) + '}'
  385. else:
  386. tables_str = '{}'
  387. # 处理keywords
  388. keywords = theme.get('keywords', [])
  389. if isinstance(keywords, list):
  390. keywords_str = '{' + ','.join(keywords) + '}'
  391. else:
  392. keywords_str = '{}'
  393. # 处理focus_areas
  394. focus_areas = theme.get('focus_areas', [])
  395. if isinstance(focus_areas, list):
  396. focus_areas_str = '{' + ','.join(focus_areas) + '}'
  397. else:
  398. focus_areas_str = '{}'
  399. # 生成INSERT语句
  400. f.write("INSERT INTO metadata(topic_name, description, related_tables, keywords, focus_areas) VALUES\n")
  401. f.write("(\n")
  402. f.write(f" '{self._escape_sql_string(topic_name)}',\n")
  403. f.write(f" '{self._escape_sql_string(description)}',\n")
  404. f.write(f" '{tables_str}',\n")
  405. f.write(f" '{keywords_str}',\n")
  406. f.write(f" '{focus_areas_str}'\n")
  407. f.write(");\n\n")
  408. self.logger.info(f"✅ metadata.txt文件已生成: {metadata_file}")
  409. return metadata_file
  410. except Exception as e:
  411. self.logger.error(f"生成metadata.txt文件失败: {e}")
  412. return None
  413. def _escape_sql_string(self, value: str) -> str:
  414. """转义SQL字符串中的特殊字符"""
  415. if not value:
  416. return ""
  417. # 转义单引号
  418. return value.replace("'", "''")