sql_validation_agent.py 32 KB


  1. import asyncio
  2. import json
  3. import logging
  4. import time
  5. from datetime import datetime
  6. from pathlib import Path
  7. from typing import List, Dict, Any, Optional
  8. from schema_tools.config import SCHEMA_TOOLS_CONFIG
  9. from schema_tools.validators import SQLValidator, SQLValidationResult, ValidationStats
  10. from schema_tools.utils.logger import setup_logging
  11. class SQLValidationAgent:
  12. """SQL验证Agent - 管理SQL验证的完整流程"""
  13. def __init__(self,
  14. db_connection: str,
  15. input_file: str,
  16. output_dir: str = None):
  17. """
  18. 初始化SQL验证Agent
  19. Args:
  20. db_connection: 数据库连接字符串
  21. input_file: 输入的JSON文件路径(包含Question-SQL对)
  22. output_dir: 输出目录(默认为输入文件同目录)
  23. """
  24. self.db_connection = db_connection
  25. self.input_file = Path(input_file)
  26. self.output_dir = Path(output_dir) if output_dir else self.input_file.parent
  27. self.config = SCHEMA_TOOLS_CONFIG['sql_validation']
  28. self.logger = logging.getLogger("schema_tools.SQLValidationAgent")
  29. # 初始化验证器
  30. self.validator = SQLValidator(db_connection)
  31. # 初始化LLM实例(用于SQL修复)
  32. self.vn = None
  33. if self.config.get('enable_sql_repair', True):
  34. self._initialize_llm()
  35. # 统计信息
  36. self.total_questions = 0
  37. self.validation_start_time = None
  38. async def validate(self) -> Dict[str, Any]:
  39. """
  40. 执行SQL验证流程
  41. Returns:
  42. 验证结果报告
  43. """
  44. try:
  45. self.validation_start_time = time.time()
  46. self.logger.info("🚀 开始SQL验证流程")
  47. # 1. 读取输入文件
  48. self.logger.info(f"📖 读取输入文件: {self.input_file}")
  49. questions_sqls = await self._load_questions_sqls()
  50. self.total_questions = len(questions_sqls)
  51. if not questions_sqls:
  52. raise ValueError("输入文件中没有找到有效的Question-SQL对")
  53. self.logger.info(f"✅ 成功读取 {self.total_questions} 个Question-SQL对")
  54. # 2. 提取SQL语句
  55. sqls = [item['sql'] for item in questions_sqls]
  56. # 3. 执行验证
  57. self.logger.info("🔍 开始SQL验证...")
  58. validation_results = await self._validate_sqls_with_batching(sqls)
  59. # 4. 计算统计信息
  60. stats = self.validator.calculate_stats(validation_results)
  61. # 5. 尝试修复失败的SQL(如果启用LLM修复)
  62. file_modification_stats = {'modified': 0, 'deleted': 0, 'failed_modifications': 0}
  63. if self.config.get('enable_sql_repair', False) and self.vn:
  64. self.logger.info("🔧 启用LLM修复功能,开始修复失败的SQL...")
  65. validation_results = await self._attempt_sql_repair(questions_sqls, validation_results)
  66. # 重新计算统计信息(包含修复结果)
  67. stats = self.validator.calculate_stats(validation_results)
  68. # 6. 修改原始JSON文件(如果启用文件修改)
  69. if self.config.get('modify_original_file', False):
  70. self.logger.info("📝 启用文件修改功能,开始修改原始JSON文件...")
  71. file_modification_stats = await self._modify_original_json_file(questions_sqls, validation_results)
  72. else:
  73. self.logger.info("📋 不修改原始文件")
  74. # 7. 生成详细报告
  75. report = await self._generate_report(questions_sqls, validation_results, stats, file_modification_stats)
  76. # 8. 保存验证报告
  77. if self.config['save_validation_report']:
  78. await self._save_validation_report(report)
  79. # 9. 输出结果摘要
  80. self._print_summary(stats, validation_results, file_modification_stats)
  81. return report
  82. except Exception as e:
  83. self.logger.exception("❌ SQL验证流程失败")
  84. raise
  85. async def _load_questions_sqls(self) -> List[Dict[str, str]]:
  86. """读取Question-SQL对"""
  87. try:
  88. with open(self.input_file, 'r', encoding='utf-8') as f:
  89. data = json.load(f)
  90. # 验证数据格式
  91. if not isinstance(data, list):
  92. raise ValueError("输入文件应包含Question-SQL对的数组")
  93. questions_sqls = []
  94. for i, item in enumerate(data):
  95. if not isinstance(item, dict):
  96. self.logger.warning(f"跳过第 {i+1} 项:格式不正确")
  97. continue
  98. if 'question' not in item or 'sql' not in item:
  99. self.logger.warning(f"跳过第 {i+1} 项:缺少question或sql字段")
  100. continue
  101. questions_sqls.append({
  102. 'index': i,
  103. 'question': item['question'],
  104. 'sql': item['sql'].strip()
  105. })
  106. return questions_sqls
  107. except json.JSONDecodeError as e:
  108. raise ValueError(f"输入文件不是有效的JSON格式: {e}")
  109. except Exception as e:
  110. raise ValueError(f"读取输入文件失败: {e}")
  111. async def _validate_sqls_with_batching(self, sqls: List[str]) -> List[SQLValidationResult]:
  112. """使用批处理方式验证SQL"""
  113. batch_size = self.config['batch_size']
  114. all_results = []
  115. # 分批处理
  116. for i in range(0, len(sqls), batch_size):
  117. batch = sqls[i:i + batch_size]
  118. batch_num = i // batch_size + 1
  119. total_batches = (len(sqls) + batch_size - 1) // batch_size
  120. self.logger.info(f"📦 处理批次 {batch_num}/{total_batches} ({len(batch)} 个SQL)")
  121. batch_results = await self.validator.validate_sqls_batch(batch)
  122. all_results.extend(batch_results)
  123. # 显示批次进度
  124. valid_count = sum(1 for r in batch_results if r.valid)
  125. self.logger.info(f"✅ 批次 {batch_num} 完成: {valid_count}/{len(batch)} 有效")
  126. return all_results
  127. async def _generate_report(self,
  128. questions_sqls: List[Dict],
  129. validation_results: List[SQLValidationResult],
  130. stats: ValidationStats,
  131. file_modification_stats: Dict[str, int] = None) -> Dict[str, Any]:
  132. """生成详细验证报告"""
  133. validation_time = time.time() - self.validation_start_time
  134. # 合并问题和验证结果
  135. detailed_results = []
  136. for i, (qs, result) in enumerate(zip(questions_sqls, validation_results)):
  137. detailed_results.append({
  138. 'index': i + 1,
  139. 'question': qs['question'],
  140. 'sql': qs['sql'],
  141. 'valid': result.valid,
  142. 'error_message': result.error_message,
  143. 'execution_time': result.execution_time,
  144. 'retry_count': result.retry_count,
  145. # 添加修复信息
  146. 'repair_attempted': result.repair_attempted,
  147. 'repair_successful': result.repair_successful,
  148. 'repaired_sql': result.repaired_sql,
  149. 'repair_error': result.repair_error
  150. })
  151. # 生成报告
  152. report = {
  153. 'metadata': {
  154. 'input_file': str(self.input_file),
  155. 'validation_time': datetime.now().isoformat(),
  156. 'total_validation_time': validation_time,
  157. 'database_connection': self._mask_connection_string(self.db_connection),
  158. 'config': self.config.copy()
  159. },
  160. 'summary': {
  161. 'total_questions': stats.total_sqls,
  162. 'valid_sqls': stats.valid_sqls,
  163. 'invalid_sqls': stats.invalid_sqls,
  164. 'success_rate': stats.valid_sqls / stats.total_sqls if stats.total_sqls > 0 else 0.0,
  165. 'average_execution_time': stats.avg_time_per_sql,
  166. 'total_retries': stats.retry_count,
  167. # 添加修复统计
  168. 'repair_stats': {
  169. 'attempted': stats.repair_attempted,
  170. 'successful': stats.repair_successful,
  171. 'failed': stats.repair_failed
  172. },
  173. # 添加文件修改统计
  174. 'file_modification_stats': file_modification_stats or {
  175. 'modified': 0, 'deleted': 0, 'failed_modifications': 0
  176. }
  177. },
  178. 'detailed_results': detailed_results
  179. }
  180. return report
  181. async def _save_validation_report(self, report: Dict[str, Any]):
  182. """保存验证报告"""
  183. timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
  184. # 只保存文本格式摘要(便于查看)
  185. txt_file = self.output_dir / f"{self.config['report_file_prefix']}_{timestamp}_summary.txt"
  186. with open(txt_file, 'w', encoding='utf-8') as f:
  187. f.write(f"SQL验证报告\n")
  188. f.write(f"=" * 50 + "\n\n")
  189. f.write(f"输入文件: {self.input_file}\n")
  190. f.write(f"验证时间: {report['metadata']['validation_time']}\n")
  191. f.write(f"验证耗时: {report['metadata']['total_validation_time']:.2f}秒\n\n")
  192. f.write(f"验证结果摘要:\n")
  193. f.write(f" 总SQL数量: {report['summary']['total_questions']}\n")
  194. f.write(f" 有效SQL: {report['summary']['valid_sqls']}\n")
  195. f.write(f" 无效SQL: {report['summary']['invalid_sqls']}\n")
  196. f.write(f" 成功率: {report['summary']['success_rate']:.2%}\n")
  197. f.write(f" 平均耗时: {report['summary']['average_execution_time']:.3f}秒\n")
  198. f.write(f" 重试次数: {report['summary']['total_retries']}\n\n")
  199. # 添加修复统计
  200. if 'repair_stats' in report['summary']:
  201. repair_stats = report['summary']['repair_stats']
  202. f.write(f"SQL修复统计:\n")
  203. f.write(f" 尝试修复: {repair_stats['attempted']}\n")
  204. f.write(f" 修复成功: {repair_stats['successful']}\n")
  205. f.write(f" 修复失败: {repair_stats['failed']}\n")
  206. if repair_stats['attempted'] > 0:
  207. f.write(f" 修复成功率: {repair_stats['successful'] / repair_stats['attempted']:.2%}\n")
  208. f.write(f"\n")
  209. # 添加文件修改统计
  210. if 'file_modification_stats' in report['summary']:
  211. file_stats = report['summary']['file_modification_stats']
  212. f.write(f"原始文件修改统计:\n")
  213. f.write(f" 修改的SQL: {file_stats['modified']}\n")
  214. f.write(f" 删除的无效项: {file_stats['deleted']}\n")
  215. f.write(f" 修改失败: {file_stats['failed_modifications']}\n")
  216. f.write(f"\n")
  217. # 提取错误详情(显示完整SQL)
  218. error_results = [r for r in report['detailed_results'] if not r['valid'] and not r.get('repair_successful', False)]
  219. if error_results:
  220. f.write(f"错误详情(共{len(error_results)}个):\n")
  221. f.write(f"=" * 50 + "\n")
  222. for i, error_result in enumerate(error_results, 1):
  223. f.write(f"\n{i}. 问题: {error_result['question']}\n")
  224. f.write(f" 错误: {error_result['error_message']}\n")
  225. if error_result['retry_count'] > 0:
  226. f.write(f" 重试: {error_result['retry_count']}次\n")
  227. # 显示修复尝试信息
  228. if error_result.get('repair_attempted', False):
  229. if error_result.get('repair_successful', False):
  230. f.write(f" LLM修复尝试: 成功\n")
  231. f.write(f" 修复后SQL:\n")
  232. f.write(f" {error_result.get('repaired_sql', '')}\n")
  233. else:
  234. f.write(f" LLM修复尝试: 失败\n")
  235. repair_error = error_result.get('repair_error', '未知错误')
  236. f.write(f" 修复失败原因: {repair_error}\n")
  237. else:
  238. f.write(f" LLM修复尝试: 未尝试\n")
  239. f.write(f" 完整SQL:\n")
  240. f.write(f" {error_result['sql']}\n")
  241. f.write(f" {'-' * 40}\n")
  242. # 显示成功修复的SQL
  243. repaired_results = [r for r in report['detailed_results'] if r.get('repair_successful', False)]
  244. if repaired_results:
  245. f.write(f"\n成功修复的SQL(共{len(repaired_results)}个):\n")
  246. f.write(f"=" * 50 + "\n")
  247. for i, repaired_result in enumerate(repaired_results, 1):
  248. f.write(f"\n{i}. 问题: {repaired_result['question']}\n")
  249. f.write(f" 原始错误: {repaired_result['error_message']}\n")
  250. f.write(f" 修复后SQL:\n")
  251. f.write(f" {repaired_result.get('repaired_sql', '')}\n")
  252. f.write(f" {'-' * 40}\n")
  253. self.logger.info(f"📊 验证报告已保存: {txt_file}")
  254. # 如果配置允许,也可以保存JSON格式的详细报告(可选)
  255. if self.config.get('save_detailed_json_report', False):
  256. json_file = self.output_dir / f"{self.config['report_file_prefix']}_{timestamp}_report.json"
  257. with open(json_file, 'w', encoding='utf-8') as f:
  258. json.dump(report, f, ensure_ascii=False, indent=2)
  259. self.logger.info(f"📊 详细JSON报告已保存: {json_file}")
  260. def _mask_connection_string(self, conn_str: str) -> str:
  261. """隐藏连接字符串中的敏感信息"""
  262. import re
  263. # 隐藏密码
  264. return re.sub(r':[^:@]+@', ':***@', conn_str)
  265. def _print_summary(self, stats: ValidationStats, validation_results: List[SQLValidationResult] = None, file_modification_stats: Dict[str, int] = None):
  266. """打印验证结果摘要"""
  267. validation_time = time.time() - self.validation_start_time
  268. self.logger.info("=" * 60)
  269. self.logger.info("📊 SQL验证结果摘要")
  270. self.logger.info(f" 📝 总SQL数量: {stats.total_sqls}")
  271. self.logger.info(f" ✅ 有效SQL: {stats.valid_sqls}")
  272. self.logger.info(f" ❌ 无效SQL: {stats.invalid_sqls}")
  273. self.logger.info(f" 📈 成功率: {stats.valid_sqls / stats.total_sqls:.2%}")
  274. self.logger.info(f" ⏱️ 平均耗时: {stats.avg_time_per_sql:.3f}秒/SQL")
  275. self.logger.info(f" 🔄 重试次数: {stats.retry_count}")
  276. self.logger.info(f" ⏰ 总耗时: {validation_time:.2f}秒")
  277. # 添加修复统计
  278. if stats.repair_attempted > 0:
  279. self.logger.info(f" 🔧 修复尝试: {stats.repair_attempted}")
  280. self.logger.info(f" ✅ 修复成功: {stats.repair_successful}")
  281. self.logger.info(f" ❌ 修复失败: {stats.repair_failed}")
  282. repair_rate = stats.repair_successful / stats.repair_attempted if stats.repair_attempted > 0 else 0.0
  283. self.logger.info(f" 📈 修复成功率: {repair_rate:.2%}")
  284. # 添加文件修改统计
  285. if file_modification_stats:
  286. self.logger.info(f" 📝 文件修改: {file_modification_stats['modified']} 个SQL")
  287. self.logger.info(f" 🗑️ 删除无效项: {file_modification_stats['deleted']} 个")
  288. if file_modification_stats['failed_modifications'] > 0:
  289. self.logger.info(f" ⚠️ 修改失败: {file_modification_stats['failed_modifications']} 个")
  290. self.logger.info("=" * 60)
  291. # 显示部分错误信息
  292. if validation_results:
  293. error_results = [r for r in validation_results if not r.valid]
  294. if error_results:
  295. self.logger.info(f"⚠️ 前5个错误示例:")
  296. for i, error_result in enumerate(error_results[:5], 1):
  297. self.logger.info(f" {i}. {error_result.error_message}")
  298. # 显示SQL的前80个字符
  299. sql_preview = error_result.sql[:80] + '...' if len(error_result.sql) > 80 else error_result.sql
  300. self.logger.info(f" SQL: {sql_preview}")
  301. def _initialize_llm(self):
  302. """初始化LLM实例"""
  303. try:
  304. from core.vanna_llm_factory import create_vanna_instance
  305. self.vn = create_vanna_instance()
  306. self.logger.info("✅ LLM实例初始化成功,SQL修复功能已启用")
  307. except Exception as e:
  308. self.logger.warning(f"⚠️ LLM初始化失败,SQL修复功能将被禁用: {e}")
  309. self.vn = None
  310. async def _attempt_sql_repair(self, questions_sqls: List[Dict], validation_results: List[SQLValidationResult]) -> List[SQLValidationResult]:
  311. """
  312. 尝试修复失败的SQL
  313. Args:
  314. questions_sqls: 问题SQL对列表
  315. validation_results: 验证结果列表
  316. Returns:
  317. 更新后的验证结果列表
  318. """
  319. # 找出需要修复的SQL
  320. failed_indices = []
  321. for i, result in enumerate(validation_results):
  322. if not result.valid:
  323. failed_indices.append(i)
  324. if not failed_indices:
  325. self.logger.info("🎉 所有SQL都有效,无需修复")
  326. return validation_results
  327. self.logger.info(f"🔧 开始修复 {len(failed_indices)} 个失败的SQL...")
  328. # 批量修复
  329. batch_size = self.config.get('repair_batch_size', 5)
  330. updated_results = validation_results.copy()
  331. for i in range(0, len(failed_indices), batch_size):
  332. batch_indices = failed_indices[i:i + batch_size]
  333. self.logger.info(f"📦 修复批次 {i//batch_size + 1}/{(len(failed_indices) + batch_size - 1)//batch_size} ({len(batch_indices)} 个SQL)")
  334. # 准备批次数据
  335. batch_data = []
  336. for idx in batch_indices:
  337. batch_data.append({
  338. 'index': idx,
  339. 'question': questions_sqls[idx]['question'],
  340. 'sql': validation_results[idx].sql,
  341. 'error': validation_results[idx].error_message
  342. })
  343. # 调用LLM修复
  344. repaired_sqls = await self._repair_sqls_with_llm(batch_data)
  345. # 验证修复后的SQL
  346. for j, idx in enumerate(batch_indices):
  347. original_result = updated_results[idx]
  348. original_result.repair_attempted = True
  349. if j < len(repaired_sqls) and repaired_sqls[j]:
  350. repaired_sql = repaired_sqls[j]
  351. # 验证修复后的SQL
  352. repair_result = await self.validator.validate_sql(repaired_sql)
  353. if repair_result.valid:
  354. # 修复成功
  355. original_result.repair_successful = True
  356. original_result.repaired_sql = repaired_sql
  357. original_result.valid = True # 更新为有效
  358. self.logger.info(f"✅ SQL修复成功 (索引 {idx})")
  359. else:
  360. # 修复失败
  361. original_result.repair_successful = False
  362. original_result.repair_error = repair_result.error_message
  363. self.logger.warning(f"❌ SQL修复失败 (索引 {idx}): {repair_result.error_message}")
  364. else:
  365. # LLM修复失败
  366. original_result.repair_successful = False
  367. original_result.repair_error = "LLM修复失败或返回空结果"
  368. self.logger.warning(f"❌ LLM修复失败 (索引 {idx})")
  369. # 统计修复结果
  370. repair_attempted = sum(1 for r in updated_results if r.repair_attempted)
  371. repair_successful = sum(1 for r in updated_results if r.repair_successful)
  372. self.logger.info(f"🔧 修复完成: {repair_successful}/{repair_attempted} 成功")
  373. return updated_results
  374. async def _modify_original_json_file(self, questions_sqls: List[Dict], validation_results: List[SQLValidationResult]) -> Dict[str, int]:
  375. """
  376. 修改原始JSON文件:
  377. 1. 对于修复成功的SQL,更新原始文件中的SQL内容
  378. 2. 对于无法修复的SQL,从原始文件中删除对应的键值对
  379. Returns:
  380. 修改统计信息
  381. """
  382. stats = {'modified': 0, 'deleted': 0, 'failed_modifications': 0}
  383. try:
  384. # 读取原始JSON文件
  385. with open(self.input_file, 'r', encoding='utf-8') as f:
  386. original_data = json.load(f)
  387. if not isinstance(original_data, list):
  388. self.logger.error("原始JSON文件格式不正确,无法修改")
  389. stats['failed_modifications'] = 1
  390. return stats
  391. # 创建备份文件
  392. backup_file = Path(str(self.input_file) + '.backup')
  393. with open(backup_file, 'w', encoding='utf-8') as f:
  394. json.dump(original_data, f, ensure_ascii=False, indent=2)
  395. self.logger.info(f"已创建备份文件: {backup_file}")
  396. # 构建修改计划
  397. modifications = []
  398. deletions = []
  399. for i, (qs, result) in enumerate(zip(questions_sqls, validation_results)):
  400. if result.repair_successful and result.repaired_sql:
  401. # 修复成功的SQL
  402. modifications.append({
  403. 'index': i,
  404. 'original_sql': result.sql,
  405. 'repaired_sql': result.repaired_sql,
  406. 'question': qs['question']
  407. })
  408. elif not result.valid and not result.repair_successful:
  409. # 无法修复的SQL,标记删除
  410. deletions.append({
  411. 'index': i,
  412. 'question': qs['question'],
  413. 'sql': result.sql,
  414. 'error': result.error_message
  415. })
  416. # 执行修改(从后往前,避免索引变化)
  417. new_data = original_data.copy()
  418. # 先删除无效项(从后往前删除)
  419. for deletion in sorted(deletions, key=lambda x: x['index'], reverse=True):
  420. if deletion['index'] < len(new_data):
  421. removed_item = new_data.pop(deletion['index'])
  422. stats['deleted'] += 1
  423. self.logger.info(f"删除无效项 {deletion['index']}: {deletion['question'][:50]}...")
  424. # 再修改SQL(需要重新计算索引)
  425. index_offset = 0
  426. for modification in sorted(modifications, key=lambda x: x['index']):
  427. # 计算删除后的新索引
  428. new_index = modification['index']
  429. for deletion in deletions:
  430. if deletion['index'] < modification['index']:
  431. new_index -= 1
  432. if new_index < len(new_data):
  433. new_data[new_index]['sql'] = modification['repaired_sql']
  434. stats['modified'] += 1
  435. self.logger.info(f"修改SQL {new_index}: {modification['question'][:50]}...")
  436. # 写入修改后的文件
  437. with open(self.input_file, 'w', encoding='utf-8') as f:
  438. json.dump(new_data, f, ensure_ascii=False, indent=2)
  439. self.logger.info(f"✅ 原始文件修改完成: 修改{stats['modified']}个SQL,删除{stats['deleted']}个无效项")
  440. # 记录详细修改信息到日志文件
  441. await self._write_modification_log(modifications, deletions)
  442. except Exception as e:
  443. self.logger.error(f"修改原始JSON文件失败: {e}")
  444. stats['failed_modifications'] = 1
  445. return stats
  446. async def _write_modification_log(self, modifications: List[Dict], deletions: List[Dict]):
  447. """写入详细的修改日志"""
  448. try:
  449. timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
  450. log_file = self.output_dir / f"file_modifications_{timestamp}.log"
  451. with open(log_file, 'w', encoding='utf-8') as f:
  452. f.write(f"原始JSON文件修改日志\n")
  453. f.write(f"=" * 50 + "\n")
  454. f.write(f"修改时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
  455. f.write(f"原始文件: {self.input_file}\n")
  456. f.write(f"备份文件: {str(self.input_file)}.backup\n")
  457. f.write(f"\n")
  458. if modifications:
  459. f.write(f"修改的SQL ({len(modifications)}个):\n")
  460. f.write(f"-" * 40 + "\n")
  461. for i, mod in enumerate(modifications, 1):
  462. f.write(f"{i}. 索引: {mod['index']}\n")
  463. f.write(f" 问题: {mod['question']}\n")
  464. f.write(f" 原SQL: {mod['original_sql']}\n")
  465. f.write(f" 新SQL: {mod['repaired_sql']}\n\n")
  466. if deletions:
  467. f.write(f"删除的无效项 ({len(deletions)}个):\n")
  468. f.write(f"-" * 40 + "\n")
  469. for i, del_item in enumerate(deletions, 1):
  470. f.write(f"{i}. 索引: {del_item['index']}\n")
  471. f.write(f" 问题: {del_item['question']}\n")
  472. f.write(f" SQL: {del_item['sql']}\n")
  473. f.write(f" 错误: {del_item['error']}\n\n")
  474. self.logger.info(f"详细修改日志已保存: {log_file}")
  475. except Exception as e:
  476. self.logger.warning(f"写入修改日志失败: {e}")
  477. async def _repair_sqls_with_llm(self, batch_data: List[Dict]) -> List[str]:
  478. """
  479. 使用LLM修复SQL批次
  480. Args:
  481. batch_data: 批次数据,包含question, sql, error
  482. Returns:
  483. 修复后的SQL列表
  484. """
  485. try:
  486. # 构建修复提示词
  487. prompt = self._build_repair_prompt(batch_data)
  488. # 调用LLM
  489. response = await self._call_llm_for_repair(prompt)
  490. # 解析响应
  491. repaired_sqls = self._parse_repair_response(response, len(batch_data))
  492. return repaired_sqls
  493. except Exception as e:
  494. self.logger.error(f"LLM修复批次失败: {e}")
  495. return [""] * len(batch_data) # 返回空字符串列表
  496. def _build_repair_prompt(self, batch_data: List[Dict]) -> str:
  497. """构建SQL修复提示词"""
  498. # 提取数据库类型
  499. db_type = "PostgreSQL" # 从连接字符串可以确定是PostgreSQL
  500. prompt = f"""你是一个SQL专家,专门修复PostgreSQL数据库的SQL语句错误。
  501. 数据库类型: {db_type}
  502. 请修复以下SQL语句中的错误。对于每个SQL,我会提供问题描述、错误信息和完整的SQL语句。
  503. 修复要求:
  504. 1. 只修复语法错误和表结构错误
  505. 2. 保持SQL的原始业务逻辑不变
  506. 3. 使用PostgreSQL标准语法
  507. 4. 确保修复后的SQL语法正确
  508. 需要修复的SQL:
  509. """
  510. # 添加每个SQL的详细信息
  511. for i, data in enumerate(batch_data, 1):
  512. prompt += f"""
  513. {i}. 问题: {data['question']}
  514. 错误: {data['error']}
  515. 完整SQL:
  516. {data['sql']}
  517. """
  518. prompt += f"""
  519. 请按以下JSON格式输出修复后的SQL:
  520. ```json
  521. {{
  522. "repaired_sqls": [
  523. "修复后的SQL1",
  524. "修复后的SQL2",
  525. "修复后的SQL3"
  526. ]
  527. }}
  528. ```
  529. 注意:
  530. - 必须输出 {len(batch_data)} 个修复后的SQL
  531. - 如果某个SQL无法修复,请输出原始SQL
  532. - SQL语句必须以分号结束
  533. - 保持原始的中文别名和业务逻辑"""
  534. return prompt
  535. async def _call_llm_for_repair(self, prompt: str) -> str:
  536. """调用LLM进行修复"""
  537. import asyncio
  538. try:
  539. timeout = self.config.get('llm_repair_timeout', 60)
  540. response = await asyncio.wait_for(
  541. asyncio.to_thread(
  542. self.vn.chat_with_llm,
  543. question=prompt,
  544. system_prompt="你是一个专业的PostgreSQL SQL专家,专门负责修复SQL语句中的语法错误和表结构错误。请严格按照JSON格式输出修复结果。"
  545. ),
  546. timeout=timeout
  547. )
  548. if not response or not response.strip():
  549. raise ValueError("LLM返回空响应")
  550. return response.strip()
  551. except asyncio.TimeoutError:
  552. raise Exception(f"LLM调用超时({timeout}秒)")
  553. except Exception as e:
  554. raise Exception(f"LLM调用失败: {e}")
  555. def _parse_repair_response(self, response: str, expected_count: int) -> List[str]:
  556. """解析LLM修复响应"""
  557. import json
  558. import re
  559. try:
  560. # 尝试提取JSON部分
  561. json_match = re.search(r'```json\s*(.*?)\s*```', response, re.DOTALL)
  562. if json_match:
  563. json_str = json_match.group(1)
  564. else:
  565. # 如果没有代码块,尝试直接解析
  566. json_str = response
  567. # 解析JSON
  568. parsed_data = json.loads(json_str)
  569. repaired_sqls = parsed_data.get('repaired_sqls', [])
  570. # 验证数量
  571. if len(repaired_sqls) != expected_count:
  572. self.logger.warning(f"LLM返回的SQL数量不匹配:期望{expected_count},实际{len(repaired_sqls)}")
  573. # 补齐或截断
  574. while len(repaired_sqls) < expected_count:
  575. repaired_sqls.append("")
  576. repaired_sqls = repaired_sqls[:expected_count]
  577. # 清理SQL语句
  578. cleaned_sqls = []
  579. for sql in repaired_sqls:
  580. if sql and isinstance(sql, str):
  581. cleaned_sql = sql.strip()
  582. # 确保以分号结束
  583. if cleaned_sql and not cleaned_sql.endswith(';'):
  584. cleaned_sql += ';'
  585. cleaned_sqls.append(cleaned_sql)
  586. else:
  587. cleaned_sqls.append("")
  588. return cleaned_sqls
  589. except json.JSONDecodeError as e:
  590. self.logger.error(f"解析LLM修复响应失败: {e}")
  591. self.logger.debug(f"原始响应: {response}")
  592. return [""] * expected_count
  593. except Exception as e:
  594. self.logger.error(f"处理修复响应失败: {e}")
  595. return [""] * expected_count