training_data_agent.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. import asyncio
  2. import time
  3. import logging
  4. import os
  5. from typing import List, Dict, Any, Optional
  6. from pathlib import Path
  7. from schema_tools.tools.base import ToolRegistry, PipelineExecutor
  8. from schema_tools.utils.data_structures import TableMetadata, TableProcessingContext, ProcessingResult
  9. from schema_tools.utils.file_manager import FileNameManager
  10. from schema_tools.utils.system_filter import SystemTableFilter
  11. from schema_tools.utils.permission_checker import DatabasePermissionChecker
  12. from schema_tools.utils.table_parser import TableListParser
  13. from schema_tools.utils.logger import setup_logging
  14. from schema_tools.config import SCHEMA_TOOLS_CONFIG
  15. class SchemaTrainingDataAgent:
  16. """Schema训练数据生成AI Agent"""
  17. def __init__(self,
  18. db_connection: str,
  19. table_list_file: str,
  20. business_context: str = None,
  21. output_dir: str = None,
  22. pipeline: str = "full"):
  23. self.db_connection = db_connection
  24. self.table_list_file = table_list_file
  25. self.business_context = business_context or "数据库管理系统"
  26. self.pipeline = pipeline
  27. # 配置管理
  28. self.config = SCHEMA_TOOLS_CONFIG
  29. self.output_dir = output_dir or self.config["output_directory"]
  30. # 初始化组件
  31. self.file_manager = FileNameManager(self.output_dir)
  32. self.system_filter = SystemTableFilter()
  33. self.table_parser = TableListParser()
  34. self.pipeline_executor = PipelineExecutor(self.config["available_pipelines"])
  35. # 统计信息
  36. self.stats = {
  37. 'total_tables': 0,
  38. 'processed_tables': 0,
  39. 'failed_tables': 0,
  40. 'skipped_tables': 0,
  41. 'start_time': None,
  42. 'end_time': None
  43. }
  44. self.failed_tables = []
  45. self.logger = logging.getLogger("schema_tools.Agent")
  46. async def generate_training_data(self) -> Dict[str, Any]:
  47. """主入口:生成训练数据"""
  48. try:
  49. self.stats['start_time'] = time.time()
  50. self.logger.info("🚀 开始生成Schema训练数据")
  51. # 1. 初始化
  52. await self._initialize()
  53. # 2. 检查数据库权限
  54. await self._check_database_permissions()
  55. # 3. 解析表清单
  56. tables = await self._parse_table_list()
  57. # 4. 过滤系统表
  58. user_tables = self._filter_system_tables(tables)
  59. # 5. 并发处理表
  60. results = await self._process_tables_concurrently(user_tables)
  61. # 6. 设置结束时间
  62. self.stats['end_time'] = time.time()
  63. # 7. 生成总结报告
  64. report = self._generate_summary_report(results)
  65. self.logger.info("✅ Schema训练数据生成完成")
  66. return report
  67. except Exception as e:
  68. self.stats['end_time'] = time.time()
  69. self.logger.exception("❌ Schema训练数据生成失败")
  70. raise
  71. async def _initialize(self):
  72. """初始化Agent"""
  73. # 创建输出目录
  74. os.makedirs(self.output_dir, exist_ok=True)
  75. if self.config["create_subdirectories"]:
  76. os.makedirs(os.path.join(self.output_dir, "ddl"), exist_ok=True)
  77. os.makedirs(os.path.join(self.output_dir, "docs"), exist_ok=True)
  78. os.makedirs(os.path.join(self.output_dir, "logs"), exist_ok=True)
  79. # 初始化数据库工具
  80. database_tool = ToolRegistry.get_tool("database_inspector", db_connection=self.db_connection)
  81. await database_tool._create_connection_pool()
  82. self.logger.info(f"初始化完成,输出目录: {self.output_dir}")
  83. async def _check_database_permissions(self):
  84. """检查数据库权限"""
  85. if not self.config["check_permissions"]:
  86. return
  87. inspector = ToolRegistry.get_tool("database_inspector")
  88. checker = DatabasePermissionChecker(inspector)
  89. permissions = await checker.check_permissions()
  90. if not permissions['connect']:
  91. raise Exception("无法连接到数据库")
  92. if self.config["require_select_permission"] and not permissions['select_data']:
  93. if not self.config["allow_readonly_database"]:
  94. raise Exception("数据库查询权限不足")
  95. else:
  96. self.logger.warning("数据库为只读或权限受限,部分功能可能受影响")
  97. self.logger.info(f"数据库权限检查完成: {permissions}")
  98. async def _parse_table_list(self) -> List[str]:
  99. """解析表清单文件"""
  100. tables = self.table_parser.parse_file(self.table_list_file)
  101. self.stats['total_tables'] = len(tables)
  102. self.logger.info(f"📋 从清单文件读取到 {len(tables)} 个表")
  103. return tables
  104. def _filter_system_tables(self, tables: List[str]) -> List[str]:
  105. """过滤系统表"""
  106. if not self.config["filter_system_tables"]:
  107. return tables
  108. user_tables = self.system_filter.filter_user_tables(tables)
  109. filtered_count = len(tables) - len(user_tables)
  110. if filtered_count > 0:
  111. self.logger.info(f"🔍 过滤了 {filtered_count} 个系统表,保留 {len(user_tables)} 个用户表")
  112. self.stats['skipped_tables'] += filtered_count
  113. return user_tables
  114. async def _process_tables_concurrently(self, tables: List[str]) -> List[Dict[str, Any]]:
  115. """并发处理表"""
  116. max_concurrent = self.config["max_concurrent_tables"]
  117. semaphore = asyncio.Semaphore(max_concurrent)
  118. self.logger.info(f"🔄 开始并发处理 {len(tables)} 个表 (最大并发: {max_concurrent})")
  119. # 创建任务
  120. tasks = [
  121. self._process_single_table_with_semaphore(semaphore, table_spec)
  122. for table_spec in tables
  123. ]
  124. # 并发执行
  125. results = await asyncio.gather(*tasks, return_exceptions=True)
  126. # 统计结果
  127. successful = sum(1 for r in results if isinstance(r, dict) and r.get('success', False))
  128. failed = len(results) - successful
  129. self.stats['processed_tables'] = successful
  130. self.stats['failed_tables'] = failed
  131. self.logger.info(f"📊 处理完成: 成功 {successful} 个,失败 {failed} 个")
  132. return [r for r in results if isinstance(r, dict)]
  133. async def _process_single_table_with_semaphore(self, semaphore: asyncio.Semaphore, table_spec: str) -> Dict[str, Any]:
  134. """带信号量的单表处理"""
  135. async with semaphore:
  136. return await self._process_single_table(table_spec)
  137. async def _process_single_table(self, table_spec: str) -> Dict[str, Any]:
  138. """处理单个表"""
  139. start_time = time.time()
  140. try:
  141. # 解析表名
  142. if '.' in table_spec:
  143. schema_name, table_name = table_spec.split('.', 1)
  144. else:
  145. schema_name, table_name = 'public', table_spec
  146. full_name = f"{schema_name}.{table_name}"
  147. self.logger.info(f"🔍 开始处理表: {full_name}")
  148. # 创建表元数据
  149. table_metadata = TableMetadata(
  150. schema_name=schema_name,
  151. table_name=table_name,
  152. full_name=full_name
  153. )
  154. # 创建处理上下文
  155. context = TableProcessingContext(
  156. table_metadata=table_metadata,
  157. business_context=self.business_context,
  158. output_dir=self.output_dir,
  159. pipeline=self.pipeline,
  160. vn=None, # 将在工具中注入
  161. file_manager=self.file_manager,
  162. start_time=start_time
  163. )
  164. # 执行处理链
  165. step_results = await self.pipeline_executor.execute_pipeline(self.pipeline, context)
  166. # 计算总体成功状态
  167. success = all(result.success for result in step_results.values())
  168. execution_time = time.time() - start_time
  169. if success:
  170. self.logger.info(f"✅ 表 {full_name} 处理成功,耗时: {execution_time:.2f}秒")
  171. else:
  172. self.logger.error(f"❌ 表 {full_name} 处理失败,耗时: {execution_time:.2f}秒")
  173. self.failed_tables.append(full_name)
  174. return {
  175. 'success': success,
  176. 'table_name': full_name,
  177. 'execution_time': execution_time,
  178. 'step_results': {k: v.to_dict() for k, v in step_results.items()},
  179. 'metadata': {
  180. 'fields_count': len(table_metadata.fields),
  181. 'row_count': table_metadata.row_count,
  182. 'enum_fields': len([f for f in table_metadata.fields if f.is_enum])
  183. }
  184. }
  185. except Exception as e:
  186. execution_time = time.time() - start_time
  187. error_msg = f"表 {table_spec} 处理异常: {str(e)}"
  188. self.logger.exception(error_msg)
  189. self.failed_tables.append(table_spec)
  190. return {
  191. 'success': False,
  192. 'table_name': table_spec,
  193. 'execution_time': execution_time,
  194. 'error_message': error_msg,
  195. 'step_results': {}
  196. }
  197. def _generate_summary_report(self, results: List[Dict[str, Any]]) -> Dict[str, Any]:
  198. """生成总结报告"""
  199. total_time = self.stats['end_time'] - self.stats['start_time']
  200. # 计算统计信息
  201. successful_results = [r for r in results if r.get('success', False)]
  202. failed_results = [r for r in results if not r.get('success', False)]
  203. total_fields = sum(r.get('metadata', {}).get('fields_count', 0) for r in successful_results)
  204. total_enum_fields = sum(r.get('metadata', {}).get('enum_fields', 0) for r in successful_results)
  205. avg_execution_time = sum(r.get('execution_time', 0) for r in results) / len(results) if results else 0
  206. report = {
  207. 'summary': {
  208. 'total_tables': self.stats['total_tables'],
  209. 'processed_successfully': len(successful_results),
  210. 'failed': len(failed_results),
  211. 'skipped_system_tables': self.stats['skipped_tables'],
  212. 'total_execution_time': total_time,
  213. 'average_table_time': avg_execution_time
  214. },
  215. 'statistics': {
  216. 'total_fields_processed': total_fields,
  217. 'enum_fields_detected': total_enum_fields,
  218. 'files_generated': len(successful_results) * (2 if self.pipeline == 'full' else 1)
  219. },
  220. 'failed_tables': self.failed_tables,
  221. 'detailed_results': results,
  222. 'configuration': {
  223. 'pipeline': self.pipeline,
  224. 'business_context': self.business_context,
  225. 'output_directory': self.output_dir,
  226. 'max_concurrent_tables': self.config['max_concurrent_tables']
  227. }
  228. }
  229. # 输出总结
  230. self.logger.info(f"📊 处理总结:")
  231. self.logger.info(f" ✅ 成功: {report['summary']['processed_successfully']} 个表")
  232. self.logger.info(f" ❌ 失败: {report['summary']['failed']} 个表")
  233. self.logger.info(f" ⏭️ 跳过: {report['summary']['skipped_system_tables']} 个系统表")
  234. self.logger.info(f" 📁 生成文件: {report['statistics']['files_generated']} 个")
  235. self.logger.info(f" 🕐 总耗时: {total_time:.2f} 秒")
  236. if self.failed_tables:
  237. self.logger.warning(f"❌ 失败的表: {', '.join(self.failed_tables)}")
  238. # 写入文件名映射报告
  239. self.file_manager.write_mapping_report()
  240. return report
  241. async def check_database_permissions(self) -> Dict[str, bool]:
  242. """检查数据库权限(供外部调用)"""
  243. inspector = ToolRegistry.get_tool("database_inspector", db_connection=self.db_connection)
  244. await inspector._create_connection_pool()
  245. checker = DatabasePermissionChecker(inspector)
  246. return await checker.check_permissions()