qs_generator.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. """
  2. Question-SQL生成器命令行入口
  3. 用于从已生成的DDL和MD文件生成Question-SQL训练数据
  4. """
  5. import argparse
  6. import asyncio
  7. import sys
  8. import os
  9. from pathlib import Path
  10. from .qs_agent import QuestionSQLGenerationAgent
  11. from data_pipeline.utils.logger import setup_logging
  12. def setup_argument_parser():
  13. """设置命令行参数解析器"""
  14. parser = argparse.ArgumentParser(
  15. description='Question-SQL Generator - 从MD文件生成Question-SQL训练数据',
  16. formatter_class=argparse.RawDescriptionHelpFormatter,
  17. epilog="""
  18. 示例用法:
  19. # 基本使用
  20. python -m data_pipeline.qa_generation.qs_generator --output-dir ./output --table-list ./tables.txt --business-context "高速公路服务区管理系统"
  21. # 使用task_id自动解析路径
  22. python -m data_pipeline.qa_generation.qs_generator --task-id manual_20250720_130541 --table-list ./tables.txt --business-context "高速公路服务区管理系统"
  23. # 指定数据库名称
  24. python -m data_pipeline.qa_generation.qs_generator --output-dir ./output --table-list ./tables.txt --business-context "电商系统" --db-name ecommerce_db
  25. # 启用详细日志
  26. python -m data_pipeline.qa_generation.qs_generator --output-dir ./output --table-list ./tables.txt --business-context "管理系统" --verbose
  27. """
  28. )
  29. # 可选参数(当使用task-id时,output-dir变为可选)
  30. parser.add_argument(
  31. '--task-id',
  32. help='任务ID,指定后将自动构建输出目录路径 (基础目录/task_id)'
  33. )
  34. parser.add_argument(
  35. '--output-dir',
  36. help='包含DDL和MD文件的输出目录'
  37. )
  38. parser.add_argument(
  39. '--table-list',
  40. required=True,
  41. help='表清单文件路径(用于验证文件数量)'
  42. )
  43. parser.add_argument(
  44. '--business-context',
  45. required=True,
  46. help='业务上下文描述'
  47. )
  48. # 可选参数
  49. parser.add_argument(
  50. '--db-name',
  51. help='数据库名称(用于输出文件命名)'
  52. )
  53. parser.add_argument(
  54. '--verbose', '-v',
  55. action='store_true',
  56. help='启用详细日志输出'
  57. )
  58. parser.add_argument(
  59. '--log-file',
  60. help='日志文件路径'
  61. )
  62. return parser
  63. def resolve_output_directory(args):
  64. """解析输出目录路径"""
  65. if args.output_dir:
  66. # 用户明确指定了输出目录
  67. return args.output_dir
  68. elif args.task_id:
  69. # 使用task_id构建输出目录
  70. from data_pipeline.config import SCHEMA_TOOLS_CONFIG
  71. base_dir = SCHEMA_TOOLS_CONFIG.get("output_directory", "./data_pipeline/training_data/")
  72. # 处理相对路径
  73. from pathlib import Path
  74. if not Path(base_dir).is_absolute():
  75. # 相对于项目根目录解析
  76. project_root = Path(__file__).parent.parent.parent
  77. base_dir = project_root / base_dir
  78. return str(Path(base_dir) / args.task_id)
  79. else:
  80. # 没有指定输出目录或task_id
  81. return None
  82. async def main():
  83. """主入口函数"""
  84. parser = setup_argument_parser()
  85. args = parser.parse_args()
  86. # 设置日志
  87. setup_logging(
  88. verbose=args.verbose,
  89. log_file=args.log_file
  90. )
  91. # 解析输出目录
  92. output_dir = resolve_output_directory(args)
  93. # 验证参数
  94. if not output_dir:
  95. print("错误: 需要指定 --output-dir 或 --task-id 参数")
  96. parser.print_help()
  97. sys.exit(1)
  98. output_path = Path(output_dir)
  99. if not output_path.exists():
  100. print(f"错误: 输出目录不存在: {output_dir}")
  101. sys.exit(1)
  102. if not os.path.exists(args.table_list):
  103. print(f"错误: 表清单文件不存在: {args.table_list}")
  104. sys.exit(1)
  105. try:
  106. # 创建Agent
  107. agent = QuestionSQLGenerationAgent(
  108. output_dir=output_dir,
  109. table_list_file=args.table_list,
  110. business_context=args.business_context,
  111. db_name=args.db_name,
  112. task_id=args.task_id # 传递task_id
  113. )
  114. # 执行生成
  115. print(f"🚀 开始生成Question-SQL训练数据...")
  116. print(f"📁 输出目录: {output_dir}")
  117. print(f"📋 表清单: {args.table_list}")
  118. print(f"🏢 业务背景: {args.business_context}")
  119. report = await agent.generate()
  120. # 输出结果
  121. if report['success']:
  122. if report['failed_themes']:
  123. print(f"\n⚠️ 生成完成,但有 {len(report['failed_themes'])} 个主题失败")
  124. exit_code = 2 # 部分成功
  125. else:
  126. print("\n🎉 所有主题生成成功!")
  127. exit_code = 0 # 完全成功
  128. else:
  129. print("\n❌ 生成失败")
  130. exit_code = 1
  131. print(f"📁 输出文件: {report['output_file']}")
  132. sys.exit(exit_code)
  133. except KeyboardInterrupt:
  134. print("\n\n⏹️ 用户中断,程序退出")
  135. sys.exit(130)
  136. except Exception as e:
  137. print(f"\n❌ 程序执行失败: {e}")
  138. if args.verbose:
  139. import traceback
  140. traceback.print_exc()
  141. sys.exit(1)
  142. if __name__ == "__main__":
  143. asyncio.run(main())