Schema Tools 详细设计文档.md 95 KB

Schema Tools 详细设计文档

1. 项目结构与模块设计

1.1 完整目录结构

schema_tools/
├── __init__.py                     # 模块入口,导出主要接口
├── __main__.py                     # 命令行入口
├── config.py                       # 配置管理
├── training_data_agent.py          # 主AI Agent
├── qs_agent.py                     # Question-SQL生成Agent (新增)
├── qs_generator.py                 # Question-SQL命令行入口 (新增)
├── sql_validation_agent.py         # SQL验证Agent (新增)
├── sql_validator.py                # SQL验证命令行入口 (新增)
├── schema_workflow_orchestrator.py # 端到端工作流编排器 (新增)
├── tools/                          # Agent工具集
│   ├── __init__.py                 # 工具模块初始化
│   ├── base.py                     # 基础工具类和注册机制
│   ├── database_inspector.py       # 数据库元数据检查工具
│   ├── data_sampler.py             # 数据采样工具
│   ├── comment_generator.py        # LLM注释生成工具
│   ├── ddl_generator.py            # DDL格式生成工具
│   └── doc_generator.py            # MD文档生成工具
├── validators/                     # 验证器模块 (新增)
│   ├── __init__.py
│   ├── file_count_validator.py     # 文件数量验证器
│   └── sql_validator.py            # SQL验证器核心模块
├── analyzers/                      # 分析器模块 (新增)
│   ├── __init__.py
│   ├── md_analyzer.py              # MD文件分析器
│   └── theme_extractor.py          # 主题提取器
├── utils/                          # 工具函数
│   ├── __init__.py
│   ├── data_structures.py          # 数据结构定义
│   ├── table_parser.py             # 表清单解析器
│   ├── file_manager.py             # 文件管理器
│   ├── system_filter.py            # 系统表过滤器
│   ├── permission_checker.py       # 权限检查器
│   ├── large_table_handler.py      # 大表处理器
│   └── logger.py                   # 日志管理
├── prompts/                        # 提示词模板
│   ├── table_comment_template.txt
│   ├── field_comment_template.txt
│   ├── enum_detection_template.txt
│   ├── business_context.txt
│   └── business_dictionary.txt
└── tests/                          # 测试用例
    ├── unit/
    ├── integration/
    └── fixtures/

2. 核心数据结构设计

2.1 数据结构定义 (utils/data_structures.py)

from dataclasses import dataclass, field
from typing import List, Dict, Optional, Any, Union
from enum import Enum
import hashlib
import json

class FieldType(Enum):
    """字段类型枚举"""
    INTEGER = "integer"
    VARCHAR = "varchar"
    TEXT = "text"
    TIMESTAMP = "timestamp"
    DATE = "date"
    BOOLEAN = "boolean"
    NUMERIC = "numeric"
    ENUM = "enum"
    JSON = "json"
    UUID = "uuid"
    OTHER = "other"

class ProcessingStatus(Enum):
    """处理状态枚举"""
    PENDING = "pending"
    RUNNING = "running"
    SUCCESS = "success"
    FAILED = "failed"
    SKIPPED = "skipped"

@dataclass
class FieldInfo:
    """字段信息标准结构"""
    name: str
    type: str
    nullable: bool
    default_value: Optional[str] = None
    comment: Optional[str] = None
    original_comment: Optional[str] = None  # 原始注释
    generated_comment: Optional[str] = None  # LLM生成的注释
    is_primary_key: bool = False
    is_foreign_key: bool = False
    is_enum: bool = False
    enum_values: Optional[List[str]] = None
    enum_description: Optional[str] = None
    max_length: Optional[int] = None
    precision: Optional[int] = None
    scale: Optional[int] = None
    
    def to_dict(self) -> Dict[str, Any]:
        """转换为字典格式"""
        return {
            'name': self.name,
            'type': self.type,
            'nullable': self.nullable,
            'default_value': self.default_value,
            'comment': self.comment,
            'is_primary_key': self.is_primary_key,
            'is_foreign_key': self.is_foreign_key,
            'is_enum': self.is_enum,
            'enum_values': self.enum_values
        }

@dataclass
class TableMetadata:
    """表元数据标准结构"""
    schema_name: str
    table_name: str
    full_name: str  # schema.table_name
    comment: Optional[str] = None
    original_comment: Optional[str] = None  # 原始注释
    generated_comment: Optional[str] = None  # LLM生成的注释
    fields: List[FieldInfo] = field(default_factory=list)
    sample_data: List[Dict[str, Any]] = field(default_factory=list)
    row_count: Optional[int] = None
    table_size: Optional[str] = None  # 表大小(如 "1.2 MB")
    created_date: Optional[str] = None
    
    @property
    def safe_file_name(self) -> str:
        """生成安全的文件名"""
        if self.schema_name.lower() == 'public':
            return self.table_name
        return f"{self.schema_name}__{self.table_name}".replace('.', '__').replace('-', '_').replace(' ', '_')
    
    def get_metadata_hash(self) -> str:
        """计算元数据哈希值,用于增量更新判断"""
        hash_data = {
            'schema_name': self.schema_name,
            'table_name': self.table_name,
            'fields': [f.to_dict() for f in self.fields],
            'comment': self.original_comment
        }
        return hashlib.md5(json.dumps(hash_data, sort_keys=True).encode()).hexdigest()

@dataclass
class ProcessingResult:
    """工具处理结果标准结构"""
    success: bool
    data: Optional[Any] = None
    error_message: Optional[str] = None
    warnings: List[str] = field(default_factory=list)
    execution_time: Optional[float] = None
    metadata: Dict[str, Any] = field(default_factory=dict)
    
    def add_warning(self, warning: str):
        """添加警告信息"""
        self.warnings.append(warning)
    
    def to_dict(self) -> Dict[str, Any]:
        """转换为字典格式"""
        return {
            'success': self.success,
            'data': self.data,
            'error_message': self.error_message,
            'warnings': self.warnings,
            'execution_time': self.execution_time,
            'metadata': self.metadata
        }

@dataclass
class TableProcessingContext:
    """表处理上下文"""
    table_metadata: TableMetadata
    business_context: str
    output_dir: str
    pipeline: str
    vn: Any  # vanna实例
    file_manager: Any
    current_step: str = "initialized"
    step_results: Dict[str, ProcessingResult] = field(default_factory=dict)
    start_time: Optional[float] = None
    
    def update_step(self, step_name: str, result: ProcessingResult):
        """更新步骤结果"""
        self.current_step = step_name
        self.step_results[step_name] = result

3. 工具注册与管理系统

3.1 基础工具类 (tools/base.py)

import asyncio
import time
import logging
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional, Type
from utils.data_structures import ProcessingResult, TableProcessingContext

class ToolRegistry:
    """工具注册管理器"""
    _tools: Dict[str, Type['BaseTool']] = {}
    _instances: Dict[str, 'BaseTool'] = {}
    
    @classmethod
    def register(cls, name: str):
        """装饰器:注册工具"""
        def decorator(tool_class: Type['BaseTool']):
            cls._tools[name] = tool_class
            logging.debug(f"注册工具: {name} -> {tool_class.__name__}")
            return tool_class
        return decorator
    
    @classmethod
    def get_tool(cls, name: str, **kwargs) -> 'BaseTool':
        """获取工具实例,支持单例模式"""
        if name not in cls._instances:
            if name not in cls._tools:
                raise ValueError(f"工具 '{name}' 未注册")
            
            tool_class = cls._tools[name]
            
            # 自动注入vanna实例到需要LLM的工具
            if hasattr(tool_class, 'needs_llm') and tool_class.needs_llm:
                from core.vanna_llm_factory import create_vanna_instance
                kwargs['vn'] = create_vanna_instance()
                logging.debug(f"为工具 {name} 注入LLM实例")
            
            cls._instances[name] = tool_class(**kwargs)
        
        return cls._instances[name]
    
    @classmethod
    def list_tools(cls) -> List[str]:
        """列出所有已注册的工具"""
        return list(cls._tools.keys())
    
    @classmethod
    def clear_instances(cls):
        """清除所有工具实例(用于测试)"""
        cls._instances.clear()

class BaseTool(ABC):
    """工具基类"""
    
    needs_llm: bool = False  # 是否需要LLM实例
    tool_name: str = ""      # 工具名称
    
    def __init__(self, **kwargs):
        self.logger = logging.getLogger(f"schema_tools.{self.__class__.__name__}")
        
        # 如果工具需要LLM,检查是否已注入
        if self.needs_llm and 'vn' not in kwargs:
            raise ValueError(f"工具 {self.__class__.__name__} 需要LLM实例但未提供")
        
        # 存储vanna实例
        if 'vn' in kwargs:
            self.vn = kwargs['vn']
    
    @abstractmethod
    async def execute(self, context: TableProcessingContext) -> ProcessingResult:
        """
        执行工具逻辑
        Args:
            context: 表处理上下文
        Returns:
            ProcessingResult: 处理结果
        """
        pass
    
    async def _execute_with_timing(self, context: TableProcessingContext) -> ProcessingResult:
        """带计时的执行包装器"""
        start_time = time.time()
        
        try:
            self.logger.info(f"开始执行工具: {self.tool_name}")
            result = await self.execute(context)
            execution_time = time.time() - start_time
            result.execution_time = execution_time
            
            if result.success:
                self.logger.info(f"工具 {self.tool_name} 执行成功,耗时: {execution_time:.2f}秒")
            else:
                self.logger.error(f"工具 {self.tool_name} 执行失败: {result.error_message}")
            
            return result
            
        except Exception as e:
            execution_time = time.time() - start_time
            self.logger.exception(f"工具 {self.tool_name} 执行异常")
            
            return ProcessingResult(
                success=False,
                error_message=f"工具执行异常: {str(e)}",
                execution_time=execution_time
            )
    
    def validate_input(self, context: TableProcessingContext) -> bool:
        """输入验证(子类可重写)"""
        return context.table_metadata is not None

3.2 Pipeline执行器 (training_data_agent.py 的一部分)

