schema_workflow_orchestrator.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586
  1. """
  2. Schema工作流编排器
  3. 统一管理完整的数据库Schema处理流程
  4. """
  5. import asyncio
  6. import time
  7. import logging
  8. from typing import Dict, Any, List, Optional
  9. from pathlib import Path
  10. from datetime import datetime
  11. from schema_tools.training_data_agent import SchemaTrainingDataAgent
  12. from schema_tools.qs_agent import QuestionSQLGenerationAgent
  13. from schema_tools.sql_validation_agent import SQLValidationAgent
  14. from schema_tools.config import SCHEMA_TOOLS_CONFIG
  15. from schema_tools.utils.logger import setup_logging
  16. class SchemaWorkflowOrchestrator:
  17. """端到端的Schema处理编排器 - 完整工作流程"""
  18. def __init__(self,
  19. db_connection: str,
  20. table_list_file: str,
  21. business_context: str,
  22. db_name: str,
  23. output_dir: str = None,
  24. enable_sql_validation: bool = True,
  25. enable_llm_repair: bool = True,
  26. modify_original_file: bool = True):
  27. """
  28. 初始化Schema工作流编排器
  29. Args:
  30. db_connection: 数据库连接字符串
  31. table_list_file: 表清单文件路径
  32. business_context: 业务上下文描述
  33. db_name: 数据库名称(用于生成文件名)
  34. output_dir: 输出目录
  35. enable_sql_validation: 是否启用SQL验证
  36. enable_llm_repair: 是否启用LLM修复功能
  37. modify_original_file: 是否修改原始JSON文件
  38. """
  39. self.db_connection = db_connection
  40. self.table_list_file = table_list_file
  41. self.business_context = business_context
  42. self.db_name = db_name
  43. self.output_dir = Path(output_dir) if output_dir else Path("./output")
  44. self.enable_sql_validation = enable_sql_validation
  45. self.enable_llm_repair = enable_llm_repair
  46. self.modify_original_file = modify_original_file
  47. # 确保输出目录存在
  48. self.output_dir.mkdir(parents=True, exist_ok=True)
  49. # 初始化日志
  50. self.logger = logging.getLogger("schema_tools.SchemaWorkflowOrchestrator")
  51. # 工作流程状态
  52. self.workflow_state = {
  53. "start_time": None,
  54. "end_time": None,
  55. "current_step": None,
  56. "completed_steps": [],
  57. "failed_steps": [],
  58. "artifacts": {}, # 存储各步骤产生的文件
  59. "statistics": {}
  60. }
  61. async def execute_complete_workflow(self) -> Dict[str, Any]:
  62. """
  63. 执行完整的Schema处理工作流程
  64. Returns:
  65. 完整的工作流程报告
  66. """
  67. self.workflow_state["start_time"] = time.time()
  68. self.logger.info("🚀 开始执行Schema工作流编排")
  69. self.logger.info(f"📁 输出目录: {self.output_dir}")
  70. self.logger.info(f"🏢 业务背景: {self.business_context}")
  71. self.logger.info(f"💾 数据库: {self.db_name}")
  72. try:
  73. # 步骤1: 生成DDL和MD文件
  74. await self._execute_step_1_ddl_md_generation()
  75. # 步骤2: 生成Question-SQL对
  76. await self._execute_step_2_question_sql_generation()
  77. # 步骤3: 验证和修正SQL(可选)
  78. if self.enable_sql_validation:
  79. await self._execute_step_3_sql_validation()
  80. else:
  81. self.logger.info("⏭️ 跳过SQL验证步骤")
  82. # 生成最终报告
  83. final_report = await self._generate_final_report()
  84. self.workflow_state["end_time"] = time.time()
  85. self.logger.info("✅ Schema工作流编排完成")
  86. return final_report
  87. except Exception as e:
  88. self.workflow_state["end_time"] = time.time()
  89. self.logger.exception(f"❌ 工作流程执行失败: {str(e)}")
  90. error_report = await self._generate_error_report(e)
  91. return error_report
  92. async def _execute_step_1_ddl_md_generation(self):
  93. """步骤1: 生成DDL和MD文件"""
  94. self.workflow_state["current_step"] = "ddl_md_generation"
  95. self.logger.info("=" * 60)
  96. self.logger.info("📝 步骤1: 开始生成DDL和MD文件")
  97. self.logger.info("=" * 60)
  98. step_start_time = time.time()
  99. try:
  100. # 创建DDL/MD生成Agent
  101. ddl_md_agent = SchemaTrainingDataAgent(
  102. db_connection=self.db_connection,
  103. table_list_file=self.table_list_file,
  104. business_context=self.business_context,
  105. output_dir=str(self.output_dir),
  106. pipeline="full"
  107. )
  108. # 执行DDL/MD生成
  109. ddl_md_result = await ddl_md_agent.generate_training_data()
  110. step_duration = time.time() - step_start_time
  111. # 记录结果
  112. self.workflow_state["completed_steps"].append("ddl_md_generation")
  113. self.workflow_state["artifacts"]["ddl_md_generation"] = {
  114. "total_tables": ddl_md_result.get("summary", {}).get("total_tables", 0),
  115. "processed_successfully": ddl_md_result.get("summary", {}).get("processed_successfully", 0),
  116. "failed": ddl_md_result.get("summary", {}).get("failed", 0),
  117. "files_generated": ddl_md_result.get("statistics", {}).get("files_generated", 0),
  118. "duration": step_duration
  119. }
  120. self.workflow_state["statistics"]["step1_duration"] = step_duration
  121. processed_tables = ddl_md_result.get("summary", {}).get("processed_successfully", 0)
  122. self.logger.info(f"✅ 步骤1完成: 成功处理 {processed_tables} 个表,耗时 {step_duration:.2f}秒")
  123. except Exception as e:
  124. self.workflow_state["failed_steps"].append("ddl_md_generation")
  125. self.logger.error(f"❌ 步骤1失败: {str(e)}")
  126. raise
  127. async def _execute_step_2_question_sql_generation(self):
  128. """步骤2: 生成Question-SQL对"""
  129. self.workflow_state["current_step"] = "question_sql_generation"
  130. self.logger.info("=" * 60)
  131. self.logger.info("🤖 步骤2: 开始生成Question-SQL对")
  132. self.logger.info("=" * 60)
  133. step_start_time = time.time()
  134. try:
  135. # 创建Question-SQL生成Agent
  136. qs_agent = QuestionSQLGenerationAgent(
  137. output_dir=str(self.output_dir),
  138. table_list_file=self.table_list_file,
  139. business_context=self.business_context,
  140. db_name=self.db_name
  141. )
  142. # 执行Question-SQL生成
  143. qs_result = await qs_agent.generate()
  144. step_duration = time.time() - step_start_time
  145. # 记录结果
  146. self.workflow_state["completed_steps"].append("question_sql_generation")
  147. self.workflow_state["artifacts"]["question_sql_generation"] = {
  148. "output_file": str(qs_result.get("output_file", "")),
  149. "total_questions": qs_result.get("total_questions", 0),
  150. "total_themes": qs_result.get("total_themes", 0),
  151. "successful_themes": qs_result.get("successful_themes", 0),
  152. "failed_themes": qs_result.get("failed_themes", []),
  153. "duration": step_duration
  154. }
  155. self.workflow_state["statistics"]["step2_duration"] = step_duration
  156. total_questions = qs_result.get("total_questions", 0)
  157. self.logger.info(f"✅ 步骤2完成: 生成了 {total_questions} 个问答对,耗时 {step_duration:.2f}秒")
  158. except Exception as e:
  159. self.workflow_state["failed_steps"].append("question_sql_generation")
  160. self.logger.error(f"❌ 步骤2失败: {str(e)}")
  161. raise
  162. async def _execute_step_3_sql_validation(self):
  163. """步骤3: 验证和修正SQL"""
  164. self.workflow_state["current_step"] = "sql_validation"
  165. self.logger.info("=" * 60)
  166. self.logger.info("🔍 步骤3: 开始验证和修正SQL")
  167. self.logger.info("=" * 60)
  168. step_start_time = time.time()
  169. try:
  170. # 获取步骤2生成的文件
  171. qs_artifacts = self.workflow_state["artifacts"].get("question_sql_generation", {})
  172. qs_file = qs_artifacts.get("output_file")
  173. if not qs_file or not Path(qs_file).exists():
  174. raise FileNotFoundError(f"找不到Question-SQL文件: {qs_file}")
  175. self.logger.info(f"📄 验证文件: {qs_file}")
  176. # 动态设置验证配置
  177. SCHEMA_TOOLS_CONFIG['sql_validation']['enable_sql_repair'] = self.enable_llm_repair
  178. SCHEMA_TOOLS_CONFIG['sql_validation']['modify_original_file'] = self.modify_original_file
  179. # 创建SQL验证Agent
  180. sql_validator = SQLValidationAgent(
  181. db_connection=self.db_connection,
  182. input_file=str(qs_file),
  183. output_dir=str(self.output_dir)
  184. )
  185. # 执行SQL验证和修正
  186. validation_result = await sql_validator.validate()
  187. step_duration = time.time() - step_start_time
  188. # 记录结果
  189. self.workflow_state["completed_steps"].append("sql_validation")
  190. summary = validation_result.get("summary", {})
  191. self.workflow_state["artifacts"]["sql_validation"] = {
  192. "original_sql_count": summary.get("total_questions", 0),
  193. "valid_sql_count": summary.get("valid_sqls", 0),
  194. "invalid_sql_count": summary.get("invalid_sqls", 0),
  195. "success_rate": summary.get("success_rate", 0),
  196. "repair_stats": summary.get("repair_stats", {}),
  197. "file_modification_stats": summary.get("file_modification_stats", {}),
  198. "average_execution_time": summary.get("average_execution_time", 0),
  199. "total_retries": summary.get("total_retries", 0),
  200. "duration": step_duration
  201. }
  202. self.workflow_state["statistics"]["step3_duration"] = step_duration
  203. success_rate = summary.get("success_rate", 0)
  204. valid_count = summary.get("valid_sqls", 0)
  205. total_count = summary.get("total_questions", 0)
  206. self.logger.info(f"✅ 步骤3完成: SQL验证成功率 {success_rate:.1%} ({valid_count}/{total_count}),耗时 {step_duration:.2f}秒")
  207. # 显示修复统计
  208. repair_stats = summary.get("repair_stats", {})
  209. if repair_stats.get("attempted", 0) > 0:
  210. self.logger.info(f"🔧 修复统计: 尝试 {repair_stats['attempted']},成功 {repair_stats['successful']},失败 {repair_stats['failed']}")
  211. # 显示文件修改统计
  212. file_stats = summary.get("file_modification_stats", {})
  213. if file_stats.get("modified", 0) > 0 or file_stats.get("deleted", 0) > 0:
  214. self.logger.info(f"📝 文件修改: 更新 {file_stats.get('modified', 0)} 个SQL,删除 {file_stats.get('deleted', 0)} 个无效项")
  215. except Exception as e:
  216. self.workflow_state["failed_steps"].append("sql_validation")
  217. self.logger.error(f"❌ 步骤3失败: {str(e)}")
  218. raise
  219. async def _generate_final_report(self) -> Dict[str, Any]:
  220. """生成最终工作流程报告"""
  221. total_duration = self.workflow_state["end_time"] - self.workflow_state["start_time"]
  222. # 计算最终输出文件
  223. qs_artifacts = self.workflow_state["artifacts"].get("question_sql_generation", {})
  224. final_output_file = qs_artifacts.get("output_file", "")
  225. # 计算最终问题数量
  226. if "sql_validation" in self.workflow_state["artifacts"]:
  227. # 如果有验证步骤,使用验证后的数量
  228. validation_artifacts = self.workflow_state["artifacts"]["sql_validation"]
  229. final_question_count = validation_artifacts.get("valid_sql_count", 0)
  230. else:
  231. # 否则使用生成的数量
  232. final_question_count = qs_artifacts.get("total_questions", 0)
  233. report = {
  234. "success": True,
  235. "workflow_summary": {
  236. "total_duration": round(total_duration, 2),
  237. "completed_steps": self.workflow_state["completed_steps"],
  238. "failed_steps": self.workflow_state["failed_steps"],
  239. "total_steps": len(self.workflow_state["completed_steps"]),
  240. "workflow_started": datetime.fromtimestamp(self.workflow_state["start_time"]).isoformat(),
  241. "workflow_completed": datetime.fromtimestamp(self.workflow_state["end_time"]).isoformat()
  242. },
  243. "input_parameters": {
  244. "db_connection": self._mask_connection_string(self.db_connection),
  245. "table_list_file": self.table_list_file,
  246. "business_context": self.business_context,
  247. "db_name": self.db_name,
  248. "output_directory": str(self.output_dir),
  249. "enable_sql_validation": self.enable_sql_validation,
  250. "enable_llm_repair": self.enable_llm_repair,
  251. "modify_original_file": self.modify_original_file
  252. },
  253. "processing_results": {
  254. "ddl_md_generation": self.workflow_state["artifacts"].get("ddl_md_generation", {}),
  255. "question_sql_generation": self.workflow_state["artifacts"].get("question_sql_generation", {}),
  256. "sql_validation": self.workflow_state["artifacts"].get("sql_validation", {})
  257. },
  258. "final_outputs": {
  259. "primary_output_file": final_output_file,
  260. "output_directory": str(self.output_dir),
  261. "final_question_count": final_question_count,
  262. "backup_files_created": self.modify_original_file
  263. },
  264. "performance_metrics": {
  265. "step1_duration": round(self.workflow_state["statistics"].get("step1_duration", 0), 2),
  266. "step2_duration": round(self.workflow_state["statistics"].get("step2_duration", 0), 2),
  267. "step3_duration": round(self.workflow_state["statistics"].get("step3_duration", 0), 2),
  268. "total_duration": round(total_duration, 2)
  269. }
  270. }
  271. return report
  272. async def _generate_error_report(self, error: Exception) -> Dict[str, Any]:
  273. """生成错误报告"""
  274. total_duration = self.workflow_state["end_time"] - self.workflow_state["start_time"]
  275. return {
  276. "success": False,
  277. "error": {
  278. "message": str(error),
  279. "type": type(error).__name__,
  280. "failed_step": self.workflow_state["current_step"]
  281. },
  282. "workflow_summary": {
  283. "total_duration": round(total_duration, 2),
  284. "completed_steps": self.workflow_state["completed_steps"],
  285. "failed_steps": self.workflow_state["failed_steps"],
  286. "workflow_started": datetime.fromtimestamp(self.workflow_state["start_time"]).isoformat() if self.workflow_state["start_time"] else None,
  287. "workflow_failed": datetime.fromtimestamp(self.workflow_state["end_time"]).isoformat() if self.workflow_state["end_time"] else None
  288. },
  289. "partial_results": self.workflow_state["artifacts"],
  290. "input_parameters": {
  291. "db_connection": self._mask_connection_string(self.db_connection),
  292. "table_list_file": self.table_list_file,
  293. "business_context": self.business_context,
  294. "db_name": self.db_name,
  295. "output_directory": str(self.output_dir)
  296. }
  297. }
  298. def _mask_connection_string(self, conn_str: str) -> str:
  299. """隐藏连接字符串中的敏感信息"""
  300. import re
  301. return re.sub(r':[^:@]+@', ':***@', conn_str)
  302. def print_final_summary(self, report: Dict[str, Any]):
  303. """打印最终摘要"""
  304. self.logger.info("=" * 80)
  305. self.logger.info("📊 工作流程执行摘要")
  306. self.logger.info("=" * 80)
  307. if report["success"]:
  308. summary = report["workflow_summary"]
  309. results = report["processing_results"]
  310. outputs = report["final_outputs"]
  311. metrics = report["performance_metrics"]
  312. self.logger.info(f"✅ 工作流程执行成功")
  313. self.logger.info(f"⏱️ 总耗时: {summary['total_duration']} 秒")
  314. self.logger.info(f"📝 完成步骤: {len(summary['completed_steps'])}/{summary['total_steps']}")
  315. # DDL/MD生成结果
  316. if "ddl_md_generation" in results:
  317. ddl_md = results["ddl_md_generation"]
  318. self.logger.info(f"📋 DDL/MD生成: {ddl_md.get('processed_successfully', 0)} 个表成功处理")
  319. # Question-SQL生成结果
  320. if "question_sql_generation" in results:
  321. qs = results["question_sql_generation"]
  322. self.logger.info(f"🤖 Question-SQL生成: {qs.get('total_questions', 0)} 个问答对")
  323. # SQL验证结果
  324. if "sql_validation" in results:
  325. validation = results["sql_validation"]
  326. success_rate = validation.get('success_rate', 0)
  327. self.logger.info(f"🔍 SQL验证: {success_rate:.1%} 成功率 ({validation.get('valid_sql_count', 0)}/{validation.get('original_sql_count', 0)})")
  328. self.logger.info(f"📁 输出目录: {outputs['output_directory']}")
  329. self.logger.info(f"📄 主要输出文件: {outputs['primary_output_file']}")
  330. self.logger.info(f"❓ 最终问题数量: {outputs['final_question_count']}")
  331. else:
  332. error = report["error"]
  333. summary = report["workflow_summary"]
  334. self.logger.error(f"❌ 工作流程执行失败")
  335. self.logger.error(f"💥 失败原因: {error['message']}")
  336. self.logger.error(f"💥 失败步骤: {error['failed_step']}")
  337. self.logger.error(f"⏱️ 执行耗时: {summary['total_duration']} 秒")
  338. self.logger.error(f"✅ 已完成步骤: {', '.join(summary['completed_steps']) if summary['completed_steps'] else '无'}")
  339. self.logger.info("=" * 80)
  340. # 便捷的命令行接口
  341. def setup_argument_parser():
  342. """设置命令行参数解析器"""
  343. import argparse
  344. parser = argparse.ArgumentParser(
  345. description="Schema工作流编排器 - 端到端的Schema处理流程",
  346. formatter_class=argparse.RawDescriptionHelpFormatter,
  347. epilog="""
  348. 示例用法:
  349. # 完整工作流程
  350. python -m schema_tools.schema_workflow_orchestrator \\
  351. --db-connection "postgresql://user:pass@localhost:5432/dbname" \\
  352. --table-list tables.txt \\
  353. --business-context "高速公路服务区管理系统" \\
  354. --db-name highway_db \\
  355. --output-dir ./output
  356. # 跳过SQL验证
  357. python -m schema_tools.schema_workflow_orchestrator \\
  358. --db-connection "postgresql://user:pass@localhost:5432/dbname" \\
  359. --table-list tables.txt \\
  360. --business-context "电商系统" \\
  361. --db-name ecommerce_db \\
  362. --skip-validation
  363. # 禁用LLM修复
  364. python -m schema_tools.schema_workflow_orchestrator \\
  365. --db-connection "postgresql://user:pass@localhost:5432/dbname" \\
  366. --table-list tables.txt \\
  367. --business-context "管理系统" \\
  368. --db-name management_db \\
  369. --disable-llm-repair
  370. """
  371. )
  372. # 必需参数
  373. parser.add_argument(
  374. "--db-connection",
  375. required=True,
  376. help="数据库连接字符串 (postgresql://user:pass@host:port/dbname)"
  377. )
  378. parser.add_argument(
  379. "--table-list",
  380. required=True,
  381. help="表清单文件路径"
  382. )
  383. parser.add_argument(
  384. "--business-context",
  385. required=True,
  386. help="业务上下文描述"
  387. )
  388. parser.add_argument(
  389. "--db-name",
  390. required=True,
  391. help="数据库名称(用于生成文件名)"
  392. )
  393. # 可选参数
  394. parser.add_argument(
  395. "--output-dir",
  396. default="./output",
  397. help="输出目录(默认:./output)"
  398. )
  399. parser.add_argument(
  400. "--skip-validation",
  401. action="store_true",
  402. help="跳过SQL验证步骤"
  403. )
  404. parser.add_argument(
  405. "--disable-llm-repair",
  406. action="store_true",
  407. help="禁用LLM修复功能"
  408. )
  409. parser.add_argument(
  410. "--no-modify-file",
  411. action="store_true",
  412. help="不修改原始JSON文件(仅生成报告)"
  413. )
  414. parser.add_argument(
  415. "--verbose", "-v",
  416. action="store_true",
  417. help="启用详细日志输出"
  418. )
  419. parser.add_argument(
  420. "--log-file",
  421. help="日志文件路径"
  422. )
  423. return parser
  424. async def main():
  425. """命令行入口点"""
  426. import sys
  427. import os
  428. parser = setup_argument_parser()
  429. args = parser.parse_args()
  430. # 设置日志
  431. setup_logging(
  432. verbose=args.verbose,
  433. log_file=args.log_file,
  434. log_dir=os.path.join(args.output_dir, 'logs') if args.output_dir else None
  435. )
  436. # 验证输入文件
  437. if not os.path.exists(args.table_list):
  438. print(f"错误: 表清单文件不存在: {args.table_list}")
  439. sys.exit(1)
  440. try:
  441. # 创建并执行工作流编排器
  442. orchestrator = SchemaWorkflowOrchestrator(
  443. db_connection=args.db_connection,
  444. table_list_file=args.table_list,
  445. business_context=args.business_context,
  446. db_name=args.db_name,
  447. output_dir=args.output_dir,
  448. enable_sql_validation=not args.skip_validation,
  449. enable_llm_repair=not args.disable_llm_repair,
  450. modify_original_file=not args.no_modify_file
  451. )
  452. # 显示启动信息
  453. print(f"🚀 开始执行Schema工作流编排...")
  454. print(f"📁 输出目录: {args.output_dir}")
  455. print(f"📋 表清单: {args.table_list}")
  456. print(f"🏢 业务背景: {args.business_context}")
  457. print(f"💾 数据库: {args.db_name}")
  458. print(f"🔍 SQL验证: {'启用' if not args.skip_validation else '禁用'}")
  459. print(f"🔧 LLM修复: {'启用' if not args.disable_llm_repair else '禁用'}")
  460. # 执行完整工作流程
  461. report = await orchestrator.execute_complete_workflow()
  462. # 打印详细摘要
  463. orchestrator.print_final_summary(report)
  464. # 输出结果并设置退出码
  465. if report["success"]:
  466. if report["processing_results"].get("sql_validation", {}).get("success_rate", 1.0) >= 0.8:
  467. print(f"\n🎉 工作流程执行成功!")
  468. exit_code = 0 # 完全成功
  469. else:
  470. print(f"\n⚠️ 工作流程执行完成,但SQL验证成功率较低")
  471. exit_code = 1 # 部分成功
  472. else:
  473. print(f"\n❌ 工作流程执行失败")
  474. exit_code = 2 # 失败
  475. print(f"📄 主要输出文件: {report['final_outputs']['primary_output_file']}")
  476. sys.exit(exit_code)
  477. except KeyboardInterrupt:
  478. print("\n\n⏹️ 用户中断,程序退出")
  479. sys.exit(130)
  480. except Exception as e:
  481. print(f"\n❌ 程序执行失败: {e}")
  482. if args.verbose:
  483. import traceback
  484. traceback.print_exc()
  485. sys.exit(1)
  486. if __name__ == "__main__":
  487. asyncio.run(main())