schema_workflow.py 43 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008
  1. """
  2. Schema工作流编排器
  3. 统一管理完整的数据库Schema处理流程
  4. """
  5. import asyncio
  6. import time
  7. import logging
  8. from typing import Dict, Any, List, Optional
  9. from pathlib import Path
  10. from datetime import datetime
  11. from data_pipeline.ddl_generation.training_data_agent import SchemaTrainingDataAgent
  12. from data_pipeline.qa_generation.qs_agent import QuestionSQLGenerationAgent
  13. from data_pipeline.validators.sql_validation_agent import SQLValidationAgent
  14. from data_pipeline.config import SCHEMA_TOOLS_CONFIG
  15. from data_pipeline.dp_logging import get_logger
  16. from data_pipeline.utils.logger import setup_logging
  17. class SchemaWorkflowOrchestrator:
  18. """端到端的Schema处理编排器 - 完整工作流程"""
  19. def __init__(self,
  20. db_connection: str,
  21. table_list_file: str,
  22. business_context: str,
  23. output_dir: str = None,
  24. task_id: str = None,
  25. enable_sql_validation: bool = True,
  26. enable_llm_repair: bool = True,
  27. modify_original_file: bool = True,
  28. enable_training_data_load: bool = True,
  29. backup_vector_tables: bool = False,
  30. truncate_vector_tables: bool = False,
  31. skip_training: bool = False):
  32. """
  33. 初始化Schema工作流编排器
  34. Args:
  35. db_connection: 数据库连接字符串 (postgresql://user:pass@host:port/dbname)
  36. table_list_file: 表清单文件路径
  37. business_context: 业务上下文描述
  38. output_dir: 输出目录
  39. task_id: 任务ID (API模式传递,脚本模式自动生成)
  40. enable_sql_validation: 是否启用SQL验证
  41. enable_llm_repair: 是否启用LLM修复功能
  42. modify_original_file: 是否修改原始JSON文件
  43. enable_training_data_load: 是否启用训练数据加载
  44. backup_vector_tables: 是否备份vector表数据
  45. truncate_vector_tables: 是否清空vector表数据(自动启用备份)
  46. skip_training: 是否跳过训练文件处理,仅执行Vector表管理
  47. """
  48. self.db_connection = db_connection
  49. self.table_list_file = table_list_file
  50. self.business_context = business_context
  51. self.db_name = self._extract_db_name_from_connection(db_connection)
  52. self.enable_sql_validation = enable_sql_validation
  53. self.enable_llm_repair = enable_llm_repair
  54. self.modify_original_file = modify_original_file
  55. self.enable_training_data_load = enable_training_data_load
  56. # 处理vector表管理参数
  57. # 参数验证和自动启用逻辑:如果启用truncate,自动启用backup
  58. if truncate_vector_tables:
  59. backup_vector_tables = True
  60. self.backup_vector_tables = backup_vector_tables
  61. self.truncate_vector_tables = truncate_vector_tables
  62. self.skip_training = skip_training
  63. # 处理task_id
  64. if task_id is None:
  65. # 脚本模式:自动生成manual开头的task_id
  66. self.task_id = f"manual_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
  67. else:
  68. # API模式:使用传递的task_id
  69. self.task_id = task_id
  70. # 设置输出目录
  71. if output_dir is None:
  72. # 脚本模式或未指定输出目录时,使用默认基础目录
  73. # 获取项目根目录的绝对路径
  74. project_root = Path(__file__).parent.parent
  75. base_dir = project_root / "data_pipeline" / "training_data"
  76. # 在基础目录下创建task子目录
  77. self.output_dir = base_dir / self.task_id
  78. else:
  79. # 用户指定了输出目录时,检查是否为API模式
  80. output_path = Path(output_dir)
  81. # API模式判断:如果output_dir路径已经包含task_id,则直接使用,不再创建子目录
  82. if self.task_id in str(output_path):
  83. # API模式:直接使用传入的目录,这个目录已经是task专用目录
  84. self.output_dir = output_path
  85. else:
  86. # 脚本模式:在指定目录下创建task子目录
  87. self.output_dir = output_path / self.task_id
  88. # 确保输出目录存在
  89. self.output_dir.mkdir(parents=True, exist_ok=True)
  90. # 初始化独立日志系统
  91. self.logger = get_logger("SchemaWorkflowOrchestrator", self.task_id)
  92. # 记录Vector表管理参数状态
  93. if self.truncate_vector_tables and truncate_vector_tables != backup_vector_tables:
  94. self.logger.info("🔄 启用truncate时自动启用backup")
  95. if self.backup_vector_tables or self.truncate_vector_tables:
  96. self.logger.info(f"🗂️ Vector表管理参数: backup={self.backup_vector_tables}, truncate={self.truncate_vector_tables}")
  97. # 工作流程状态
  98. self.workflow_state = {
  99. "start_time": None,
  100. "end_time": None,
  101. "current_step": None,
  102. "completed_steps": [],
  103. "failed_steps": [],
  104. "artifacts": {}, # 存储各步骤产生的文件
  105. "statistics": {}
  106. }
  107. def _extract_db_name_from_connection(self, connection_string: str) -> str:
  108. """
  109. 从数据库连接字符串中提取数据库名称
  110. Args:
  111. connection_string: PostgreSQL连接字符串
  112. Returns:
  113. str: 数据库名称
  114. """
  115. try:
  116. # 处理标准的PostgreSQL连接字符串: postgresql://user:pass@host:port/dbname
  117. if '/' in connection_string:
  118. # 取最后一个 '/' 后面的部分作为数据库名
  119. db_name = connection_string.split('/')[-1]
  120. # 移除可能的查询参数
  121. if '?' in db_name:
  122. db_name = db_name.split('?')[0]
  123. return db_name if db_name else "database"
  124. else:
  125. return "database"
  126. except Exception:
  127. return "database"
  128. async def execute_complete_workflow(self) -> Dict[str, Any]:
  129. """
  130. 执行完整的Schema处理工作流程
  131. Returns:
  132. 完整的工作流程报告
  133. """
  134. self.workflow_state["start_time"] = time.time()
  135. self.logger.info("🚀 开始执行Schema工作流编排")
  136. self.logger.info(f"📁 输出目录: {self.output_dir}")
  137. self.logger.info(f"🏢 业务背景: {self.business_context}")
  138. self.logger.info(f"💾 数据库: {self.db_name}")
  139. try:
  140. # 步骤1: 生成DDL和MD文件
  141. await self._execute_step_1_ddl_md_generation()
  142. # 步骤2: 生成Question-SQL对
  143. await self._execute_step_2_question_sql_generation()
  144. # 步骤3: 验证和修正SQL(可选)
  145. if self.enable_sql_validation:
  146. await self._execute_step_3_sql_validation()
  147. else:
  148. self.logger.info("⏭️ 跳过SQL验证步骤")
  149. # 步骤4: 训练数据加载(可选)
  150. if self.enable_training_data_load:
  151. await self._execute_step_4_training_data_load()
  152. else:
  153. self.logger.info("⏭️ 跳过训练数据加载步骤")
  154. # 设置结束时间
  155. self.workflow_state["end_time"] = time.time()
  156. # 生成最终报告
  157. final_report = await self._generate_final_report()
  158. self.logger.info("✅ Schema工作流编排完成")
  159. return final_report
  160. except Exception as e:
  161. self.workflow_state["end_time"] = time.time()
  162. self.logger.exception(f"❌ 工作流程执行失败: {str(e)}")
  163. error_report = await self._generate_error_report(e)
  164. return error_report
  165. async def _execute_step_1_ddl_md_generation(self):
  166. """步骤1: 生成DDL和MD文件"""
  167. self.workflow_state["current_step"] = "ddl_md_generation"
  168. self.logger.info("=" * 60)
  169. self.logger.info("📝 步骤1: 开始生成DDL和MD文件")
  170. self.logger.info("=" * 60)
  171. step_start_time = time.time()
  172. try:
  173. # 创建DDL/MD生成Agent
  174. ddl_md_agent = SchemaTrainingDataAgent(
  175. db_connection=self.db_connection,
  176. table_list_file=self.table_list_file,
  177. business_context=self.business_context,
  178. output_dir=str(self.output_dir),
  179. task_id=self.task_id, # 传递task_id
  180. pipeline="full"
  181. )
  182. # 执行DDL/MD生成
  183. ddl_md_result = await ddl_md_agent.generate_training_data()
  184. step_duration = time.time() - step_start_time
  185. # 记录结果
  186. self.workflow_state["completed_steps"].append("ddl_md_generation")
  187. self.workflow_state["artifacts"]["ddl_md_generation"] = {
  188. "total_tables": ddl_md_result.get("summary", {}).get("total_tables", 0),
  189. "processed_successfully": ddl_md_result.get("summary", {}).get("processed_successfully", 0),
  190. "failed": ddl_md_result.get("summary", {}).get("failed", 0),
  191. "files_generated": ddl_md_result.get("statistics", {}).get("total_files_generated", 0),
  192. "duration": step_duration
  193. }
  194. self.workflow_state["statistics"]["step1_duration"] = step_duration
  195. processed_tables = ddl_md_result.get("summary", {}).get("processed_successfully", 0)
  196. # 获取文件统计信息
  197. statistics = ddl_md_result.get("statistics", {})
  198. md_files = statistics.get("md_files_generated", 0)
  199. ddl_files = statistics.get("ddl_files_generated", 0)
  200. if md_files > 0 and ddl_files > 0:
  201. file_info = f"生成 {md_files} 个MD文件,{ddl_files} 个DDL文件"
  202. elif md_files > 0:
  203. file_info = f"生成 {md_files} 个MD文件"
  204. elif ddl_files > 0:
  205. file_info = f"生成 {ddl_files} 个DDL文件"
  206. else:
  207. file_info = "未生成文件"
  208. self.logger.info(f"✅ 步骤1完成: 成功处理 {processed_tables} 个表,{file_info},耗时 {step_duration:.2f}秒")
  209. except Exception as e:
  210. self.workflow_state["failed_steps"].append("ddl_md_generation")
  211. self.logger.error(f"❌ 步骤1失败: {str(e)}")
  212. raise
  213. async def _execute_step_2_question_sql_generation(self):
  214. """步骤2: 生成Question-SQL对"""
  215. self.workflow_state["current_step"] = "question_sql_generation"
  216. self.logger.info("=" * 60)
  217. self.logger.info("🤖 步骤2: 开始生成Question-SQL对")
  218. self.logger.info("=" * 60)
  219. step_start_time = time.time()
  220. try:
  221. # 创建Question-SQL生成Agent
  222. qs_agent = QuestionSQLGenerationAgent(
  223. output_dir=str(self.output_dir),
  224. table_list_file=self.table_list_file,
  225. business_context=self.business_context,
  226. db_name=self.db_name,
  227. task_id=self.task_id # 传递task_id
  228. )
  229. # 执行Question-SQL生成
  230. qs_result = await qs_agent.generate()
  231. step_duration = time.time() - step_start_time
  232. # 记录结果
  233. self.workflow_state["completed_steps"].append("question_sql_generation")
  234. self.workflow_state["artifacts"]["question_sql_generation"] = {
  235. "output_file": str(qs_result.get("output_file", "")),
  236. "total_questions": qs_result.get("total_questions", 0),
  237. "total_themes": qs_result.get("total_themes", 0),
  238. "successful_themes": qs_result.get("successful_themes", 0),
  239. "failed_themes": qs_result.get("failed_themes", []),
  240. "duration": step_duration
  241. }
  242. self.workflow_state["statistics"]["step2_duration"] = step_duration
  243. total_questions = qs_result.get("total_questions", 0)
  244. self.logger.info(f"✅ 步骤2完成: 生成了 {total_questions} 个问答对,耗时 {step_duration:.2f}秒")
  245. except Exception as e:
  246. self.workflow_state["failed_steps"].append("question_sql_generation")
  247. self.logger.error(f"❌ 步骤2失败: {str(e)}")
  248. raise
  249. async def _execute_step_3_sql_validation(self):
  250. """步骤3: 验证和修正SQL"""
  251. self.workflow_state["current_step"] = "sql_validation"
  252. self.logger.info("=" * 60)
  253. self.logger.info("🔍 步骤3: 开始验证和修正SQL")
  254. self.logger.info("=" * 60)
  255. step_start_time = time.time()
  256. try:
  257. # 首先尝试从workflow_state获取文件(完整工作流模式)
  258. qs_artifacts = self.workflow_state["artifacts"].get("question_sql_generation", {})
  259. qs_file = qs_artifacts.get("output_file")
  260. # 如果workflow_state中没有文件信息,则在任务目录中查找(分步执行模式)
  261. if not qs_file or not Path(qs_file).exists():
  262. self.logger.info("🔍 从workflow_state未找到文件,在任务目录中查找Question-SQL文件...")
  263. # 在输出目录中查找匹配的文件
  264. possible_files = list(self.output_dir.glob("*_pair.json"))
  265. if not possible_files:
  266. raise FileNotFoundError(
  267. f"在任务目录 {self.output_dir} 中找不到Question-SQL文件(*_pair.json)。"
  268. f"请确保已执行qa_generation步骤并生成了Question-SQL对文件。"
  269. )
  270. # 选择最新的文件(按修改时间排序)
  271. qs_file = str(max(possible_files, key=lambda f: f.stat().st_mtime))
  272. self.logger.info(f"🎯 找到Question-SQL文件: {qs_file}")
  273. self.logger.info(f"📄 验证文件: {qs_file}")
  274. # 创建SQL验证Agent,通过参数传递配置而非修改全局配置
  275. sql_validator = SQLValidationAgent(
  276. db_connection=self.db_connection,
  277. input_file=str(qs_file),
  278. output_dir=str(self.output_dir),
  279. task_id=self.task_id, # 传递task_id
  280. enable_sql_repair=self.enable_llm_repair,
  281. modify_original_file=self.modify_original_file
  282. )
  283. # 执行SQL验证和修正
  284. validation_result = await sql_validator.validate()
  285. step_duration = time.time() - step_start_time
  286. # 记录结果
  287. self.workflow_state["completed_steps"].append("sql_validation")
  288. summary = validation_result.get("summary", {})
  289. self.workflow_state["artifacts"]["sql_validation"] = {
  290. "original_sql_count": summary.get("total_questions", 0),
  291. "valid_sql_count": summary.get("valid_sqls", 0),
  292. "invalid_sql_count": summary.get("invalid_sqls", 0),
  293. "success_rate": summary.get("success_rate", 0),
  294. "repair_stats": summary.get("repair_stats", {}),
  295. "file_modification_stats": summary.get("file_modification_stats", {}),
  296. "average_execution_time": summary.get("average_execution_time", 0),
  297. "total_retries": summary.get("total_retries", 0),
  298. "duration": step_duration
  299. }
  300. self.workflow_state["statistics"]["step3_duration"] = step_duration
  301. success_rate = summary.get("success_rate", 0)
  302. valid_count = summary.get("valid_sqls", 0)
  303. total_count = summary.get("total_questions", 0)
  304. self.logger.info(f"✅ 步骤3完成: SQL验证成功率 {success_rate:.1%} ({valid_count}/{total_count}),耗时 {step_duration:.2f}秒")
  305. # 显示修复统计
  306. repair_stats = summary.get("repair_stats", {})
  307. if repair_stats.get("attempted", 0) > 0:
  308. self.logger.info(f"🔧 修复统计: 尝试 {repair_stats['attempted']},成功 {repair_stats['successful']},失败 {repair_stats['failed']}")
  309. # 显示文件修改统计
  310. file_stats = summary.get("file_modification_stats", {})
  311. if file_stats.get("modified", 0) > 0 or file_stats.get("deleted", 0) > 0:
  312. self.logger.info(f"📝 文件修改: 更新 {file_stats.get('modified', 0)} 个SQL,删除 {file_stats.get('deleted', 0)} 个无效项")
  313. except Exception as e:
  314. self.workflow_state["failed_steps"].append("sql_validation")
  315. self.logger.error(f"❌ 步骤3失败: {str(e)}")
  316. raise
  317. async def _execute_vector_table_management(self):
  318. """独立执行Vector表管理"""
  319. if not (self.backup_vector_tables or self.truncate_vector_tables):
  320. return
  321. self.logger.info("=" * 60)
  322. self.logger.info("🗂️ 开始执行Vector表管理")
  323. self.logger.info("=" * 60)
  324. vector_stats = None
  325. try:
  326. from data_pipeline.trainer.vector_table_manager import VectorTableManager
  327. vector_manager = VectorTableManager(
  328. task_output_dir=str(self.output_dir),
  329. task_id=self.task_id
  330. )
  331. # 执行vector表管理
  332. vector_stats = vector_manager.execute_vector_management(
  333. backup=self.backup_vector_tables,
  334. truncate=self.truncate_vector_tables
  335. )
  336. # 记录结果到工作流状态(无论成功失败都记录)
  337. self.workflow_state["artifacts"]["vector_management"] = vector_stats
  338. if vector_stats.get("errors"):
  339. self.logger.warning(f"⚠️ Vector表管理完成,但有错误: {'; '.join(vector_stats['errors'])}")
  340. else:
  341. self.logger.info("✅ Vector表管理完成")
  342. except Exception as e:
  343. self.logger.error(f"❌ Vector表管理失败: {e}")
  344. # 即使异常也要记录基本状态
  345. if vector_stats is None:
  346. vector_stats = {
  347. "backup_performed": self.backup_vector_tables,
  348. "truncate_performed": False,
  349. "errors": [f"执行异常: {str(e)}"],
  350. "duration": 0
  351. }
  352. self.workflow_state["artifacts"]["vector_management"] = vector_stats
  353. raise
  354. async def _execute_step_4_training_data_load(self):
  355. """步骤4: 训练数据加载"""
  356. self.workflow_state["current_step"] = "training_data_load"
  357. self.logger.info("=" * 60)
  358. self.logger.info("🎯 步骤4: 开始加载训练数据")
  359. self.logger.info("=" * 60)
  360. step_start_time = time.time()
  361. try:
  362. # 确保输出目录存在所需的训练数据
  363. training_data_dir = str(self.output_dir)
  364. self.logger.info(f"📁 训练数据目录: {training_data_dir}")
  365. # 导入训练器模块
  366. import sys
  367. import os
  368. sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  369. from data_pipeline.trainer.run_training import process_training_files
  370. # 执行训练数据加载
  371. self.logger.info("🔄 开始处理训练文件...")
  372. # 传递Vector表管理参数到training步骤
  373. load_successful, vector_stats = process_training_files(training_data_dir, self.task_id,
  374. backup_vector_tables=self.backup_vector_tables,
  375. truncate_vector_tables=self.truncate_vector_tables,
  376. skip_training=self.skip_training)
  377. step_duration = time.time() - step_start_time
  378. # 记录Vector表管理结果到工作流状态
  379. if vector_stats:
  380. if "artifacts" not in self.workflow_state:
  381. self.workflow_state["artifacts"] = {}
  382. self.workflow_state["artifacts"]["vector_management"] = vector_stats
  383. if load_successful:
  384. # 获取统计信息
  385. from data_pipeline.trainer.vanna_trainer import flush_training, shutdown_trainer
  386. # 刷新批处理器
  387. self.logger.info("🔄 刷新批处理器...")
  388. flush_training()
  389. shutdown_trainer()
  390. # 验证加载结果
  391. try:
  392. from core.vanna_llm_factory import create_vanna_instance
  393. vn = create_vanna_instance()
  394. training_data = vn.get_training_data()
  395. if training_data is not None and not training_data.empty:
  396. total_records = len(training_data)
  397. self.logger.info(f"✅ 成功加载 {total_records} 条训练数据")
  398. # 统计数据类型
  399. if 'training_data_type' in training_data.columns:
  400. type_counts = training_data['training_data_type'].value_counts().to_dict()
  401. else:
  402. type_counts = {}
  403. else:
  404. total_records = 0
  405. type_counts = {}
  406. self.logger.warning("⚠️ 未能验证训练数据加载结果")
  407. except Exception as e:
  408. self.logger.warning(f"⚠️ 验证训练数据时出错: {e}")
  409. total_records = 0
  410. type_counts = {}
  411. # 记录结果
  412. self.workflow_state["completed_steps"].append("training_data_load")
  413. self.workflow_state["artifacts"]["training_data_load"] = {
  414. "training_data_dir": training_data_dir,
  415. "load_successful": True,
  416. "total_records": total_records,
  417. "data_type_counts": type_counts,
  418. "duration": step_duration
  419. }
  420. self.workflow_state["statistics"]["step4_duration"] = step_duration
  421. self.logger.info(f"✅ 步骤4完成: 成功加载训练数据,耗时 {step_duration:.2f}秒")
  422. else:
  423. raise Exception("训练数据加载失败:未找到可处理的训练文件")
  424. except Exception as e:
  425. self.workflow_state["failed_steps"].append("training_data_load")
  426. self.logger.error(f"❌ 步骤4失败: {str(e)}")
  427. raise
  428. async def _generate_final_report(self) -> Dict[str, Any]:
  429. """生成最终工作流程报告"""
  430. total_duration = self.workflow_state["end_time"] - self.workflow_state["start_time"]
  431. # 计算最终输出文件
  432. qs_artifacts = self.workflow_state["artifacts"].get("question_sql_generation", {})
  433. final_output_file = qs_artifacts.get("output_file", "")
  434. # 计算最终问题数量
  435. if "sql_validation" in self.workflow_state["artifacts"]:
  436. # 如果有验证步骤,使用验证后的数量
  437. validation_artifacts = self.workflow_state["artifacts"]["sql_validation"]
  438. final_question_count = validation_artifacts.get("valid_sql_count", 0)
  439. else:
  440. # 否则使用生成的数量
  441. final_question_count = qs_artifacts.get("total_questions", 0)
  442. report = {
  443. "success": True,
  444. "workflow_summary": {
  445. "total_duration": round(total_duration, 2),
  446. "completed_steps": self.workflow_state["completed_steps"],
  447. "failed_steps": self.workflow_state["failed_steps"],
  448. "total_steps": len(self.workflow_state["completed_steps"]),
  449. "workflow_started": datetime.fromtimestamp(self.workflow_state["start_time"]).isoformat(),
  450. "workflow_completed": datetime.fromtimestamp(self.workflow_state["end_time"]).isoformat()
  451. },
  452. "input_parameters": {
  453. "db_connection": self._mask_connection_string(self.db_connection),
  454. "table_list_file": self.table_list_file,
  455. "business_context": self.business_context,
  456. "db_name": self.db_name,
  457. "output_directory": str(self.output_dir),
  458. "enable_sql_validation": self.enable_sql_validation,
  459. "enable_llm_repair": self.enable_llm_repair,
  460. "modify_original_file": self.modify_original_file,
  461. "enable_training_data_load": self.enable_training_data_load
  462. },
  463. "processing_results": {
  464. "ddl_md_generation": self.workflow_state["artifacts"].get("ddl_md_generation", {}),
  465. "question_sql_generation": self.workflow_state["artifacts"].get("question_sql_generation", {}),
  466. "sql_validation": self.workflow_state["artifacts"].get("sql_validation", {}),
  467. "training_data_load": self.workflow_state["artifacts"].get("training_data_load", {})
  468. },
  469. "final_outputs": {
  470. "primary_output_file": final_output_file,
  471. "output_directory": str(self.output_dir),
  472. "final_question_count": final_question_count,
  473. "backup_files_created": self.modify_original_file
  474. },
  475. "performance_metrics": {
  476. "step1_duration": round(self.workflow_state["statistics"].get("step1_duration", 0), 2),
  477. "step2_duration": round(self.workflow_state["statistics"].get("step2_duration", 0), 2),
  478. "step3_duration": round(self.workflow_state["statistics"].get("step3_duration", 0), 2),
  479. "step4_duration": round(self.workflow_state["statistics"].get("step4_duration", 0), 2),
  480. "total_duration": round(total_duration, 2)
  481. }
  482. }
  483. return report
  484. async def _generate_error_report(self, error: Exception) -> Dict[str, Any]:
  485. """生成错误报告"""
  486. total_duration = self.workflow_state["end_time"] - self.workflow_state["start_time"]
  487. return {
  488. "success": False,
  489. "error": {
  490. "message": str(error),
  491. "type": type(error).__name__,
  492. "failed_step": self.workflow_state["current_step"]
  493. },
  494. "workflow_summary": {
  495. "total_duration": round(total_duration, 2),
  496. "completed_steps": self.workflow_state["completed_steps"],
  497. "failed_steps": self.workflow_state["failed_steps"],
  498. "workflow_started": datetime.fromtimestamp(self.workflow_state["start_time"]).isoformat() if self.workflow_state["start_time"] else None,
  499. "workflow_failed": datetime.fromtimestamp(self.workflow_state["end_time"]).isoformat() if self.workflow_state["end_time"] else None
  500. },
  501. "partial_results": self.workflow_state["artifacts"],
  502. "input_parameters": {
  503. "db_connection": self._mask_connection_string(self.db_connection),
  504. "table_list_file": self.table_list_file,
  505. "business_context": self.business_context,
  506. "db_name": self.db_name,
  507. "output_directory": str(self.output_dir)
  508. }
  509. }
  510. def _mask_connection_string(self, conn_str: str) -> str:
  511. """隐藏连接字符串中的敏感信息"""
  512. import re
  513. return re.sub(r':[^:@]+@', ':***@', conn_str)
  514. def print_final_summary(self, report: Dict[str, Any]):
  515. """打印最终摘要"""
  516. self.logger.info("=" * 80)
  517. self.logger.info("📊 工作流程执行摘要")
  518. self.logger.info("=" * 80)
  519. if report["success"]:
  520. summary = report["workflow_summary"]
  521. results = report["processing_results"]
  522. outputs = report["final_outputs"]
  523. metrics = report["performance_metrics"]
  524. self.logger.info(f"✅ 工作流程执行成功")
  525. self.logger.info(f"⏱️ 总耗时: {summary['total_duration']} 秒")
  526. self.logger.info(f"📝 完成步骤: {len(summary['completed_steps'])}/{summary['total_steps']}")
  527. # 获取并显示embedding模型信息
  528. try:
  529. from common.utils import get_current_model_info
  530. model_info = get_current_model_info()
  531. self.logger.info(f"🤖 使用的embedding模型: {model_info['embedding_model']} ({model_info['embedding_type']})")
  532. except Exception as e:
  533. self.logger.info(f"🤖 使用的embedding模型: 未知 (获取信息失败: {e})")
  534. # 解析并显示源库信息
  535. try:
  536. db_info = self._parse_db_connection(self.db_connection)
  537. self.logger.info(f"🗄️ 源库名: {db_info['dbname']}")
  538. self.logger.info(f"🏠 源库Hostname: {db_info['host']}:{db_info['port']}")
  539. except Exception as e:
  540. self.logger.info(f"🗄️ 源库名: {self.db_name}")
  541. self.logger.info(f"🏠 源库Hostname: 未知 (解析失败: {e})")
  542. # DDL/MD生成结果 - 增加详细的文件统计
  543. if "ddl_md_generation" in results:
  544. ddl_md = results["ddl_md_generation"]
  545. self.logger.info(f"📋 DDL/MD生成: {ddl_md.get('processed_successfully', 0)} 个表成功处理")
  546. # 尝试获取详细的文件统计信息
  547. try:
  548. # 从输出目录统计实际生成的文件
  549. output_path = Path(self.output_dir)
  550. if output_path.exists():
  551. md_files = list(output_path.glob("*.md"))
  552. ddl_files = list(output_path.glob("*.ddl"))
  553. md_count = len([f for f in md_files if not f.name.startswith('metadata')]) # 排除metadata.md
  554. ddl_count = len(ddl_files)
  555. self.logger.info(f"📁 生成文件: {md_count} 个MD文件,{ddl_count} 个DDL文件")
  556. else:
  557. self.logger.info(f"📁 生成文件: 统计信息不可用")
  558. except Exception as e:
  559. self.logger.info(f"📁 生成文件: 统计失败 ({e})")
  560. # Question-SQL生成结果
  561. if "question_sql_generation" in results:
  562. qs = results["question_sql_generation"]
  563. self.logger.info(f"🤖 Question-SQL生成: {qs.get('total_questions', 0)} 个问答对")
  564. # SQL验证结果
  565. if "sql_validation" in results:
  566. validation = results["sql_validation"]
  567. success_rate = validation.get('success_rate', 0)
  568. self.logger.info(f"🔍 SQL验证: {success_rate:.1%} 成功率 ({validation.get('valid_sql_count', 0)}/{validation.get('original_sql_count', 0)})")
  569. self.logger.info(f"📁 输出目录: {outputs['output_directory']}")
  570. self.logger.info(f"📄 QUESTION/SQL键值对文件: {outputs['primary_output_file']}")
  571. self.logger.info(f"❓ 最终问题数量: {outputs['final_question_count']}")
  572. # 配置参数反馈
  573. self.logger.info("⚙️ 执行配置:")
  574. self.logger.info(f" 🔍 SQL验证: {'启用' if self.enable_sql_validation else '禁用'}")
  575. self.logger.info(f" 🔧 LLM修复: {'启用' if self.enable_llm_repair else '禁用'}")
  576. self.logger.info(f" 📝 文件修改: {'启用' if self.modify_original_file else '禁用'}")
  577. if not self.enable_training_data_load:
  578. self.logger.info(f" ⏭️ 训练数据加载: 已跳过")
  579. else:
  580. self.logger.info(f" 📚 训练数据加载: 启用")
  581. # Vector表管理总结
  582. vector_stats = report.get("workflow_state", {}).get("artifacts", {}).get("vector_management")
  583. if vector_stats:
  584. self.logger.info("📊 Vector表管理:")
  585. if vector_stats.get("backup_performed", False):
  586. tables_count = len(vector_stats.get("tables_backed_up", {}))
  587. total_size = sum(
  588. self._parse_file_size(info.get("file_size", "0 B"))
  589. for info in vector_stats.get("tables_backed_up", {}).values()
  590. if info.get("success", False)
  591. )
  592. self.logger.info(f" ✅ 备份执行: {tables_count}个表,总大小: {self._format_size(total_size)}")
  593. else:
  594. self.logger.info(" - 备份执行: 未执行")
  595. if vector_stats.get("truncate_performed", False):
  596. self.logger.info(" ✅ 清空执行: langchain_pg_embedding表已清空")
  597. else:
  598. self.logger.info(" - 清空执行: 未执行")
  599. duration = vector_stats.get("duration", 0)
  600. self.logger.info(f" ⏱️ 执行耗时: {duration:.1f}秒")
  601. else:
  602. self.logger.info("📊 Vector表管理: 未执行(未启用相关参数)")
  603. else:
  604. error = report["error"]
  605. summary = report["workflow_summary"]
  606. self.logger.error(f"❌ 工作流程执行失败")
  607. self.logger.error(f"💥 失败原因: {error['message']}")
  608. self.logger.error(f"💥 失败步骤: {error['failed_step']}")
  609. self.logger.error(f"⏱️ 执行耗时: {summary['total_duration']} 秒")
  610. self.logger.error(f"✅ 已完成步骤: {', '.join(summary['completed_steps']) if summary['completed_steps'] else '无'}")
  611. self.logger.info("=" * 80)
  612. def _parse_file_size(self, size_str: str) -> float:
  613. """解析文件大小字符串为字节数"""
  614. import re
  615. # 匹配数字和单位的正则表达式
  616. match = re.match(r'(\d+\.?\d*)\s*([KMGT]?B)', size_str.upper())
  617. if not match:
  618. return 0.0
  619. size, unit = match.groups()
  620. size = float(size)
  621. unit_multipliers = {
  622. 'B': 1,
  623. 'KB': 1024,
  624. 'MB': 1024**2,
  625. 'GB': 1024**3,
  626. 'TB': 1024**4
  627. }
  628. return size * unit_multipliers.get(unit, 1)
  629. def _format_size(self, size_bytes: float) -> str:
  630. """格式化字节数为可读的大小字符串"""
  631. if size_bytes == 0:
  632. return "0 B"
  633. size_names = ["B", "KB", "MB", "GB"]
  634. i = 0
  635. size = float(size_bytes)
  636. while size >= 1024.0 and i < len(size_names) - 1:
  637. size /= 1024.0
  638. i += 1
  639. return f"{size:.1f} {size_names[i]}"
  640. def _parse_db_connection(self, db_connection: str) -> Dict[str, str]:
  641. """
  642. 解析PostgreSQL连接字符串
  643. Args:
  644. db_connection: PostgreSQL连接字符串,格式为 postgresql://user:password@host:port/dbname
  645. Returns:
  646. 包含数据库连接参数的字典
  647. """
  648. import re
  649. # 解析连接字符串的正则表达式
  650. pattern = r'postgresql://([^:]+):([^@]+)@([^:]+):(\d+)/(.+)'
  651. match = re.match(pattern, db_connection)
  652. if not match:
  653. raise ValueError(f"无效的PostgreSQL连接字符串格式: {db_connection}")
  654. user, password, host, port, dbname = match.groups()
  655. return {
  656. 'user': user,
  657. 'password': password,
  658. 'host': host,
  659. 'port': port,
  660. 'dbname': dbname
  661. }
  662. # 便捷的命令行接口
  663. def setup_argument_parser():
  664. """设置命令行参数解析器"""
  665. import argparse
  666. parser = argparse.ArgumentParser(
  667. description="Schema工作流编排器 - 端到端的Schema处理流程",
  668. formatter_class=argparse.RawDescriptionHelpFormatter,
  669. epilog="""
  670. 示例用法:
  671. # 完整工作流程(会在指定目录下创建任务子目录)
  672. python -m data_pipeline.schema_workflow \\
  673. --db-connection "postgresql://user:pass@localhost:5432/highway_db" \\
  674. --table-list tables.txt \\
  675. --business-context "高速公路服务区管理系统" \\
  676. --output-dir ./data_pipeline/training_data/
  677. # 跳过SQL验证
  678. python -m data_pipeline.schema_workflow \\
  679. --db-connection "postgresql://user:pass@localhost:5432/ecommerce_db" \\
  680. --table-list tables.txt \\
  681. --business-context "电商系统" \\
  682. --skip-validation
  683. # 禁用LLM修复
  684. python -m data_pipeline.schema_workflow \\
  685. --db-connection "postgresql://user:pass@localhost:5432/management_db" \\
  686. --table-list tables.txt \\
  687. --business-context "管理系统" \\
  688. --disable-llm-repair
  689. # 跳过训练数据加载
  690. python -m data_pipeline.schema_workflow \\
  691. --db-connection "postgresql://user:pass@localhost:5432/management_db" \\
  692. --table-list tables.txt \\
  693. --business-context "管理系统" \\
  694. --skip-training-load
  695. """
  696. )
  697. # 必需参数
  698. parser.add_argument(
  699. "--db-connection",
  700. required=True,
  701. help="数据库连接字符串 (postgresql://user:pass@host:port/dbname)"
  702. )
  703. parser.add_argument(
  704. "--table-list",
  705. required=True,
  706. help="表清单文件路径"
  707. )
  708. parser.add_argument(
  709. "--business-context",
  710. required=True,
  711. help="业务上下文描述"
  712. )
  713. # 可选参数
  714. parser.add_argument(
  715. "--output-dir",
  716. default=None,
  717. help="基础输出目录,将在此目录下创建任务子目录(默认:./data_pipeline/training_data/)"
  718. )
  719. parser.add_argument(
  720. "--skip-validation",
  721. action="store_true",
  722. help="跳过SQL验证步骤"
  723. )
  724. parser.add_argument(
  725. "--disable-llm-repair",
  726. action="store_true",
  727. help="禁用LLM修复功能"
  728. )
  729. parser.add_argument(
  730. "--no-modify-file",
  731. action="store_true",
  732. help="不修改原始JSON文件(仅生成报告)"
  733. )
  734. parser.add_argument(
  735. "--backup-vector-tables",
  736. action="store_true",
  737. help="备份vector表数据到任务目录"
  738. )
  739. parser.add_argument(
  740. "--truncate-vector-tables",
  741. action="store_true",
  742. help="清空vector表数据(自动启用备份)"
  743. )
  744. parser.add_argument(
  745. "--skip-training",
  746. action="store_true",
  747. help="跳过训练文件处理,仅执行Vector表管理"
  748. )
  749. parser.add_argument(
  750. "--verbose", "-v",
  751. action="store_true",
  752. help="启用详细日志输出"
  753. )
  754. parser.add_argument(
  755. "--log-file",
  756. help="日志文件路径"
  757. )
  758. return parser
  759. async def main():
  760. """命令行入口点"""
  761. import sys
  762. import os
  763. parser = setup_argument_parser()
  764. args = parser.parse_args()
  765. # 设置日志
  766. setup_logging(
  767. verbose=args.verbose,
  768. log_file=args.log_file
  769. )
  770. # 验证输入文件
  771. if not os.path.exists(args.table_list):
  772. # 为脚本模式生成task_id
  773. from datetime import datetime
  774. script_task_id = f"manual_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
  775. # 使用独立日志系统
  776. from data_pipeline.dp_logging import get_logger
  777. logger = get_logger("SchemaWorkflow", script_task_id)
  778. logger.error(f"错误: 表清单文件不存在: {args.table_list}")
  779. sys.exit(1)
  780. try:
  781. # 创建并执行工作流编排器
  782. orchestrator = SchemaWorkflowOrchestrator(
  783. db_connection=args.db_connection,
  784. table_list_file=args.table_list,
  785. business_context=args.business_context,
  786. output_dir=args.output_dir,
  787. enable_sql_validation=not args.skip_validation,
  788. enable_llm_repair=not args.disable_llm_repair,
  789. modify_original_file=not args.no_modify_file,
  790. enable_training_data_load=True,
  791. backup_vector_tables=args.backup_vector_tables,
  792. truncate_vector_tables=args.truncate_vector_tables,
  793. skip_training=args.skip_training
  794. )
  795. # 获取logger用于启动信息
  796. # 为脚本模式生成task_id
  797. from datetime import datetime
  798. script_task_id = f"manual_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
  799. # 使用独立日志系统
  800. from data_pipeline.dp_logging import get_logger
  801. logger = get_logger("SchemaWorkflow", script_task_id)
  802. logger.info(f"🚀 开始执行Schema工作流编排...")
  803. logger.info(f"📁 输出目录: {orchestrator.output_dir}")
  804. logger.info(f"📋 表清单: {args.table_list}")
  805. logger.info(f"🏢 业务背景: {args.business_context}")
  806. logger.info(f"💾 数据库: {orchestrator.db_name}")
  807. logger.info(f"🔍 SQL验证: {'启用' if not args.skip_validation else '禁用'}")
  808. logger.info(f"🔧 LLM修复: {'启用' if not args.disable_llm_repair else '禁用'}")
  809. logger.info(f"🎯 训练数据加载: {'启用' if not args.skip_training else '禁用'}")
  810. # 执行完整工作流程
  811. report = await orchestrator.execute_complete_workflow()
  812. # 打印详细摘要
  813. orchestrator.print_final_summary(report)
  814. # 输出结果并设置退出码
  815. if report["success"]:
  816. if report["processing_results"].get("sql_validation", {}).get("success_rate", 1.0) >= 0.8:
  817. logger.info(f"\n🎉 工作流程执行成功!")
  818. exit_code = 0 # 完全成功
  819. else:
  820. logger.warning(f"\n⚠️ 工作流程执行完成,但SQL验证成功率较低")
  821. exit_code = 1 # 部分成功
  822. else:
  823. logger.error(f"\n❌ 工作流程执行失败")
  824. exit_code = 2 # 失败
  825. sys.exit(exit_code)
  826. except KeyboardInterrupt:
  827. logger.info("\n\n⏹️ 用户中断,程序退出")
  828. sys.exit(130)
  829. except Exception as e:
  830. logger.error(f"\n❌ 程序执行失败: {e}")
  831. if args.verbose:
  832. import traceback
  833. traceback.print_exc()
  834. sys.exit(1)
  835. if __name__ == "__main__":
  836. asyncio.run(main())