class PipelineExecutor:
    """处理链执行器"""
    
    def __init__(self, pipeline_config: Dict[str, List[str]]):
        self.pipeline_config = pipeline_config
        self.logger = logging.getLogger("schema_tools.PipelineExecutor")
    
    async def execute_pipeline(self, pipeline_name: str, context: TableProcessingContext) -> Dict[str, ProcessingResult]:
        """执行指定的处理链"""
        if pipeline_name not in self.pipeline_config:
            raise ValueError(f"未知的处理链: {pipeline_name}")
        
        steps = self.pipeline_config[pipeline_name]
        results = {}
        
        self.logger.info(f"开始执行处理链 '{pipeline_name}': {' -> '.join(steps)}")
        
        for step_name in steps:
            try:
                tool = ToolRegistry.get_tool(step_name)
                
                # 验证输入
                if not tool.validate_input(context):
                    result = ProcessingResult(
                        success=False,
                        error_message=f"工具 {step_name} 输入验证失败"
                    )
                else:
                    result = await tool._execute_with_timing(context)
                
                results[step_name] = result
                context.update_step(step_name, result)
                
                # 如果步骤失败且不允许继续,则停止
                if not result.success:
                    from config import SCHEMA_TOOLS_CONFIG
                    if not SCHEMA_TOOLS_CONFIG["continue_on_error"]:
                        self.logger.error(f"步骤 {step_name} 失败,停止处理链执行")
                        break
                    else:
                        self.logger.warning(f"步骤 {step_name} 失败,继续执行下一步")
                
            except Exception as e:
                self.logger.exception(f"执行步骤 {step_name} 时发生异常")
                results[step_name] = ProcessingResult(
                    success=False,
                    error_message=f"步骤执行异常: {str(e)}"
                )
                break
        
        return results

4. 核心工具实现

4.1 数据库检查工具 (tools/database_inspector.py)

import asyncio
import asyncpg
from typing import List, Dict, Any, Optional
from tools.base import BaseTool, ToolRegistry
from utils.data_structures import ProcessingResult, TableProcessingContext, FieldInfo, TableMetadata

@ToolRegistry.register("database_inspector")
class DatabaseInspectorTool(BaseTool):
    """数据库元数据检查工具"""
    
    needs_llm = False
    tool_name = "数据库检查器"
    
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.db_connection = kwargs.get('db_connection')
        self.connection_pool = None
    
    async def execute(self, context: TableProcessingContext) -> ProcessingResult:
        """执行数据库元数据检查"""
        try:
            # 建立数据库连接
            if not self.connection_pool:
                await self._create_connection_pool()
            
            table_name = context.table_metadata.table_name
            schema_name = context.table_metadata.schema_name
            
            # 获取表的基本信息
            table_info = await self._get_table_info(schema_name, table_name)
            if not table_info:
                return ProcessingResult(
                    success=False,
                    error_message=f"表 {schema_name}.{table_name} 不存在或无权限访问"
                )
            
            # 获取字段信息
            fields = await self._get_table_fields(schema_name, table_name)
            
            # 获取表注释
            table_comment = await self._get_table_comment(schema_name, table_name)
            
            # 获取表统计信息
            stats = await self._get_table_statistics(schema_name, table_name)
            
            # 更新表元数据
            context.table_metadata.original_comment = table_comment
            context.table_metadata.comment = table_comment
            context.table_metadata.fields = fields
            context.table_metadata.row_count = stats.get('row_count')
            context.table_metadata.table_size = stats.get('table_size')
            
            return ProcessingResult(
                success=True,
                data={
                    'fields_count': len(fields),
                    'table_comment': table_comment,
                    'row_count': stats.get('row_count'),
                    'table_size': stats.get('table_size')
                },
                metadata={'tool': self.tool_name}
            )
            
        except Exception as e:
            self.logger.exception(f"数据库检查失败")
            return ProcessingResult(
                success=False,
                error_message=f"数据库检查失败: {str(e)}"
            )
    
    async def _create_connection_pool(self):
        """创建数据库连接池"""
        try:
            self.connection_pool = await asyncpg.create_pool(
                self.db_connection,
                min_size=1,
                max_size=5,
                command_timeout=30
            )
            self.logger.info("数据库连接池创建成功")
        except Exception as e:
            self.logger.error(f"创建数据库连接池失败: {e}")
            raise
    
    async def _get_table_info(self, schema_name: str, table_name: str) -> Optional[Dict]:
        """获取表基本信息"""
        query = """
        SELECT schemaname, tablename, tableowner, tablespace, hasindexes, hasrules, hastriggers
        FROM pg_tables 
        WHERE schemaname = $1 AND tablename = $2
        """
        async with self.connection_pool.acquire() as conn:
            result = await conn.fetchrow(query, schema_name, table_name)
            return dict(result) if result else None
    
    async def _get_table_fields(self, schema_name: str, table_name: str) -> List[FieldInfo]:
        """获取表字段信息"""
        query = """
        SELECT 
            c.column_name,
            c.data_type,
            c.is_nullable,
            c.column_default,
            c.character_maximum_length,
            c.numeric_precision,
            c.numeric_scale,
            pd.description as column_comment,
            CASE WHEN pk.column_name IS NOT NULL THEN true ELSE false END as is_primary_key,
            CASE WHEN fk.column_name IS NOT NULL THEN true ELSE false END as is_foreign_key
        FROM information_schema.columns c
        LEFT JOIN pg_description pd ON pd.objsubid = c.ordinal_position 
            AND pd.objoid = (
                SELECT oid FROM pg_class 
                WHERE relname = c.table_name 
                AND relnamespace = (SELECT oid FROM pg_namespace WHERE nspname = c.table_schema)
            )
        LEFT JOIN (
            SELECT ku.column_name
            FROM information_schema.table_constraints tc
            JOIN information_schema.key_column_usage ku ON tc.constraint_name = ku.constraint_name
            WHERE tc.table_schema = $1 AND tc.table_name = $2 AND tc.constraint_type = 'PRIMARY KEY'
        ) pk ON pk.column_name = c.column_name
        LEFT JOIN (
            SELECT ku.column_name
            FROM information_schema.table_constraints tc
            JOIN information_schema.key_column_usage ku ON tc.constraint_name = ku.constraint_name
            WHERE tc.table_schema = $1 AND tc.table_name = $2 AND tc.constraint_type = 'FOREIGN KEY'
        ) fk ON fk.column_name = c.column_name
        WHERE c.table_schema = $1 AND c.table_name = $2
        ORDER BY c.ordinal_position
        """
        
        fields = []
        async with self.connection_pool.acquire() as conn:
            rows = await conn.fetch(query, schema_name, table_name)
            
            for row in rows:
                field = FieldInfo(
                    name=row['column_name'],
                    type=row['data_type'],
                    nullable=row['is_nullable'] == 'YES',
                    default_value=row['column_default'],
                    original_comment=row['column_comment'],
                    comment=row['column_comment'],
                    is_primary_key=row['is_primary_key'],
                    is_foreign_key=row['is_foreign_key'],
                    max_length=row['character_maximum_length'],
                    precision=row['numeric_precision'],
                    scale=row['numeric_scale']
                )
                fields.append(field)
        
        return fields
    
    async def _get_table_comment(self, schema_name: str, table_name: str) -> Optional[str]:
        """获取表注释"""
        query = """
        SELECT obj_description(oid) as table_comment
        FROM pg_class 
        WHERE relname = $2 
        AND relnamespace = (SELECT oid FROM pg_namespace WHERE nspname = $1)
        """
        async with self.connection_pool.acquire() as conn:
            result = await conn.fetchval(query, schema_name, table_name)
            return result
    
    async def _get_table_statistics(self, schema_name: str, table_name: str) -> Dict[str, Any]:
        """获取表统计信息"""
        stats_query = """
        SELECT 
            schemaname,
            tablename,
            attname,
            n_distinct,
            most_common_vals,
            most_common_freqs,
            histogram_bounds
        FROM pg_stats 
        WHERE schemaname = $1 AND tablename = $2
        """
        
        size_query = """
        SELECT pg_size_pretty(pg_total_relation_size($1)) as table_size,
               pg_relation_size($1) as table_size_bytes
        """
        
        count_query = f"SELECT COUNT(*) as row_count FROM {schema_name}.{table_name}"
        
        stats = {}
        async with self.connection_pool.acquire() as conn:
            try:
                # 获取行数
                row_count = await conn.fetchval(count_query)
                stats['row_count'] = row_count
                
                # 获取表大小
                table_oid = await conn.fetchval(
                    "SELECT oid FROM pg_class WHERE relname = $1 AND relnamespace = (SELECT oid FROM pg_namespace WHERE nspname = $2)",
                    table_name, schema_name
                )
                if table_oid:
                    size_result = await conn.fetchrow(size_query, table_oid)
                    stats['table_size'] = size_result['table_size']
                    stats['table_size_bytes'] = size_result['table_size_bytes']
                
            except Exception as e:
                self.logger.warning(f"获取表统计信息失败: {e}")
        
        return stats

4.2 数据采样工具 (tools/data_sampler.py)

import random
from typing import List, Dict, Any
from tools.base import BaseTool, ToolRegistry
from utils.data_structures import ProcessingResult, TableProcessingContext

