training_data_agent.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. import asyncio
  2. import time
  3. import os
  4. from typing import List, Dict, Any, Optional
  5. from pathlib import Path
  6. from data_pipeline.tools.base import ToolRegistry, PipelineExecutor
  7. from data_pipeline.utils.data_structures import TableMetadata, TableProcessingContext, ProcessingResult
  8. from data_pipeline.utils.file_manager import FileNameManager
  9. from data_pipeline.utils.system_filter import SystemTableFilter
  10. from data_pipeline.utils.permission_checker import DatabasePermissionChecker
  11. from data_pipeline.utils.table_parser import TableListParser
  12. from data_pipeline.config import SCHEMA_TOOLS_CONFIG
  13. from data_pipeline.dp_logging import get_logger
  14. class SchemaTrainingDataAgent:
  15. """Schema训练数据生成AI Agent"""
  16. def __init__(self,
  17. db_connection: str,
  18. table_list_file: str,
  19. business_context: str = None,
  20. output_dir: str = None,
  21. task_id: str = None,
  22. pipeline: str = "full"):
  23. self.db_connection = db_connection
  24. self.table_list_file = table_list_file
  25. self.business_context = business_context or "数据库管理系统"
  26. self.pipeline = pipeline
  27. # 配置管理
  28. self.config = SCHEMA_TOOLS_CONFIG
  29. self.output_dir = output_dir or self.config["output_directory"]
  30. # 初始化组件
  31. self.file_manager = FileNameManager(self.output_dir)
  32. self.system_filter = SystemTableFilter()
  33. self.table_parser = TableListParser()
  34. self.pipeline_executor = PipelineExecutor(self.config["available_pipelines"])
  35. # 统计信息
  36. self.stats = {
  37. 'total_tables': 0,
  38. 'processed_tables': 0,
  39. 'failed_tables': 0,
  40. 'skipped_tables': 0,
  41. 'start_time': None,
  42. 'end_time': None
  43. }
  44. self.failed_tables = []
  45. self.task_id = task_id
  46. # 初始化独立日志系统
  47. if task_id:
  48. self.logger = get_logger("SchemaTrainingDataAgent", task_id)
  49. else:
  50. # 脚本模式下,如果没有传递task_id,生成一个
  51. from datetime import datetime
  52. self.task_id = f"manual_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
  53. self.logger = get_logger("SchemaTrainingDataAgent", self.task_id)
  54. async def generate_training_data(self) -> Dict[str, Any]:
  55. """主入口:生成训练数据"""
  56. try:
  57. self.stats['start_time'] = time.time()
  58. self.logger.info("🚀 开始生成Schema训练数据")
  59. # 1. 初始化
  60. await self._initialize()
  61. # 2. 检查数据库权限
  62. await self._check_database_permissions()
  63. # 3. 解析表清单
  64. tables = await self._parse_table_list()
  65. # 4. 过滤系统表
  66. user_tables = self._filter_system_tables(tables)
  67. # 5. 并发处理表
  68. results = await self._process_tables_concurrently(user_tables)
  69. # 6. 设置结束时间
  70. self.stats['end_time'] = time.time()
  71. # 7. 生成总结报告
  72. report = self._generate_summary_report(results)
  73. self.logger.info("✅ Schema训练数据生成完成")
  74. return report
  75. except Exception as e:
  76. self.stats['end_time'] = time.time()
  77. self.logger.exception("❌ Schema训练数据生成失败")
  78. raise
  79. async def _initialize(self):
  80. """初始化Agent"""
  81. # 创建输出目录
  82. os.makedirs(self.output_dir, exist_ok=True)
  83. if self.config["create_subdirectories"]:
  84. os.makedirs(os.path.join(self.output_dir, "ddl"), exist_ok=True)
  85. os.makedirs(os.path.join(self.output_dir, "docs"), exist_ok=True)
  86. # logs目录始终创建
  87. # os.makedirs(os.path.join(self.output_dir, "logs"), exist_ok=True)
  88. # 初始化数据库工具
  89. database_tool = ToolRegistry.get_tool("database_inspector", db_connection=self.db_connection)
  90. await database_tool._create_connection_pool()
  91. self.logger.info(f"初始化完成,输出目录: {self.output_dir}")
  92. async def _check_database_permissions(self):
  93. """检查数据库权限"""
  94. if not self.config["check_permissions"]:
  95. return
  96. inspector = ToolRegistry.get_tool("database_inspector", db_connection=self.db_connection)
  97. # 确保连接池已创建
  98. if not inspector.connection_pool:
  99. await inspector._create_connection_pool()
  100. checker = DatabasePermissionChecker(inspector)
  101. permissions = await checker.check_permissions()
  102. if not permissions['connect']:
  103. raise Exception("无法连接到数据库")
  104. if self.config["require_select_permission"] and not permissions['select_data']:
  105. if not self.config["allow_readonly_database"]:
  106. raise Exception("数据库查询权限不足")
  107. else:
  108. self.logger.warning("数据库为只读或权限受限,部分功能可能受影响")
  109. self.logger.info(f"数据库权限检查完成: {permissions}")
  110. async def _parse_table_list(self) -> List[str]:
  111. """解析表清单文件"""
  112. tables = self.table_parser.parse_file(self.table_list_file)
  113. self.stats['total_tables'] = len(tables)
  114. self.logger.info(f"📋 从清单文件读取到 {len(tables)} 个表")
  115. return tables
  116. def _filter_system_tables(self, tables: List[str]) -> List[str]:
  117. """过滤系统表"""
  118. if not self.config["filter_system_tables"]:
  119. return tables
  120. user_tables = self.system_filter.filter_user_tables(tables)
  121. filtered_count = len(tables) - len(user_tables)
  122. if filtered_count > 0:
  123. self.logger.info(f"🔍 过滤了 {filtered_count} 个系统表,保留 {len(user_tables)} 个用户表")
  124. self.stats['skipped_tables'] += filtered_count
  125. return user_tables
  126. async def _process_tables_concurrently(self, tables: List[str]) -> List[Dict[str, Any]]:
  127. """并发处理表"""
  128. max_concurrent = self.config["max_concurrent_tables"]
  129. semaphore = asyncio.Semaphore(max_concurrent)
  130. self.logger.info(f"🔄 开始并发处理 {len(tables)} 个表 (最大并发: {max_concurrent})")
  131. # 创建任务
  132. tasks = [
  133. self._process_single_table_with_semaphore(semaphore, table_spec)
  134. for table_spec in tables
  135. ]
  136. # 并发执行
  137. results = await asyncio.gather(*tasks, return_exceptions=True)
  138. # 统计结果
  139. successful = sum(1 for r in results if isinstance(r, dict) and r.get('success', False))
  140. failed = len(results) - successful
  141. self.stats['processed_tables'] = successful
  142. self.stats['failed_tables'] = failed
  143. self.logger.info(f"📊 处理完成: 成功 {successful} 个,失败 {failed} 个")
  144. return [r for r in results if isinstance(r, dict)]
  145. async def _process_single_table_with_semaphore(self, semaphore: asyncio.Semaphore, table_spec: str) -> Dict[str, Any]:
  146. """带信号量的单表处理"""
  147. async with semaphore:
  148. return await self._process_single_table(table_spec)
  149. async def _process_single_table(self, table_spec: str) -> Dict[str, Any]:
  150. """处理单个表"""
  151. start_time = time.time()
  152. try:
  153. # 解析表名
  154. if '.' in table_spec:
  155. schema_name, table_name = table_spec.split('.', 1)
  156. else:
  157. schema_name, table_name = 'public', table_spec
  158. full_name = f"{schema_name}.{table_name}"
  159. self.logger.info(f"🔍 开始处理表: {full_name}")
  160. # 创建表元数据
  161. table_metadata = TableMetadata(
  162. schema_name=schema_name,
  163. table_name=table_name,
  164. full_name=full_name
  165. )
  166. # 创建处理上下文
  167. context = TableProcessingContext(
  168. table_metadata=table_metadata,
  169. business_context=self.business_context,
  170. output_dir=self.output_dir,
  171. pipeline=self.pipeline,
  172. vn=None, # 将在工具中注入
  173. file_manager=self.file_manager,
  174. db_connection=self.db_connection, # 添加数据库连接参数
  175. start_time=start_time
  176. )
  177. # 执行处理链
  178. step_results = await self.pipeline_executor.execute_pipeline(self.pipeline, context)
  179. # 计算总体成功状态
  180. success = all(result.success for result in step_results.values())
  181. execution_time = time.time() - start_time
  182. if success:
  183. self.logger.info(f"✅ 表 {full_name} 处理成功,耗时: {execution_time:.2f}秒")
  184. else:
  185. self.logger.error(f"❌ 表 {full_name} 处理失败,耗时: {execution_time:.2f}秒")
  186. self.failed_tables.append(full_name)
  187. return {
  188. 'success': success,
  189. 'table_name': full_name,
  190. 'execution_time': execution_time,
  191. 'step_results': {k: v.to_dict() for k, v in step_results.items()},
  192. 'metadata': {
  193. 'fields_count': len(table_metadata.fields),
  194. 'row_count': table_metadata.row_count,
  195. 'enum_fields': len([f for f in table_metadata.fields if f.is_enum])
  196. }
  197. }
  198. except Exception as e:
  199. execution_time = time.time() - start_time
  200. error_msg = f"表 {table_spec} 处理异常: {str(e)}"
  201. self.logger.exception(error_msg)
  202. self.failed_tables.append(table_spec)
  203. return {
  204. 'success': False,
  205. 'table_name': table_spec,
  206. 'execution_time': execution_time,
  207. 'error_message': error_msg,
  208. 'step_results': {}
  209. }
  210. def _generate_summary_report(self, results: List[Dict[str, Any]]) -> Dict[str, Any]:
  211. """生成总结报告"""
  212. total_time = self.stats['end_time'] - self.stats['start_time']
  213. # 计算统计信息
  214. successful_results = [r for r in results if r.get('success', False)]
  215. failed_results = [r for r in results if not r.get('success', False)]
  216. total_fields = sum(r.get('metadata', {}).get('fields_count', 0) for r in successful_results)
  217. total_enum_fields = sum(r.get('metadata', {}).get('enum_fields', 0) for r in successful_results)
  218. avg_execution_time = sum(r.get('execution_time', 0) for r in results) / len(results) if results else 0
  219. report = {
  220. 'summary': {
  221. 'total_tables': self.stats['total_tables'],
  222. 'processed_successfully': len(successful_results),
  223. 'failed': len(failed_results),
  224. 'skipped_system_tables': self.stats['skipped_tables'],
  225. 'total_execution_time': total_time,
  226. 'average_table_time': avg_execution_time
  227. },
  228. 'statistics': {
  229. 'total_fields_processed': total_fields,
  230. 'enum_fields_detected': total_enum_fields,
  231. 'files_generated': len(successful_results) * (2 if self.pipeline == 'full' else 1)
  232. },
  233. 'failed_tables': self.failed_tables,
  234. 'detailed_results': results,
  235. 'configuration': {
  236. 'pipeline': self.pipeline,
  237. 'business_context': self.business_context,
  238. 'output_directory': self.output_dir,
  239. 'max_concurrent_tables': self.config['max_concurrent_tables']
  240. }
  241. }
  242. # 输出总结
  243. self.logger.info(f"📊 处理总结:")
  244. self.logger.info(f" ✅ 成功: {report['summary']['processed_successfully']} 个表")
  245. self.logger.info(f" ❌ 失败: {report['summary']['failed']} 个表")
  246. self.logger.info(f" ⏭️ 跳过: {report['summary']['skipped_system_tables']} 个系统表")
  247. self.logger.info(f" 📁 生成文件: {report['statistics']['files_generated']} 个")
  248. self.logger.info(f" 🕐 总耗时: {total_time:.2f} 秒")
  249. if self.failed_tables:
  250. self.logger.warning(f"❌ 失败的表: {', '.join(self.failed_tables)}")
  251. # 写入文件名映射报告
  252. self.file_manager.write_mapping_report()
  253. return report
  254. async def check_database_permissions(self) -> Dict[str, bool]:
  255. """检查数据库权限(供外部调用)"""
  256. inspector = ToolRegistry.get_tool("database_inspector", db_connection=self.db_connection)
  257. await inspector._create_connection_pool()
  258. checker = DatabasePermissionChecker(inspector)
  259. return await checker.check_permissions()