schema_workflow.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707
  1. """
  2. Schema工作流编排器
  3. 统一管理完整的数据库Schema处理流程
  4. """
  5. import asyncio
  6. import time
  7. import logging
  8. from typing import Dict, Any, List, Optional
  9. from pathlib import Path
  10. from datetime import datetime
  11. from data_pipeline.ddl_generation.training_data_agent import SchemaTrainingDataAgent
  12. from data_pipeline.qa_generation.qs_agent import QuestionSQLGenerationAgent
  13. from data_pipeline.validators.sql_validation_agent import SQLValidationAgent
  14. from data_pipeline.config import SCHEMA_TOOLS_CONFIG
  15. from data_pipeline.utils.logger import setup_logging
  16. class SchemaWorkflowOrchestrator:
  17. """端到端的Schema处理编排器 - 完整工作流程"""
  18. def __init__(self,
  19. db_connection: str,
  20. table_list_file: str,
  21. business_context: str,
  22. output_dir: str = None,
  23. enable_sql_validation: bool = True,
  24. enable_llm_repair: bool = True,
  25. modify_original_file: bool = True,
  26. enable_training_data_load: bool = True):
  27. """
  28. 初始化Schema工作流编排器
  29. Args:
  30. db_connection: 数据库连接字符串 (postgresql://user:pass@host:port/dbname)
  31. table_list_file: 表清单文件路径
  32. business_context: 业务上下文描述
  33. output_dir: 输出目录
  34. enable_sql_validation: 是否启用SQL验证
  35. enable_llm_repair: 是否启用LLM修复功能
  36. modify_original_file: 是否修改原始JSON文件
  37. enable_training_data_load: 是否启用训练数据加载
  38. """
  39. self.db_connection = db_connection
  40. self.table_list_file = table_list_file
  41. self.business_context = business_context
  42. self.db_name = self._extract_db_name_from_connection(db_connection)
  43. self.output_dir = Path(output_dir) if output_dir else Path("./output")
  44. self.enable_sql_validation = enable_sql_validation
  45. self.enable_llm_repair = enable_llm_repair
  46. self.modify_original_file = modify_original_file
  47. self.enable_training_data_load = enable_training_data_load
  48. # 确保输出目录存在
  49. self.output_dir.mkdir(parents=True, exist_ok=True)
  50. # 初始化日志
  51. self.logger = logging.getLogger("schema_tools.SchemaWorkflowOrchestrator")
  52. # 工作流程状态
  53. self.workflow_state = {
  54. "start_time": None,
  55. "end_time": None,
  56. "current_step": None,
  57. "completed_steps": [],
  58. "failed_steps": [],
  59. "artifacts": {}, # 存储各步骤产生的文件
  60. "statistics": {}
  61. }
  62. def _extract_db_name_from_connection(self, connection_string: str) -> str:
  63. """
  64. 从数据库连接字符串中提取数据库名称
  65. Args:
  66. connection_string: PostgreSQL连接字符串
  67. Returns:
  68. str: 数据库名称
  69. """
  70. try:
  71. # 处理标准的PostgreSQL连接字符串: postgresql://user:pass@host:port/dbname
  72. if '/' in connection_string:
  73. # 取最后一个 '/' 后面的部分作为数据库名
  74. db_name = connection_string.split('/')[-1]
  75. # 移除可能的查询参数
  76. if '?' in db_name:
  77. db_name = db_name.split('?')[0]
  78. return db_name if db_name else "database"
  79. else:
  80. return "database"
  81. except Exception:
  82. return "database"
  83. async def execute_complete_workflow(self) -> Dict[str, Any]:
  84. """
  85. 执行完整的Schema处理工作流程
  86. Returns:
  87. 完整的工作流程报告
  88. """
  89. self.workflow_state["start_time"] = time.time()
  90. self.logger.info("🚀 开始执行Schema工作流编排")
  91. self.logger.info(f"📁 输出目录: {self.output_dir}")
  92. self.logger.info(f"🏢 业务背景: {self.business_context}")
  93. self.logger.info(f"💾 数据库: {self.db_name}")
  94. try:
  95. # 步骤1: 生成DDL和MD文件
  96. await self._execute_step_1_ddl_md_generation()
  97. # 步骤2: 生成Question-SQL对
  98. await self._execute_step_2_question_sql_generation()
  99. # 步骤3: 验证和修正SQL(可选)
  100. if self.enable_sql_validation:
  101. await self._execute_step_3_sql_validation()
  102. else:
  103. self.logger.info("⏭️ 跳过SQL验证步骤")
  104. # 步骤4: 训练数据加载(可选)
  105. if self.enable_training_data_load:
  106. await self._execute_step_4_training_data_load()
  107. else:
  108. self.logger.info("⏭️ 跳过训练数据加载步骤")
  109. # 设置结束时间
  110. self.workflow_state["end_time"] = time.time()
  111. # 生成最终报告
  112. final_report = await self._generate_final_report()
  113. self.logger.info("✅ Schema工作流编排完成")
  114. return final_report
  115. except Exception as e:
  116. self.workflow_state["end_time"] = time.time()
  117. self.logger.exception(f"❌ 工作流程执行失败: {str(e)}")
  118. error_report = await self._generate_error_report(e)
  119. return error_report
  120. async def _execute_step_1_ddl_md_generation(self):
  121. """步骤1: 生成DDL和MD文件"""
  122. self.workflow_state["current_step"] = "ddl_md_generation"
  123. self.logger.info("=" * 60)
  124. self.logger.info("📝 步骤1: 开始生成DDL和MD文件")
  125. self.logger.info("=" * 60)
  126. step_start_time = time.time()
  127. try:
  128. # 创建DDL/MD生成Agent
  129. ddl_md_agent = SchemaTrainingDataAgent(
  130. db_connection=self.db_connection,
  131. table_list_file=self.table_list_file,
  132. business_context=self.business_context,
  133. output_dir=str(self.output_dir),
  134. pipeline="full"
  135. )
  136. # 执行DDL/MD生成
  137. ddl_md_result = await ddl_md_agent.generate_training_data()
  138. step_duration = time.time() - step_start_time
  139. # 记录结果
  140. self.workflow_state["completed_steps"].append("ddl_md_generation")
  141. self.workflow_state["artifacts"]["ddl_md_generation"] = {
  142. "total_tables": ddl_md_result.get("summary", {}).get("total_tables", 0),
  143. "processed_successfully": ddl_md_result.get("summary", {}).get("processed_successfully", 0),
  144. "failed": ddl_md_result.get("summary", {}).get("failed", 0),
  145. "files_generated": ddl_md_result.get("statistics", {}).get("files_generated", 0),
  146. "duration": step_duration
  147. }
  148. self.workflow_state["statistics"]["step1_duration"] = step_duration
  149. processed_tables = ddl_md_result.get("summary", {}).get("processed_successfully", 0)
  150. self.logger.info(f"✅ 步骤1完成: 成功处理 {processed_tables} 个表,耗时 {step_duration:.2f}秒")
  151. except Exception as e:
  152. self.workflow_state["failed_steps"].append("ddl_md_generation")
  153. self.logger.error(f"❌ 步骤1失败: {str(e)}")
  154. raise
  155. async def _execute_step_2_question_sql_generation(self):
  156. """步骤2: 生成Question-SQL对"""
  157. self.workflow_state["current_step"] = "question_sql_generation"
  158. self.logger.info("=" * 60)
  159. self.logger.info("🤖 步骤2: 开始生成Question-SQL对")
  160. self.logger.info("=" * 60)
  161. step_start_time = time.time()
  162. try:
  163. # 创建Question-SQL生成Agent
  164. qs_agent = QuestionSQLGenerationAgent(
  165. output_dir=str(self.output_dir),
  166. table_list_file=self.table_list_file,
  167. business_context=self.business_context,
  168. db_name=self.db_name
  169. )
  170. # 执行Question-SQL生成
  171. qs_result = await qs_agent.generate()
  172. step_duration = time.time() - step_start_time
  173. # 记录结果
  174. self.workflow_state["completed_steps"].append("question_sql_generation")
  175. self.workflow_state["artifacts"]["question_sql_generation"] = {
  176. "output_file": str(qs_result.get("output_file", "")),
  177. "total_questions": qs_result.get("total_questions", 0),
  178. "total_themes": qs_result.get("total_themes", 0),
  179. "successful_themes": qs_result.get("successful_themes", 0),
  180. "failed_themes": qs_result.get("failed_themes", []),
  181. "duration": step_duration
  182. }
  183. self.workflow_state["statistics"]["step2_duration"] = step_duration
  184. total_questions = qs_result.get("total_questions", 0)
  185. self.logger.info(f"✅ 步骤2完成: 生成了 {total_questions} 个问答对,耗时 {step_duration:.2f}秒")
  186. except Exception as e:
  187. self.workflow_state["failed_steps"].append("question_sql_generation")
  188. self.logger.error(f"❌ 步骤2失败: {str(e)}")
  189. raise
  190. async def _execute_step_3_sql_validation(self):
  191. """步骤3: 验证和修正SQL"""
  192. self.workflow_state["current_step"] = "sql_validation"
  193. self.logger.info("=" * 60)
  194. self.logger.info("🔍 步骤3: 开始验证和修正SQL")
  195. self.logger.info("=" * 60)
  196. step_start_time = time.time()
  197. try:
  198. # 获取步骤2生成的文件
  199. qs_artifacts = self.workflow_state["artifacts"].get("question_sql_generation", {})
  200. qs_file = qs_artifacts.get("output_file")
  201. if not qs_file or not Path(qs_file).exists():
  202. raise FileNotFoundError(f"找不到Question-SQL文件: {qs_file}")
  203. self.logger.info(f"📄 验证文件: {qs_file}")
  204. # 创建SQL验证Agent,通过参数传递配置而非修改全局配置
  205. sql_validator = SQLValidationAgent(
  206. db_connection=self.db_connection,
  207. input_file=str(qs_file),
  208. output_dir=str(self.output_dir),
  209. enable_sql_repair=self.enable_llm_repair,
  210. modify_original_file=self.modify_original_file
  211. )
  212. # 执行SQL验证和修正
  213. validation_result = await sql_validator.validate()
  214. step_duration = time.time() - step_start_time
  215. # 记录结果
  216. self.workflow_state["completed_steps"].append("sql_validation")
  217. summary = validation_result.get("summary", {})
  218. self.workflow_state["artifacts"]["sql_validation"] = {
  219. "original_sql_count": summary.get("total_questions", 0),
  220. "valid_sql_count": summary.get("valid_sqls", 0),
  221. "invalid_sql_count": summary.get("invalid_sqls", 0),
  222. "success_rate": summary.get("success_rate", 0),
  223. "repair_stats": summary.get("repair_stats", {}),
  224. "file_modification_stats": summary.get("file_modification_stats", {}),
  225. "average_execution_time": summary.get("average_execution_time", 0),
  226. "total_retries": summary.get("total_retries", 0),
  227. "duration": step_duration
  228. }
  229. self.workflow_state["statistics"]["step3_duration"] = step_duration
  230. success_rate = summary.get("success_rate", 0)
  231. valid_count = summary.get("valid_sqls", 0)
  232. total_count = summary.get("total_questions", 0)
  233. self.logger.info(f"✅ 步骤3完成: SQL验证成功率 {success_rate:.1%} ({valid_count}/{total_count}),耗时 {step_duration:.2f}秒")
  234. # 显示修复统计
  235. repair_stats = summary.get("repair_stats", {})
  236. if repair_stats.get("attempted", 0) > 0:
  237. self.logger.info(f"🔧 修复统计: 尝试 {repair_stats['attempted']},成功 {repair_stats['successful']},失败 {repair_stats['failed']}")
  238. # 显示文件修改统计
  239. file_stats = summary.get("file_modification_stats", {})
  240. if file_stats.get("modified", 0) > 0 or file_stats.get("deleted", 0) > 0:
  241. self.logger.info(f"📝 文件修改: 更新 {file_stats.get('modified', 0)} 个SQL,删除 {file_stats.get('deleted', 0)} 个无效项")
  242. except Exception as e:
  243. self.workflow_state["failed_steps"].append("sql_validation")
  244. self.logger.error(f"❌ 步骤3失败: {str(e)}")
  245. raise
  246. async def _execute_step_4_training_data_load(self):
  247. """步骤4: 训练数据加载"""
  248. self.workflow_state["current_step"] = "training_data_load"
  249. self.logger.info("=" * 60)
  250. self.logger.info("🎯 步骤4: 开始加载训练数据")
  251. self.logger.info("=" * 60)
  252. step_start_time = time.time()
  253. try:
  254. # 确保输出目录存在所需的训练数据
  255. training_data_dir = str(self.output_dir)
  256. self.logger.info(f"📁 训练数据目录: {training_data_dir}")
  257. # 导入训练器模块
  258. import sys
  259. import os
  260. sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  261. from data_pipeline.trainer.run_training import process_training_files
  262. # 执行训练数据加载
  263. self.logger.info("🔄 开始处理训练文件...")
  264. load_successful = process_training_files(training_data_dir)
  265. step_duration = time.time() - step_start_time
  266. if load_successful:
  267. # 获取统计信息
  268. from data_pipeline.trainer.vanna_trainer import flush_training, shutdown_trainer
  269. # 刷新批处理器
  270. self.logger.info("🔄 刷新批处理器...")
  271. flush_training()
  272. shutdown_trainer()
  273. # 验证加载结果
  274. try:
  275. from core.vanna_llm_factory import create_vanna_instance
  276. vn = create_vanna_instance()
  277. training_data = vn.get_training_data()
  278. if training_data is not None and not training_data.empty:
  279. total_records = len(training_data)
  280. self.logger.info(f"✅ 成功加载 {total_records} 条训练数据")
  281. # 统计数据类型
  282. if 'training_data_type' in training_data.columns:
  283. type_counts = training_data['training_data_type'].value_counts().to_dict()
  284. else:
  285. type_counts = {}
  286. else:
  287. total_records = 0
  288. type_counts = {}
  289. self.logger.warning("⚠️ 未能验证训练数据加载结果")
  290. except Exception as e:
  291. self.logger.warning(f"⚠️ 验证训练数据时出错: {e}")
  292. total_records = 0
  293. type_counts = {}
  294. # 记录结果
  295. self.workflow_state["completed_steps"].append("training_data_load")
  296. self.workflow_state["artifacts"]["training_data_load"] = {
  297. "training_data_dir": training_data_dir,
  298. "load_successful": True,
  299. "total_records": total_records,
  300. "data_type_counts": type_counts,
  301. "duration": step_duration
  302. }
  303. self.workflow_state["statistics"]["step4_duration"] = step_duration
  304. self.logger.info(f"✅ 步骤4完成: 成功加载训练数据,耗时 {step_duration:.2f}秒")
  305. else:
  306. raise Exception("训练数据加载失败:未找到可处理的训练文件")
  307. except Exception as e:
  308. self.workflow_state["failed_steps"].append("training_data_load")
  309. self.logger.error(f"❌ 步骤4失败: {str(e)}")
  310. raise
  311. async def _generate_final_report(self) -> Dict[str, Any]:
  312. """生成最终工作流程报告"""
  313. total_duration = self.workflow_state["end_time"] - self.workflow_state["start_time"]
  314. # 计算最终输出文件
  315. qs_artifacts = self.workflow_state["artifacts"].get("question_sql_generation", {})
  316. final_output_file = qs_artifacts.get("output_file", "")
  317. # 计算最终问题数量
  318. if "sql_validation" in self.workflow_state["artifacts"]:
  319. # 如果有验证步骤,使用验证后的数量
  320. validation_artifacts = self.workflow_state["artifacts"]["sql_validation"]
  321. final_question_count = validation_artifacts.get("valid_sql_count", 0)
  322. else:
  323. # 否则使用生成的数量
  324. final_question_count = qs_artifacts.get("total_questions", 0)
  325. report = {
  326. "success": True,
  327. "workflow_summary": {
  328. "total_duration": round(total_duration, 2),
  329. "completed_steps": self.workflow_state["completed_steps"],
  330. "failed_steps": self.workflow_state["failed_steps"],
  331. "total_steps": len(self.workflow_state["completed_steps"]),
  332. "workflow_started": datetime.fromtimestamp(self.workflow_state["start_time"]).isoformat(),
  333. "workflow_completed": datetime.fromtimestamp(self.workflow_state["end_time"]).isoformat()
  334. },
  335. "input_parameters": {
  336. "db_connection": self._mask_connection_string(self.db_connection),
  337. "table_list_file": self.table_list_file,
  338. "business_context": self.business_context,
  339. "db_name": self.db_name,
  340. "output_directory": str(self.output_dir),
  341. "enable_sql_validation": self.enable_sql_validation,
  342. "enable_llm_repair": self.enable_llm_repair,
  343. "modify_original_file": self.modify_original_file,
  344. "enable_training_data_load": self.enable_training_data_load
  345. },
  346. "processing_results": {
  347. "ddl_md_generation": self.workflow_state["artifacts"].get("ddl_md_generation", {}),
  348. "question_sql_generation": self.workflow_state["artifacts"].get("question_sql_generation", {}),
  349. "sql_validation": self.workflow_state["artifacts"].get("sql_validation", {}),
  350. "training_data_load": self.workflow_state["artifacts"].get("training_data_load", {})
  351. },
  352. "final_outputs": {
  353. "primary_output_file": final_output_file,
  354. "output_directory": str(self.output_dir),
  355. "final_question_count": final_question_count,
  356. "backup_files_created": self.modify_original_file
  357. },
  358. "performance_metrics": {
  359. "step1_duration": round(self.workflow_state["statistics"].get("step1_duration", 0), 2),
  360. "step2_duration": round(self.workflow_state["statistics"].get("step2_duration", 0), 2),
  361. "step3_duration": round(self.workflow_state["statistics"].get("step3_duration", 0), 2),
  362. "step4_duration": round(self.workflow_state["statistics"].get("step4_duration", 0), 2),
  363. "total_duration": round(total_duration, 2)
  364. }
  365. }
  366. return report
  367. async def _generate_error_report(self, error: Exception) -> Dict[str, Any]:
  368. """生成错误报告"""
  369. total_duration = self.workflow_state["end_time"] - self.workflow_state["start_time"]
  370. return {
  371. "success": False,
  372. "error": {
  373. "message": str(error),
  374. "type": type(error).__name__,
  375. "failed_step": self.workflow_state["current_step"]
  376. },
  377. "workflow_summary": {
  378. "total_duration": round(total_duration, 2),
  379. "completed_steps": self.workflow_state["completed_steps"],
  380. "failed_steps": self.workflow_state["failed_steps"],
  381. "workflow_started": datetime.fromtimestamp(self.workflow_state["start_time"]).isoformat() if self.workflow_state["start_time"] else None,
  382. "workflow_failed": datetime.fromtimestamp(self.workflow_state["end_time"]).isoformat() if self.workflow_state["end_time"] else None
  383. },
  384. "partial_results": self.workflow_state["artifacts"],
  385. "input_parameters": {
  386. "db_connection": self._mask_connection_string(self.db_connection),
  387. "table_list_file": self.table_list_file,
  388. "business_context": self.business_context,
  389. "db_name": self.db_name,
  390. "output_directory": str(self.output_dir)
  391. }
  392. }
  393. def _mask_connection_string(self, conn_str: str) -> str:
  394. """隐藏连接字符串中的敏感信息"""
  395. import re
  396. return re.sub(r':[^:@]+@', ':***@', conn_str)
  397. def print_final_summary(self, report: Dict[str, Any]):
  398. """打印最终摘要"""
  399. self.logger.info("=" * 80)
  400. self.logger.info("📊 工作流程执行摘要")
  401. self.logger.info("=" * 80)
  402. if report["success"]:
  403. summary = report["workflow_summary"]
  404. results = report["processing_results"]
  405. outputs = report["final_outputs"]
  406. metrics = report["performance_metrics"]
  407. self.logger.info(f"✅ 工作流程执行成功")
  408. self.logger.info(f"⏱️ 总耗时: {summary['total_duration']} 秒")
  409. self.logger.info(f"📝 完成步骤: {len(summary['completed_steps'])}/{summary['total_steps']}")
  410. # DDL/MD生成结果
  411. if "ddl_md_generation" in results:
  412. ddl_md = results["ddl_md_generation"]
  413. self.logger.info(f"📋 DDL/MD生成: {ddl_md.get('processed_successfully', 0)} 个表成功处理")
  414. # Question-SQL生成结果
  415. if "question_sql_generation" in results:
  416. qs = results["question_sql_generation"]
  417. self.logger.info(f"🤖 Question-SQL生成: {qs.get('total_questions', 0)} 个问答对")
  418. # SQL验证结果
  419. if "sql_validation" in results:
  420. validation = results["sql_validation"]
  421. success_rate = validation.get('success_rate', 0)
  422. self.logger.info(f"🔍 SQL验证: {success_rate:.1%} 成功率 ({validation.get('valid_sql_count', 0)}/{validation.get('original_sql_count', 0)})")
  423. self.logger.info(f"📁 输出目录: {outputs['output_directory']}")
  424. self.logger.info(f"📄 主要输出文件: {outputs['primary_output_file']}")
  425. self.logger.info(f"❓ 最终问题数量: {outputs['final_question_count']}")
  426. else:
  427. error = report["error"]
  428. summary = report["workflow_summary"]
  429. self.logger.error(f"❌ 工作流程执行失败")
  430. self.logger.error(f"💥 失败原因: {error['message']}")
  431. self.logger.error(f"💥 失败步骤: {error['failed_step']}")
  432. self.logger.error(f"⏱️ 执行耗时: {summary['total_duration']} 秒")
  433. self.logger.error(f"✅ 已完成步骤: {', '.join(summary['completed_steps']) if summary['completed_steps'] else '无'}")
  434. self.logger.info("=" * 80)
  435. # 便捷的命令行接口
  436. def setup_argument_parser():
  437. """设置命令行参数解析器"""
  438. import argparse
  439. parser = argparse.ArgumentParser(
  440. description="Schema工作流编排器 - 端到端的Schema处理流程",
  441. formatter_class=argparse.RawDescriptionHelpFormatter,
  442. epilog="""
  443. 示例用法:
  444. # 完整工作流程
  445. python -m data_pipeline.schema_workflow \\
  446. --db-connection "postgresql://user:pass@localhost:5432/highway_db" \\
  447. --table-list tables.txt \\
  448. --business-context "高速公路服务区管理系统" \\
  449. --output-dir ./data_pipeline/training_data/
  450. # 跳过SQL验证
  451. python -m data_pipeline.schema_workflow \\
  452. --db-connection "postgresql://user:pass@localhost:5432/ecommerce_db" \\
  453. --table-list tables.txt \\
  454. --business-context "电商系统" \\
  455. --skip-validation
  456. # 禁用LLM修复
  457. python -m data_pipeline.schema_workflow \\
  458. --db-connection "postgresql://user:pass@localhost:5432/management_db" \\
  459. --table-list tables.txt \\
  460. --business-context "管理系统" \\
  461. --disable-llm-repair
  462. # 跳过训练数据加载
  463. python -m data_pipeline.schema_workflow \\
  464. --db-connection "postgresql://user:pass@localhost:5432/management_db" \\
  465. --table-list tables.txt \\
  466. --business-context "管理系统" \\
  467. --skip-training-load
  468. """
  469. )
  470. # 必需参数
  471. parser.add_argument(
  472. "--db-connection",
  473. required=True,
  474. help="数据库连接字符串 (postgresql://user:pass@host:port/dbname)"
  475. )
  476. parser.add_argument(
  477. "--table-list",
  478. required=True,
  479. help="表清单文件路径"
  480. )
  481. parser.add_argument(
  482. "--business-context",
  483. required=True,
  484. help="业务上下文描述"
  485. )
  486. # 可选参数
  487. parser.add_argument(
  488. "--output-dir",
  489. default="./data_pipeline/training_data/",
  490. help="输出目录(默认:./data_pipeline/training_data/)"
  491. )
  492. parser.add_argument(
  493. "--skip-validation",
  494. action="store_true",
  495. help="跳过SQL验证步骤"
  496. )
  497. parser.add_argument(
  498. "--disable-llm-repair",
  499. action="store_true",
  500. help="禁用LLM修复功能"
  501. )
  502. parser.add_argument(
  503. "--no-modify-file",
  504. action="store_true",
  505. help="不修改原始JSON文件(仅生成报告)"
  506. )
  507. parser.add_argument(
  508. "--skip-training-load",
  509. action="store_true",
  510. help="跳过训练数据加载步骤"
  511. )
  512. parser.add_argument(
  513. "--verbose", "-v",
  514. action="store_true",
  515. help="启用详细日志输出"
  516. )
  517. parser.add_argument(
  518. "--log-file",
  519. help="日志文件路径"
  520. )
  521. return parser
  522. async def main():
  523. """命令行入口点"""
  524. import sys
  525. import os
  526. parser = setup_argument_parser()
  527. args = parser.parse_args()
  528. # 设置日志
  529. setup_logging(
  530. verbose=args.verbose,
  531. log_file=args.log_file
  532. )
  533. # 验证输入文件
  534. if not os.path.exists(args.table_list):
  535. print(f"错误: 表清单文件不存在: {args.table_list}")
  536. sys.exit(1)
  537. try:
  538. # 创建并执行工作流编排器
  539. orchestrator = SchemaWorkflowOrchestrator(
  540. db_connection=args.db_connection,
  541. table_list_file=args.table_list,
  542. business_context=args.business_context,
  543. output_dir=args.output_dir,
  544. enable_sql_validation=not args.skip_validation,
  545. enable_llm_repair=not args.disable_llm_repair,
  546. modify_original_file=not args.no_modify_file,
  547. enable_training_data_load=not args.skip_training_load
  548. )
  549. # 显示启动信息
  550. print(f"🚀 开始执行Schema工作流编排...")
  551. print(f"📁 输出目录: {args.output_dir}")
  552. print(f"📋 表清单: {args.table_list}")
  553. print(f"🏢 业务背景: {args.business_context}")
  554. print(f"💾 数据库: {orchestrator.db_name}")
  555. print(f"🔍 SQL验证: {'启用' if not args.skip_validation else '禁用'}")
  556. print(f"🔧 LLM修复: {'启用' if not args.disable_llm_repair else '禁用'}")
  557. print(f"🎯 训练数据加载: {'启用' if not args.skip_training_load else '禁用'}")
  558. # 执行完整工作流程
  559. report = await orchestrator.execute_complete_workflow()
  560. # 打印详细摘要
  561. orchestrator.print_final_summary(report)
  562. # 输出结果并设置退出码
  563. if report["success"]:
  564. if report["processing_results"].get("sql_validation", {}).get("success_rate", 1.0) >= 0.8:
  565. print(f"\n🎉 工作流程执行成功!")
  566. exit_code = 0 # 完全成功
  567. else:
  568. print(f"\n⚠️ 工作流程执行完成,但SQL验证成功率较低")
  569. exit_code = 1 # 部分成功
  570. else:
  571. print(f"\n❌ 工作流程执行失败")
  572. exit_code = 2 # 失败
  573. print(f"📄 主要输出文件: {report['final_outputs']['primary_output_file']}")
  574. sys.exit(exit_code)
  575. except KeyboardInterrupt:
  576. print("\n\n⏹️ 用户中断,程序退出")
  577. sys.exit(130)
  578. except Exception as e:
  579. print(f"\n❌ 程序执行失败: {e}")
  580. if args.verbose:
  581. import traceback
  582. traceback.print_exc()
  583. sys.exit(1)
  584. if __name__ == "__main__":
  585. asyncio.run(main())