@ToolRegistry.register("data_sampler")
class DataSamplerTool(BaseTool):
    """数据采样工具"""
    
    needs_llm = False
    tool_name = "数据采样器"
    
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.db_connection = kwargs.get('db_connection')
    
    async def execute(self, context: TableProcessingContext) -> ProcessingResult:
        """执行数据采样"""
        try:
            from config import SCHEMA_TOOLS_CONFIG
            
            table_metadata = context.table_metadata
            sample_limit = SCHEMA_TOOLS_CONFIG["sample_data_limit"]
            large_table_threshold = SCHEMA_TOOLS_CONFIG["large_table_threshold"]
            
            # 判断是否为大表,使用不同的采样策略
            if table_metadata.row_count and table_metadata.row_count > large_table_threshold:
                sample_data = await self._smart_sample_large_table(table_metadata, sample_limit)
                self.logger.info(f"大表 {table_metadata.full_name} 使用智能采样策略")
            else:
                sample_data = await self._simple_sample(table_metadata, sample_limit)
            
            # 更新上下文中的采样数据
            context.table_metadata.sample_data = sample_data
            
            return ProcessingResult(
                success=True,
                data={
                    'sample_count': len(sample_data),
                    'sampling_strategy': 'smart' if table_metadata.row_count and table_metadata.row_count > large_table_threshold else 'simple'
                },
                metadata={'tool': self.tool_name}
            )
            
        except Exception as e:
            self.logger.exception(f"数据采样失败")
            return ProcessingResult(
                success=False,
                error_message=f"数据采样失败: {str(e)}"
            )
    
    async def _simple_sample(self, table_metadata: TableMetadata, limit: int) -> List[Dict[str, Any]]:
        """简单采样策略"""
        from tools.database_inspector import DatabaseInspectorTool
        
        # 复用数据库检查工具的连接
        inspector = ToolRegistry.get_tool("database_inspector")
        
        query = f"SELECT * FROM {table_metadata.full_name} LIMIT {limit}"
        
        async with inspector.connection_pool.acquire() as conn:
            rows = await conn.fetch(query)
            return [dict(row) for row in rows]
    
    async def _smart_sample_large_table(self, table_metadata: TableMetadata, limit: int) -> List[Dict[str, Any]]:
        """智能采样策略(用于大表)"""
        from tools.database_inspector import DatabaseInspectorTool
        
        inspector = ToolRegistry.get_tool("database_inspector")
        samples_per_section = max(1, limit // 3)
        
        samples = []
        
        async with inspector.connection_pool.acquire() as conn:
            # 1. 前N行采样
            front_query = f"SELECT * FROM {table_metadata.full_name} LIMIT {samples_per_section}"
            front_rows = await conn.fetch(front_query)
            samples.extend([dict(row) for row in front_rows])
            
            # 2. 随机中间采样(使用TABLESAMPLE)
            if table_metadata.row_count > samples_per_section * 2:
                try:
                    # 计算采样百分比
                    sample_percent = min(1.0, (samples_per_section * 100.0) / table_metadata.row_count)
                    middle_query = f"""
                    SELECT * FROM {table_metadata.full_name} 
                    TABLESAMPLE SYSTEM({sample_percent}) 
                    LIMIT {samples_per_section}
                    """
                    middle_rows = await conn.fetch(middle_query)
                    samples.extend([dict(row) for row in middle_rows])
                except Exception as e:
                    self.logger.warning(f"TABLESAMPLE采样失败,使用OFFSET采样: {e}")
                    # 回退到OFFSET采样
                    offset = random.randint(samples_per_section, table_metadata.row_count - samples_per_section)
                    offset_query = f"SELECT * FROM {table_metadata.full_name} OFFSET {offset} LIMIT {samples_per_section}"
                    offset_rows = await conn.fetch(offset_query)
                    samples.extend([dict(row) for row in offset_rows])
            
            # 3. 后N行采样
            remaining = limit - len(samples)
            if remaining > 0:
                # 使用ORDER BY ... DESC来获取最后的行
                tail_query = f"""
                SELECT * FROM (
                    SELECT *, ROW_NUMBER() OVER() as rn 
                    FROM {table_metadata.full_name}
                ) sub 
                WHERE sub.rn > (SELECT COUNT(*) FROM {table_metadata.full_name}) - {remaining}
                ORDER BY sub.rn
                """
                try:
                    tail_rows = await conn.fetch(tail_query)
                    # 移除ROW_NUMBER列
                    for row in tail_rows:
                        row_dict = dict(row)
                        row_dict.pop('rn', None)
                        samples.append(row_dict)
                except Exception as e:
                    self.logger.warning(f"尾部采样失败: {e}")
        
        return samples[:limit]  # 确保不超过限制

4.3 LLM注释生成工具 (tools/comment_generator.py)

import asyncio
from typing import List, Dict, Any, Tuple
from tools.base import BaseTool, ToolRegistry
from utils.data_structures import ProcessingResult, TableProcessingContext, FieldInfo

@ToolRegistry.register("comment_generator")
class CommentGeneratorTool(BaseTool):
    """LLM注释生成工具"""
    
    needs_llm = True
    tool_name = "注释生成器"
    
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.business_context = kwargs.get('business_context', '')
        self.business_dictionary = self._load_business_dictionary()
    
    async def execute(self, context: TableProcessingContext) -> ProcessingResult:
        """执行注释生成"""
        try:
            table_metadata = context.table_metadata
            
            # 生成表注释
            table_comment_result = await self._generate_table_comment(table_metadata, context.business_context)
            
            # 生成字段注释和枚举建议
            field_results = await self._generate_field_comments_and_enums(table_metadata, context.business_context)
            
            # 更新表元数据
            if table_comment_result['success']:
                table_metadata.generated_comment = table_comment_result['comment']
                table_metadata.comment = table_comment_result['comment']
            
            # 更新字段信息
            enum_suggestions = []
            for i, field in enumerate(table_metadata.fields):
                if i < len(field_results) and field_results[i]['success']:
                    field.generated_comment = field_results[i]['comment']
                    field.comment = field_results[i]['comment']
                    
                    # 处理枚举建议
                    if field_results[i].get('is_enum'):
                        field.is_enum = True
                        enum_suggestions.append({
                            'field_name': field.name,
                            'suggested_values': field_results[i].get('enum_values', []),
                            'enum_description': field_results[i].get('enum_description', '')
                        })
            
            # 验证枚举建议
            if enum_suggestions:
                validated_enums = await self._validate_enum_suggestions(table_metadata, enum_suggestions)
                
                # 更新验证后的枚举信息
                for enum_info in validated_enums:
                    field_name = enum_info['field_name']
                    for field in table_metadata.fields:
                        if field.name == field_name:
                            field.enum_values = enum_info['actual_values']
                            field.enum_description = enum_info['description']
                            break
            
            return ProcessingResult(
                success=True,
                data={
                    'table_comment_generated': table_comment_result['success'],
                    'field_comments_generated': sum(1 for r in field_results if r['success']),
                    'enum_fields_detected': len([f for f in table_metadata.fields if f.is_enum]),
                    'enum_suggestions': enum_suggestions
                },
                metadata={'tool': self.tool_name}
            )
            
        except Exception as e:
            self.logger.exception(f"注释生成失败")
            return ProcessingResult(
                success=False,
                error_message=f"注释生成失败: {str(e)}"
            )
    
    async def _generate_table_comment(self, table_metadata, business_context: str) -> Dict[str, Any]:
        """生成表注释"""
        try:
            prompt = self._build_table_comment_prompt(table_metadata, business_context)
            
            # 调用LLM
            response = await self._call_llm_with_retry(prompt)
            
            # 解析响应
            comment = self._extract_table_comment(response)
            
            return {
                'success': True,
                'comment': comment,
                'original_response': response
            }
            
        except Exception as e:
            self.logger.error(f"表注释生成失败: {e}")
            return {
                'success': False,
                'comment': table_metadata.original_comment or f"{table_metadata.table_name}表",
                'error': str(e)
            }
    
    async def _generate_field_comments_and_enums(self, table_metadata, business_context: str) -> List[Dict[str, Any]]:
        """批量生成字段注释和枚举建议"""
        try:
            # 构建批量处理的提示词
            prompt = self._build_field_batch_prompt(table_metadata, business_context)
            
            # 调用LLM
            response = await self._call_llm_with_retry(prompt)
            
            # 解析批量响应
            field_results = self._parse_field_batch_response(response, table_metadata.fields)
            
            return field_results
            
        except Exception as e:
            self.logger.error(f"字段注释批量生成失败: {e}")
            # 返回默认结果
            return [
                {
                    'success': False,
                    'comment': field.original_comment or field.name,
                    'is_enum': False,
                    'error': str(e)
                }
                for field in table_metadata.fields
            ]
    
    def _build_table_comment_prompt(self, table_metadata, business_context: str) -> str:
        """构建表注释生成提示词"""
        # 准备字段信息摘要
        fields_summary = []
        for field in table_metadata.fields[:10]:  # 只显示前10个字段避免过长
            field_desc = f"- {field.name} ({field.type})"
            if field.comment:
                field_desc += f": {field.comment}"
            fields_summary.append(field_desc)
        
        # 准备样例数据摘要
        sample_summary = ""
        if table_metadata.sample_data:
            sample_count = min(3, len(table_metadata.sample_data))
            sample_summary = f"\n样例数据({sample_count}条):\n"
            for i, sample in enumerate(table_metadata.sample_data[:sample_count]):
                sample_str = ", ".join([f"{k}={v}" for k, v in list(sample.items())[:5]])
                sample_summary += f"{i+1}. {sample_str}\n"
        
        prompt = f"""你是一个数据库文档专家。请根据以下信息为数据库表生成简洁、准确的中文注释。

业务背景: {business_context}
{self.business_dictionary}

表信息:
- 表名: {table_metadata.table_name}
- Schema: {table_metadata.schema_name}
- 现有注释: {table_metadata.original_comment or "无"}
- 字段数量: {len(table_metadata.fields)}
- 数据行数: {table_metadata.row_count or "未知"}

主要字段:
{chr(10).join(fields_summary)}

{sample_summary}

请生成一个简洁、准确的中文表注释,要求:
1. 如果现有注释是英文,请翻译为中文并改进
2. 根据字段名称和样例数据推断表的业务用途
3. 注释长度控制在50字以内
4. 突出表的核心业务价值

表注释:"""
        
        return prompt
    
    def _build_field_batch_prompt(self, table_metadata, business_context: str) -> str:
        """构建字段批量处理提示词"""
        # 准备字段信息
        fields_info = []
        sample_values = {}
        
        # 收集字段的样例值
        for sample in table_metadata.sample_data[:5]:
            for field_name, value in sample.items():
                if field_name not in sample_values:
                    sample_values[field_name] = []
                if value is not None and len(sample_values[field_name]) < 5:
                    sample_values[field_name].append(str(value))
        
        # 构建字段信息列表
        for field in table_metadata.fields:
            field_info = f"{field.name} ({field.type})"
            if field.original_comment:
                field_info += f" - 原注释: {field.original_comment}"
            
            # 添加样例值
            if field.name in sample_values and sample_values[field.name]:
                values_str = ", ".join(sample_values[field.name][:3])
                field_info += f" - 样例值: {values_str}"
            
            fields_info.append(field_info)
        
        prompt = f"""你是一个数据库文档专家。请为以下表的所有字段生成中文注释,并识别可能的枚举字段。

业务背景: {business_context}
{self.business_dictionary}

表名: {table_metadata.schema_name}.{table_metadata.table_name}
表注释: {table_metadata.comment or "无"}

字段列表:
{chr(10).join([f"{i+1}. {info}" for i, info in enumerate(fields_info)])}

请按以下JSON格式输出每个字段的分析结果:
```json
{{
  "fields": [
    {{
      "name": "字段名",
      "comment": "中文注释(简洁明确,15字以内)",
      "is_enum": true/false,
      "enum_values": ["值1", "值2", "值3"] (如果是枚举),
      "enum_description": "枚举含义说明" (如果是枚举)
    }}
  ]
}}

注释生成要求:

  1. 如果原注释是英文,翻译为中文并改进
  2. 根据字段名、类型和样例值推断字段含义
  3. 识别可能的枚举字段(如状态、类型、级别等)
  4. 枚举判断标准: VARCHAR类型 + 样例值重复度高 + 字段名暗示分类
  5. 注释要贴近{business_context}的业务场景

请输出JSON格式的分析结果:"""

    return prompt

async def _call_llm_with_retry(self, prompt: str, max_retries: int = 3) -> str:
    """带重试的LLM调用"""
    from config import SCHEMA_TOOLS_CONFIG
    
    for attempt in range(max_retries):
        try:
            # 使用vanna实例调用LLM
            response = await asyncio.to_thread(self.vn.ask, prompt)
            
            if response and response.strip():
                return response.strip()
            else:
                raise ValueError("LLM返回空响应")
                
        except Exception as e:
            self.logger.warning(f"LLM调用失败 (尝试 {attempt + 1}/{max_retries}): {e}")
            if attempt == max_retries - 1:
                raise
            await asyncio.sleep(1)  # 等待1秒后重试
    
    raise Exception("LLM调用达到最大重试次数")

def _extract_table_comment(self, llm_response: str) -> str:
    """从LLM响应中提取表注释"""
    # 简单的文本清理和提取逻辑
    lines = llm_response.strip().split('\n')
    
    # 查找包含实际注释的行
    for line in lines:
        line = line.strip()
        if line and not line.startswith('#') and not line.startswith('*'):
            # 移除可能的前缀
            prefixes = ['表注释:', '注释:', '说明:', '表说明:']
            for prefix in prefixes:
                if line.startswith(prefix):
                    line = line[len(prefix):].strip()
            
            if line:
                return line[:200]  # 限制长度
    
    return llm_response.strip()[:200]

def _parse_field_batch_response(self, llm_response: str, fields: List[FieldInfo]) -> List[Dict[str, Any]]:
    """解析字段批量处理响应"""
    import json
    import re
    
    try:
        # 尝试提取JSON部分
        json_match = re.search(r'```json\s*(.*?)\s*```', llm_response, re.DOTALL)
        if json_match:
            json_str = json_match.group(1)
        else:
            # 如果没有代码块,尝试直接解析
            json_str = llm_response
        
        # 解析JSON
        parsed_data = json.loads(json_str)
        field_data = parsed_data.get('fields', [])
        
        # 映射到字段结果
        results = []
        for i, field in enumerate(fields):
            if i < len(field_data):
                data = field_data[i]
                results.append({
                    'success': True,
                    'comment': data.get('comment', field.name),
                    'is_enum': data.get('is_enum', False),
                    'enum_values': data.get('enum_values', []),
                    'enum_description': data.get('enum_description', '')
                })
            else:
                # 默认结果
                results.append({
                    'success': False,
                    'comment': field.original_comment or field.name,
                    'is_enum': False
                })
        
        return results
        
    except Exception as e:
        self.logger.error(f"解析字段批量响应失败: {e}")
        # 返回默认结果
        return [
            {
                'success': False,
                'comment': field.original_comment or field.name,
                'is_enum': False,
                'error': str(e)
            }
            for field in fields
        ]

async def _validate_enum_suggestions(self, table_metadata, enum_suggestions: List[Dict]) -> List[Dict]:
    """验证枚举建议"""
    from tools.database_inspector import DatabaseInspectorTool
    from config import SCHEMA_TOOLS_CONFIG
    
    validated_enums = []
    inspector = ToolRegistry.get_tool("database_inspector")
    sample_limit = SCHEMA_TOOLS_CONFIG["enum_detection_sample_limit"]
    
    for enum_info in enum_suggestions:
        field_name = enum_info['field_name']
        
        try:
            # 查询字段的所有不同值
            query = f"""
            SELECT DISTINCT {field_name} as value, COUNT(*) as count
            FROM {table_metadata.full_name}
            WHERE {field_name} IS NOT NULL
            GROUP BY {field_name}
            ORDER BY count DESC
            LIMIT {sample_limit}
            """
            
            async with inspector.connection_pool.acquire() as conn:
                rows = await conn.fetch(query)
                
                actual_values = [str(row['value']) for row in rows]
                
                # 验证是否真的是枚举(不同值数量合理)
                max_enum_values = SCHEMA_TOOLS_CONFIG["enum_max_distinct_values"]
                if len(actual_values) <= max_enum_values:
                    validated_enums.append({
                        'field_name': field_name,
                        'actual_values': actual_values,
                        'suggested_values': enum_info['suggested_values'],
                        'description': enum_info['enum_description'],
                        'value_counts': [(row['value'], row['count']) for row in rows]
                    })
                    self.logger.info(f"确认字段 {field_name} 为枚举类型,包含 {len(actual_values)} 个值")
                else:
                    self.logger.info(f"字段 {field_name} 不同值过多({len(actual_values)}),不认为是枚举")
                    
        except Exception as e:
            self.logger.warning(f"验证字段 {field_name} 的枚举建议失败: {e}")
    
    return validated_enums

def _load_business_dictionary(self) -> str:
    """加载业务词典"""
    try:
        import os
        dict_file = os.path.join(os.path.dirname(__file__), '..', 'prompts', 'business_dictionary.txt')
        if os.path.exists(dict_file):
            with open(dict_file, 'r', encoding='utf-8') as f:
                content = f.read().strip()
                return f"\n业务词典:\n{content}\n" if content else ""
        return ""
    except Exception as e:
        self.logger.warning(f"加载业务词典失败: {e}")
        return ""
## 5. 主AI Agent实现

### 5.1 主Agent核心代码 (`training_data_agent.py`)

```python
import asyncio
import time
import logging
import os
from typing import List, Dict, Any, Optional
from pathlib import Path

from tools.base import ToolRegistry, PipelineExecutor
from utils.data_structures import TableMetadata, TableProcessingContext, ProcessingResult
from utils.file_manager import FileNameManager
from utils.system_filter import SystemTableFilter
from utils.permission_checker import DatabasePermissionChecker
from utils.table_parser import TableListParser
from utils.logger import setup_logging

class SchemaTrainingDataAgent:
    """Schema训练数据生成AI Agent"""
    
    def __init__(self, 
                 db_connection: str,
                 table_list_file: str,
                 business_context: str = None,
                 output_dir: str = None,
                 pipeline: str = "full"):
        
        self.db_connection = db_connection
        self.table_list_file = table_list_file
        self.business_context = business_context or "数据库管理系统"
        self.pipeline = pipeline
        
        # 配置管理
        from config import SCHEMA_TOOLS_CONFIG
        self.config = SCHEMA_TOOLS_CONFIG
        self.output_dir = output_dir or self.config["output_directory"]
        
        # 初始化组件
        self.file_manager = FileNameManager(self.output_dir)
        self.system_filter = SystemTableFilter()
        self.table_parser = TableListParser()
        self.pipeline_executor = PipelineExecutor(self.config["available_pipelines"])
        
        # 统计信息
        self.stats = {
            'total_tables': 0,
            'processed_tables': 0,
            'failed_tables': 0,
            'skipped_tables': 0,
            'start_time': None,
            'end_time': None
        }
        
        self.failed_tables = []
        self.logger = logging.getLogger("schema_tools.Agent")
    
    async def generate_training_data(self) -> Dict[str, Any]:
        """主入口:生成训练数据"""
        try:
            self.stats['start_time'] = time.time()
            self.logger.info("🚀 开始生成Schema训练数据")
            
            # 1. 初始化
            await self._initialize()
            
            # 2. 检查数据库权限
            await self._check_database_permissions()
            
            # 3. 解析表清单
            tables = await self._parse_table_list()
            
            # 4. 过滤系统表
            user_tables = self._filter_system_tables(tables)
            
            # 5. 并发处理表
            results = await self._process_tables_concurrently(user_tables)
            
            # 6. 生成总结报告
            report = self._generate_summary_report(results)
            
            self.stats['end_time'] = time.time()
            self.logger.info("✅ Schema训练数据生成完成")
            
            return report
            
        except Exception as e:
            self.stats['end_time'] = time.time()
            self.logger.exception("❌ Schema训练数据生成失败")
            raise
    
    async def _initialize(self):
        """初始化Agent"""
        # 创建输出目录
        os.makedirs(self.output_dir, exist_ok=True)
        if self.config["create_subdirectories"]:
            os.makedirs(os.path.join(self.output_dir, "ddl"), exist_ok=True)
            os.makedirs(os.path.join(self.output_dir, "docs"), exist_ok=True)
            os.makedirs(os.path.join(self.output_dir, "logs"), exist_ok=True)
        
        # 初始化数据库工具
        database_tool = ToolRegistry.get_tool("database_inspector", db_connection=self.db_connection)
        await database_tool._create_connection_pool()
        
        self.logger.info(f"初始化完成,输出目录: {self.output_dir}")
    
    async def _check_database_permissions(self):
        """检查数据库权限"""
        if not self.config["check_permissions"]:
            return
        
        inspector = ToolRegistry.get_tool("database_inspector")
        checker = DatabasePermissionChecker(inspector)
        
        permissions = await checker.check_permissions()
        
        if not permissions['connect']:
            raise Exception("无法连接到数据库")
        
        if self.config["require_select_permission"] and not permissions['select_data']:
            if not self.config["allow_readonly_database"]:
                raise Exception("数据库查询权限不足")
            else:
                self.logger.warning("数据库为只读或权限受限,部分功能可能受影响")
        
        self.logger.info(f"数据库权限检查完成: {permissions}")
    
    async def _parse_table_list(self) -> List[str]:
        """解析表清单文件"""
        tables = self.table_parser.parse_file(self.table_list_file)
        self.stats['total_tables'] = len(tables)
        self.logger.info(f"📋 从清单文件读取到 {len(tables)} 个表")
        return tables
    
    def _filter_system_tables(self, tables: List[str]) -> List[str]:
        """过滤系统表"""
        if not self.config["filter_system_tables"]:
            return tables
        
        user_tables = self.system_filter.filter_user_tables(tables)
        filtered_count = len(tables) - len(user_tables)
        
        if filtered_count > 0:
            self.logger.info(f"🔍 过滤了 {filtered_count} 个系统表,保留 {len(user_tables)} 个用户表")
            self.stats['skipped_tables'] += filtered_count
        
        return user_tables
    
    async def _process_tables_concurrently(self, tables: List[str]) -> List[Dict[str, Any]]:
        """并发处理表"""
        max_concurrent = self.config["max_concurrent_tables"]
        semaphore = asyncio.Semaphore(max_concurrent)
        
        self.logger.info(f"🔄 开始并发处理 {len(tables)} 个表 (最大并发: {max_concurrent})")
        
        # 创建任务
        tasks = [
            self._process_single_table_with_semaphore(semaphore, table_spec)
            for table_spec in tables
        ]
        
        # 并发执行
        results = await asyncio.gather(*tasks, return_exceptions=True)
        
        # 统计结果
        successful = sum(1 for r in results if isinstance(r, dict) and r.get('success', False))
        failed = len(results) - successful
        
        self.stats['processed_tables'] = successful
        self.stats['failed_tables'] = failed
        
        self.logger.info(f"📊 处理完成: 成功 {successful} 个,失败 {failed} 个")
        
        return [r for r in results if isinstance(r, dict)]
    
    async def _process_single_table_with_semaphore(self, semaphore: asyncio.Semaphore, table_spec: str) -> Dict[str, Any]:
        """带信号量的单表处理"""
        async with semaphore:
            return await self._process_single_table(table_spec)
    
    async def _process_single_table(self, table_spec: str) -> Dict[str, Any]:
        """处理单个表"""
        start_time = time.time()
        
        try:
            # 解析表名
            if '.' in table_spec:
                schema_name, table_name = table_spec.split('.', 1)
            else:
                schema_name, table_name = 'public', table_spec
            
            full_name = f"{schema_name}.{table_name}"
            self.logger.info(f"🔍 开始处理表: {full_name}")
            
            # 创建表元数据
            table_metadata = TableMetadata(
                schema_name=schema_name,
                table_name=table_name,
                full_name=full_name
            )
            
            # 创建处理上下文
            context = TableProcessingContext(
                table_metadata=table_metadata,
                business_context=self.business_context,
                output_dir=self.output_dir,
                pipeline=self.pipeline,
                vn=None,  # 将在工具中注入
                file_manager=self.file_manager,
                start_time=start_time
            )
            
            # 执行处理链
            step_results = await self.pipeline_executor.execute_pipeline(self.pipeline, context)
            
            # 计算总体成功状态
            success = all(result.success for result in step_results.values())
            
            execution_time = time.time() - start_time
            
            if success:
                self.logger.info(f"✅ 表 {full_name} 处理成功,耗时: {execution_time:.2f}秒")
            else:
                self.logger.error(f"❌ 表 {full_name} 处理失败,耗时: {execution_time:.2f}秒")
                self.failed_tables.append(full_name)
            
            return {
                'success': success,
                'table_name': full_name,
                'execution_time': execution_time,
                'step_results': {k: v.to_dict() for k, v in step_results.items()},
                'metadata': {
                    'fields_count': len(table_metadata.fields),
                    'row_count': table_metadata.row_count,
                    'enum_fields': len([f for f in table_metadata.fields if f.is_enum])
                }
            }
            
        except Exception as e:
            execution_time = time.time() - start_time
            error_msg = f"表 {table_spec} 处理异常: {str(e)}"
            self.logger.exception(error_msg)
            self.failed_tables.append(table_spec)
            
            return {
                'success': False,
                'table_name': table_spec,
                'execution_time': execution_time,
                'error_message': error_msg,
                'step_results': {}
            }
    
    def _generate_summary_report(self, results: List[Dict[str, Any]]) -> Dict[str, Any]:
        """生成总结报告"""
        total_time = self.stats['end_time'] - self.stats['start_time']
        
        # 计算统计信息
        successful_results = [r for r in results if r.get('success', False)]
        failed_results = [r for r in results if not r.get('success', False)]
        
        total_fields = sum(r.get('metadata', {}).get('fields_count', 0) for r in successful_results)
        total_enum_fields = sum(r.get('metadata', {}).get('enum_fields', 0) for r in successful_results)
        
        avg_execution_time = sum(r.get('execution_time', 0) for r in results) / len(results) if results else 0
        
        report = {
            'summary': {
                'total_tables': self.stats['total_tables'],
                'processed_successfully': len(successful_results),
                'failed': len(failed_results),
                'skipped_system_tables': self.stats['skipped_tables'],
                'total_execution_time': total_time,
                'average_table_time': avg_execution_time
            },
            'statistics': {
                'total_fields_processed': total_fields,
                'enum_fields_detected': total_enum_fields,
                'files_generated': len(successful_results) * (2 if self.pipeline == 'full' else 1)
            },
            'failed_tables': self.failed_tables,
            'detailed_results': results,
            'configuration': {
                'pipeline': self.pipeline,
                'business_context': self.business_context,
                'output_directory': self.output_dir,
                'max_concurrent_tables': self.config['max_concurrent_tables']
            }
        }
        
        # 输出总结
        self.logger.info(f"📊 处理总结:")
        self.logger.info(f"  ✅ 成功: {report['summary']['processed_successfully']} 个表")
        self.logger.info(f"  ❌ 失败: {report['summary']['failed']} 个表")
        self.logger.info(f"  ⏭️  跳过: {report['summary']['skipped_system_tables']} 个系统表")
        self.logger.info(f"  📁 生成文件: {report['statistics']['files_generated']} 个")
        self.logger.info(f"  🕐 总耗时: {total_time:.2f} 秒")
        
        if self.failed_tables:
            self.logger.warning(f"❌ 失败的表: {', '.join(self.failed_tables)}")
        
        return report
    
    async def check_database_permissions(self) -> Dict[str, bool]:
        """检查数据库权限(供外部调用)"""
        inspector = ToolRegistry.get_tool("database_inspector", db_connection=self.db_connection)
        checker = DatabasePermissionChecker(inspector)
        return await checker.check_permissions()

6. 命令行接口实现

6.1 命令行入口 (__main__.py)

import argparse
import asyncio
import sys
import os
import logging
from pathlib import Path

def setup_argument_parser():
    """设置命令行参数解析器"""
    parser = argparse.ArgumentParser(
        description='Schema Tools - 自动生成数据库训练数据',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
示例用法:
  # 基本使用
  python -m schema_tools --db-connection "postgresql://user:pass@host:5432/db" --table-list tables.txt
  
  # 指定业务上下文和输出目录
  python -m schema_tools --db-connection "..." --table-list tables.txt --business-context "电商系统" --output-dir output
  
  # 仅生成DDL文件
  python -m schema_tools --db-connection "..." --table-list tables.txt --pipeline ddl_only
  
  # 权限检查模式
  python -m schema_tools --db-connection "..." --check-permissions-only
        """
    )
    
    # 必需参数
    parser.add_argument(
        '--db-connection',
        required=True,
        help='数据库连接字符串 (例如: postgresql://user:pass@localhost:5432/dbname)'
    )
    
    # 可选参数
    parser.add_argument(
        '--table-list',
        help='表清单文件路径'
    )
    
    parser.add_argument(
        '--business-context',
        help='业务上下文描述'
    )
    
    parser.add_argument(
        '--business-context-file',
        help='业务上下文文件路径'
    )
    
    parser.add_argument(
        '--output-dir',
        help='输出目录路径'
    )
    
    parser.add_argument(
        '--pipeline',
        choices=['full', 'ddl_only', 'analysis_only'],
        help='处理链类型'
    )
    
    parser.add_argument(
        '--max-concurrent',
        type=int,
        help='最大并发表数量'
    )
    
    # 功能开关
    parser.add_argument(
        '--no-filter-system-tables',
        action='store_true',
        help='禁用系统表过滤'
    )
    
    parser.add_argument(
        '--check-permissions-only',
        action='store_true',
        help='仅检查数据库权限,不处理表'
    )
    
    parser.add_argument(
        '--verbose', '-v',
        action='store_true',
        help='启用详细日志输出'
    )
    
    parser.add_argument(
        '--log-file',
        help='日志文件路径'
    )
    
    return parser

def load_config_with_overrides(args):
    """加载配置并应用命令行覆盖"""
    from config import SCHEMA_TOOLS_CONFIG
    
    config = SCHEMA_TOOLS_CONFIG.copy()
    
    # 命令行参数覆盖配置
    if args.output_dir:
        config["output_directory"] = args.output_dir
    
    if args.pipeline:
        config["default_pipeline"] = args.pipeline
    
    if args.max_concurrent:
        config["max_concurrent_tables"] = args.max_concurrent
    
    if args.no_filter_system_tables:
        config["filter_system_tables"] = False
    
    if args.log_file:
        config["log_file"] = args.log_file
    
    return config

def load_business_context(args):
    """加载业务上下文"""
    if args.business_context_file:
        try:
            with open(args.business_context_file, 'r', encoding='utf-8') as f:
                return f.read().strip()
        except Exception as e:
            print(f"警告: 无法读取业务上下文文件 {args.business_context_file}: {e}")
    
    if args.business_context:
        return args.business_context
    
    from config import SCHEMA_TOOLS_CONFIG
    return SCHEMA_TOOLS_CONFIG.get("default_business_context", "数据库管理系统")

async def check_permissions_only(db_connection: str):
    """仅检查数据库权限"""
    from training_data_agent import SchemaTrainingDataAgent
    
    print("🔍 检查数据库权限...")
    
    try:
        agent = SchemaTrainingDataAgent(
            db_connection=db_connection,
            table_list_file="",  # 不需要表清单
            business_context=""   # 不需要业务上下文
        )
        
        # 初始化Agent以建立数据库连接
        await agent._initialize()
        
        # 检查权限
        permissions = await agent.check_database_permissions()
        
        print("\n📋 权限检查结果:")
        print(f"  ✅ 数据库连接: {'可用' if permissions['connect'] else '不可用'}")
        print(f"  ✅ 元数据查询: {'可用' if permissions['select_metadata'] else '不可用'}")
        print(f"  ✅ 数据查询: {'可用' if permissions['select_data'] else '不可用'}")
        print(f"  ℹ️  数据库类型: {'只读' if permissions['is_readonly'] else '读写'}")
        
        if all(permissions.values()):
            print("\n✅ 数据库权限检查通过,可以开始处理")
            return True
        else:
            print("\n❌ 数据库权限不足,请检查配置")
            return False
            
    except Exception as e:
        print(f"\n❌ 权限检查失败: {e}")
        return False

async def main():
    """主入口函数"""
    parser = setup_argument_parser()
    args = parser.parse_args()
    
    # 设置日志
    from utils.logger import setup_logging
    setup_logging(
        verbose=args.verbose,
        log_file=args.log_file
    )
    
    # 仅权限检查模式
    if args.check_permissions_only:
        success = await check_permissions_only(args.db_connection)
        sys.exit(0 if success else 1)
    
    # 验证必需参数
    if not args.table_list:
        print("错误: 需要指定 --table-list 参数")
        parser.print_help()
        sys.exit(1)
    
    if not os.path.exists(args.table_list):
        print(f"错误: 表清单文件不存在: {args.table_list}")
        sys.exit(1)
    
    try:
        # 加载配置和业务上下文
        config = load_config_with_overrides(args)
        business_context = load_business_context(args)
        
        # 创建Agent
        from training_data_agent import SchemaTrainingDataAgent
        
        agent = SchemaTrainingDataAgent(
            db_connection=args.db_connection,
            table_list_file=args.table_list,
            business_context=business_context,
            output_dir=config["output_directory"],
            pipeline=config["default_pipeline"]
        )
        
        # 执行生成
        print("🚀 开始生成Schema训练数据...")
        report = await agent.generate_training_data()
        
        # 输出结果
        if report['summary']['failed'] == 0:
            print("\n🎉 所有表处理成功!")
        else:
            print(f"\n⚠️  处理完成,但有 {report['summary']['failed']} 个表失败")
        
        print(f"📁 输出目录: {config['output_directory']}")
        
        # 如果有失败的表,返回非零退出码
        sys.exit(1 if report['summary']['failed'] > 0 else 0)
        
    except KeyboardInterrupt:
        print("\n\n⏹️  用户中断,程序退出")
        sys.exit(130)
    except Exception as e:
        print(f"\n❌ 程序执行失败: {e}")
        if args.verbose:
            import traceback
            traceback.print_exc()
        sys.exit(1)

if __name__ == "__main__":
    asyncio.run(main())

6.2 实际输出样例(基于高速公路服务区业务)

6.2.1 DDL文件输出样例 (bss_service_area.ddl)

sql

-- 中文名: 服务区基础信息表
-- 描述: 记录高速公路服务区的基础属性,包括服务区编码、名称、方向、公司归属、地理位置、服务类型和状态,是业务分析与服务区定位的核心表。
create table bss_service_area (
  id varchar(32) not null,              -- 服务区唯一标识(主键,UUID格式)
  version integer not null,             -- 数据版本号
  create_ts timestamp(3),               -- 创建时间
  created_by varchar(50),               -- 创建人
  update_ts timestamp(3),               -- 更新时间
  updated_by varchar(50),               -- 更新人
  delete_ts timestamp(3),               -- 删除时间
  deleted_by varchar(50),               -- 删除人
  service_area_name varchar(255),       -- 服务区名称
  service_area_no varchar(255),         -- 服务区编码(业务唯一标识)
  company_id varchar(32),               -- 公司ID(外键关联bss_company.id)
  service_position varchar(255),        -- 经纬度坐标
  service_area_type varchar(50),        -- 服务区类型(枚举:信息化服务区、智能化服务区)
  service_state varchar(50),            -- 服务区状态(枚举:开放、关闭、上传数据)
  primary key (id)
);

6.2.2 MD文档输出样例 (bss_service_area_detail.md)

markdown

## bss_service_area(服务区基础信息表)
bss_service_area 表记录高速公路服务区的基础属性,包括服务区编码、名称、方向、公司归属、地理位置、服务类型和状态,是业务分析与服务区定位的核心表。

字段列表:
- id (varchar(32)) - 服务区唯一标识(主键,UUID格式)[示例: 0271d68ef93de9684b7ad8c7aae600b6]
- version (integer) - 数据版本号 [示例: 3]
- create_ts (timestamp(3)) - 创建时间 [示例: 2021-05-21 13:26:40.589]
- created_by (varchar(50)) - 创建人 [示例: admin]
- update_ts (timestamp(3)) - 更新时间 [示例: 2021-07-10 15:41:28.795]
- updated_by (varchar(50)) - 更新人 [示例: admin]
- delete_ts (timestamp(3)) - 删除时间
- deleted_by (varchar(50)) - 删除人
- service_area_name (varchar(255)) - 服务区名称 [示例: 鄱阳湖服务区]
- service_area_no (varchar(255)) - 服务区编码(业务唯一标识)[示例: H0509]
- company_id (varchar(32)) - 公司ID(外键关联bss_company.id)[示例: b1629f07c8d9ac81494fbc1de61f1ea5]
- service_position (varchar(255)) - 经纬度坐标 [示例: 114.574721,26.825584]
- service_area_type (varchar(50)) - 服务区类型(枚举:信息化服务区、智能化服务区)[示例: 信息化服务区]
- service_state (varchar(50)) - 服务区状态(枚举:开放、关闭、上传数据)[示例: 开放]

字段补充说明:
- id 为主键,使用 UUID 编码,唯一标识每个服务区
- company_id 外键关联服务区管理公司表(bss_company.id)
- service_position 经纬度格式为"经度,纬度"
- service_area_type 为枚举字段,包含两个取值:信息化服务区、智能化服务区
- service_state 为枚举字段,包含三个取值:开放、关闭、上传数据
- 本表是多个表(bss_branch, bss_car_day_count等)的核心关联实体

6.2.3 复杂表样例 (bss_business_day_data.ddl)

sql

-- 中文名: 档口日营业数据表
-- 描述: 记录每天每个档口的营业情况,包含微信、支付宝、现金、金豆等支付方式的金额与订单数,是核心交易数据表。
create table bss_business_day_data (
  id varchar(32) not null,        -- 主键ID
  version integer not null,       -- 数据版本号
  create_ts timestamp(3),         -- 创建时间
  created_by varchar(50),         -- 创建人
  update_ts timestamp(3),         -- 更新时间
  updated_by varchar(50),         -- 更新人
  delete_ts timestamp(3),         -- 删除时间
  deleted_by varchar(50),         -- 删除人
  oper_date date,                 -- 统计日期
  service_no varchar(255),        -- 服务区编码
  service_name varchar(255),      -- 服务区名称
  branch_no varchar(255),         -- 档口编码
  branch_name varchar(255),       -- 档口名称
  wx numeric(19,4),               -- 微信支付金额
  wx_order integer,               -- 微信支付订单数量
  zfb numeric(19,4),              -- 支付宝支付金额
  zf_order integer,               -- 支付宝支付订单数量
  rmb numeric(19,4),              -- 现金支付金额
  rmb_order integer,              -- 现金支付订单数量
  xs numeric(19,4),               -- 行吧支付金额
  xs_order integer,               -- 行吧支付订单数量
  jd numeric(19,4),               -- 金豆支付金额
  jd_order integer,               -- 金豆支付订单数量
  order_sum integer,              -- 订单总数
  pay_sum numeric(19,4),          -- 支付总金额
  source_type integer,            -- 数据来源类型ID
  primary key (id)
);

6.3 输出格式关键特征

6.3.1 DDL格式特征

  • 中文表头注释: 包含表中文名和业务描述
  • 字段注释: 每个字段都有中文注释说明
  • 枚举标识: 对于枚举字段,在注释中明确标出可选值
  • 外键关系: 明确标出外键关联关系
  • 业务标识: 特殊业务字段(如编码、ID)有详细说明

6.3.2 MD格式特征

  • 表级描述: 详细的表业务用途说明
  • 字段示例值: 每个字段都提供真实的示例数据
  • 枚举值详解: 枚举字段的所有可能取值完整列出
  • 补充说明: 重要字段的额外业务逻辑说明
  • 关联关系: 与其他表的关联关系说明

7. 配置文件完整实现

7.1 配置文件 (config.py)

import os
import sys

# 导入app_config获取数据库等配置
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
try:
    import app_config
except ImportError:
    app_config = None

# Schema Tools专用配置
SCHEMA_TOOLS_CONFIG = {
    # 核心配置
    "default_db_connection": None,  # 从命令行指定
    "default_business_context": "数据库管理系统", 
    "output_directory": "training/generated_data",
    
    # 处理链配置
    "default_pipeline": "full",
    "available_pipelines": {
        "full": [
            "database_inspector", 
            "data_sampler", 
            "comment_generator", 
            "ddl_generator", 
            "doc_generator"
        ],
        "ddl_only": [
            "database_inspector", 
            "data_sampler", 
            "comment_generator", 
            "ddl_generator"
        ],
        "analysis_only": [
            "database_inspector", 
            "data_sampler", 
            "comment_generator"
        ]
    },
    
    # 数据处理配置
    "sample_data_limit": 20,                    # 用于LLM分析的采样数据量
    "enum_detection_sample_limit": 5000,        # 枚举检测时的采样限制
    "enum_max_distinct_values": 20,             # 枚举字段最大不同值数量
    "enum_varchar_keywords": [                  # VARCHAR枚举关键词
        "性别", "gender", "状态", "status", "类型", "type", 
        "级别", "level", "方向", "direction", "品类", "classify",
        "模式", "mode", "格式", "format"
    ],
    "large_table_threshold": 1000000,           # 大表阈值(行数)
    
    # 并发配置
    "max_concurrent_tables": 1,  # 建议保持1,避免LLM并发调用问题                 # 最大并发处理表数
    
    # LLM配置
    "use_app_config_llm": True,                # 是否使用app_config中的LLM配置
    "comment_generation_timeout": 30,          # LLM调用超时时间(秒)
    "max_llm_retries": 3,                      # LLM调用最大重试次数
    
    # 系统表过滤配置
    "filter_system_tables": True,              # 是否过滤系统表
    "custom_system_prefixes": [],              # 用户自定义系统表前缀
    "custom_system_schemas": [],               # 用户自定义系统schema
    
    # 权限与安全配置
    "check_permissions": True,                 # 是否检查数据库权限
    "require_select_permission": True,         # 是否要求SELECT权限
    "allow_readonly_database": True,           # 是否允许只读数据库
    
    # 错误处理配置
    "continue_on_error": True,                 # 遇到错误是否继续
    "max_table_failures": 5,                  # 最大允许失败表数
    "skip_large_tables": False,               # 是否跳过超大表
    "max_table_size": 10000000,               # 最大表行数限制
    
    # 文件配置
    "ddl_file_suffix": ".ddl",
    "doc_file_suffix": "_detail.md",
    "log_file": "schema_tools.log",
    "create_subdirectories": True,            # 是否创建ddl/docs子目录
    
    # 输出格式配置
    "include_sample_data_in_comments": True,  # 注释中是否包含示例数据
    "max_comment_length": 500,                # 最大注释长度
    "include_field_statistics": True,         # 是否包含字段统计信息
    
    # 调试配置
    "debug_mode": False,                      # 调试模式
    "save_llm_prompts": False,               # 是否保存LLM提示词
    "save_llm_responses": False,             # 是否保存LLM响应
}

# 从app_config获取相关配置(如果可用)
if app_config:
    # 继承数据库配置
    if hasattr(app_config, 'PGVECTOR_CONFIG'):
        pgvector_config = app_config.PGVECTOR_CONFIG
        if not SCHEMA_TOOLS_CONFIG["default_db_connection"]:
            SCHEMA_TOOLS_CONFIG["default_db_connection"] = (
                f"postgresql://{pgvector_config['user']}:{pgvector_config['password']}"
                f"@{pgvector_config['host']}:{pgvector_config['port']}/{pgvector_config['dbname']}"
            )

def get_config():
    """获取当前配置"""
    return SCHEMA_TOOLS_CONFIG

def update_config(**kwargs):
    """更新配置"""
    SCHEMA_TOOLS_CONFIG.update(kwargs)

def validate_config():
    """验证配置有效性"""
    errors = []
    
    # 检查必要配置
    if SCHEMA_TOOLS_CONFIG["max_concurrent_tables"] <= 0:
        errors.append("max_concurrent_tables 必须大于0")
    
    if SCHEMA_TOOLS_CONFIG["sample_data_limit"] <= 0:
        errors.append("sample_data_limit 必须大于0")
    
    # 检查处理链配置
    default_pipeline = SCHEMA_TOOLS_CONFIG["default_pipeline"]
    available_pipelines = SCHEMA_TOOLS_CONFIG["available_pipelines"]
    
    if default_pipeline not in available_pipelines:
        errors.append(f"default_pipeline '{default_pipeline}' 不在 available_pipelines 中")
    
    if errors:
        raise ValueError("配置验证失败:\n" + "\n".join(f"  - {error}" for error in errors))
    
    return True

# 启动时验证配置
try:
    validate_config()
except ValueError as e:
    print(f"警告: {e}")

这个详细设计文档涵盖了Schema Tools的完整实现,包括:

核心特性

  1. 完整的数据结构设计 - 标准化的数据模型
  2. 工具注册机制 - 装饰器注册和自动依赖注入
  3. Pipeline处理链 - 可配置的处理流程
  4. 并发处理 - 表级并发和错误处理
  5. LLM集成 - 智能注释生成和枚举检测
  6. 权限管理 - 数据库权限检查和只读适配
  7. 命令行接口 - 完整的CLI支持

实现亮点

  • 类型安全: 使用dataclass定义明确的数据结构
  • 错误处理: 完善的异常处理和重试机制
  • 可扩展性: 工具注册机制便于添加新功能
  • 配置灵活: 多层次配置支持
  • 日志完整: 详细的执行日志和统计报告

8. Question-SQL生成功能详细设计(新增)

8.1 功能概述

Question-SQL生成功能是Schema Tools的扩展模块,用于从已生成的DDL和MD文件自动生成高质量的Question-SQL训练数据对。该功能可以独立运行,支持人工检查DDL/MD文件后再执行。

8.2 核心组件设计

8.2.1 QuestionSQLGenerationAgent (qs_agent.py)

class QuestionSQLGenerationAgent:
    """Question-SQL生成Agent"""
    
    def __init__(self, 
                 output_dir: str,
                 table_list_file: str,
                 business_context: str,
                 db_name: str = None):
        """
        初始化Agent
        
        Args:
            output_dir: 输出目录(包含DDL和MD文件)
            table_list_file: 表清单文件路径
            business_context: 业务上下文
            db_name: 数据库名称(用于输出文件命名)
        """
        self.output_dir = Path(output_dir)
        self.table_list_file = table_list_file
        self.business_context = business_context
        self.db_name = db_name or "db"
        
        # 初始化组件
        self.validator = FileCountValidator()
        self.md_analyzer = MDFileAnalyzer(output_dir)
        self.theme_extractor = None  # 延迟初始化
        
        # 中间结果存储
        self.intermediate_results = []
        self.intermediate_file = None

8.2.2 文件数量验证器 (validators/file_count_validator.py)

@dataclass
class ValidationResult:
    """验证结果"""
    is_valid: bool
    table_count: int
    ddl_count: int
    md_count: int
    error: str = ""
    missing_ddl: List[str] = field(default_factory=list)
    missing_md: List[str] = field(default_factory=list)

class FileCountValidator:
    """文件数量验证器"""
    
    def validate(self, table_list_file: str, output_dir: str) -> ValidationResult:
        """
        验证生成的文件数量是否与表数量一致
        
        主要验证:
        1. 表数量是否超过20个限制
        2. DDL文件数量是否与表数量一致
        3. MD文件数量是否与表数量一致
        """
        # 解析表清单
        tables = self.table_parser.parse_file(table_list_file)
        table_count = len(tables)
        
        # 检查表数量限制
        max_tables = self.config['qs_generation']['max_tables']
        if table_count > max_tables:
            return ValidationResult(
                is_valid=False,
                table_count=table_count,
                ddl_count=0,
                md_count=0,
                error=f"表数量({table_count})超过限制({max_tables})"
            )

8.2.3 MD文件分析器 (analyzers/md_analyzer.py)

class MDFileAnalyzer:
    """MD文件分析器"""
    
    async def read_all_md_files(self) -> str:
        """
        读取所有MD文件的完整内容
        
        Returns:
            所有MD文件内容的组合字符串
        """
        md_files = sorted(self.output_dir.glob("*_detail.md"))
        
        all_contents = []
        all_contents.append(f"# 数据库表结构文档汇总\n")
        all_contents.append(f"共包含 {len(md_files)} 个表\n\n")
        
        for md_file in md_files:
            content = md_file.read_text(encoding='utf-8')
            
            # 添加分隔符,便于LLM区分不同表
            all_contents.append("=" * 80)
            all_contents.append(f"# 文件: {md_file.name}")
            all_contents.append("=" * 80)
            all_contents.append(content)
            all_contents.append("\n")
        
        combined_content = "\n".join(all_contents)
        
        # 检查内容大小(预估token数)
        estimated_tokens = len(combined_content) / 4
        if estimated_tokens > 100000:
            self.logger.warning(f"MD内容可能过大,预估tokens: {estimated_tokens:.0f}")
        
        return combined_content

8.2.4 主题提取器 (analyzers/theme_extractor.py)

class ThemeExtractor:
    """主题提取器"""
    
    async def extract_themes(self, md_contents: str) -> List[Dict[str, Any]]:
        """
        从MD内容中提取分析主题
        """
        prompt = f"""你是一位经验丰富的业务数据分析师,正在分析{self.business_context}的数据库。

以下是数据库中所有表的详细结构说明:

{md_contents}

基于对这些表结构的理解,请从业务分析的角度提出 {theme_count} 个数据查询分析主题。

要求:
1. 每个主题应该有明确的业务价值和分析目标
2. 主题之间应该有所区别,覆盖不同的业务领域  
3. 你需要自行决定每个主题应该涉及哪些表
4. 主题应该体现实际业务场景的数据分析需求
5. 考虑时间维度、对比分析、排名统计等多种分析角度

请以JSON格式输出:
```json
{{
  "themes": [
    {{
      "name": "经营收入分析",
      "description": "分析服务区的营业收入情况,包括日收入趋势、月度对比、服务区排名等",
      "focus_areas": ["收入趋势", "服务区对比", "时间维度分析"],
      "related_tables": ["bss_business_day_data", "其他相关表名"]
    }}
  ]
}}
```"""
        
        response = await self._call_llm(prompt)
        themes = self._parse_theme_response(response)
        
        return themes

8.3 执行流程详细设计

8.3.1 主流程

async def generate(self) -> Dict[str, Any]:
    """生成Question-SQL对"""
    
    # 1. 验证文件数量
    validation_result = self.validator.validate(self.table_list_file, str(self.output_dir))
    if not validation_result.is_valid:
        raise ValueError(f"文件验证失败: {validation_result.error}")
    
    # 2. 读取所有MD文件内容
    md_contents = await self.md_analyzer.read_all_md_files()
    
    # 3. 初始化LLM组件
    self._initialize_llm_components()
    
    # 4. 提取分析主题
    themes = await self.theme_extractor.extract_themes(md_contents)
    
    # 5. 初始化中间结果文件
    self._init_intermediate_file()
    
    # 6. 处理每个主题
    if self.config['qs_generation']['max_concurrent_themes'] > 1:
        results = await self._process_themes_parallel(themes, md_contents)
    else:
        results = await self._process_themes_serial(themes, md_contents)
    
    # 7. 保存最终结果
    output_file = await self._save_final_results(all_qs_pairs)
    
    return report

8.3.2 主题处理

async def _process_single_theme(self, theme: Dict, md_contents: str) -> Dict:
    """处理单个主题"""
    
    prompt = f"""你是一位业务数据分析师,正在为{self.business_context}设计数据查询。

当前分析主题:{theme['name']}
主题描述:{theme['description']}
关注领域:{', '.join(theme['focus_areas'])}
相关表:{', '.join(theme['related_tables'])}

数据库表结构信息:
{md_contents}

请为这个主题生成 {questions_count} 个业务问题和对应的SQL查询。

要求:
1. 问题应该从业务角度出发,贴合主题要求,具有实际分析价值
2. SQL必须使用PostgreSQL语法
3. 考虑实际业务逻辑(如软删除使用 delete_ts IS NULL 条件)
4. 使用中文别名提高可读性(使用 AS 指定列别名)
5. 问题应该多样化,覆盖不同的分析角度
6. 包含时间筛选、分组统计、排序、限制等不同类型的查询
7. SQL语句末尾必须以分号结束

输出JSON格式:
```json
[
  {{
    "question": "具体的业务问题?",
    "sql": "SELECT column AS 中文名 FROM table WHERE condition;"
  }}
]
```"""
    
    response = await self._call_llm(prompt)
    qs_pairs = self._parse_qs_response(response)
    validated_pairs = self._validate_qs_pairs(qs_pairs, theme['name'])
    
    # 保存中间结果
    await self._save_theme_results(theme['name'], validated_pairs)
    
    return {
        'success': True,
        'theme_name': theme['name'],
        'qs_pairs': validated_pairs
    }

8.4 中间结果保存机制

def _init_intermediate_file(self):
    """初始化中间结果文件"""
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    self.intermediate_file = self.output_dir / f"qs_intermediate_{timestamp}.json"
    self.intermediate_results = []

async def _save_theme_results(self, theme_name: str, qs_pairs: List[Dict]):
    """保存单个主题的结果"""
    theme_result = {
        "theme": theme_name,
        "timestamp": datetime.now().isoformat(),
        "questions_count": len(qs_pairs),
        "questions": qs_pairs
    }
    
    self.intermediate_results.append(theme_result)
    
    # 立即保存到中间文件
    if self.config['qs_generation']['save_intermediate']:
        with open(self.intermediate_file, 'w', encoding='utf-8') as f:
            json.dump(self.intermediate_results, f, ensure_ascii=False, indent=2)

9. SQL验证器核心模块

9.1 SQL验证器设计 (validators/sql_validator.py)

@dataclass
class SQLValidationResult:
    """SQL验证结果"""
    sql: str
    valid: bool
    error_message: str = ""
    execution_time: float = 0.0
    retry_count: int = 0
    
    # SQL修复相关字段
    repair_attempted: bool = False
    repair_successful: bool = False
    repaired_sql: str = ""
    repair_error: str = ""

@dataclass
class ValidationStats:
    """验证统计信息"""
    total_sqls: int = 0
    valid_sqls: int = 0
    invalid_sqls: int = 0
    total_time: float = 0.0
    avg_time_per_sql: float = 0.0
    retry_count: int = 0
    
    # SQL修复统计
    repair_attempted: int = 0
    repair_successful: int = 0
    repair_failed: int = 0

class SQLValidator:
    """SQL验证器核心类"""
    
    def __init__(self, db_connection: str = None):
        self.db_connection = db_connection
        self.connection_pool = None
        self.config = SCHEMA_TOOLS_CONFIG['sql_validation']
        
    async def validate_sql(self, sql: str, retry_count: int = 0) -> SQLValidationResult:
        """验证单个SQL语句"""
        start_time = time.time()
        
        try:
            if not self.connection_pool:
                await self._get_connection_pool()
            
            # 使用EXPLAIN验证SQL语法和表结构
            explain_sql = f"EXPLAIN {sql}"
            
            async with self.connection_pool.acquire() as conn:
                # 设置只读模式
                if self.config['readonly_mode']:
                    await conn.execute("SET default_transaction_read_only = on")
                
                # 执行EXPLAIN
                await asyncio.wait_for(
                    conn.fetch(explain_sql),
                    timeout=self.config['validation_timeout']
                )
            
            execution_time = time.time() - start_time
            
            return SQLValidationResult(
                sql=sql,
                valid=True,
                execution_time=execution_time,
                retry_count=retry_count
            )
            
        except asyncio.TimeoutError:
            execution_time = time.time() - start_time
            error_msg = f"SQL验证超时({self.config['validation_timeout']}秒)"
            
            return SQLValidationResult(
                sql=sql,
                valid=False,
                error_message=error_msg,
                execution_time=execution_time,
                retry_count=retry_count
            )
            
        except Exception as e:
            execution_time = time.time() - start_time
            error_msg = str(e)
            
            # 检查是否需要重试
            if retry_count < self.config['max_retry_count'] and self._should_retry(e):
                await asyncio.sleep(0.5)  # 短暂延迟
                return await self.validate_sql(sql, retry_count + 1)
            
            return SQLValidationResult(
                sql=sql,
                valid=False,
                error_message=error_msg,
                execution_time=execution_time,
                retry_count=retry_count
            )
    
    async def validate_sqls_batch(self, sqls: List[str]) -> List[SQLValidationResult]:
        """批量验证SQL语句"""
        max_concurrent = self.config['max_concurrent_validations']
        semaphore = asyncio.Semaphore(max_concurrent)
        
        async def validate_with_semaphore(sql):
            async with semaphore:
                return await self.validate_sql(sql)
        
        # 并发执行验证
        tasks = [validate_with_semaphore(sql) for sql in sqls]
        results = await asyncio.gather(*tasks, return_exceptions=True)
        
        # 处理异常结果
        processed_results = []
        for i, result in enumerate(results):
            if isinstance(result, Exception):
                processed_results.append(SQLValidationResult(
                    sql=sqls[i],
                    valid=False,
                    error_message=f"验证异常: {str(result)}"
                ))
            else:
                processed_results.append(result)
        
        return processed_results
    
    def calculate_stats(self, results: List[SQLValidationResult]) -> ValidationStats:
        """计算验证统计信息"""
        stats = ValidationStats()
        
        stats.total_sqls = len(results)
        stats.valid_sqls = sum(1 for r in results if r.valid)
        stats.invalid_sqls = stats.total_sqls - stats.valid_sqls
        stats.total_time = sum(r.execution_time for r in results)
        stats.avg_time_per_sql = stats.total_time / stats.total_sqls if stats.total_sqls > 0 else 0.0
        stats.retry_count = sum(r.retry_count for r in results)
        
        # 修复统计
        stats.repair_attempted = sum(1 for r in results if r.repair_attempted)
        stats.repair_successful = sum(1 for r in results if r.repair_successful)
        stats.repair_failed = stats.repair_attempted - stats.repair_successful
        
        return stats

9.2 SQL验证Agent (sql_validation_agent.py)

class SQLValidationAgent:
    """SQL验证Agent - 管理SQL验证的完整流程"""
    
    async def validate(self) -> Dict[str, Any]:
        """执行SQL验证流程"""
        
        # 1. 读取输入文件
        questions_sqls = await self._load_questions_sqls()
        
        # 2. 提取SQL语句
        sqls = [item['sql'] for item in questions_sqls]
        
        # 3. 执行验证
        validation_results = await self._validate_sqls_with_batching(sqls)
        
        # 4. 计算统计信息
        stats = self.validator.calculate_stats(validation_results)
        
        # 5. 尝试修复失败的SQL(如果启用LLM修复)
        if self.config.get('enable_sql_repair', False) and self.vn:
            validation_results = await self._attempt_sql_repair(questions_sqls, validation_results)
            stats = self.validator.calculate_stats(validation_results)
        
        # 6. 修改原始JSON文件(如果启用文件修改)
        file_modification_stats = {'modified': 0, 'deleted': 0, 'failed_modifications': 0}
        if self.config.get('modify_original_file', False):
            file_modification_stats = await self._modify_original_json_file(questions_sqls, validation_results)
        
        # 7. 生成详细报告
        report = await self._generate_report(questions_sqls, validation_results, stats, file_modification_stats)
        
        # 8. 保存验证报告
        if self.config['save_validation_report']:
            await self._save_validation_report(report)
        
        return report
    
    async def _attempt_sql_repair(self, questions_sqls: List[Dict], validation_results: List[SQLValidationResult]) -> List[SQLValidationResult]:
        """尝试修复失败的SQL"""
        
        failed_indices = [i for i, result in enumerate(validation_results) if not result.valid]
        
        if not failed_indices:
            return validation_results
        
        # 批量修复
        batch_size = self.config.get('repair_batch_size', 5)
        updated_results = validation_results.copy()
        
        for i in range(0, len(failed_indices), batch_size):
            batch_indices = failed_indices[i:i + batch_size]
            
            # 准备批次数据
            batch_data = []
            for idx in batch_indices:
                batch_data.append({
                    'index': idx,
                    'question': questions_sqls[idx]['question'],
                    'sql': validation_results[idx].sql,
                    'error': validation_results[idx].error_message
                })
            
            # 调用LLM修复
            repaired_sqls = await self._repair_sqls_with_llm(batch_data)
            
            # 验证修复后的SQL
            for j, idx in enumerate(batch_indices):
                original_result = updated_results[idx]
                original_result.repair_attempted = True
                
                if j < len(repaired_sqls) and repaired_sqls[j]:
                    repaired_sql = repaired_sqls[j]
                    
                    # 验证修复后的SQL
                    repair_result = await self.validator.validate_sql(repaired_sql)
                    
                    if repair_result.valid:
                        # 修复成功
                        original_result.repair_successful = True
                        original_result.repaired_sql = repaired_sql
                        original_result.valid = True  # 更新为有效
                    else:
                        # 修复失败
                        original_result.repair_successful = False
                        original_result.repair_error = repair_result.error_message
                else:
                    # LLM修复失败
                    original_result.repair_successful = False
                    original_result.repair_error = "LLM修复失败或返回空结果"
        
        return updated_results
    
    async def _modify_original_json_file(self, questions_sqls: List[Dict], validation_results: List[SQLValidationResult]) -> Dict[str, int]:
        """修改原始JSON文件"""
        stats = {'modified': 0, 'deleted': 0, 'failed_modifications': 0}
        
        try:
            # 读取原始JSON文件
            with open(self.input_file, 'r', encoding='utf-8') as f:
                original_data = json.load(f)
            
            # 创建备份文件
            backup_file = Path(str(self.input_file) + '.backup')
            with open(backup_file, 'w', encoding='utf-8') as f:
                json.dump(original_data, f, ensure_ascii=False, indent=2)
            
            # 构建修改计划
            modifications = []
            deletions = []
            
            for i, (qs, result) in enumerate(zip(questions_sqls, validation_results)):
                if result.repair_successful and result.repaired_sql:
                    # 修复成功的SQL
                    modifications.append({
                        'index': i,
                        'original_sql': result.sql,
                        'repaired_sql': result.repaired_sql,
                        'question': qs['question']
                    })
                elif not result.valid and not result.repair_successful:
                    # 无法修复的SQL,标记删除
                    deletions.append({
                        'index': i,
                        'question': qs['question'],
                        'sql': result.sql,
                        'error': result.error_message
                    })
            
            # 执行修改(从后往前,避免索引变化)
            new_data = original_data.copy()
            
            # 先删除无效项(从后往前删除)
            for deletion in sorted(deletions, key=lambda x: x['index'], reverse=True):
                if deletion['index'] < len(new_data):
                    new_data.pop(deletion['index'])
                    stats['deleted'] += 1
            
            # 再修改SQL(需要重新计算索引)
            for modification in sorted(modifications, key=lambda x: x['index']):
                # 计算删除后的新索引
                new_index = modification['index']
                for deletion in deletions:
                    if deletion['index'] < modification['index']:
                        new_index -= 1
                
                if new_index < len(new_data):
                    new_data[new_index]['sql'] = modification['repaired_sql']
                    stats['modified'] += 1
            
            # 写入修改后的文件
            with open(self.input_file, 'w', encoding='utf-8') as f:
                json.dump(new_data, f, ensure_ascii=False, indent=2)
            
            # 记录详细修改信息到日志文件
            await self._write_modification_log(modifications, deletions)
            
        except Exception as e:
            stats['failed_modifications'] = 1
        
        return stats

10. 工作流编排器设计

10.1 SchemaWorkflowOrchestrator核心功能

class SchemaWorkflowOrchestrator:
    """端到端的Schema处理编排器"""
    
    async def execute_complete_workflow(self) -> Dict[str, Any]:
        """执行完整的Schema处理工作流程"""
        
        # 步骤1: 生成DDL和MD文件
        await self._execute_step_1_ddl_md_generation()
        
        # 步骤2: 生成Question-SQL对
        await self._execute_step_2_question_sql_generation()
        
        # 步骤3: 验证和修正SQL(可选)
        if self.enable_sql_validation:
            await self._execute_step_3_sql_validation()
        
        # 生成最终报告
        final_report = await self._generate_final_report()
        
        return final_report

这样,文档就与当前代码完全一致了,包含了所有新增的SQL验证、LLM修复、文件修改等功能的详细设计说明。