schema_workflow.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844
  1. """
  2. Schema工作流编排器
  3. 统一管理完整的数据库Schema处理流程
  4. """
  5. import asyncio
  6. import time
  7. import logging
  8. from typing import Dict, Any, List, Optional
  9. from pathlib import Path
  10. from datetime import datetime
  11. from data_pipeline.ddl_generation.training_data_agent import SchemaTrainingDataAgent
  12. from data_pipeline.qa_generation.qs_agent import QuestionSQLGenerationAgent
  13. from data_pipeline.validators.sql_validation_agent import SQLValidationAgent
  14. from data_pipeline.config import SCHEMA_TOOLS_CONFIG
  15. from data_pipeline.dp_logging import get_logger
  16. from data_pipeline.utils.logger import setup_logging
  17. class SchemaWorkflowOrchestrator:
  18. """端到端的Schema处理编排器 - 完整工作流程"""
  19. def __init__(self,
  20. db_connection: str,
  21. table_list_file: str,
  22. business_context: str,
  23. output_dir: str = None,
  24. task_id: str = None,
  25. enable_sql_validation: bool = True,
  26. enable_llm_repair: bool = True,
  27. modify_original_file: bool = True,
  28. enable_training_data_load: bool = True):
  29. """
  30. 初始化Schema工作流编排器
  31. Args:
  32. db_connection: 数据库连接字符串 (postgresql://user:pass@host:port/dbname)
  33. table_list_file: 表清单文件路径
  34. business_context: 业务上下文描述
  35. output_dir: 输出目录
  36. task_id: 任务ID (API模式传递,脚本模式自动生成)
  37. enable_sql_validation: 是否启用SQL验证
  38. enable_llm_repair: 是否启用LLM修复功能
  39. modify_original_file: 是否修改原始JSON文件
  40. enable_training_data_load: 是否启用训练数据加载
  41. """
  42. self.db_connection = db_connection
  43. self.table_list_file = table_list_file
  44. self.business_context = business_context
  45. self.db_name = self._extract_db_name_from_connection(db_connection)
  46. self.enable_sql_validation = enable_sql_validation
  47. self.enable_llm_repair = enable_llm_repair
  48. self.modify_original_file = modify_original_file
  49. self.enable_training_data_load = enable_training_data_load
  50. # 处理task_id
  51. if task_id is None:
  52. # 脚本模式:自动生成manual开头的task_id
  53. self.task_id = f"manual_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
  54. else:
  55. # API模式:使用传递的task_id
  56. self.task_id = task_id
  57. # 设置输出目录
  58. if output_dir is None:
  59. # 脚本模式或未指定输出目录时,使用默认基础目录
  60. # 获取项目根目录的绝对路径
  61. project_root = Path(__file__).parent.parent
  62. base_dir = project_root / "data_pipeline" / "training_data"
  63. else:
  64. # 用户指定了输出目录时,使用指定的目录作为基础目录
  65. base_dir = Path(output_dir)
  66. # 无论哪种情况,都在基础目录下创建task子目录
  67. self.output_dir = base_dir / self.task_id
  68. # 确保输出目录存在
  69. self.output_dir.mkdir(parents=True, exist_ok=True)
  70. # 初始化独立日志系统
  71. self.logger = get_logger("SchemaWorkflowOrchestrator", self.task_id)
  72. # 工作流程状态
  73. self.workflow_state = {
  74. "start_time": None,
  75. "end_time": None,
  76. "current_step": None,
  77. "completed_steps": [],
  78. "failed_steps": [],
  79. "artifacts": {}, # 存储各步骤产生的文件
  80. "statistics": {}
  81. }
  82. def _extract_db_name_from_connection(self, connection_string: str) -> str:
  83. """
  84. 从数据库连接字符串中提取数据库名称
  85. Args:
  86. connection_string: PostgreSQL连接字符串
  87. Returns:
  88. str: 数据库名称
  89. """
  90. try:
  91. # 处理标准的PostgreSQL连接字符串: postgresql://user:pass@host:port/dbname
  92. if '/' in connection_string:
  93. # 取最后一个 '/' 后面的部分作为数据库名
  94. db_name = connection_string.split('/')[-1]
  95. # 移除可能的查询参数
  96. if '?' in db_name:
  97. db_name = db_name.split('?')[0]
  98. return db_name if db_name else "database"
  99. else:
  100. return "database"
  101. except Exception:
  102. return "database"
  103. async def execute_complete_workflow(self) -> Dict[str, Any]:
  104. """
  105. 执行完整的Schema处理工作流程
  106. Returns:
  107. 完整的工作流程报告
  108. """
  109. self.workflow_state["start_time"] = time.time()
  110. self.logger.info("🚀 开始执行Schema工作流编排")
  111. self.logger.info(f"📁 输出目录: {self.output_dir}")
  112. self.logger.info(f"🏢 业务背景: {self.business_context}")
  113. self.logger.info(f"💾 数据库: {self.db_name}")
  114. try:
  115. # 步骤1: 生成DDL和MD文件
  116. await self._execute_step_1_ddl_md_generation()
  117. # 步骤2: 生成Question-SQL对
  118. await self._execute_step_2_question_sql_generation()
  119. # 步骤3: 验证和修正SQL(可选)
  120. if self.enable_sql_validation:
  121. await self._execute_step_3_sql_validation()
  122. else:
  123. self.logger.info("⏭️ 跳过SQL验证步骤")
  124. # 步骤4: 训练数据加载(可选)
  125. if self.enable_training_data_load:
  126. await self._execute_step_4_training_data_load()
  127. else:
  128. self.logger.info("⏭️ 跳过训练数据加载步骤")
  129. # 设置结束时间
  130. self.workflow_state["end_time"] = time.time()
  131. # 生成最终报告
  132. final_report = await self._generate_final_report()
  133. self.logger.info("✅ Schema工作流编排完成")
  134. return final_report
  135. except Exception as e:
  136. self.workflow_state["end_time"] = time.time()
  137. self.logger.exception(f"❌ 工作流程执行失败: {str(e)}")
  138. error_report = await self._generate_error_report(e)
  139. return error_report
  140. async def _execute_step_1_ddl_md_generation(self):
  141. """步骤1: 生成DDL和MD文件"""
  142. self.workflow_state["current_step"] = "ddl_md_generation"
  143. self.logger.info("=" * 60)
  144. self.logger.info("📝 步骤1: 开始生成DDL和MD文件")
  145. self.logger.info("=" * 60)
  146. step_start_time = time.time()
  147. try:
  148. # 创建DDL/MD生成Agent
  149. ddl_md_agent = SchemaTrainingDataAgent(
  150. db_connection=self.db_connection,
  151. table_list_file=self.table_list_file,
  152. business_context=self.business_context,
  153. output_dir=str(self.output_dir),
  154. task_id=self.task_id, # 传递task_id
  155. pipeline="full"
  156. )
  157. # 执行DDL/MD生成
  158. ddl_md_result = await ddl_md_agent.generate_training_data()
  159. step_duration = time.time() - step_start_time
  160. # 记录结果
  161. self.workflow_state["completed_steps"].append("ddl_md_generation")
  162. self.workflow_state["artifacts"]["ddl_md_generation"] = {
  163. "total_tables": ddl_md_result.get("summary", {}).get("total_tables", 0),
  164. "processed_successfully": ddl_md_result.get("summary", {}).get("processed_successfully", 0),
  165. "failed": ddl_md_result.get("summary", {}).get("failed", 0),
  166. "files_generated": ddl_md_result.get("statistics", {}).get("total_files_generated", 0),
  167. "duration": step_duration
  168. }
  169. self.workflow_state["statistics"]["step1_duration"] = step_duration
  170. processed_tables = ddl_md_result.get("summary", {}).get("processed_successfully", 0)
  171. # 获取文件统计信息
  172. statistics = ddl_md_result.get("statistics", {})
  173. md_files = statistics.get("md_files_generated", 0)
  174. ddl_files = statistics.get("ddl_files_generated", 0)
  175. if md_files > 0 and ddl_files > 0:
  176. file_info = f"生成 {md_files} 个MD文件,{ddl_files} 个DDL文件"
  177. elif md_files > 0:
  178. file_info = f"生成 {md_files} 个MD文件"
  179. elif ddl_files > 0:
  180. file_info = f"生成 {ddl_files} 个DDL文件"
  181. else:
  182. file_info = "未生成文件"
  183. self.logger.info(f"✅ 步骤1完成: 成功处理 {processed_tables} 个表,{file_info},耗时 {step_duration:.2f}秒")
  184. except Exception as e:
  185. self.workflow_state["failed_steps"].append("ddl_md_generation")
  186. self.logger.error(f"❌ 步骤1失败: {str(e)}")
  187. raise
  188. async def _execute_step_2_question_sql_generation(self):
  189. """步骤2: 生成Question-SQL对"""
  190. self.workflow_state["current_step"] = "question_sql_generation"
  191. self.logger.info("=" * 60)
  192. self.logger.info("🤖 步骤2: 开始生成Question-SQL对")
  193. self.logger.info("=" * 60)
  194. step_start_time = time.time()
  195. try:
  196. # 创建Question-SQL生成Agent
  197. qs_agent = QuestionSQLGenerationAgent(
  198. output_dir=str(self.output_dir),
  199. table_list_file=self.table_list_file,
  200. business_context=self.business_context,
  201. db_name=self.db_name,
  202. task_id=self.task_id # 传递task_id
  203. )
  204. # 执行Question-SQL生成
  205. qs_result = await qs_agent.generate()
  206. step_duration = time.time() - step_start_time
  207. # 记录结果
  208. self.workflow_state["completed_steps"].append("question_sql_generation")
  209. self.workflow_state["artifacts"]["question_sql_generation"] = {
  210. "output_file": str(qs_result.get("output_file", "")),
  211. "total_questions": qs_result.get("total_questions", 0),
  212. "total_themes": qs_result.get("total_themes", 0),
  213. "successful_themes": qs_result.get("successful_themes", 0),
  214. "failed_themes": qs_result.get("failed_themes", []),
  215. "duration": step_duration
  216. }
  217. self.workflow_state["statistics"]["step2_duration"] = step_duration
  218. total_questions = qs_result.get("total_questions", 0)
  219. self.logger.info(f"✅ 步骤2完成: 生成了 {total_questions} 个问答对,耗时 {step_duration:.2f}秒")
  220. except Exception as e:
  221. self.workflow_state["failed_steps"].append("question_sql_generation")
  222. self.logger.error(f"❌ 步骤2失败: {str(e)}")
  223. raise
  224. async def _execute_step_3_sql_validation(self):
  225. """步骤3: 验证和修正SQL"""
  226. self.workflow_state["current_step"] = "sql_validation"
  227. self.logger.info("=" * 60)
  228. self.logger.info("🔍 步骤3: 开始验证和修正SQL")
  229. self.logger.info("=" * 60)
  230. step_start_time = time.time()
  231. try:
  232. # 首先尝试从workflow_state获取文件(完整工作流模式)
  233. qs_artifacts = self.workflow_state["artifacts"].get("question_sql_generation", {})
  234. qs_file = qs_artifacts.get("output_file")
  235. # 如果workflow_state中没有文件信息,则在任务目录中查找(分步执行模式)
  236. if not qs_file or not Path(qs_file).exists():
  237. self.logger.info("🔍 从workflow_state未找到文件,在任务目录中查找Question-SQL文件...")
  238. # 在输出目录中查找匹配的文件
  239. possible_files = list(self.output_dir.glob("*_pair.json"))
  240. if not possible_files:
  241. raise FileNotFoundError(
  242. f"在任务目录 {self.output_dir} 中找不到Question-SQL文件(*_pair.json)。"
  243. f"请确保已执行qa_generation步骤并生成了Question-SQL对文件。"
  244. )
  245. # 选择最新的文件(按修改时间排序)
  246. qs_file = str(max(possible_files, key=lambda f: f.stat().st_mtime))
  247. self.logger.info(f"🎯 找到Question-SQL文件: {qs_file}")
  248. self.logger.info(f"📄 验证文件: {qs_file}")
  249. # 创建SQL验证Agent,通过参数传递配置而非修改全局配置
  250. sql_validator = SQLValidationAgent(
  251. db_connection=self.db_connection,
  252. input_file=str(qs_file),
  253. output_dir=str(self.output_dir),
  254. task_id=self.task_id, # 传递task_id
  255. enable_sql_repair=self.enable_llm_repair,
  256. modify_original_file=self.modify_original_file
  257. )
  258. # 执行SQL验证和修正
  259. validation_result = await sql_validator.validate()
  260. step_duration = time.time() - step_start_time
  261. # 记录结果
  262. self.workflow_state["completed_steps"].append("sql_validation")
  263. summary = validation_result.get("summary", {})
  264. self.workflow_state["artifacts"]["sql_validation"] = {
  265. "original_sql_count": summary.get("total_questions", 0),
  266. "valid_sql_count": summary.get("valid_sqls", 0),
  267. "invalid_sql_count": summary.get("invalid_sqls", 0),
  268. "success_rate": summary.get("success_rate", 0),
  269. "repair_stats": summary.get("repair_stats", {}),
  270. "file_modification_stats": summary.get("file_modification_stats", {}),
  271. "average_execution_time": summary.get("average_execution_time", 0),
  272. "total_retries": summary.get("total_retries", 0),
  273. "duration": step_duration
  274. }
  275. self.workflow_state["statistics"]["step3_duration"] = step_duration
  276. success_rate = summary.get("success_rate", 0)
  277. valid_count = summary.get("valid_sqls", 0)
  278. total_count = summary.get("total_questions", 0)
  279. self.logger.info(f"✅ 步骤3完成: SQL验证成功率 {success_rate:.1%} ({valid_count}/{total_count}),耗时 {step_duration:.2f}秒")
  280. # 显示修复统计
  281. repair_stats = summary.get("repair_stats", {})
  282. if repair_stats.get("attempted", 0) > 0:
  283. self.logger.info(f"🔧 修复统计: 尝试 {repair_stats['attempted']},成功 {repair_stats['successful']},失败 {repair_stats['failed']}")
  284. # 显示文件修改统计
  285. file_stats = summary.get("file_modification_stats", {})
  286. if file_stats.get("modified", 0) > 0 or file_stats.get("deleted", 0) > 0:
  287. self.logger.info(f"📝 文件修改: 更新 {file_stats.get('modified', 0)} 个SQL,删除 {file_stats.get('deleted', 0)} 个无效项")
  288. except Exception as e:
  289. self.workflow_state["failed_steps"].append("sql_validation")
  290. self.logger.error(f"❌ 步骤3失败: {str(e)}")
  291. raise
  292. async def _execute_step_4_training_data_load(self):
  293. """步骤4: 训练数据加载"""
  294. self.workflow_state["current_step"] = "training_data_load"
  295. self.logger.info("=" * 60)
  296. self.logger.info("🎯 步骤4: 开始加载训练数据")
  297. self.logger.info("=" * 60)
  298. step_start_time = time.time()
  299. try:
  300. # 确保输出目录存在所需的训练数据
  301. training_data_dir = str(self.output_dir)
  302. self.logger.info(f"📁 训练数据目录: {training_data_dir}")
  303. # 导入训练器模块
  304. import sys
  305. import os
  306. sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  307. from data_pipeline.trainer.run_training import process_training_files
  308. # 执行训练数据加载
  309. self.logger.info("🔄 开始处理训练文件...")
  310. load_successful = process_training_files(training_data_dir, self.task_id)
  311. step_duration = time.time() - step_start_time
  312. if load_successful:
  313. # 获取统计信息
  314. from data_pipeline.trainer.vanna_trainer import flush_training, shutdown_trainer
  315. # 刷新批处理器
  316. self.logger.info("🔄 刷新批处理器...")
  317. flush_training()
  318. shutdown_trainer()
  319. # 验证加载结果
  320. try:
  321. from core.vanna_llm_factory import create_vanna_instance
  322. vn = create_vanna_instance()
  323. training_data = vn.get_training_data()
  324. if training_data is not None and not training_data.empty:
  325. total_records = len(training_data)
  326. self.logger.info(f"✅ 成功加载 {total_records} 条训练数据")
  327. # 统计数据类型
  328. if 'training_data_type' in training_data.columns:
  329. type_counts = training_data['training_data_type'].value_counts().to_dict()
  330. else:
  331. type_counts = {}
  332. else:
  333. total_records = 0
  334. type_counts = {}
  335. self.logger.warning("⚠️ 未能验证训练数据加载结果")
  336. except Exception as e:
  337. self.logger.warning(f"⚠️ 验证训练数据时出错: {e}")
  338. total_records = 0
  339. type_counts = {}
  340. # 记录结果
  341. self.workflow_state["completed_steps"].append("training_data_load")
  342. self.workflow_state["artifacts"]["training_data_load"] = {
  343. "training_data_dir": training_data_dir,
  344. "load_successful": True,
  345. "total_records": total_records,
  346. "data_type_counts": type_counts,
  347. "duration": step_duration
  348. }
  349. self.workflow_state["statistics"]["step4_duration"] = step_duration
  350. self.logger.info(f"✅ 步骤4完成: 成功加载训练数据,耗时 {step_duration:.2f}秒")
  351. else:
  352. raise Exception("训练数据加载失败:未找到可处理的训练文件")
  353. except Exception as e:
  354. self.workflow_state["failed_steps"].append("training_data_load")
  355. self.logger.error(f"❌ 步骤4失败: {str(e)}")
  356. raise
  357. async def _generate_final_report(self) -> Dict[str, Any]:
  358. """生成最终工作流程报告"""
  359. total_duration = self.workflow_state["end_time"] - self.workflow_state["start_time"]
  360. # 计算最终输出文件
  361. qs_artifacts = self.workflow_state["artifacts"].get("question_sql_generation", {})
  362. final_output_file = qs_artifacts.get("output_file", "")
  363. # 计算最终问题数量
  364. if "sql_validation" in self.workflow_state["artifacts"]:
  365. # 如果有验证步骤,使用验证后的数量
  366. validation_artifacts = self.workflow_state["artifacts"]["sql_validation"]
  367. final_question_count = validation_artifacts.get("valid_sql_count", 0)
  368. else:
  369. # 否则使用生成的数量
  370. final_question_count = qs_artifacts.get("total_questions", 0)
  371. report = {
  372. "success": True,
  373. "workflow_summary": {
  374. "total_duration": round(total_duration, 2),
  375. "completed_steps": self.workflow_state["completed_steps"],
  376. "failed_steps": self.workflow_state["failed_steps"],
  377. "total_steps": len(self.workflow_state["completed_steps"]),
  378. "workflow_started": datetime.fromtimestamp(self.workflow_state["start_time"]).isoformat(),
  379. "workflow_completed": datetime.fromtimestamp(self.workflow_state["end_time"]).isoformat()
  380. },
  381. "input_parameters": {
  382. "db_connection": self._mask_connection_string(self.db_connection),
  383. "table_list_file": self.table_list_file,
  384. "business_context": self.business_context,
  385. "db_name": self.db_name,
  386. "output_directory": str(self.output_dir),
  387. "enable_sql_validation": self.enable_sql_validation,
  388. "enable_llm_repair": self.enable_llm_repair,
  389. "modify_original_file": self.modify_original_file,
  390. "enable_training_data_load": self.enable_training_data_load
  391. },
  392. "processing_results": {
  393. "ddl_md_generation": self.workflow_state["artifacts"].get("ddl_md_generation", {}),
  394. "question_sql_generation": self.workflow_state["artifacts"].get("question_sql_generation", {}),
  395. "sql_validation": self.workflow_state["artifacts"].get("sql_validation", {}),
  396. "training_data_load": self.workflow_state["artifacts"].get("training_data_load", {})
  397. },
  398. "final_outputs": {
  399. "primary_output_file": final_output_file,
  400. "output_directory": str(self.output_dir),
  401. "final_question_count": final_question_count,
  402. "backup_files_created": self.modify_original_file
  403. },
  404. "performance_metrics": {
  405. "step1_duration": round(self.workflow_state["statistics"].get("step1_duration", 0), 2),
  406. "step2_duration": round(self.workflow_state["statistics"].get("step2_duration", 0), 2),
  407. "step3_duration": round(self.workflow_state["statistics"].get("step3_duration", 0), 2),
  408. "step4_duration": round(self.workflow_state["statistics"].get("step4_duration", 0), 2),
  409. "total_duration": round(total_duration, 2)
  410. }
  411. }
  412. return report
  413. async def _generate_error_report(self, error: Exception) -> Dict[str, Any]:
  414. """生成错误报告"""
  415. total_duration = self.workflow_state["end_time"] - self.workflow_state["start_time"]
  416. return {
  417. "success": False,
  418. "error": {
  419. "message": str(error),
  420. "type": type(error).__name__,
  421. "failed_step": self.workflow_state["current_step"]
  422. },
  423. "workflow_summary": {
  424. "total_duration": round(total_duration, 2),
  425. "completed_steps": self.workflow_state["completed_steps"],
  426. "failed_steps": self.workflow_state["failed_steps"],
  427. "workflow_started": datetime.fromtimestamp(self.workflow_state["start_time"]).isoformat() if self.workflow_state["start_time"] else None,
  428. "workflow_failed": datetime.fromtimestamp(self.workflow_state["end_time"]).isoformat() if self.workflow_state["end_time"] else None
  429. },
  430. "partial_results": self.workflow_state["artifacts"],
  431. "input_parameters": {
  432. "db_connection": self._mask_connection_string(self.db_connection),
  433. "table_list_file": self.table_list_file,
  434. "business_context": self.business_context,
  435. "db_name": self.db_name,
  436. "output_directory": str(self.output_dir)
  437. }
  438. }
  439. def _mask_connection_string(self, conn_str: str) -> str:
  440. """隐藏连接字符串中的敏感信息"""
  441. import re
  442. return re.sub(r':[^:@]+@', ':***@', conn_str)
  443. def print_final_summary(self, report: Dict[str, Any]):
  444. """打印最终摘要"""
  445. self.logger.info("=" * 80)
  446. self.logger.info("📊 工作流程执行摘要")
  447. self.logger.info("=" * 80)
  448. if report["success"]:
  449. summary = report["workflow_summary"]
  450. results = report["processing_results"]
  451. outputs = report["final_outputs"]
  452. metrics = report["performance_metrics"]
  453. self.logger.info(f"✅ 工作流程执行成功")
  454. self.logger.info(f"⏱️ 总耗时: {summary['total_duration']} 秒")
  455. self.logger.info(f"📝 完成步骤: {len(summary['completed_steps'])}/{summary['total_steps']}")
  456. # 获取并显示embedding模型信息
  457. try:
  458. from common.utils import get_current_model_info
  459. model_info = get_current_model_info()
  460. self.logger.info(f"🤖 使用的embedding模型: {model_info['embedding_model']} ({model_info['embedding_type']})")
  461. except Exception as e:
  462. self.logger.info(f"🤖 使用的embedding模型: 未知 (获取信息失败: {e})")
  463. # 解析并显示源库信息
  464. try:
  465. db_info = self._parse_db_connection(self.db_connection)
  466. self.logger.info(f"🗄️ 源库名: {db_info['dbname']}")
  467. self.logger.info(f"🏠 源库Hostname: {db_info['host']}:{db_info['port']}")
  468. except Exception as e:
  469. self.logger.info(f"🗄️ 源库名: {self.db_name}")
  470. self.logger.info(f"🏠 源库Hostname: 未知 (解析失败: {e})")
  471. # DDL/MD生成结果 - 增加详细的文件统计
  472. if "ddl_md_generation" in results:
  473. ddl_md = results["ddl_md_generation"]
  474. self.logger.info(f"📋 DDL/MD生成: {ddl_md.get('processed_successfully', 0)} 个表成功处理")
  475. # 尝试获取详细的文件统计信息
  476. try:
  477. # 从输出目录统计实际生成的文件
  478. output_path = Path(self.output_dir)
  479. if output_path.exists():
  480. md_files = list(output_path.glob("*.md"))
  481. ddl_files = list(output_path.glob("*.ddl"))
  482. md_count = len([f for f in md_files if not f.name.startswith('metadata')]) # 排除metadata.md
  483. ddl_count = len(ddl_files)
  484. self.logger.info(f"📁 生成文件: {md_count} 个MD文件,{ddl_count} 个DDL文件")
  485. else:
  486. self.logger.info(f"📁 生成文件: 统计信息不可用")
  487. except Exception as e:
  488. self.logger.info(f"📁 生成文件: 统计失败 ({e})")
  489. # Question-SQL生成结果
  490. if "question_sql_generation" in results:
  491. qs = results["question_sql_generation"]
  492. self.logger.info(f"🤖 Question-SQL生成: {qs.get('total_questions', 0)} 个问答对")
  493. # SQL验证结果
  494. if "sql_validation" in results:
  495. validation = results["sql_validation"]
  496. success_rate = validation.get('success_rate', 0)
  497. self.logger.info(f"🔍 SQL验证: {success_rate:.1%} 成功率 ({validation.get('valid_sql_count', 0)}/{validation.get('original_sql_count', 0)})")
  498. self.logger.info(f"📁 输出目录: {outputs['output_directory']}")
  499. self.logger.info(f"📄 QUESTION/SQL键值对文件: {outputs['primary_output_file']}")
  500. self.logger.info(f"❓ 最终问题数量: {outputs['final_question_count']}")
  501. # 配置参数反馈
  502. self.logger.info("⚙️ 执行配置:")
  503. self.logger.info(f" 🔍 SQL验证: {'启用' if self.enable_sql_validation else '禁用'}")
  504. self.logger.info(f" 🔧 LLM修复: {'启用' if self.enable_llm_repair else '禁用'}")
  505. self.logger.info(f" 📝 文件修改: {'启用' if self.modify_original_file else '禁用'}")
  506. if not self.enable_training_data_load:
  507. self.logger.info(f" ⏭️ 训练数据加载: 已跳过")
  508. else:
  509. self.logger.info(f" 📚 训练数据加载: 启用")
  510. else:
  511. error = report["error"]
  512. summary = report["workflow_summary"]
  513. self.logger.error(f"❌ 工作流程执行失败")
  514. self.logger.error(f"💥 失败原因: {error['message']}")
  515. self.logger.error(f"💥 失败步骤: {error['failed_step']}")
  516. self.logger.error(f"⏱️ 执行耗时: {summary['total_duration']} 秒")
  517. self.logger.error(f"✅ 已完成步骤: {', '.join(summary['completed_steps']) if summary['completed_steps'] else '无'}")
  518. self.logger.info("=" * 80)
  519. def _parse_db_connection(self, db_connection: str) -> Dict[str, str]:
  520. """
  521. 解析PostgreSQL连接字符串
  522. Args:
  523. db_connection: PostgreSQL连接字符串,格式为 postgresql://user:password@host:port/dbname
  524. Returns:
  525. 包含数据库连接参数的字典
  526. """
  527. import re
  528. # 解析连接字符串的正则表达式
  529. pattern = r'postgresql://([^:]+):([^@]+)@([^:]+):(\d+)/(.+)'
  530. match = re.match(pattern, db_connection)
  531. if not match:
  532. raise ValueError(f"无效的PostgreSQL连接字符串格式: {db_connection}")
  533. user, password, host, port, dbname = match.groups()
  534. return {
  535. 'user': user,
  536. 'password': password,
  537. 'host': host,
  538. 'port': port,
  539. 'dbname': dbname
  540. }
  541. # 便捷的命令行接口
  542. def setup_argument_parser():
  543. """设置命令行参数解析器"""
  544. import argparse
  545. parser = argparse.ArgumentParser(
  546. description="Schema工作流编排器 - 端到端的Schema处理流程",
  547. formatter_class=argparse.RawDescriptionHelpFormatter,
  548. epilog="""
  549. 示例用法:
  550. # 完整工作流程(会在指定目录下创建任务子目录)
  551. python -m data_pipeline.schema_workflow \\
  552. --db-connection "postgresql://user:pass@localhost:5432/highway_db" \\
  553. --table-list tables.txt \\
  554. --business-context "高速公路服务区管理系统" \\
  555. --output-dir ./data_pipeline/training_data/
  556. # 跳过SQL验证
  557. python -m data_pipeline.schema_workflow \\
  558. --db-connection "postgresql://user:pass@localhost:5432/ecommerce_db" \\
  559. --table-list tables.txt \\
  560. --business-context "电商系统" \\
  561. --skip-validation
  562. # 禁用LLM修复
  563. python -m data_pipeline.schema_workflow \\
  564. --db-connection "postgresql://user:pass@localhost:5432/management_db" \\
  565. --table-list tables.txt \\
  566. --business-context "管理系统" \\
  567. --disable-llm-repair
  568. # 跳过训练数据加载
  569. python -m data_pipeline.schema_workflow \\
  570. --db-connection "postgresql://user:pass@localhost:5432/management_db" \\
  571. --table-list tables.txt \\
  572. --business-context "管理系统" \\
  573. --skip-training-load
  574. """
  575. )
  576. # 必需参数
  577. parser.add_argument(
  578. "--db-connection",
  579. required=True,
  580. help="数据库连接字符串 (postgresql://user:pass@host:port/dbname)"
  581. )
  582. parser.add_argument(
  583. "--table-list",
  584. required=True,
  585. help="表清单文件路径"
  586. )
  587. parser.add_argument(
  588. "--business-context",
  589. required=True,
  590. help="业务上下文描述"
  591. )
  592. # 可选参数
  593. parser.add_argument(
  594. "--output-dir",
  595. default=None,
  596. help="基础输出目录,将在此目录下创建任务子目录(默认:./data_pipeline/training_data/)"
  597. )
  598. parser.add_argument(
  599. "--skip-validation",
  600. action="store_true",
  601. help="跳过SQL验证步骤"
  602. )
  603. parser.add_argument(
  604. "--disable-llm-repair",
  605. action="store_true",
  606. help="禁用LLM修复功能"
  607. )
  608. parser.add_argument(
  609. "--no-modify-file",
  610. action="store_true",
  611. help="不修改原始JSON文件(仅生成报告)"
  612. )
  613. parser.add_argument(
  614. "--skip-training-load",
  615. action="store_true",
  616. help="跳过训练数据加载步骤"
  617. )
  618. parser.add_argument(
  619. "--verbose", "-v",
  620. action="store_true",
  621. help="启用详细日志输出"
  622. )
  623. parser.add_argument(
  624. "--log-file",
  625. help="日志文件路径"
  626. )
  627. return parser
  628. async def main():
  629. """命令行入口点"""
  630. import sys
  631. import os
  632. parser = setup_argument_parser()
  633. args = parser.parse_args()
  634. # 设置日志
  635. setup_logging(
  636. verbose=args.verbose,
  637. log_file=args.log_file
  638. )
  639. # 验证输入文件
  640. if not os.path.exists(args.table_list):
  641. # 为脚本模式生成task_id
  642. from datetime import datetime
  643. script_task_id = f"manual_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
  644. # 使用独立日志系统
  645. from data_pipeline.dp_logging import get_logger
  646. logger = get_logger("SchemaWorkflow", script_task_id)
  647. logger.error(f"错误: 表清单文件不存在: {args.table_list}")
  648. sys.exit(1)
  649. try:
  650. # 创建并执行工作流编排器
  651. orchestrator = SchemaWorkflowOrchestrator(
  652. db_connection=args.db_connection,
  653. table_list_file=args.table_list,
  654. business_context=args.business_context,
  655. output_dir=args.output_dir,
  656. enable_sql_validation=not args.skip_validation,
  657. enable_llm_repair=not args.disable_llm_repair,
  658. modify_original_file=not args.no_modify_file,
  659. enable_training_data_load=not args.skip_training_load
  660. )
  661. # 获取logger用于启动信息
  662. # 为脚本模式生成task_id
  663. from datetime import datetime
  664. script_task_id = f"manual_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
  665. # 使用独立日志系统
  666. from data_pipeline.dp_logging import get_logger
  667. logger = get_logger("SchemaWorkflow", script_task_id)
  668. logger.info(f"🚀 开始执行Schema工作流编排...")
  669. logger.info(f"📁 输出目录: {orchestrator.output_dir}")
  670. logger.info(f"📋 表清单: {args.table_list}")
  671. logger.info(f"🏢 业务背景: {args.business_context}")
  672. logger.info(f"💾 数据库: {orchestrator.db_name}")
  673. logger.info(f"🔍 SQL验证: {'启用' if not args.skip_validation else '禁用'}")
  674. logger.info(f"🔧 LLM修复: {'启用' if not args.disable_llm_repair else '禁用'}")
  675. logger.info(f"🎯 训练数据加载: {'启用' if not args.skip_training_load else '禁用'}")
  676. # 执行完整工作流程
  677. report = await orchestrator.execute_complete_workflow()
  678. # 打印详细摘要
  679. orchestrator.print_final_summary(report)
  680. # 输出结果并设置退出码
  681. if report["success"]:
  682. if report["processing_results"].get("sql_validation", {}).get("success_rate", 1.0) >= 0.8:
  683. logger.info(f"\n🎉 工作流程执行成功!")
  684. exit_code = 0 # 完全成功
  685. else:
  686. logger.warning(f"\n⚠️ 工作流程执行完成,但SQL验证成功率较低")
  687. exit_code = 1 # 部分成功
  688. else:
  689. logger.error(f"\n❌ 工作流程执行失败")
  690. exit_code = 2 # 失败
  691. sys.exit(exit_code)
  692. except KeyboardInterrupt:
  693. logger.info("\n\n⏹️ 用户中断,程序退出")
  694. sys.exit(130)
  695. except Exception as e:
  696. logger.error(f"\n❌ 程序执行失败: {e}")
  697. if args.verbose:
  698. import traceback
  699. traceback.print_exc()
  700. sys.exit(1)
  701. if __name__ == "__main__":
  702. asyncio.run(main())