ddl_md_generator.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. """
  2. DDL和MD文档生成器命令行入口
  3. 用于从PostgreSQL数据库生成DDL和MD训练数据
  4. """
  5. import argparse
  6. import asyncio
  7. import sys
  8. import os
  9. import logging
  10. from pathlib import Path
  11. def setup_argument_parser():
  12. """设置命令行参数解析器"""
  13. parser = argparse.ArgumentParser(
  14. description='DDL/MD文档生成器 - 从PostgreSQL数据库生成训练数据',
  15. formatter_class=argparse.RawDescriptionHelpFormatter,
  16. epilog="""
  17. 示例用法:
  18. # 基本使用
  19. python -m data_pipeline.ddl_md_generator --db-connection "postgresql://user:pass@host:5432/db" --table-list tables.txt --business-context "电商系统"
  20. # 指定输出目录
  21. python -m data_pipeline.ddl_md_generator --db-connection "..." --table-list tables.txt --business-context "电商系统" --output-dir ./data_pipeline/training_data/
  22. # 仅生成DDL文件
  23. python -m data_pipeline.ddl_md_generator --db-connection "..." --table-list tables.txt --business-context "电商系统" --pipeline ddl_only
  24. # 权限检查模式
  25. python -m data_pipeline.ddl_md_generator --db-connection "..." --check-permissions-only
  26. """
  27. )
  28. # 必需参数
  29. parser.add_argument(
  30. '--db-connection',
  31. required=True,
  32. help='数据库连接字符串 (例如: postgresql://user:pass@localhost:5432/dbname)'
  33. )
  34. # 可选参数
  35. parser.add_argument(
  36. '--table-list',
  37. help='表清单文件路径'
  38. )
  39. parser.add_argument(
  40. '--business-context',
  41. help='业务上下文描述'
  42. )
  43. parser.add_argument(
  44. '--business-context-file',
  45. help='业务上下文文件路径'
  46. )
  47. parser.add_argument(
  48. '--output-dir',
  49. help='输出目录路径'
  50. )
  51. parser.add_argument(
  52. '--pipeline',
  53. choices=['full', 'ddl_only', 'analysis_only'],
  54. help='处理链类型'
  55. )
  56. parser.add_argument(
  57. '--max-concurrent',
  58. type=int,
  59. help='最大并发表数量'
  60. )
  61. # 功能开关
  62. parser.add_argument(
  63. '--no-filter-system-tables',
  64. action='store_true',
  65. help='禁用系统表过滤'
  66. )
  67. parser.add_argument(
  68. '--check-permissions-only',
  69. action='store_true',
  70. help='仅检查数据库权限,不处理表'
  71. )
  72. parser.add_argument(
  73. '--verbose', '-v',
  74. action='store_true',
  75. help='启用详细日志输出'
  76. )
  77. parser.add_argument(
  78. '--log-file',
  79. help='日志文件路径'
  80. )
  81. return parser
  82. def load_config_with_overrides(args):
  83. """加载配置并应用命令行覆盖"""
  84. from data_pipeline.config import SCHEMA_TOOLS_CONFIG
  85. config = SCHEMA_TOOLS_CONFIG.copy()
  86. # 命令行参数覆盖配置
  87. if args.output_dir:
  88. config["output_directory"] = args.output_dir
  89. if args.pipeline:
  90. config["default_pipeline"] = args.pipeline
  91. if args.max_concurrent:
  92. config["max_concurrent_tables"] = args.max_concurrent
  93. if args.no_filter_system_tables:
  94. config["filter_system_tables"] = False
  95. if args.log_file:
  96. config["log_file"] = args.log_file
  97. return config
  98. def load_business_context(args):
  99. """加载业务上下文"""
  100. if args.business_context_file:
  101. try:
  102. with open(args.business_context_file, 'r', encoding='utf-8') as f:
  103. return f.read().strip()
  104. except Exception as e:
  105. print(f"警告: 无法读取业务上下文文件 {args.business_context_file}: {e}")
  106. if args.business_context:
  107. return args.business_context
  108. from data_pipeline.config import SCHEMA_TOOLS_CONFIG
  109. return SCHEMA_TOOLS_CONFIG.get("default_business_context", "数据库管理系统")
  110. async def check_permissions_only(db_connection: str):
  111. """仅检查数据库权限"""
  112. from .training_data_agent import SchemaTrainingDataAgent
  113. print("🔍 检查数据库权限...")
  114. try:
  115. agent = SchemaTrainingDataAgent(
  116. db_connection=db_connection,
  117. table_list_file="", # 不需要表清单
  118. business_context="" # 不需要业务上下文
  119. )
  120. # 初始化Agent以建立数据库连接
  121. await agent._initialize()
  122. # 检查权限
  123. permissions = await agent.check_database_permissions()
  124. print("\n📋 权限检查结果:")
  125. print(f" ✅ 数据库连接: {'可用' if permissions['connect'] else '不可用'}")
  126. print(f" ✅ 元数据查询: {'可用' if permissions['select_metadata'] else '不可用'}")
  127. print(f" ✅ 数据查询: {'可用' if permissions['select_data'] else '不可用'}")
  128. print(f" ℹ️ 数据库类型: {'只读' if permissions['is_readonly'] else '读写'}")
  129. # 修复判断逻辑:is_readonly=False表示可读写,是好事
  130. required_permissions = ['connect', 'select_metadata', 'select_data']
  131. has_required_permissions = all(permissions.get(perm, False) for perm in required_permissions)
  132. if has_required_permissions:
  133. print("\n✅ 数据库权限检查通过,可以开始处理")
  134. return True
  135. else:
  136. print("\n❌ 数据库权限不足,请检查配置")
  137. return False
  138. except Exception as e:
  139. print(f"\n❌ 权限检查失败: {e}")
  140. return False
  141. async def main():
  142. """主入口函数"""
  143. parser = setup_argument_parser()
  144. args = parser.parse_args()
  145. # 设置日志
  146. from data_pipeline.utils.logger import setup_logging
  147. setup_logging(
  148. verbose=args.verbose,
  149. log_file=args.log_file
  150. )
  151. # 仅权限检查模式
  152. if args.check_permissions_only:
  153. success = await check_permissions_only(args.db_connection)
  154. sys.exit(0 if success else 1)
  155. # 验证必需参数
  156. if not args.table_list:
  157. print("错误: 需要指定 --table-list 参数")
  158. parser.print_help()
  159. sys.exit(1)
  160. if not os.path.exists(args.table_list):
  161. print(f"错误: 表清单文件不存在: {args.table_list}")
  162. sys.exit(1)
  163. try:
  164. # 加载配置和业务上下文
  165. config = load_config_with_overrides(args)
  166. business_context = load_business_context(args)
  167. # 创建Agent
  168. from .training_data_agent import SchemaTrainingDataAgent
  169. agent = SchemaTrainingDataAgent(
  170. db_connection=args.db_connection,
  171. table_list_file=args.table_list,
  172. business_context=business_context,
  173. output_dir=config["output_directory"],
  174. pipeline=config["default_pipeline"]
  175. )
  176. # 执行生成
  177. print("🚀 开始生成DDL和MD文档...")
  178. report = await agent.generate_training_data()
  179. # 输出结果
  180. if report['summary']['failed'] == 0:
  181. print("\n🎉 所有表处理成功!")
  182. else:
  183. print(f"\n⚠️ 处理完成,但有 {report['summary']['failed']} 个表失败")
  184. print(f"📁 输出目录: {config['output_directory']}")
  185. # 如果有失败的表,返回非零退出码
  186. sys.exit(1 if report['summary']['failed'] > 0 else 0)
  187. except KeyboardInterrupt:
  188. print("\n\n⏹️ 用户中断,程序退出")
  189. sys.exit(130)
  190. except Exception as e:
  191. print(f"\n❌ 程序执行失败: {e}")
  192. if args.verbose:
  193. import traceback
  194. traceback.print_exc()
  195. sys.exit(1)
  196. if __name__ == "__main__":
  197. asyncio.run(main())