sql_validator.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. """
  2. SQL验证器命令行入口
  3. 用于验证Question-SQL对中的SQL语句是否有效
  4. """
  5. import argparse
  6. import asyncio
  7. import sys
  8. import os
  9. from pathlib import Path
  10. from schema_tools.sql_validation_agent import SQLValidationAgent
  11. from schema_tools.utils.logger import setup_logging
  12. def setup_argument_parser():
  13. """设置命令行参数解析器"""
  14. parser = argparse.ArgumentParser(
  15. description='SQL Validator - 验证Question-SQL对中的SQL语句',
  16. formatter_class=argparse.RawDescriptionHelpFormatter,
  17. epilog="""
  18. 示例用法:
  19. # 基本使用(仅验证,不修改文件)
  20. python -m schema_tools.sql_validator --db-connection "postgresql://user:pass@localhost:5432/dbname" --input-file ./data.json
  21. # 启用文件修改,但禁用LLM修复(仅删除无效SQL)
  22. python -m schema_tools.sql_validator --db-connection "postgresql://user:pass@localhost:5432/dbname" --input-file ./data.json --modify-original-file --disable-llm-repair
  23. # 启用文件修改和LLM修复功能
  24. python -m schema_tools.sql_validator --db-connection "postgresql://user:pass@localhost:5432/dbname" --input-file ./data.json --modify-original-file
  25. # 指定输出目录
  26. python -m schema_tools.sql_validator --db-connection "postgresql://user:pass@localhost:5432/dbname" --input-file ./data.json --output-dir ./reports
  27. # 启用详细日志
  28. python -m schema_tools.sql_validator --db-connection "postgresql://user:pass@localhost:5432/dbname" --input-file ./data.json --verbose
  29. """
  30. )
  31. # 必需参数
  32. parser.add_argument(
  33. '--db-connection',
  34. required=True,
  35. help='数据库连接字符串 (postgresql://user:pass@host:port/dbname)'
  36. )
  37. parser.add_argument(
  38. '--input-file',
  39. required=True,
  40. help='输入的JSON文件路径(包含Question-SQL对)'
  41. )
  42. # 可选参数
  43. parser.add_argument(
  44. '--output-dir',
  45. help='验证报告输出目录(默认为输入文件同目录)'
  46. )
  47. parser.add_argument(
  48. '--max-concurrent',
  49. type=int,
  50. help='最大并发验证数(覆盖配置文件设置)'
  51. )
  52. parser.add_argument(
  53. '--batch-size',
  54. type=int,
  55. help='批处理大小(覆盖配置文件设置)'
  56. )
  57. parser.add_argument(
  58. '--timeout',
  59. type=int,
  60. help='单个SQL验证超时时间(秒)'
  61. )
  62. parser.add_argument(
  63. '--verbose', '-v',
  64. action='store_true',
  65. help='启用详细日志输出'
  66. )
  67. parser.add_argument(
  68. '--log-file',
  69. help='日志文件路径'
  70. )
  71. parser.add_argument(
  72. '--dry-run',
  73. action='store_true',
  74. help='仅读取和解析文件,不执行验证'
  75. )
  76. parser.add_argument(
  77. '--save-json',
  78. action='store_true',
  79. help='同时保存详细的JSON报告'
  80. )
  81. parser.add_argument(
  82. '--disable-llm-repair',
  83. action='store_true',
  84. help='禁用LLM自动修复功能'
  85. )
  86. # 向后兼容的别名参数
  87. parser.add_argument(
  88. '--enable-llm-repair',
  89. action='store_true',
  90. help='启用LLM自动修复功能(与--disable-llm-repair相反,保持向后兼容性)'
  91. )
  92. parser.add_argument(
  93. '--no-modify-file',
  94. action='store_true',
  95. help='不修改原始JSON文件(仅生成验证报告)'
  96. )
  97. # 向后兼容的别名参数
  98. parser.add_argument(
  99. '--modify-original-file',
  100. action='store_true',
  101. help='修改原始JSON文件(与--no-modify-file相反,保持向后兼容性)'
  102. )
  103. return parser
  104. def apply_config_overrides(args):
  105. """应用命令行参数覆盖配置"""
  106. from schema_tools.config import SCHEMA_TOOLS_CONFIG
  107. sql_config = SCHEMA_TOOLS_CONFIG['sql_validation']
  108. if args.max_concurrent:
  109. sql_config['max_concurrent_validations'] = args.max_concurrent
  110. print(f"覆盖并发数配置: {args.max_concurrent}")
  111. if args.batch_size:
  112. sql_config['batch_size'] = args.batch_size
  113. print(f"覆盖批处理大小: {args.batch_size}")
  114. if args.timeout:
  115. sql_config['validation_timeout'] = args.timeout
  116. print(f"覆盖超时配置: {args.timeout}秒")
  117. if args.save_json:
  118. sql_config['save_detailed_json_report'] = True
  119. print(f"启用详细JSON报告保存")
  120. # 注意:现在是disable_llm_repair,逻辑反转,同时支持向后兼容的enable_llm_repair
  121. if args.disable_llm_repair and args.enable_llm_repair:
  122. print("警告: --disable-llm-repair 和 --enable-llm-repair 不能同时使用,优先使用 --disable-llm-repair")
  123. sql_config['enable_sql_repair'] = False
  124. print(f"LLM修复功能已禁用")
  125. elif args.disable_llm_repair:
  126. sql_config['enable_sql_repair'] = False
  127. print(f"LLM修复功能已禁用")
  128. elif args.enable_llm_repair:
  129. sql_config['enable_sql_repair'] = True
  130. print(f"启用LLM自动修复功能(向后兼容参数)")
  131. else:
  132. # 默认启用LLM修复功能
  133. sql_config['enable_sql_repair'] = True
  134. print(f"启用LLM自动修复功能(默认行为)")
  135. # 注意:现在是no_modify_file,逻辑反转,同时支持向后兼容的modify_original_file
  136. if args.no_modify_file and args.modify_original_file:
  137. print("警告: --no-modify-file 和 --modify-original-file 不能同时使用,优先使用 --no-modify-file")
  138. sql_config['modify_original_file'] = False
  139. print(f"不修改原文件")
  140. elif args.no_modify_file:
  141. sql_config['modify_original_file'] = False
  142. print(f"不修改原文件")
  143. elif args.modify_original_file:
  144. sql_config['modify_original_file'] = True
  145. print(f"启用原文件修改功能(向后兼容参数)")
  146. else:
  147. # 默认启用文件修改功能
  148. sql_config['modify_original_file'] = True
  149. print(f"启用原文件修改功能(默认行为)")
  150. async def main():
  151. """主入口函数"""
  152. parser = setup_argument_parser()
  153. args = parser.parse_args()
  154. # 设置日志
  155. setup_logging(
  156. verbose=args.verbose,
  157. log_file=args.log_file,
  158. log_dir=os.path.join(args.output_dir, 'logs') if args.output_dir else None
  159. )
  160. # 验证参数
  161. if not os.path.exists(args.input_file):
  162. print(f"错误: 输入文件不存在: {args.input_file}")
  163. sys.exit(1)
  164. input_path = Path(args.input_file)
  165. if not input_path.suffix.lower() == '.json':
  166. print(f"警告: 输入文件可能不是JSON格式: {args.input_file}")
  167. # 应用配置覆盖
  168. apply_config_overrides(args)
  169. try:
  170. # 创建SQL验证Agent
  171. agent = SQLValidationAgent(
  172. db_connection=args.db_connection,
  173. input_file=args.input_file,
  174. output_dir=args.output_dir
  175. )
  176. # 显示运行信息
  177. print(f"🚀 开始SQL验证...")
  178. print(f"📁 输入文件: {args.input_file}")
  179. if args.output_dir:
  180. print(f"📁 输出目录: {args.output_dir}")
  181. print(f"🔗 数据库: {_mask_db_connection(args.db_connection)}")
  182. if args.dry_run:
  183. print("\n🔍 执行预检查模式...")
  184. # 仅读取和验证文件格式
  185. questions_sqls = await agent._load_questions_sqls()
  186. print(f"✅ 成功读取 {len(questions_sqls)} 个Question-SQL对")
  187. print("📊 SQL样例:")
  188. for i, qs in enumerate(questions_sqls[:3], 1):
  189. print(f" {i}. {qs['question']}")
  190. print(f" SQL: {qs['sql'][:100]}{'...' if len(qs['sql']) > 100 else ''}")
  191. print()
  192. sys.exit(0)
  193. # 执行验证
  194. report = await agent.validate()
  195. # 输出结果
  196. success_rate = report['summary']['success_rate']
  197. if success_rate >= 0.9: # 90%以上成功率
  198. print(f"\n🎉 验证完成,成功率: {success_rate:.1%}")
  199. exit_code = 0
  200. elif success_rate >= 0.7: # 70%-90%成功率
  201. print(f"\n⚠️ 验证完成,成功率较低: {success_rate:.1%}")
  202. exit_code = 1
  203. else: # 70%以下成功率
  204. print(f"\n❌ 验证完成,成功率过低: {success_rate:.1%}")
  205. exit_code = 2
  206. print(f"📊 详细结果: {report['summary']['valid_sqls']}/{report['summary']['total_questions']} SQL有效")
  207. sys.exit(exit_code)
  208. except KeyboardInterrupt:
  209. print("\n\n⏹️ 用户中断,程序退出")
  210. sys.exit(130)
  211. except Exception as e:
  212. print(f"\n❌ 程序执行失败: {e}")
  213. if args.verbose:
  214. import traceback
  215. traceback.print_exc()
  216. sys.exit(1)
  217. def _mask_db_connection(conn_str: str) -> str:
  218. """隐藏数据库连接字符串中的敏感信息"""
  219. import re
  220. return re.sub(r'://[^:]+:[^@]+@', '://***:***@', conn_str)
  221. if __name__ == "__main__":
  222. asyncio.run(main())