training_data_agent.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. import asyncio
  2. import time
  3. import logging
  4. import os
  5. from typing import List, Dict, Any, Optional
  6. from pathlib import Path
  7. from schema_tools.tools.base import ToolRegistry, PipelineExecutor
  8. from schema_tools.utils.data_structures import TableMetadata, TableProcessingContext, ProcessingResult
  9. from schema_tools.utils.file_manager import FileNameManager
  10. from schema_tools.utils.system_filter import SystemTableFilter
  11. from schema_tools.utils.permission_checker import DatabasePermissionChecker
  12. from schema_tools.utils.table_parser import TableListParser
  13. from schema_tools.utils.logger import setup_logging
  14. from schema_tools.config import SCHEMA_TOOLS_CONFIG
  15. class SchemaTrainingDataAgent:
  16. """Schema训练数据生成AI Agent"""
  17. def __init__(self,
  18. db_connection: str,
  19. table_list_file: str,
  20. business_context: str = None,
  21. output_dir: str = None,
  22. pipeline: str = "full"):
  23. self.db_connection = db_connection
  24. self.table_list_file = table_list_file
  25. self.business_context = business_context or "数据库管理系统"
  26. self.pipeline = pipeline
  27. # 配置管理
  28. self.config = SCHEMA_TOOLS_CONFIG
  29. self.output_dir = output_dir or self.config["output_directory"]
  30. # 初始化组件
  31. self.file_manager = FileNameManager(self.output_dir)
  32. self.system_filter = SystemTableFilter()
  33. self.table_parser = TableListParser()
  34. self.pipeline_executor = PipelineExecutor(self.config["available_pipelines"])
  35. # 统计信息
  36. self.stats = {
  37. 'total_tables': 0,
  38. 'processed_tables': 0,
  39. 'failed_tables': 0,
  40. 'skipped_tables': 0,
  41. 'start_time': None,
  42. 'end_time': None
  43. }
  44. self.failed_tables = []
  45. self.logger = logging.getLogger("schema_tools.Agent")
  46. async def generate_training_data(self) -> Dict[str, Any]:
  47. """主入口:生成训练数据"""
  48. try:
  49. self.stats['start_time'] = time.time()
  50. self.logger.info("🚀 开始生成Schema训练数据")
  51. # 1. 初始化
  52. await self._initialize()
  53. # 2. 检查数据库权限
  54. await self._check_database_permissions()
  55. # 3. 解析表清单
  56. tables = await self._parse_table_list()
  57. # 4. 过滤系统表
  58. user_tables = self._filter_system_tables(tables)
  59. # 5. 并发处理表
  60. results = await self._process_tables_concurrently(user_tables)
  61. # 6. 设置结束时间
  62. self.stats['end_time'] = time.time()
  63. # 7. 生成总结报告
  64. report = self._generate_summary_report(results)
  65. self.logger.info("✅ Schema训练数据生成完成")
  66. return report
  67. except Exception as e:
  68. self.stats['end_time'] = time.time()
  69. self.logger.exception("❌ Schema训练数据生成失败")
  70. raise
  71. async def _initialize(self):
  72. """初始化Agent"""
  73. # 创建输出目录
  74. os.makedirs(self.output_dir, exist_ok=True)
  75. if self.config["create_subdirectories"]:
  76. os.makedirs(os.path.join(self.output_dir, "ddl"), exist_ok=True)
  77. os.makedirs(os.path.join(self.output_dir, "docs"), exist_ok=True)
  78. # logs目录始终创建
  79. os.makedirs(os.path.join(self.output_dir, "logs"), exist_ok=True)
  80. # 初始化数据库工具
  81. database_tool = ToolRegistry.get_tool("database_inspector", db_connection=self.db_connection)
  82. await database_tool._create_connection_pool()
  83. self.logger.info(f"初始化完成,输出目录: {self.output_dir}")
  84. async def _check_database_permissions(self):
  85. """检查数据库权限"""
  86. if not self.config["check_permissions"]:
  87. return
  88. inspector = ToolRegistry.get_tool("database_inspector")
  89. checker = DatabasePermissionChecker(inspector)
  90. permissions = await checker.check_permissions()
  91. if not permissions['connect']:
  92. raise Exception("无法连接到数据库")
  93. if self.config["require_select_permission"] and not permissions['select_data']:
  94. if not self.config["allow_readonly_database"]:
  95. raise Exception("数据库查询权限不足")
  96. else:
  97. self.logger.warning("数据库为只读或权限受限,部分功能可能受影响")
  98. self.logger.info(f"数据库权限检查完成: {permissions}")
  99. async def _parse_table_list(self) -> List[str]:
  100. """解析表清单文件"""
  101. tables = self.table_parser.parse_file(self.table_list_file)
  102. self.stats['total_tables'] = len(tables)
  103. self.logger.info(f"📋 从清单文件读取到 {len(tables)} 个表")
  104. return tables
  105. def _filter_system_tables(self, tables: List[str]) -> List[str]:
  106. """过滤系统表"""
  107. if not self.config["filter_system_tables"]:
  108. return tables
  109. user_tables = self.system_filter.filter_user_tables(tables)
  110. filtered_count = len(tables) - len(user_tables)
  111. if filtered_count > 0:
  112. self.logger.info(f"🔍 过滤了 {filtered_count} 个系统表,保留 {len(user_tables)} 个用户表")
  113. self.stats['skipped_tables'] += filtered_count
  114. return user_tables
  115. async def _process_tables_concurrently(self, tables: List[str]) -> List[Dict[str, Any]]:
  116. """并发处理表"""
  117. max_concurrent = self.config["max_concurrent_tables"]
  118. semaphore = asyncio.Semaphore(max_concurrent)
  119. self.logger.info(f"🔄 开始并发处理 {len(tables)} 个表 (最大并发: {max_concurrent})")
  120. # 创建任务
  121. tasks = [
  122. self._process_single_table_with_semaphore(semaphore, table_spec)
  123. for table_spec in tables
  124. ]
  125. # 并发执行
  126. results = await asyncio.gather(*tasks, return_exceptions=True)
  127. # 统计结果
  128. successful = sum(1 for r in results if isinstance(r, dict) and r.get('success', False))
  129. failed = len(results) - successful
  130. self.stats['processed_tables'] = successful
  131. self.stats['failed_tables'] = failed
  132. self.logger.info(f"📊 处理完成: 成功 {successful} 个,失败 {failed} 个")
  133. return [r for r in results if isinstance(r, dict)]
  134. async def _process_single_table_with_semaphore(self, semaphore: asyncio.Semaphore, table_spec: str) -> Dict[str, Any]:
  135. """带信号量的单表处理"""
  136. async with semaphore:
  137. return await self._process_single_table(table_spec)
  138. async def _process_single_table(self, table_spec: str) -> Dict[str, Any]:
  139. """处理单个表"""
  140. start_time = time.time()
  141. try:
  142. # 解析表名
  143. if '.' in table_spec:
  144. schema_name, table_name = table_spec.split('.', 1)
  145. else:
  146. schema_name, table_name = 'public', table_spec
  147. full_name = f"{schema_name}.{table_name}"
  148. self.logger.info(f"🔍 开始处理表: {full_name}")
  149. # 创建表元数据
  150. table_metadata = TableMetadata(
  151. schema_name=schema_name,
  152. table_name=table_name,
  153. full_name=full_name
  154. )
  155. # 创建处理上下文
  156. context = TableProcessingContext(
  157. table_metadata=table_metadata,
  158. business_context=self.business_context,
  159. output_dir=self.output_dir,
  160. pipeline=self.pipeline,
  161. vn=None, # 将在工具中注入
  162. file_manager=self.file_manager,
  163. start_time=start_time
  164. )
  165. # 执行处理链
  166. step_results = await self.pipeline_executor.execute_pipeline(self.pipeline, context)
  167. # 计算总体成功状态
  168. success = all(result.success for result in step_results.values())
  169. execution_time = time.time() - start_time
  170. if success:
  171. self.logger.info(f"✅ 表 {full_name} 处理成功,耗时: {execution_time:.2f}秒")
  172. else:
  173. self.logger.error(f"❌ 表 {full_name} 处理失败,耗时: {execution_time:.2f}秒")
  174. self.failed_tables.append(full_name)
  175. return {
  176. 'success': success,
  177. 'table_name': full_name,
  178. 'execution_time': execution_time,
  179. 'step_results': {k: v.to_dict() for k, v in step_results.items()},
  180. 'metadata': {
  181. 'fields_count': len(table_metadata.fields),
  182. 'row_count': table_metadata.row_count,
  183. 'enum_fields': len([f for f in table_metadata.fields if f.is_enum])
  184. }
  185. }
  186. except Exception as e:
  187. execution_time = time.time() - start_time
  188. error_msg = f"表 {table_spec} 处理异常: {str(e)}"
  189. self.logger.exception(error_msg)
  190. self.failed_tables.append(table_spec)
  191. return {
  192. 'success': False,
  193. 'table_name': table_spec,
  194. 'execution_time': execution_time,
  195. 'error_message': error_msg,
  196. 'step_results': {}
  197. }
  198. def _generate_summary_report(self, results: List[Dict[str, Any]]) -> Dict[str, Any]:
  199. """生成总结报告"""
  200. total_time = self.stats['end_time'] - self.stats['start_time']
  201. # 计算统计信息
  202. successful_results = [r for r in results if r.get('success', False)]
  203. failed_results = [r for r in results if not r.get('success', False)]
  204. total_fields = sum(r.get('metadata', {}).get('fields_count', 0) for r in successful_results)
  205. total_enum_fields = sum(r.get('metadata', {}).get('enum_fields', 0) for r in successful_results)
  206. avg_execution_time = sum(r.get('execution_time', 0) for r in results) / len(results) if results else 0
  207. report = {
  208. 'summary': {
  209. 'total_tables': self.stats['total_tables'],
  210. 'processed_successfully': len(successful_results),
  211. 'failed': len(failed_results),
  212. 'skipped_system_tables': self.stats['skipped_tables'],
  213. 'total_execution_time': total_time,
  214. 'average_table_time': avg_execution_time
  215. },
  216. 'statistics': {
  217. 'total_fields_processed': total_fields,
  218. 'enum_fields_detected': total_enum_fields,
  219. 'files_generated': len(successful_results) * (2 if self.pipeline == 'full' else 1)
  220. },
  221. 'failed_tables': self.failed_tables,
  222. 'detailed_results': results,
  223. 'configuration': {
  224. 'pipeline': self.pipeline,
  225. 'business_context': self.business_context,
  226. 'output_directory': self.output_dir,
  227. 'max_concurrent_tables': self.config['max_concurrent_tables']
  228. }
  229. }
  230. # 输出总结
  231. self.logger.info(f"📊 处理总结:")
  232. self.logger.info(f" ✅ 成功: {report['summary']['processed_successfully']} 个表")
  233. self.logger.info(f" ❌ 失败: {report['summary']['failed']} 个表")
  234. self.logger.info(f" ⏭️ 跳过: {report['summary']['skipped_system_tables']} 个系统表")
  235. self.logger.info(f" 📁 生成文件: {report['statistics']['files_generated']} 个")
  236. self.logger.info(f" 🕐 总耗时: {total_time:.2f} 秒")
  237. if self.failed_tables:
  238. self.logger.warning(f"❌ 失败的表: {', '.join(self.failed_tables)}")
  239. # 写入文件名映射报告
  240. self.file_manager.write_mapping_report()
  241. return report
  242. async def check_database_permissions(self) -> Dict[str, bool]:
  243. """检查数据库权限(供外部调用)"""
  244. inspector = ToolRegistry.get_tool("database_inspector", db_connection=self.db_connection)
  245. await inspector._create_connection_pool()
  246. checker = DatabasePermissionChecker(inspector)
  247. return await checker.check_permissions()