qs_generator.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. """
  2. Question-SQL生成器命令行入口
  3. 用于从已生成的DDL和MD文件生成Question-SQL训练数据
  4. """
  5. import argparse
  6. import asyncio
  7. import sys
  8. import os
  9. from pathlib import Path
  10. from schema_tools.qs_agent import QuestionSQLGenerationAgent
  11. from schema_tools.utils.logger import setup_logging
  12. def setup_argument_parser():
  13. """设置命令行参数解析器"""
  14. parser = argparse.ArgumentParser(
  15. description='Question-SQL Generator - 从MD文件生成Question-SQL训练数据',
  16. formatter_class=argparse.RawDescriptionHelpFormatter,
  17. epilog="""
  18. 示例用法:
  19. # 基本使用
  20. python -m schema_tools.qs_generator --output-dir ./output --table-list ./tables.txt --business-context "高速公路服务区管理系统"
  21. # 指定数据库名称
  22. python -m schema_tools.qs_generator --output-dir ./output --table-list ./tables.txt --business-context "电商系统" --db-name ecommerce_db
  23. # 启用详细日志
  24. python -m schema_tools.qs_generator --output-dir ./output --table-list ./tables.txt --business-context "管理系统" --verbose
  25. """
  26. )
  27. # 必需参数
  28. parser.add_argument(
  29. '--output-dir',
  30. required=True,
  31. help='包含DDL和MD文件的输出目录'
  32. )
  33. parser.add_argument(
  34. '--table-list',
  35. required=True,
  36. help='表清单文件路径(用于验证文件数量)'
  37. )
  38. parser.add_argument(
  39. '--business-context',
  40. required=True,
  41. help='业务上下文描述'
  42. )
  43. # 可选参数
  44. parser.add_argument(
  45. '--db-name',
  46. help='数据库名称(用于输出文件命名)'
  47. )
  48. parser.add_argument(
  49. '--verbose', '-v',
  50. action='store_true',
  51. help='启用详细日志输出'
  52. )
  53. parser.add_argument(
  54. '--log-file',
  55. help='日志文件路径'
  56. )
  57. return parser
  58. async def main():
  59. """主入口函数"""
  60. parser = setup_argument_parser()
  61. args = parser.parse_args()
  62. # 设置日志
  63. setup_logging(
  64. verbose=args.verbose,
  65. log_file=args.log_file,
  66. log_dir=os.path.join(args.output_dir, 'logs') if args.output_dir else None
  67. )
  68. # 验证参数
  69. output_path = Path(args.output_dir)
  70. if not output_path.exists():
  71. print(f"错误: 输出目录不存在: {args.output_dir}")
  72. sys.exit(1)
  73. if not os.path.exists(args.table_list):
  74. print(f"错误: 表清单文件不存在: {args.table_list}")
  75. sys.exit(1)
  76. try:
  77. # 创建Agent
  78. agent = QuestionSQLGenerationAgent(
  79. output_dir=args.output_dir,
  80. table_list_file=args.table_list,
  81. business_context=args.business_context,
  82. db_name=args.db_name
  83. )
  84. # 执行生成
  85. print(f"🚀 开始生成Question-SQL训练数据...")
  86. print(f"📁 输出目录: {args.output_dir}")
  87. print(f"📋 表清单: {args.table_list}")
  88. print(f"🏢 业务背景: {args.business_context}")
  89. report = await agent.generate()
  90. # 输出结果
  91. if report['success']:
  92. if report['failed_themes']:
  93. print(f"\n⚠️ 生成完成,但有 {len(report['failed_themes'])} 个主题失败")
  94. exit_code = 2 # 部分成功
  95. else:
  96. print("\n🎉 所有主题生成成功!")
  97. exit_code = 0 # 完全成功
  98. else:
  99. print("\n❌ 生成失败")
  100. exit_code = 1
  101. print(f"📁 输出文件: {report['output_file']}")
  102. sys.exit(exit_code)
  103. except KeyboardInterrupt:
  104. print("\n\n⏹️ 用户中断,程序退出")
  105. sys.exit(130)
  106. except Exception as e:
  107. print(f"\n❌ 程序执行失败: {e}")
  108. if args.verbose:
  109. import traceback
  110. traceback.print_exc()
  111. sys.exit(1)
  112. if __name__ == "__main__":
  113. asyncio.run(main())