sql_validator.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. """
  2. SQL验证器命令行入口
  3. 用于验证Question-SQL对中的SQL语句是否有效
  4. """
  5. import argparse
  6. import asyncio
  7. import sys
  8. import os
  9. from pathlib import Path
  10. from schema_tools.sql_validation_agent import SQLValidationAgent
  11. from schema_tools.utils.logger import setup_logging
  12. def setup_argument_parser():
  13. """设置命令行参数解析器"""
  14. parser = argparse.ArgumentParser(
  15. description='SQL Validator - 验证Question-SQL对中的SQL语句',
  16. formatter_class=argparse.RawDescriptionHelpFormatter,
  17. epilog="""
  18. 示例用法:
  19. # 基本使用(仅验证,不修改文件)
  20. python -m schema_tools.sql_validator --db-connection "postgresql://user:pass@localhost:5432/dbname" --input-file ./data.json
  21. # 仅删除无效SQL,不进行LLM修复
  22. python -m schema_tools.sql_validator --db-connection "postgresql://user:pass@localhost:5432/dbname" --input-file ./data.json --modify-original-file
  23. # 启用LLM修复功能(需要同时指定文件修改参数)
  24. python -m schema_tools.sql_validator --db-connection "postgresql://user:pass@localhost:5432/dbname" --input-file ./data.json --enable-llm-repair --modify-original-file
  25. # 指定输出目录
  26. python -m schema_tools.sql_validator --db-connection "postgresql://user:pass@localhost:5432/dbname" --input-file ./data.json --output-dir ./reports
  27. # 启用详细日志
  28. python -m schema_tools.sql_validator --db-connection "postgresql://user:pass@localhost:5432/dbname" --input-file ./data.json --verbose
  29. """
  30. )
  31. # 必需参数
  32. parser.add_argument(
  33. '--db-connection',
  34. required=True,
  35. help='数据库连接字符串 (postgresql://user:pass@host:port/dbname)'
  36. )
  37. parser.add_argument(
  38. '--input-file',
  39. required=True,
  40. help='输入的JSON文件路径(包含Question-SQL对)'
  41. )
  42. # 可选参数
  43. parser.add_argument(
  44. '--output-dir',
  45. help='验证报告输出目录(默认为输入文件同目录)'
  46. )
  47. parser.add_argument(
  48. '--max-concurrent',
  49. type=int,
  50. help='最大并发验证数(覆盖配置文件设置)'
  51. )
  52. parser.add_argument(
  53. '--batch-size',
  54. type=int,
  55. help='批处理大小(覆盖配置文件设置)'
  56. )
  57. parser.add_argument(
  58. '--timeout',
  59. type=int,
  60. help='单个SQL验证超时时间(秒)'
  61. )
  62. parser.add_argument(
  63. '--verbose', '-v',
  64. action='store_true',
  65. help='启用详细日志输出'
  66. )
  67. parser.add_argument(
  68. '--log-file',
  69. help='日志文件路径'
  70. )
  71. parser.add_argument(
  72. '--dry-run',
  73. action='store_true',
  74. help='仅读取和解析文件,不执行验证'
  75. )
  76. parser.add_argument(
  77. '--save-json',
  78. action='store_true',
  79. help='同时保存详细的JSON报告'
  80. )
  81. parser.add_argument(
  82. '--enable-llm-repair',
  83. action='store_true',
  84. help='启用LLM自动修复功能'
  85. )
  86. parser.add_argument(
  87. '--modify-original-file',
  88. action='store_true',
  89. help='修改原始JSON文件(删除无效SQL,如果启用LLM修复则同时更新修复后的SQL)'
  90. )
  91. return parser
  92. def apply_config_overrides(args):
  93. """应用命令行参数覆盖配置"""
  94. from schema_tools.config import SCHEMA_TOOLS_CONFIG
  95. sql_config = SCHEMA_TOOLS_CONFIG['sql_validation']
  96. if args.max_concurrent:
  97. sql_config['max_concurrent_validations'] = args.max_concurrent
  98. print(f"覆盖并发数配置: {args.max_concurrent}")
  99. if args.batch_size:
  100. sql_config['batch_size'] = args.batch_size
  101. print(f"覆盖批处理大小: {args.batch_size}")
  102. if args.timeout:
  103. sql_config['validation_timeout'] = args.timeout
  104. print(f"覆盖超时配置: {args.timeout}秒")
  105. if args.save_json:
  106. sql_config['save_detailed_json_report'] = True
  107. print(f"启用详细JSON报告保存")
  108. if args.enable_llm_repair:
  109. sql_config['enable_sql_repair'] = True
  110. print(f"启用LLM自动修复功能")
  111. else:
  112. sql_config['enable_sql_repair'] = False
  113. print(f"LLM修复功能已禁用")
  114. if args.modify_original_file:
  115. sql_config['modify_original_file'] = True
  116. print(f"启用原文件修改功能")
  117. else:
  118. sql_config['modify_original_file'] = False
  119. print(f"不修改原文件")
  120. async def main():
  121. """主入口函数"""
  122. parser = setup_argument_parser()
  123. args = parser.parse_args()
  124. # 设置日志
  125. setup_logging(
  126. verbose=args.verbose,
  127. log_file=args.log_file,
  128. log_dir=os.path.join(args.output_dir, 'logs') if args.output_dir else None
  129. )
  130. # 验证参数
  131. if not os.path.exists(args.input_file):
  132. print(f"错误: 输入文件不存在: {args.input_file}")
  133. sys.exit(1)
  134. input_path = Path(args.input_file)
  135. if not input_path.suffix.lower() == '.json':
  136. print(f"警告: 输入文件可能不是JSON格式: {args.input_file}")
  137. # 应用配置覆盖
  138. apply_config_overrides(args)
  139. try:
  140. # 创建SQL验证Agent
  141. agent = SQLValidationAgent(
  142. db_connection=args.db_connection,
  143. input_file=args.input_file,
  144. output_dir=args.output_dir
  145. )
  146. # 显示运行信息
  147. print(f"🚀 开始SQL验证...")
  148. print(f"📁 输入文件: {args.input_file}")
  149. if args.output_dir:
  150. print(f"📁 输出目录: {args.output_dir}")
  151. print(f"🔗 数据库: {_mask_db_connection(args.db_connection)}")
  152. if args.dry_run:
  153. print("\n🔍 执行预检查模式...")
  154. # 仅读取和验证文件格式
  155. questions_sqls = await agent._load_questions_sqls()
  156. print(f"✅ 成功读取 {len(questions_sqls)} 个Question-SQL对")
  157. print("📊 SQL样例:")
  158. for i, qs in enumerate(questions_sqls[:3], 1):
  159. print(f" {i}. {qs['question']}")
  160. print(f" SQL: {qs['sql'][:100]}{'...' if len(qs['sql']) > 100 else ''}")
  161. print()
  162. sys.exit(0)
  163. # 执行验证
  164. report = await agent.validate()
  165. # 输出结果
  166. success_rate = report['summary']['success_rate']
  167. if success_rate >= 0.9: # 90%以上成功率
  168. print(f"\n🎉 验证完成,成功率: {success_rate:.1%}")
  169. exit_code = 0
  170. elif success_rate >= 0.7: # 70%-90%成功率
  171. print(f"\n⚠️ 验证完成,成功率较低: {success_rate:.1%}")
  172. exit_code = 1
  173. else: # 70%以下成功率
  174. print(f"\n❌ 验证完成,成功率过低: {success_rate:.1%}")
  175. exit_code = 2
  176. print(f"📊 详细结果: {report['summary']['valid_sqls']}/{report['summary']['total_questions']} SQL有效")
  177. sys.exit(exit_code)
  178. except KeyboardInterrupt:
  179. print("\n\n⏹️ 用户中断,程序退出")
  180. sys.exit(130)
  181. except Exception as e:
  182. print(f"\n❌ 程序执行失败: {e}")
  183. if args.verbose:
  184. import traceback
  185. traceback.print_exc()
  186. sys.exit(1)
  187. def _mask_db_connection(conn_str: str) -> str:
  188. """隐藏数据库连接字符串中的敏感信息"""
  189. import re
  190. return re.sub(r'://[^:]+:[^@]+@', '://***:***@', conn_str)
  191. if __name__ == "__main__":
  192. asyncio.run(main())