123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586 |
- """
- Schema工作流编排器
- 统一管理完整的数据库Schema处理流程
- """
- import asyncio
- import time
- import logging
- from typing import Dict, Any, List, Optional
- from pathlib import Path
- from datetime import datetime
- from schema_tools.training_data_agent import SchemaTrainingDataAgent
- from schema_tools.qs_agent import QuestionSQLGenerationAgent
- from schema_tools.sql_validation_agent import SQLValidationAgent
- from schema_tools.config import SCHEMA_TOOLS_CONFIG
- from schema_tools.utils.logger import setup_logging
- class SchemaWorkflowOrchestrator:
- """端到端的Schema处理编排器 - 完整工作流程"""
-
- def __init__(self,
- db_connection: str,
- table_list_file: str,
- business_context: str,
- db_name: str,
- output_dir: str = None,
- enable_sql_validation: bool = True,
- enable_llm_repair: bool = True,
- modify_original_file: bool = True):
- """
- 初始化Schema工作流编排器
-
- Args:
- db_connection: 数据库连接字符串
- table_list_file: 表清单文件路径
- business_context: 业务上下文描述
- db_name: 数据库名称(用于生成文件名)
- output_dir: 输出目录
- enable_sql_validation: 是否启用SQL验证
- enable_llm_repair: 是否启用LLM修复功能
- modify_original_file: 是否修改原始JSON文件
- """
- self.db_connection = db_connection
- self.table_list_file = table_list_file
- self.business_context = business_context
- self.db_name = db_name
- self.output_dir = Path(output_dir) if output_dir else Path("./output")
- self.enable_sql_validation = enable_sql_validation
- self.enable_llm_repair = enable_llm_repair
- self.modify_original_file = modify_original_file
-
- # 确保输出目录存在
- self.output_dir.mkdir(parents=True, exist_ok=True)
-
- # 初始化日志
- self.logger = logging.getLogger("schema_tools.SchemaWorkflowOrchestrator")
-
- # 工作流程状态
- self.workflow_state = {
- "start_time": None,
- "end_time": None,
- "current_step": None,
- "completed_steps": [],
- "failed_steps": [],
- "artifacts": {}, # 存储各步骤产生的文件
- "statistics": {}
- }
-
- async def execute_complete_workflow(self) -> Dict[str, Any]:
- """
- 执行完整的Schema处理工作流程
-
- Returns:
- 完整的工作流程报告
- """
- self.workflow_state["start_time"] = time.time()
- self.logger.info("🚀 开始执行Schema工作流编排")
- self.logger.info(f"📁 输出目录: {self.output_dir}")
- self.logger.info(f"🏢 业务背景: {self.business_context}")
- self.logger.info(f"💾 数据库: {self.db_name}")
-
- try:
- # 步骤1: 生成DDL和MD文件
- await self._execute_step_1_ddl_md_generation()
-
- # 步骤2: 生成Question-SQL对
- await self._execute_step_2_question_sql_generation()
-
- # 步骤3: 验证和修正SQL(可选)
- if self.enable_sql_validation:
- await self._execute_step_3_sql_validation()
- else:
- self.logger.info("⏭️ 跳过SQL验证步骤")
-
- # 生成最终报告
- final_report = await self._generate_final_report()
-
- self.workflow_state["end_time"] = time.time()
- self.logger.info("✅ Schema工作流编排完成")
-
- return final_report
-
- except Exception as e:
- self.workflow_state["end_time"] = time.time()
- self.logger.exception(f"❌ 工作流程执行失败: {str(e)}")
-
- error_report = await self._generate_error_report(e)
- return error_report
-
- async def _execute_step_1_ddl_md_generation(self):
- """步骤1: 生成DDL和MD文件"""
- self.workflow_state["current_step"] = "ddl_md_generation"
- self.logger.info("=" * 60)
- self.logger.info("📝 步骤1: 开始生成DDL和MD文件")
- self.logger.info("=" * 60)
-
- step_start_time = time.time()
-
- try:
- # 创建DDL/MD生成Agent
- ddl_md_agent = SchemaTrainingDataAgent(
- db_connection=self.db_connection,
- table_list_file=self.table_list_file,
- business_context=self.business_context,
- output_dir=str(self.output_dir),
- pipeline="full"
- )
-
- # 执行DDL/MD生成
- ddl_md_result = await ddl_md_agent.generate_training_data()
-
- step_duration = time.time() - step_start_time
-
- # 记录结果
- self.workflow_state["completed_steps"].append("ddl_md_generation")
- self.workflow_state["artifacts"]["ddl_md_generation"] = {
- "total_tables": ddl_md_result.get("summary", {}).get("total_tables", 0),
- "processed_successfully": ddl_md_result.get("summary", {}).get("processed_successfully", 0),
- "failed": ddl_md_result.get("summary", {}).get("failed", 0),
- "files_generated": ddl_md_result.get("statistics", {}).get("files_generated", 0),
- "duration": step_duration
- }
- self.workflow_state["statistics"]["step1_duration"] = step_duration
-
- processed_tables = ddl_md_result.get("summary", {}).get("processed_successfully", 0)
- self.logger.info(f"✅ 步骤1完成: 成功处理 {processed_tables} 个表,耗时 {step_duration:.2f}秒")
-
- except Exception as e:
- self.workflow_state["failed_steps"].append("ddl_md_generation")
- self.logger.error(f"❌ 步骤1失败: {str(e)}")
- raise
-
- async def _execute_step_2_question_sql_generation(self):
- """步骤2: 生成Question-SQL对"""
- self.workflow_state["current_step"] = "question_sql_generation"
- self.logger.info("=" * 60)
- self.logger.info("🤖 步骤2: 开始生成Question-SQL对")
- self.logger.info("=" * 60)
-
- step_start_time = time.time()
-
- try:
- # 创建Question-SQL生成Agent
- qs_agent = QuestionSQLGenerationAgent(
- output_dir=str(self.output_dir),
- table_list_file=self.table_list_file,
- business_context=self.business_context,
- db_name=self.db_name
- )
-
- # 执行Question-SQL生成
- qs_result = await qs_agent.generate()
-
- step_duration = time.time() - step_start_time
-
- # 记录结果
- self.workflow_state["completed_steps"].append("question_sql_generation")
- self.workflow_state["artifacts"]["question_sql_generation"] = {
- "output_file": str(qs_result.get("output_file", "")),
- "total_questions": qs_result.get("total_questions", 0),
- "total_themes": qs_result.get("total_themes", 0),
- "successful_themes": qs_result.get("successful_themes", 0),
- "failed_themes": qs_result.get("failed_themes", []),
- "duration": step_duration
- }
- self.workflow_state["statistics"]["step2_duration"] = step_duration
-
- total_questions = qs_result.get("total_questions", 0)
- self.logger.info(f"✅ 步骤2完成: 生成了 {total_questions} 个问答对,耗时 {step_duration:.2f}秒")
-
- except Exception as e:
- self.workflow_state["failed_steps"].append("question_sql_generation")
- self.logger.error(f"❌ 步骤2失败: {str(e)}")
- raise
-
- async def _execute_step_3_sql_validation(self):
- """步骤3: 验证和修正SQL"""
- self.workflow_state["current_step"] = "sql_validation"
- self.logger.info("=" * 60)
- self.logger.info("🔍 步骤3: 开始验证和修正SQL")
- self.logger.info("=" * 60)
-
- step_start_time = time.time()
-
- try:
- # 获取步骤2生成的文件
- qs_artifacts = self.workflow_state["artifacts"].get("question_sql_generation", {})
- qs_file = qs_artifacts.get("output_file")
-
- if not qs_file or not Path(qs_file).exists():
- raise FileNotFoundError(f"找不到Question-SQL文件: {qs_file}")
-
- self.logger.info(f"📄 验证文件: {qs_file}")
-
- # 动态设置验证配置
- SCHEMA_TOOLS_CONFIG['sql_validation']['enable_sql_repair'] = self.enable_llm_repair
- SCHEMA_TOOLS_CONFIG['sql_validation']['modify_original_file'] = self.modify_original_file
-
- # 创建SQL验证Agent
- sql_validator = SQLValidationAgent(
- db_connection=self.db_connection,
- input_file=str(qs_file),
- output_dir=str(self.output_dir)
- )
-
- # 执行SQL验证和修正
- validation_result = await sql_validator.validate()
-
- step_duration = time.time() - step_start_time
-
- # 记录结果
- self.workflow_state["completed_steps"].append("sql_validation")
-
- summary = validation_result.get("summary", {})
- self.workflow_state["artifacts"]["sql_validation"] = {
- "original_sql_count": summary.get("total_questions", 0),
- "valid_sql_count": summary.get("valid_sqls", 0),
- "invalid_sql_count": summary.get("invalid_sqls", 0),
- "success_rate": summary.get("success_rate", 0),
- "repair_stats": summary.get("repair_stats", {}),
- "file_modification_stats": summary.get("file_modification_stats", {}),
- "average_execution_time": summary.get("average_execution_time", 0),
- "total_retries": summary.get("total_retries", 0),
- "duration": step_duration
- }
- self.workflow_state["statistics"]["step3_duration"] = step_duration
-
- success_rate = summary.get("success_rate", 0)
- valid_count = summary.get("valid_sqls", 0)
- total_count = summary.get("total_questions", 0)
-
- self.logger.info(f"✅ 步骤3完成: SQL验证成功率 {success_rate:.1%} ({valid_count}/{total_count}),耗时 {step_duration:.2f}秒")
-
- # 显示修复统计
- repair_stats = summary.get("repair_stats", {})
- if repair_stats.get("attempted", 0) > 0:
- self.logger.info(f"🔧 修复统计: 尝试 {repair_stats['attempted']},成功 {repair_stats['successful']},失败 {repair_stats['failed']}")
-
- # 显示文件修改统计
- file_stats = summary.get("file_modification_stats", {})
- if file_stats.get("modified", 0) > 0 or file_stats.get("deleted", 0) > 0:
- self.logger.info(f"📝 文件修改: 更新 {file_stats.get('modified', 0)} 个SQL,删除 {file_stats.get('deleted', 0)} 个无效项")
-
- except Exception as e:
- self.workflow_state["failed_steps"].append("sql_validation")
- self.logger.error(f"❌ 步骤3失败: {str(e)}")
- raise
-
- async def _generate_final_report(self) -> Dict[str, Any]:
- """生成最终工作流程报告"""
- total_duration = self.workflow_state["end_time"] - self.workflow_state["start_time"]
-
- # 计算最终输出文件
- qs_artifacts = self.workflow_state["artifacts"].get("question_sql_generation", {})
- final_output_file = qs_artifacts.get("output_file", "")
-
- # 计算最终问题数量
- if "sql_validation" in self.workflow_state["artifacts"]:
- # 如果有验证步骤,使用验证后的数量
- validation_artifacts = self.workflow_state["artifacts"]["sql_validation"]
- final_question_count = validation_artifacts.get("valid_sql_count", 0)
- else:
- # 否则使用生成的数量
- final_question_count = qs_artifacts.get("total_questions", 0)
-
- report = {
- "success": True,
- "workflow_summary": {
- "total_duration": round(total_duration, 2),
- "completed_steps": self.workflow_state["completed_steps"],
- "failed_steps": self.workflow_state["failed_steps"],
- "total_steps": len(self.workflow_state["completed_steps"]),
- "workflow_started": datetime.fromtimestamp(self.workflow_state["start_time"]).isoformat(),
- "workflow_completed": datetime.fromtimestamp(self.workflow_state["end_time"]).isoformat()
- },
- "input_parameters": {
- "db_connection": self._mask_connection_string(self.db_connection),
- "table_list_file": self.table_list_file,
- "business_context": self.business_context,
- "db_name": self.db_name,
- "output_directory": str(self.output_dir),
- "enable_sql_validation": self.enable_sql_validation,
- "enable_llm_repair": self.enable_llm_repair,
- "modify_original_file": self.modify_original_file
- },
- "processing_results": {
- "ddl_md_generation": self.workflow_state["artifacts"].get("ddl_md_generation", {}),
- "question_sql_generation": self.workflow_state["artifacts"].get("question_sql_generation", {}),
- "sql_validation": self.workflow_state["artifacts"].get("sql_validation", {})
- },
- "final_outputs": {
- "primary_output_file": final_output_file,
- "output_directory": str(self.output_dir),
- "final_question_count": final_question_count,
- "backup_files_created": self.modify_original_file
- },
- "performance_metrics": {
- "step1_duration": round(self.workflow_state["statistics"].get("step1_duration", 0), 2),
- "step2_duration": round(self.workflow_state["statistics"].get("step2_duration", 0), 2),
- "step3_duration": round(self.workflow_state["statistics"].get("step3_duration", 0), 2),
- "total_duration": round(total_duration, 2)
- }
- }
-
- return report
-
- async def _generate_error_report(self, error: Exception) -> Dict[str, Any]:
- """生成错误报告"""
- total_duration = self.workflow_state["end_time"] - self.workflow_state["start_time"]
-
- return {
- "success": False,
- "error": {
- "message": str(error),
- "type": type(error).__name__,
- "failed_step": self.workflow_state["current_step"]
- },
- "workflow_summary": {
- "total_duration": round(total_duration, 2),
- "completed_steps": self.workflow_state["completed_steps"],
- "failed_steps": self.workflow_state["failed_steps"],
- "workflow_started": datetime.fromtimestamp(self.workflow_state["start_time"]).isoformat() if self.workflow_state["start_time"] else None,
- "workflow_failed": datetime.fromtimestamp(self.workflow_state["end_time"]).isoformat() if self.workflow_state["end_time"] else None
- },
- "partial_results": self.workflow_state["artifacts"],
- "input_parameters": {
- "db_connection": self._mask_connection_string(self.db_connection),
- "table_list_file": self.table_list_file,
- "business_context": self.business_context,
- "db_name": self.db_name,
- "output_directory": str(self.output_dir)
- }
- }
-
- def _mask_connection_string(self, conn_str: str) -> str:
- """隐藏连接字符串中的敏感信息"""
- import re
- return re.sub(r':[^:@]+@', ':***@', conn_str)
-
- def print_final_summary(self, report: Dict[str, Any]):
- """打印最终摘要"""
- self.logger.info("=" * 80)
- self.logger.info("📊 工作流程执行摘要")
- self.logger.info("=" * 80)
-
- if report["success"]:
- summary = report["workflow_summary"]
- results = report["processing_results"]
- outputs = report["final_outputs"]
- metrics = report["performance_metrics"]
-
- self.logger.info(f"✅ 工作流程执行成功")
- self.logger.info(f"⏱️ 总耗时: {summary['total_duration']} 秒")
- self.logger.info(f"📝 完成步骤: {len(summary['completed_steps'])}/{summary['total_steps']}")
-
- # DDL/MD生成结果
- if "ddl_md_generation" in results:
- ddl_md = results["ddl_md_generation"]
- self.logger.info(f"📋 DDL/MD生成: {ddl_md.get('processed_successfully', 0)} 个表成功处理")
-
- # Question-SQL生成结果
- if "question_sql_generation" in results:
- qs = results["question_sql_generation"]
- self.logger.info(f"🤖 Question-SQL生成: {qs.get('total_questions', 0)} 个问答对")
-
- # SQL验证结果
- if "sql_validation" in results:
- validation = results["sql_validation"]
- success_rate = validation.get('success_rate', 0)
- self.logger.info(f"🔍 SQL验证: {success_rate:.1%} 成功率 ({validation.get('valid_sql_count', 0)}/{validation.get('original_sql_count', 0)})")
-
- self.logger.info(f"📁 输出目录: {outputs['output_directory']}")
- self.logger.info(f"📄 主要输出文件: {outputs['primary_output_file']}")
- self.logger.info(f"❓ 最终问题数量: {outputs['final_question_count']}")
-
- else:
- error = report["error"]
- summary = report["workflow_summary"]
-
- self.logger.error(f"❌ 工作流程执行失败")
- self.logger.error(f"💥 失败原因: {error['message']}")
- self.logger.error(f"💥 失败步骤: {error['failed_step']}")
- self.logger.error(f"⏱️ 执行耗时: {summary['total_duration']} 秒")
- self.logger.error(f"✅ 已完成步骤: {', '.join(summary['completed_steps']) if summary['completed_steps'] else '无'}")
-
- self.logger.info("=" * 80)
- # 便捷的命令行接口
- def setup_argument_parser():
- """设置命令行参数解析器"""
- import argparse
-
- parser = argparse.ArgumentParser(
- description="Schema工作流编排器 - 端到端的Schema处理流程",
- formatter_class=argparse.RawDescriptionHelpFormatter,
- epilog="""
- 示例用法:
- # 完整工作流程
- python -m schema_tools.schema_workflow_orchestrator \\
- --db-connection "postgresql://user:pass@localhost:5432/dbname" \\
- --table-list tables.txt \\
- --business-context "高速公路服务区管理系统" \\
- --db-name highway_db \\
- --output-dir ./output
-
- # 跳过SQL验证
- python -m schema_tools.schema_workflow_orchestrator \\
- --db-connection "postgresql://user:pass@localhost:5432/dbname" \\
- --table-list tables.txt \\
- --business-context "电商系统" \\
- --db-name ecommerce_db \\
- --skip-validation
-
- # 禁用LLM修复
- python -m schema_tools.schema_workflow_orchestrator \\
- --db-connection "postgresql://user:pass@localhost:5432/dbname" \\
- --table-list tables.txt \\
- --business-context "管理系统" \\
- --db-name management_db \\
- --disable-llm-repair
- """
- )
-
- # 必需参数
- parser.add_argument(
- "--db-connection",
- required=True,
- help="数据库连接字符串 (postgresql://user:pass@host:port/dbname)"
- )
-
- parser.add_argument(
- "--table-list",
- required=True,
- help="表清单文件路径"
- )
-
- parser.add_argument(
- "--business-context",
- required=True,
- help="业务上下文描述"
- )
-
- parser.add_argument(
- "--db-name",
- required=True,
- help="数据库名称(用于生成文件名)"
- )
-
- # 可选参数
- parser.add_argument(
- "--output-dir",
- default="./output",
- help="输出目录(默认:./output)"
- )
-
- parser.add_argument(
- "--skip-validation",
- action="store_true",
- help="跳过SQL验证步骤"
- )
-
- parser.add_argument(
- "--disable-llm-repair",
- action="store_true",
- help="禁用LLM修复功能"
- )
-
- parser.add_argument(
- "--no-modify-file",
- action="store_true",
- help="不修改原始JSON文件(仅生成报告)"
- )
-
- parser.add_argument(
- "--verbose", "-v",
- action="store_true",
- help="启用详细日志输出"
- )
-
- parser.add_argument(
- "--log-file",
- help="日志文件路径"
- )
-
- return parser
- async def main():
- """命令行入口点"""
- import sys
- import os
-
- parser = setup_argument_parser()
- args = parser.parse_args()
-
- # 设置日志
- setup_logging(
- verbose=args.verbose,
- log_file=args.log_file,
- log_dir=os.path.join(args.output_dir, 'logs') if args.output_dir else None
- )
-
- # 验证输入文件
- if not os.path.exists(args.table_list):
- print(f"错误: 表清单文件不存在: {args.table_list}")
- sys.exit(1)
-
- try:
- # 创建并执行工作流编排器
- orchestrator = SchemaWorkflowOrchestrator(
- db_connection=args.db_connection,
- table_list_file=args.table_list,
- business_context=args.business_context,
- db_name=args.db_name,
- output_dir=args.output_dir,
- enable_sql_validation=not args.skip_validation,
- enable_llm_repair=not args.disable_llm_repair,
- modify_original_file=not args.no_modify_file
- )
-
- # 显示启动信息
- print(f"🚀 开始执行Schema工作流编排...")
- print(f"📁 输出目录: {args.output_dir}")
- print(f"📋 表清单: {args.table_list}")
- print(f"🏢 业务背景: {args.business_context}")
- print(f"💾 数据库: {args.db_name}")
- print(f"🔍 SQL验证: {'启用' if not args.skip_validation else '禁用'}")
- print(f"🔧 LLM修复: {'启用' if not args.disable_llm_repair else '禁用'}")
-
- # 执行完整工作流程
- report = await orchestrator.execute_complete_workflow()
-
- # 打印详细摘要
- orchestrator.print_final_summary(report)
-
- # 输出结果并设置退出码
- if report["success"]:
- if report["processing_results"].get("sql_validation", {}).get("success_rate", 1.0) >= 0.8:
- print(f"\n🎉 工作流程执行成功!")
- exit_code = 0 # 完全成功
- else:
- print(f"\n⚠️ 工作流程执行完成,但SQL验证成功率较低")
- exit_code = 1 # 部分成功
- else:
- print(f"\n❌ 工作流程执行失败")
- exit_code = 2 # 失败
-
- print(f"📄 主要输出文件: {report['final_outputs']['primary_output_file']}")
- sys.exit(exit_code)
-
- except KeyboardInterrupt:
- print("\n\n⏹️ 用户中断,程序退出")
- sys.exit(130)
- except Exception as e:
- print(f"\n❌ 程序执行失败: {e}")
- if args.verbose:
- import traceback
- traceback.print_exc()
- sys.exit(1)
- if __name__ == "__main__":
- asyncio.run(main())
|