sql_validate_cli.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. """
  2. SQL验证器命令行入口
  3. 用于验证Question-SQL对中的SQL语句是否有效
  4. """
  5. import argparse
  6. import asyncio
  7. import sys
  8. import os
  9. from pathlib import Path
  10. from .sql_validation_agent import SQLValidationAgent
  11. from data_pipeline.utils.logger import setup_logging
  12. def setup_argument_parser():
  13. """设置命令行参数解析器"""
  14. parser = argparse.ArgumentParser(
  15. description='SQL Validator - 验证Question-SQL对中的SQL语句',
  16. formatter_class=argparse.RawDescriptionHelpFormatter,
  17. epilog="""
  18. 示例用法:
  19. # 基本使用(仅验证,不修改文件)
  20. python -m data_pipeline.validators.sql_validate_cli --db-connection "postgresql://user:pass@localhost:5432/dbname" --input-file ./data.json
  21. # 使用task_id自动查找文件
  22. python -m data_pipeline.validators.sql_validate_cli --task-id manual_20250720_130541 --db-connection "postgresql://user:pass@localhost:5432/dbname"
  23. # 启用文件修改,但禁用LLM修复(仅删除无效SQL)
  24. python -m data_pipeline.validators.sql_validate_cli --db-connection "postgresql://user:pass@localhost:5432/dbname" --input-file ./data.json --modify-original-file --disable-llm-repair
  25. # 启用文件修改和LLM修复功能
  26. python -m data_pipeline.validators.sql_validate_cli --db-connection "postgresql://user:pass@localhost:5432/dbname" --input-file ./data.json --modify-original-file
  27. # 指定输出目录
  28. python -m data_pipeline.validators.sql_validate_cli --db-connection "postgresql://user:pass@localhost:5432/dbname" --input-file ./data.json --output-dir ./reports
  29. # 启用详细日志
  30. python -m data_pipeline.validators.sql_validate_cli --db-connection "postgresql://user:pass@localhost:5432/dbname" --input-file ./data.json --verbose
  31. """
  32. )
  33. # 必需参数
  34. parser.add_argument(
  35. '--db-connection',
  36. required=True,
  37. help='数据库连接字符串 (postgresql://user:pass@host:port/dbname)'
  38. )
  39. # 可选参数(当使用task-id时,input-file变为可选)
  40. parser.add_argument(
  41. '--task-id',
  42. help='任务ID,指定后将自动在任务目录中查找Question-SQL文件'
  43. )
  44. parser.add_argument(
  45. '--input-file',
  46. help='输入的JSON文件路径(包含Question-SQL对)'
  47. )
  48. # 可选参数
  49. parser.add_argument(
  50. '--output-dir',
  51. help='验证报告输出目录(默认为输入文件同目录)'
  52. )
  53. parser.add_argument(
  54. '--max-concurrent',
  55. type=int,
  56. help='最大并发验证数(覆盖配置文件设置)'
  57. )
  58. parser.add_argument(
  59. '--batch-size',
  60. type=int,
  61. help='批处理大小(覆盖配置文件设置)'
  62. )
  63. parser.add_argument(
  64. '--timeout',
  65. type=int,
  66. help='单个SQL验证超时时间(秒)'
  67. )
  68. parser.add_argument(
  69. '--verbose', '-v',
  70. action='store_true',
  71. help='启用详细日志输出'
  72. )
  73. parser.add_argument(
  74. '--log-file',
  75. help='日志文件路径'
  76. )
  77. parser.add_argument(
  78. '--dry-run',
  79. action='store_true',
  80. help='仅读取和解析文件,不执行验证'
  81. )
  82. parser.add_argument(
  83. '--save-json',
  84. action='store_true',
  85. help='同时保存详细的JSON报告'
  86. )
  87. parser.add_argument(
  88. '--disable-llm-repair',
  89. action='store_true',
  90. help='禁用LLM自动修复功能'
  91. )
  92. # 向后兼容的别名参数
  93. parser.add_argument(
  94. '--enable-llm-repair',
  95. action='store_true',
  96. help='启用LLM自动修复功能(与--disable-llm-repair相反,保持向后兼容性)'
  97. )
  98. parser.add_argument(
  99. '--no-modify-file',
  100. action='store_true',
  101. help='不修改原始JSON文件(仅生成验证报告)'
  102. )
  103. # 向后兼容的别名参数
  104. parser.add_argument(
  105. '--modify-original-file',
  106. action='store_true',
  107. help='修改原始JSON文件(与--no-modify-file相反,保持向后兼容性)'
  108. )
  109. return parser
  110. def apply_config_overrides(args):
  111. """应用命令行参数覆盖配置"""
  112. from data_pipeline.config import SCHEMA_TOOLS_CONFIG
  113. sql_config = SCHEMA_TOOLS_CONFIG['sql_validation']
  114. if args.max_concurrent:
  115. sql_config['max_concurrent_validations'] = args.max_concurrent
  116. print(f"覆盖并发数配置: {args.max_concurrent}")
  117. if args.batch_size:
  118. sql_config['batch_size'] = args.batch_size
  119. print(f"覆盖批处理大小: {args.batch_size}")
  120. if args.timeout:
  121. sql_config['validation_timeout'] = args.timeout
  122. print(f"覆盖超时配置: {args.timeout}秒")
  123. if args.save_json:
  124. sql_config['save_detailed_json_report'] = True
  125. print(f"启用详细JSON报告保存")
  126. # 注意:现在是disable_llm_repair,逻辑反转,同时支持向后兼容的enable_llm_repair
  127. if args.disable_llm_repair and args.enable_llm_repair:
  128. print("警告: --disable-llm-repair 和 --enable-llm-repair 不能同时使用,优先使用 --disable-llm-repair")
  129. sql_config['enable_sql_repair'] = False
  130. print(f"LLM修复功能已禁用")
  131. elif args.disable_llm_repair:
  132. sql_config['enable_sql_repair'] = False
  133. print(f"LLM修复功能已禁用")
  134. elif args.enable_llm_repair:
  135. sql_config['enable_sql_repair'] = True
  136. print(f"启用LLM自动修复功能(向后兼容参数)")
  137. else:
  138. # 默认启用LLM修复功能
  139. sql_config['enable_sql_repair'] = True
  140. print(f"启用LLM自动修复功能(默认行为)")
  141. # 注意:现在是no_modify_file,逻辑反转,同时支持向后兼容的modify_original_file
  142. if args.no_modify_file and args.modify_original_file:
  143. print("警告: --no-modify-file 和 --modify-original-file 不能同时使用,优先使用 --no-modify-file")
  144. sql_config['modify_original_file'] = False
  145. print(f"不修改原文件")
  146. elif args.no_modify_file:
  147. sql_config['modify_original_file'] = False
  148. print(f"不修改原文件")
  149. elif args.modify_original_file:
  150. sql_config['modify_original_file'] = True
  151. print(f"启用原文件修改功能(向后兼容参数)")
  152. else:
  153. # 默认启用文件修改功能
  154. sql_config['modify_original_file'] = True
  155. print(f"启用原文件修改功能(默认行为)")
  156. async def main():
  157. """主入口函数"""
  158. parser = setup_argument_parser()
  159. args = parser.parse_args()
  160. # 设置日志
  161. setup_logging(
  162. verbose=args.verbose,
  163. log_file=args.log_file
  164. )
  165. # 验证参数
  166. if not args.input_file and not args.task_id:
  167. print("错误: 必须指定 --input-file 或 --task-id 参数")
  168. parser.print_help()
  169. sys.exit(1)
  170. # 解析输入文件和输出目录
  171. input_file, output_dir = resolve_input_file_and_output_dir(args)
  172. if not input_file:
  173. if args.task_id:
  174. print(f"错误: 在任务目录中未找到Question-SQL文件 (*_pair.json)")
  175. print(f"任务ID: {args.task_id}")
  176. else:
  177. print(f"错误: 输入文件不存在: {args.input_file}")
  178. sys.exit(1)
  179. input_path = Path(input_file)
  180. if not input_path.suffix.lower() == '.json':
  181. print(f"警告: 输入文件可能不是JSON格式: {input_file}")
  182. # 应用配置覆盖
  183. apply_config_overrides(args)
  184. try:
  185. # 创建SQL验证Agent
  186. agent = SQLValidationAgent(
  187. db_connection=args.db_connection,
  188. input_file=input_file,
  189. output_dir=output_dir,
  190. task_id=args.task_id # 传递task_id
  191. )
  192. # 显示运行信息
  193. print(f"🚀 开始SQL验证...")
  194. print(f"📁 输入文件: {input_file}")
  195. if output_dir:
  196. print(f"📁 输出目录: {output_dir}")
  197. print(f"🔗 数据库: {_mask_db_connection(args.db_connection)}")
  198. if args.dry_run:
  199. print("\n🔍 执行预检查模式...")
  200. # 仅读取和验证文件格式
  201. questions_sqls = await agent._load_questions_sqls()
  202. print(f"✅ 成功读取 {len(questions_sqls)} 个Question-SQL对")
  203. print("📊 SQL样例:")
  204. for i, qs in enumerate(questions_sqls[:3], 1):
  205. print(f" {i}. {qs['question']}")
  206. print(f" SQL: {qs['sql'][:100]}{'...' if len(qs['sql']) > 100 else ''}")
  207. print()
  208. sys.exit(0)
  209. # 执行验证
  210. report = await agent.validate()
  211. # 输出结果
  212. success_rate = report['summary']['success_rate']
  213. if success_rate >= 0.9: # 90%以上成功率
  214. print(f"\n🎉 验证完成,成功率: {success_rate:.1%}")
  215. exit_code = 0
  216. elif success_rate >= 0.7: # 70%-90%成功率
  217. print(f"\n⚠️ 验证完成,成功率较低: {success_rate:.1%}")
  218. exit_code = 1
  219. else: # 70%以下成功率
  220. print(f"\n❌ 验证完成,成功率过低: {success_rate:.1%}")
  221. exit_code = 2
  222. print(f"📊 详细结果: {report['summary']['valid_sqls']}/{report['summary']['total_questions']} SQL有效")
  223. sys.exit(exit_code)
  224. except KeyboardInterrupt:
  225. print("\n\n⏹️ 用户中断,程序退出")
  226. sys.exit(130)
  227. except Exception as e:
  228. print(f"\n❌ 程序执行失败: {e}")
  229. if args.verbose:
  230. import traceback
  231. traceback.print_exc()
  232. sys.exit(1)
  233. def _mask_db_connection(conn_str: str) -> str:
  234. """隐藏数据库连接字符串中的敏感信息"""
  235. import re
  236. return re.sub(r'://[^:]+:[^@]+@', '://***:***@', conn_str)
  237. def resolve_input_file_and_output_dir(args):
  238. """解析输入文件和输出目录路径"""
  239. input_file = None
  240. output_dir = None
  241. if args.input_file:
  242. # 用户明确指定了输入文件
  243. input_file = args.input_file
  244. output_dir = args.output_dir or str(Path(input_file).parent)
  245. elif args.task_id:
  246. # 使用task_id自动查找输入文件
  247. from data_pipeline.config import SCHEMA_TOOLS_CONFIG
  248. base_dir = SCHEMA_TOOLS_CONFIG.get("output_directory", "./data_pipeline/training_data/")
  249. # 处理相对路径
  250. from pathlib import Path
  251. if not Path(base_dir).is_absolute():
  252. # 相对于项目根目录解析
  253. project_root = Path(__file__).parent.parent.parent
  254. base_dir = project_root / base_dir
  255. task_dir = Path(base_dir) / args.task_id
  256. output_dir = args.output_dir or str(task_dir)
  257. # 在任务目录中查找Question-SQL文件
  258. if task_dir.exists():
  259. # 只搜索标准命名的文件,排除 _old 后缀
  260. possible_files = [
  261. f for f in task_dir.glob("*_pair.json")
  262. if not f.name.endswith('_old') and '.backup' not in f.name
  263. ]
  264. if possible_files:
  265. # 选择最新的文件(按修改时间排序)
  266. input_file = str(max(possible_files, key=lambda f: f.stat().st_mtime))
  267. else:
  268. input_file = None
  269. else:
  270. input_file = None
  271. return input_file, output_dir
  272. if __name__ == "__main__":
  273. asyncio.run(main())