schema_workflow.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741
  1. """
  2. Schema工作流编排器
  3. 统一管理完整的数据库Schema处理流程
  4. """
  5. import asyncio
  6. import time
  7. import logging
  8. from typing import Dict, Any, List, Optional
  9. from pathlib import Path
  10. from datetime import datetime
  11. from data_pipeline.ddl_generation.training_data_agent import SchemaTrainingDataAgent
  12. from data_pipeline.qa_generation.qs_agent import QuestionSQLGenerationAgent
  13. from data_pipeline.validators.sql_validation_agent import SQLValidationAgent
  14. from data_pipeline.config import SCHEMA_TOOLS_CONFIG
  15. from data_pipeline.dp_logging import get_logger
  16. class SchemaWorkflowOrchestrator:
  17. """端到端的Schema处理编排器 - 完整工作流程"""
  18. def __init__(self,
  19. db_connection: str,
  20. table_list_file: str,
  21. business_context: str,
  22. output_dir: str = None,
  23. task_id: str = None,
  24. enable_sql_validation: bool = True,
  25. enable_llm_repair: bool = True,
  26. modify_original_file: bool = True,
  27. enable_training_data_load: bool = True):
  28. """
  29. 初始化Schema工作流编排器
  30. Args:
  31. db_connection: 数据库连接字符串 (postgresql://user:pass@host:port/dbname)
  32. table_list_file: 表清单文件路径
  33. business_context: 业务上下文描述
  34. output_dir: 输出目录
  35. task_id: 任务ID (API模式传递,脚本模式自动生成)
  36. enable_sql_validation: 是否启用SQL验证
  37. enable_llm_repair: 是否启用LLM修复功能
  38. modify_original_file: 是否修改原始JSON文件
  39. enable_training_data_load: 是否启用训练数据加载
  40. """
  41. self.db_connection = db_connection
  42. self.table_list_file = table_list_file
  43. self.business_context = business_context
  44. self.db_name = self._extract_db_name_from_connection(db_connection)
  45. self.enable_sql_validation = enable_sql_validation
  46. self.enable_llm_repair = enable_llm_repair
  47. self.modify_original_file = modify_original_file
  48. self.enable_training_data_load = enable_training_data_load
  49. # 处理task_id
  50. if task_id is None:
  51. # 脚本模式:自动生成manual开头的task_id
  52. self.task_id = f"manual_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
  53. else:
  54. # API模式:使用传递的task_id
  55. self.task_id = task_id
  56. # 设置输出目录
  57. if output_dir is None:
  58. # 脚本模式或未指定输出目录时,使用任务目录
  59. # 获取项目根目录的绝对路径
  60. project_root = Path(__file__).parent.parent
  61. self.output_dir = project_root / "data_pipeline" / "training_data" / self.task_id
  62. else:
  63. # API模式或明确指定输出目录时,使用指定的目录
  64. self.output_dir = Path(output_dir)
  65. # 确保输出目录存在
  66. self.output_dir.mkdir(parents=True, exist_ok=True)
  67. # 初始化独立日志系统
  68. self.logger = get_logger("SchemaWorkflowOrchestrator", self.task_id)
  69. # 工作流程状态
  70. self.workflow_state = {
  71. "start_time": None,
  72. "end_time": None,
  73. "current_step": None,
  74. "completed_steps": [],
  75. "failed_steps": [],
  76. "artifacts": {}, # 存储各步骤产生的文件
  77. "statistics": {}
  78. }
  79. def _extract_db_name_from_connection(self, connection_string: str) -> str:
  80. """
  81. 从数据库连接字符串中提取数据库名称
  82. Args:
  83. connection_string: PostgreSQL连接字符串
  84. Returns:
  85. str: 数据库名称
  86. """
  87. try:
  88. # 处理标准的PostgreSQL连接字符串: postgresql://user:pass@host:port/dbname
  89. if '/' in connection_string:
  90. # 取最后一个 '/' 后面的部分作为数据库名
  91. db_name = connection_string.split('/')[-1]
  92. # 移除可能的查询参数
  93. if '?' in db_name:
  94. db_name = db_name.split('?')[0]
  95. return db_name if db_name else "database"
  96. else:
  97. return "database"
  98. except Exception:
  99. return "database"
  100. async def execute_complete_workflow(self) -> Dict[str, Any]:
  101. """
  102. 执行完整的Schema处理工作流程
  103. Returns:
  104. 完整的工作流程报告
  105. """
  106. self.workflow_state["start_time"] = time.time()
  107. self.logger.info("🚀 开始执行Schema工作流编排")
  108. self.logger.info(f"📁 输出目录: {self.output_dir}")
  109. self.logger.info(f"🏢 业务背景: {self.business_context}")
  110. self.logger.info(f"💾 数据库: {self.db_name}")
  111. try:
  112. # 步骤1: 生成DDL和MD文件
  113. await self._execute_step_1_ddl_md_generation()
  114. # 步骤2: 生成Question-SQL对
  115. await self._execute_step_2_question_sql_generation()
  116. # 步骤3: 验证和修正SQL(可选)
  117. if self.enable_sql_validation:
  118. await self._execute_step_3_sql_validation()
  119. else:
  120. self.logger.info("⏭️ 跳过SQL验证步骤")
  121. # 步骤4: 训练数据加载(可选)
  122. if self.enable_training_data_load:
  123. await self._execute_step_4_training_data_load()
  124. else:
  125. self.logger.info("⏭️ 跳过训练数据加载步骤")
  126. # 设置结束时间
  127. self.workflow_state["end_time"] = time.time()
  128. # 生成最终报告
  129. final_report = await self._generate_final_report()
  130. self.logger.info("✅ Schema工作流编排完成")
  131. return final_report
  132. except Exception as e:
  133. self.workflow_state["end_time"] = time.time()
  134. self.logger.exception(f"❌ 工作流程执行失败: {str(e)}")
  135. error_report = await self._generate_error_report(e)
  136. return error_report
  137. async def _execute_step_1_ddl_md_generation(self):
  138. """步骤1: 生成DDL和MD文件"""
  139. self.workflow_state["current_step"] = "ddl_md_generation"
  140. self.logger.info("=" * 60)
  141. self.logger.info("📝 步骤1: 开始生成DDL和MD文件")
  142. self.logger.info("=" * 60)
  143. step_start_time = time.time()
  144. try:
  145. # 创建DDL/MD生成Agent
  146. ddl_md_agent = SchemaTrainingDataAgent(
  147. db_connection=self.db_connection,
  148. table_list_file=self.table_list_file,
  149. business_context=self.business_context,
  150. output_dir=str(self.output_dir),
  151. task_id=self.task_id, # 传递task_id
  152. pipeline="full"
  153. )
  154. # 执行DDL/MD生成
  155. ddl_md_result = await ddl_md_agent.generate_training_data()
  156. step_duration = time.time() - step_start_time
  157. # 记录结果
  158. self.workflow_state["completed_steps"].append("ddl_md_generation")
  159. self.workflow_state["artifacts"]["ddl_md_generation"] = {
  160. "total_tables": ddl_md_result.get("summary", {}).get("total_tables", 0),
  161. "processed_successfully": ddl_md_result.get("summary", {}).get("processed_successfully", 0),
  162. "failed": ddl_md_result.get("summary", {}).get("failed", 0),
  163. "files_generated": ddl_md_result.get("statistics", {}).get("files_generated", 0),
  164. "duration": step_duration
  165. }
  166. self.workflow_state["statistics"]["step1_duration"] = step_duration
  167. processed_tables = ddl_md_result.get("summary", {}).get("processed_successfully", 0)
  168. self.logger.info(f"✅ 步骤1完成: 成功处理 {processed_tables} 个表,耗时 {step_duration:.2f}秒")
  169. except Exception as e:
  170. self.workflow_state["failed_steps"].append("ddl_md_generation")
  171. self.logger.error(f"❌ 步骤1失败: {str(e)}")
  172. raise
  173. async def _execute_step_2_question_sql_generation(self):
  174. """步骤2: 生成Question-SQL对"""
  175. self.workflow_state["current_step"] = "question_sql_generation"
  176. self.logger.info("=" * 60)
  177. self.logger.info("🤖 步骤2: 开始生成Question-SQL对")
  178. self.logger.info("=" * 60)
  179. step_start_time = time.time()
  180. try:
  181. # 创建Question-SQL生成Agent
  182. qs_agent = QuestionSQLGenerationAgent(
  183. output_dir=str(self.output_dir),
  184. table_list_file=self.table_list_file,
  185. business_context=self.business_context,
  186. db_name=self.db_name,
  187. task_id=self.task_id # 传递task_id
  188. )
  189. # 执行Question-SQL生成
  190. qs_result = await qs_agent.generate()
  191. step_duration = time.time() - step_start_time
  192. # 记录结果
  193. self.workflow_state["completed_steps"].append("question_sql_generation")
  194. self.workflow_state["artifacts"]["question_sql_generation"] = {
  195. "output_file": str(qs_result.get("output_file", "")),
  196. "total_questions": qs_result.get("total_questions", 0),
  197. "total_themes": qs_result.get("total_themes", 0),
  198. "successful_themes": qs_result.get("successful_themes", 0),
  199. "failed_themes": qs_result.get("failed_themes", []),
  200. "duration": step_duration
  201. }
  202. self.workflow_state["statistics"]["step2_duration"] = step_duration
  203. total_questions = qs_result.get("total_questions", 0)
  204. self.logger.info(f"✅ 步骤2完成: 生成了 {total_questions} 个问答对,耗时 {step_duration:.2f}秒")
  205. except Exception as e:
  206. self.workflow_state["failed_steps"].append("question_sql_generation")
  207. self.logger.error(f"❌ 步骤2失败: {str(e)}")
  208. raise
  209. async def _execute_step_3_sql_validation(self):
  210. """步骤3: 验证和修正SQL"""
  211. self.workflow_state["current_step"] = "sql_validation"
  212. self.logger.info("=" * 60)
  213. self.logger.info("🔍 步骤3: 开始验证和修正SQL")
  214. self.logger.info("=" * 60)
  215. step_start_time = time.time()
  216. try:
  217. # 获取步骤2生成的文件
  218. qs_artifacts = self.workflow_state["artifacts"].get("question_sql_generation", {})
  219. qs_file = qs_artifacts.get("output_file")
  220. if not qs_file or not Path(qs_file).exists():
  221. raise FileNotFoundError(f"找不到Question-SQL文件: {qs_file}")
  222. self.logger.info(f"📄 验证文件: {qs_file}")
  223. # 创建SQL验证Agent,通过参数传递配置而非修改全局配置
  224. sql_validator = SQLValidationAgent(
  225. db_connection=self.db_connection,
  226. input_file=str(qs_file),
  227. output_dir=str(self.output_dir),
  228. task_id=self.task_id, # 传递task_id
  229. enable_sql_repair=self.enable_llm_repair,
  230. modify_original_file=self.modify_original_file
  231. )
  232. # 执行SQL验证和修正
  233. validation_result = await sql_validator.validate()
  234. step_duration = time.time() - step_start_time
  235. # 记录结果
  236. self.workflow_state["completed_steps"].append("sql_validation")
  237. summary = validation_result.get("summary", {})
  238. self.workflow_state["artifacts"]["sql_validation"] = {
  239. "original_sql_count": summary.get("total_questions", 0),
  240. "valid_sql_count": summary.get("valid_sqls", 0),
  241. "invalid_sql_count": summary.get("invalid_sqls", 0),
  242. "success_rate": summary.get("success_rate", 0),
  243. "repair_stats": summary.get("repair_stats", {}),
  244. "file_modification_stats": summary.get("file_modification_stats", {}),
  245. "average_execution_time": summary.get("average_execution_time", 0),
  246. "total_retries": summary.get("total_retries", 0),
  247. "duration": step_duration
  248. }
  249. self.workflow_state["statistics"]["step3_duration"] = step_duration
  250. success_rate = summary.get("success_rate", 0)
  251. valid_count = summary.get("valid_sqls", 0)
  252. total_count = summary.get("total_questions", 0)
  253. self.logger.info(f"✅ 步骤3完成: SQL验证成功率 {success_rate:.1%} ({valid_count}/{total_count}),耗时 {step_duration:.2f}秒")
  254. # 显示修复统计
  255. repair_stats = summary.get("repair_stats", {})
  256. if repair_stats.get("attempted", 0) > 0:
  257. self.logger.info(f"🔧 修复统计: 尝试 {repair_stats['attempted']},成功 {repair_stats['successful']},失败 {repair_stats['failed']}")
  258. # 显示文件修改统计
  259. file_stats = summary.get("file_modification_stats", {})
  260. if file_stats.get("modified", 0) > 0 or file_stats.get("deleted", 0) > 0:
  261. self.logger.info(f"📝 文件修改: 更新 {file_stats.get('modified', 0)} 个SQL,删除 {file_stats.get('deleted', 0)} 个无效项")
  262. except Exception as e:
  263. self.workflow_state["failed_steps"].append("sql_validation")
  264. self.logger.error(f"❌ 步骤3失败: {str(e)}")
  265. raise
  266. async def _execute_step_4_training_data_load(self):
  267. """步骤4: 训练数据加载"""
  268. self.workflow_state["current_step"] = "training_data_load"
  269. self.logger.info("=" * 60)
  270. self.logger.info("🎯 步骤4: 开始加载训练数据")
  271. self.logger.info("=" * 60)
  272. step_start_time = time.time()
  273. try:
  274. # 确保输出目录存在所需的训练数据
  275. training_data_dir = str(self.output_dir)
  276. self.logger.info(f"📁 训练数据目录: {training_data_dir}")
  277. # 导入训练器模块
  278. import sys
  279. import os
  280. sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  281. from data_pipeline.trainer.run_training import process_training_files
  282. # 执行训练数据加载
  283. self.logger.info("🔄 开始处理训练文件...")
  284. load_successful = process_training_files(training_data_dir, self.task_id)
  285. step_duration = time.time() - step_start_time
  286. if load_successful:
  287. # 获取统计信息
  288. from data_pipeline.trainer.vanna_trainer import flush_training, shutdown_trainer
  289. # 刷新批处理器
  290. self.logger.info("🔄 刷新批处理器...")
  291. flush_training()
  292. shutdown_trainer()
  293. # 验证加载结果
  294. try:
  295. from core.vanna_llm_factory import create_vanna_instance
  296. vn = create_vanna_instance()
  297. training_data = vn.get_training_data()
  298. if training_data is not None and not training_data.empty:
  299. total_records = len(training_data)
  300. self.logger.info(f"✅ 成功加载 {total_records} 条训练数据")
  301. # 统计数据类型
  302. if 'training_data_type' in training_data.columns:
  303. type_counts = training_data['training_data_type'].value_counts().to_dict()
  304. else:
  305. type_counts = {}
  306. else:
  307. total_records = 0
  308. type_counts = {}
  309. self.logger.warning("⚠️ 未能验证训练数据加载结果")
  310. except Exception as e:
  311. self.logger.warning(f"⚠️ 验证训练数据时出错: {e}")
  312. total_records = 0
  313. type_counts = {}
  314. # 记录结果
  315. self.workflow_state["completed_steps"].append("training_data_load")
  316. self.workflow_state["artifacts"]["training_data_load"] = {
  317. "training_data_dir": training_data_dir,
  318. "load_successful": True,
  319. "total_records": total_records,
  320. "data_type_counts": type_counts,
  321. "duration": step_duration
  322. }
  323. self.workflow_state["statistics"]["step4_duration"] = step_duration
  324. self.logger.info(f"✅ 步骤4完成: 成功加载训练数据,耗时 {step_duration:.2f}秒")
  325. else:
  326. raise Exception("训练数据加载失败:未找到可处理的训练文件")
  327. except Exception as e:
  328. self.workflow_state["failed_steps"].append("training_data_load")
  329. self.logger.error(f"❌ 步骤4失败: {str(e)}")
  330. raise
  331. async def _generate_final_report(self) -> Dict[str, Any]:
  332. """生成最终工作流程报告"""
  333. total_duration = self.workflow_state["end_time"] - self.workflow_state["start_time"]
  334. # 计算最终输出文件
  335. qs_artifacts = self.workflow_state["artifacts"].get("question_sql_generation", {})
  336. final_output_file = qs_artifacts.get("output_file", "")
  337. # 计算最终问题数量
  338. if "sql_validation" in self.workflow_state["artifacts"]:
  339. # 如果有验证步骤,使用验证后的数量
  340. validation_artifacts = self.workflow_state["artifacts"]["sql_validation"]
  341. final_question_count = validation_artifacts.get("valid_sql_count", 0)
  342. else:
  343. # 否则使用生成的数量
  344. final_question_count = qs_artifacts.get("total_questions", 0)
  345. report = {
  346. "success": True,
  347. "workflow_summary": {
  348. "total_duration": round(total_duration, 2),
  349. "completed_steps": self.workflow_state["completed_steps"],
  350. "failed_steps": self.workflow_state["failed_steps"],
  351. "total_steps": len(self.workflow_state["completed_steps"]),
  352. "workflow_started": datetime.fromtimestamp(self.workflow_state["start_time"]).isoformat(),
  353. "workflow_completed": datetime.fromtimestamp(self.workflow_state["end_time"]).isoformat()
  354. },
  355. "input_parameters": {
  356. "db_connection": self._mask_connection_string(self.db_connection),
  357. "table_list_file": self.table_list_file,
  358. "business_context": self.business_context,
  359. "db_name": self.db_name,
  360. "output_directory": str(self.output_dir),
  361. "enable_sql_validation": self.enable_sql_validation,
  362. "enable_llm_repair": self.enable_llm_repair,
  363. "modify_original_file": self.modify_original_file,
  364. "enable_training_data_load": self.enable_training_data_load
  365. },
  366. "processing_results": {
  367. "ddl_md_generation": self.workflow_state["artifacts"].get("ddl_md_generation", {}),
  368. "question_sql_generation": self.workflow_state["artifacts"].get("question_sql_generation", {}),
  369. "sql_validation": self.workflow_state["artifacts"].get("sql_validation", {}),
  370. "training_data_load": self.workflow_state["artifacts"].get("training_data_load", {})
  371. },
  372. "final_outputs": {
  373. "primary_output_file": final_output_file,
  374. "output_directory": str(self.output_dir),
  375. "final_question_count": final_question_count,
  376. "backup_files_created": self.modify_original_file
  377. },
  378. "performance_metrics": {
  379. "step1_duration": round(self.workflow_state["statistics"].get("step1_duration", 0), 2),
  380. "step2_duration": round(self.workflow_state["statistics"].get("step2_duration", 0), 2),
  381. "step3_duration": round(self.workflow_state["statistics"].get("step3_duration", 0), 2),
  382. "step4_duration": round(self.workflow_state["statistics"].get("step4_duration", 0), 2),
  383. "total_duration": round(total_duration, 2)
  384. }
  385. }
  386. return report
  387. async def _generate_error_report(self, error: Exception) -> Dict[str, Any]:
  388. """生成错误报告"""
  389. total_duration = self.workflow_state["end_time"] - self.workflow_state["start_time"]
  390. return {
  391. "success": False,
  392. "error": {
  393. "message": str(error),
  394. "type": type(error).__name__,
  395. "failed_step": self.workflow_state["current_step"]
  396. },
  397. "workflow_summary": {
  398. "total_duration": round(total_duration, 2),
  399. "completed_steps": self.workflow_state["completed_steps"],
  400. "failed_steps": self.workflow_state["failed_steps"],
  401. "workflow_started": datetime.fromtimestamp(self.workflow_state["start_time"]).isoformat() if self.workflow_state["start_time"] else None,
  402. "workflow_failed": datetime.fromtimestamp(self.workflow_state["end_time"]).isoformat() if self.workflow_state["end_time"] else None
  403. },
  404. "partial_results": self.workflow_state["artifacts"],
  405. "input_parameters": {
  406. "db_connection": self._mask_connection_string(self.db_connection),
  407. "table_list_file": self.table_list_file,
  408. "business_context": self.business_context,
  409. "db_name": self.db_name,
  410. "output_directory": str(self.output_dir)
  411. }
  412. }
  413. def _mask_connection_string(self, conn_str: str) -> str:
  414. """隐藏连接字符串中的敏感信息"""
  415. import re
  416. return re.sub(r':[^:@]+@', ':***@', conn_str)
  417. def print_final_summary(self, report: Dict[str, Any]):
  418. """打印最终摘要"""
  419. self.logger.info("=" * 80)
  420. self.logger.info("📊 工作流程执行摘要")
  421. self.logger.info("=" * 80)
  422. if report["success"]:
  423. summary = report["workflow_summary"]
  424. results = report["processing_results"]
  425. outputs = report["final_outputs"]
  426. metrics = report["performance_metrics"]
  427. self.logger.info(f"✅ 工作流程执行成功")
  428. self.logger.info(f"⏱️ 总耗时: {summary['total_duration']} 秒")
  429. self.logger.info(f"📝 完成步骤: {len(summary['completed_steps'])}/{summary['total_steps']}")
  430. # DDL/MD生成结果
  431. if "ddl_md_generation" in results:
  432. ddl_md = results["ddl_md_generation"]
  433. self.logger.info(f"📋 DDL/MD生成: {ddl_md.get('processed_successfully', 0)} 个表成功处理")
  434. # Question-SQL生成结果
  435. if "question_sql_generation" in results:
  436. qs = results["question_sql_generation"]
  437. self.logger.info(f"🤖 Question-SQL生成: {qs.get('total_questions', 0)} 个问答对")
  438. # SQL验证结果
  439. if "sql_validation" in results:
  440. validation = results["sql_validation"]
  441. success_rate = validation.get('success_rate', 0)
  442. self.logger.info(f"🔍 SQL验证: {success_rate:.1%} 成功率 ({validation.get('valid_sql_count', 0)}/{validation.get('original_sql_count', 0)})")
  443. self.logger.info(f"📁 输出目录: {outputs['output_directory']}")
  444. self.logger.info(f"📄 主要输出文件: {outputs['primary_output_file']}")
  445. self.logger.info(f"❓ 最终问题数量: {outputs['final_question_count']}")
  446. else:
  447. error = report["error"]
  448. summary = report["workflow_summary"]
  449. self.logger.error(f"❌ 工作流程执行失败")
  450. self.logger.error(f"💥 失败原因: {error['message']}")
  451. self.logger.error(f"💥 失败步骤: {error['failed_step']}")
  452. self.logger.error(f"⏱️ 执行耗时: {summary['total_duration']} 秒")
  453. self.logger.error(f"✅ 已完成步骤: {', '.join(summary['completed_steps']) if summary['completed_steps'] else '无'}")
  454. self.logger.info("=" * 80)
  455. # 便捷的命令行接口
  456. def setup_argument_parser():
  457. """设置命令行参数解析器"""
  458. import argparse
  459. parser = argparse.ArgumentParser(
  460. description="Schema工作流编排器 - 端到端的Schema处理流程",
  461. formatter_class=argparse.RawDescriptionHelpFormatter,
  462. epilog="""
  463. 示例用法:
  464. # 完整工作流程
  465. python -m data_pipeline.schema_workflow \\
  466. --db-connection "postgresql://user:pass@localhost:5432/highway_db" \\
  467. --table-list tables.txt \\
  468. --business-context "高速公路服务区管理系统" \\
  469. --output-dir ./data_pipeline/training_data/
  470. # 跳过SQL验证
  471. python -m data_pipeline.schema_workflow \\
  472. --db-connection "postgresql://user:pass@localhost:5432/ecommerce_db" \\
  473. --table-list tables.txt \\
  474. --business-context "电商系统" \\
  475. --skip-validation
  476. # 禁用LLM修复
  477. python -m data_pipeline.schema_workflow \\
  478. --db-connection "postgresql://user:pass@localhost:5432/management_db" \\
  479. --table-list tables.txt \\
  480. --business-context "管理系统" \\
  481. --disable-llm-repair
  482. # 跳过训练数据加载
  483. python -m data_pipeline.schema_workflow \\
  484. --db-connection "postgresql://user:pass@localhost:5432/management_db" \\
  485. --table-list tables.txt \\
  486. --business-context "管理系统" \\
  487. --skip-training-load
  488. """
  489. )
  490. # 必需参数
  491. parser.add_argument(
  492. "--db-connection",
  493. required=True,
  494. help="数据库连接字符串 (postgresql://user:pass@host:port/dbname)"
  495. )
  496. parser.add_argument(
  497. "--table-list",
  498. required=True,
  499. help="表清单文件路径"
  500. )
  501. parser.add_argument(
  502. "--business-context",
  503. required=True,
  504. help="业务上下文描述"
  505. )
  506. # 可选参数
  507. parser.add_argument(
  508. "--output-dir",
  509. default="./data_pipeline/training_data/",
  510. help="输出目录(默认:./data_pipeline/training_data/)"
  511. )
  512. parser.add_argument(
  513. "--skip-validation",
  514. action="store_true",
  515. help="跳过SQL验证步骤"
  516. )
  517. parser.add_argument(
  518. "--disable-llm-repair",
  519. action="store_true",
  520. help="禁用LLM修复功能"
  521. )
  522. parser.add_argument(
  523. "--no-modify-file",
  524. action="store_true",
  525. help="不修改原始JSON文件(仅生成报告)"
  526. )
  527. parser.add_argument(
  528. "--skip-training-load",
  529. action="store_true",
  530. help="跳过训练数据加载步骤"
  531. )
  532. parser.add_argument(
  533. "--verbose", "-v",
  534. action="store_true",
  535. help="启用详细日志输出"
  536. )
  537. parser.add_argument(
  538. "--log-file",
  539. help="日志文件路径"
  540. )
  541. return parser
  542. async def main():
  543. """命令行入口点"""
  544. import sys
  545. import os
  546. parser = setup_argument_parser()
  547. args = parser.parse_args()
  548. # 设置日志
  549. setup_logging(
  550. verbose=args.verbose,
  551. log_file=args.log_file
  552. )
  553. # 验证输入文件
  554. if not os.path.exists(args.table_list):
  555. # 为脚本模式生成task_id
  556. from datetime import datetime
  557. script_task_id = f"manual_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
  558. # 使用独立日志系统
  559. from data_pipeline.dp_logging import get_logger
  560. logger = get_logger("SchemaWorkflow", script_task_id)
  561. logger.error(f"错误: 表清单文件不存在: {args.table_list}")
  562. sys.exit(1)
  563. try:
  564. # 创建并执行工作流编排器
  565. orchestrator = SchemaWorkflowOrchestrator(
  566. db_connection=args.db_connection,
  567. table_list_file=args.table_list,
  568. business_context=args.business_context,
  569. output_dir=args.output_dir,
  570. enable_sql_validation=not args.skip_validation,
  571. enable_llm_repair=not args.disable_llm_repair,
  572. modify_original_file=not args.no_modify_file,
  573. enable_training_data_load=not args.skip_training_load
  574. )
  575. # 获取logger用于启动信息
  576. # 为脚本模式生成task_id
  577. from datetime import datetime
  578. script_task_id = f"manual_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
  579. # 使用独立日志系统
  580. from data_pipeline.dp_logging import get_logger
  581. logger = get_logger("SchemaWorkflow", script_task_id)
  582. logger.info(f"🚀 开始执行Schema工作流编排...")
  583. logger.info(f"📁 输出目录: {args.output_dir}")
  584. logger.info(f"📋 表清单: {args.table_list}")
  585. logger.info(f"🏢 业务背景: {args.business_context}")
  586. logger.info(f"💾 数据库: {orchestrator.db_name}")
  587. logger.info(f"🔍 SQL验证: {'启用' if not args.skip_validation else '禁用'}")
  588. logger.info(f"🔧 LLM修复: {'启用' if not args.disable_llm_repair else '禁用'}")
  589. logger.info(f"🎯 训练数据加载: {'启用' if not args.skip_training_load else '禁用'}")
  590. # 执行完整工作流程
  591. report = await orchestrator.execute_complete_workflow()
  592. # 打印详细摘要
  593. orchestrator.print_final_summary(report)
  594. # 输出结果并设置退出码
  595. if report["success"]:
  596. if report["processing_results"].get("sql_validation", {}).get("success_rate", 1.0) >= 0.8:
  597. logger.info(f"\n🎉 工作流程执行成功!")
  598. exit_code = 0 # 完全成功
  599. else:
  600. logger.warning(f"\n⚠️ 工作流程执行完成,但SQL验证成功率较低")
  601. exit_code = 1 # 部分成功
  602. else:
  603. logger.error(f"\n❌ 工作流程执行失败")
  604. exit_code = 2 # 失败
  605. logger.info(f"📄 主要输出文件: {report['final_outputs']['primary_output_file']}")
  606. sys.exit(exit_code)
  607. except KeyboardInterrupt:
  608. logger.info("\n\n⏹️ 用户中断,程序退出")
  609. sys.exit(130)
  610. except Exception as e:
  611. logger.error(f"\n❌ 程序执行失败: {e}")
  612. if args.verbose:
  613. import traceback
  614. traceback.print_exc()
  615. sys.exit(1)
  616. if __name__ == "__main__":
  617. asyncio.run(main())