schema_tools/
├── __init__.py # 模块入口,导出主要接口
├── __main__.py # 命令行入口
├── config.py # 配置管理
├── training_data_agent.py # 主AI Agent
├── qs_agent.py # Question-SQL生成Agent (新增)
├── qs_generator.py # Question-SQL命令行入口 (新增)
├── sql_validation_agent.py # SQL验证Agent (新增)
├── sql_validator.py # SQL验证命令行入口 (新增)
├── schema_workflow_orchestrator.py # 端到端工作流编排器 (新增)
├── tools/ # Agent工具集
│ ├── __init__.py # 工具模块初始化
│ ├── base.py # 基础工具类和注册机制
│ ├── database_inspector.py # 数据库元数据检查工具
│ ├── data_sampler.py # 数据采样工具
│ ├── comment_generator.py # LLM注释生成工具
│ ├── ddl_generator.py # DDL格式生成工具
│ └── doc_generator.py # MD文档生成工具
├── validators/ # 验证器模块 (新增)
│ ├── __init__.py
│ ├── file_count_validator.py # 文件数量验证器
│ └── sql_validator.py # SQL验证器核心模块
├── analyzers/ # 分析器模块 (新增)
│ ├── __init__.py
│ ├── md_analyzer.py # MD文件分析器
│ └── theme_extractor.py # 主题提取器
├── utils/ # 工具函数
│ ├── __init__.py
│ ├── data_structures.py # 数据结构定义
│ ├── table_parser.py # 表清单解析器
│ ├── file_manager.py # 文件管理器
│ ├── system_filter.py # 系统表过滤器
│ ├── permission_checker.py # 权限检查器
│ ├── large_table_handler.py # 大表处理器
│ └── logger.py # 日志管理
├── prompts/ # 提示词模板
│ ├── table_comment_template.txt
│ ├── field_comment_template.txt
│ ├── enum_detection_template.txt
│ ├── business_context.txt
│ └── business_dictionary.txt
└── tests/ # 测试用例
├── unit/
├── integration/
└── fixtures/
utils/data_structures.py
)from dataclasses import dataclass, field
from typing import List, Dict, Optional, Any, Union
from enum import Enum
import hashlib
import json
class FieldType(Enum):
"""字段类型枚举"""
INTEGER = "integer"
VARCHAR = "varchar"
TEXT = "text"
TIMESTAMP = "timestamp"
DATE = "date"
BOOLEAN = "boolean"
NUMERIC = "numeric"
ENUM = "enum"
JSON = "json"
UUID = "uuid"
OTHER = "other"
class ProcessingStatus(Enum):
"""处理状态枚举"""
PENDING = "pending"
RUNNING = "running"
SUCCESS = "success"
FAILED = "failed"
SKIPPED = "skipped"
@dataclass
class FieldInfo:
"""字段信息标准结构"""
name: str
type: str
nullable: bool
default_value: Optional[str] = None
comment: Optional[str] = None
original_comment: Optional[str] = None # 原始注释
generated_comment: Optional[str] = None # LLM生成的注释
is_primary_key: bool = False
is_foreign_key: bool = False
is_enum: bool = False
enum_values: Optional[List[str]] = None
enum_description: Optional[str] = None
max_length: Optional[int] = None
precision: Optional[int] = None
scale: Optional[int] = None
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式"""
return {
'name': self.name,
'type': self.type,
'nullable': self.nullable,
'default_value': self.default_value,
'comment': self.comment,
'is_primary_key': self.is_primary_key,
'is_foreign_key': self.is_foreign_key,
'is_enum': self.is_enum,
'enum_values': self.enum_values
}
@dataclass
class TableMetadata:
"""表元数据标准结构"""
schema_name: str
table_name: str
full_name: str # schema.table_name
comment: Optional[str] = None
original_comment: Optional[str] = None # 原始注释
generated_comment: Optional[str] = None # LLM生成的注释
fields: List[FieldInfo] = field(default_factory=list)
sample_data: List[Dict[str, Any]] = field(default_factory=list)
row_count: Optional[int] = None
table_size: Optional[str] = None # 表大小(如 "1.2 MB")
created_date: Optional[str] = None
@property
def safe_file_name(self) -> str:
"""生成安全的文件名"""
if self.schema_name.lower() == 'public':
return self.table_name
return f"{self.schema_name}__{self.table_name}".replace('.', '__').replace('-', '_').replace(' ', '_')
def get_metadata_hash(self) -> str:
"""计算元数据哈希值,用于增量更新判断"""
hash_data = {
'schema_name': self.schema_name,
'table_name': self.table_name,
'fields': [f.to_dict() for f in self.fields],
'comment': self.original_comment
}
return hashlib.md5(json.dumps(hash_data, sort_keys=True).encode()).hexdigest()
@dataclass
class ProcessingResult:
"""工具处理结果标准结构"""
success: bool
data: Optional[Any] = None
error_message: Optional[str] = None
warnings: List[str] = field(default_factory=list)
execution_time: Optional[float] = None
metadata: Dict[str, Any] = field(default_factory=dict)
def add_warning(self, warning: str):
"""添加警告信息"""
self.warnings.append(warning)
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式"""
return {
'success': self.success,
'data': self.data,
'error_message': self.error_message,
'warnings': self.warnings,
'execution_time': self.execution_time,
'metadata': self.metadata
}
@dataclass
class TableProcessingContext:
"""表处理上下文"""
table_metadata: TableMetadata
business_context: str
output_dir: str
pipeline: str
vn: Any # vanna实例
file_manager: Any
current_step: str = "initialized"
step_results: Dict[str, ProcessingResult] = field(default_factory=dict)
start_time: Optional[float] = None
def update_step(self, step_name: str, result: ProcessingResult):
"""更新步骤结果"""
self.current_step = step_name
self.step_results[step_name] = result
tools/base.py
)import asyncio
import time
import logging
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional, Type
from utils.data_structures import ProcessingResult, TableProcessingContext
class ToolRegistry:
"""工具注册管理器"""
_tools: Dict[str, Type['BaseTool']] = {}
_instances: Dict[str, 'BaseTool'] = {}
@classmethod
def register(cls, name: str):
"""装饰器:注册工具"""
def decorator(tool_class: Type['BaseTool']):
cls._tools[name] = tool_class
logging.debug(f"注册工具: {name} -> {tool_class.__name__}")
return tool_class
return decorator
@classmethod
def get_tool(cls, name: str, **kwargs) -> 'BaseTool':
"""获取工具实例,支持单例模式"""
if name not in cls._instances:
if name not in cls._tools:
raise ValueError(f"工具 '{name}' 未注册")
tool_class = cls._tools[name]
# 自动注入vanna实例到需要LLM的工具
if hasattr(tool_class, 'needs_llm') and tool_class.needs_llm:
from core.vanna_llm_factory import create_vanna_instance
kwargs['vn'] = create_vanna_instance()
logging.debug(f"为工具 {name} 注入LLM实例")
cls._instances[name] = tool_class(**kwargs)
return cls._instances[name]
@classmethod
def list_tools(cls) -> List[str]:
"""列出所有已注册的工具"""
return list(cls._tools.keys())
@classmethod
def clear_instances(cls):
"""清除所有工具实例(用于测试)"""
cls._instances.clear()
class BaseTool(ABC):
"""工具基类"""
needs_llm: bool = False # 是否需要LLM实例
tool_name: str = "" # 工具名称
def __init__(self, **kwargs):
self.logger = logging.getLogger(f"schema_tools.{self.__class__.__name__}")
# 如果工具需要LLM,检查是否已注入
if self.needs_llm and 'vn' not in kwargs:
raise ValueError(f"工具 {self.__class__.__name__} 需要LLM实例但未提供")
# 存储vanna实例
if 'vn' in kwargs:
self.vn = kwargs['vn']
@abstractmethod
async def execute(self, context: TableProcessingContext) -> ProcessingResult:
"""
执行工具逻辑
Args:
context: 表处理上下文
Returns:
ProcessingResult: 处理结果
"""
pass
async def _execute_with_timing(self, context: TableProcessingContext) -> ProcessingResult:
"""带计时的执行包装器"""
start_time = time.time()
try:
self.logger.info(f"开始执行工具: {self.tool_name}")
result = await self.execute(context)
execution_time = time.time() - start_time
result.execution_time = execution_time
if result.success:
self.logger.info(f"工具 {self.tool_name} 执行成功,耗时: {execution_time:.2f}秒")
else:
self.logger.error(f"工具 {self.tool_name} 执行失败: {result.error_message}")
return result
except Exception as e:
execution_time = time.time() - start_time
self.logger.exception(f"工具 {self.tool_name} 执行异常")
return ProcessingResult(
success=False,
error_message=f"工具执行异常: {str(e)}",
execution_time=execution_time
)
def validate_input(self, context: TableProcessingContext) -> bool:
"""输入验证(子类可重写)"""
return context.table_metadata is not None
training_data_agent.py
的一部分)class PipelineExecutor:
"""处理链执行器"""
def __init__(self, pipeline_config: Dict[str, List[str]]):
self.pipeline_config = pipeline_config
self.logger = logging.getLogger("schema_tools.PipelineExecutor")
async def execute_pipeline(self, pipeline_name: str, context: TableProcessingContext) -> Dict[str, ProcessingResult]:
"""执行指定的处理链"""
if pipeline_name not in self.pipeline_config:
raise ValueError(f"未知的处理链: {pipeline_name}")
steps = self.pipeline_config[pipeline_name]
results = {}
self.logger.info(f"开始执行处理链 '{pipeline_name}': {' -> '.join(steps)}")
for step_name in steps:
try:
tool = ToolRegistry.get_tool(step_name)
# 验证输入
if not tool.validate_input(context):
result = ProcessingResult(
success=False,
error_message=f"工具 {step_name} 输入验证失败"
)
else:
result = await tool._execute_with_timing(context)
results[step_name] = result
context.update_step(step_name, result)
# 如果步骤失败且不允许继续,则停止
if not result.success:
from config import SCHEMA_TOOLS_CONFIG
if not SCHEMA_TOOLS_CONFIG["continue_on_error"]:
self.logger.error(f"步骤 {step_name} 失败,停止处理链执行")
break
else:
self.logger.warning(f"步骤 {step_name} 失败,继续执行下一步")
except Exception as e:
self.logger.exception(f"执行步骤 {step_name} 时发生异常")
results[step_name] = ProcessingResult(
success=False,
error_message=f"步骤执行异常: {str(e)}"
)
break
return results
tools/database_inspector.py
)import asyncio
import asyncpg
from typing import List, Dict, Any, Optional
from tools.base import BaseTool, ToolRegistry
from utils.data_structures import ProcessingResult, TableProcessingContext, FieldInfo, TableMetadata
@ToolRegistry.register("database_inspector")
class DatabaseInspectorTool(BaseTool):
"""数据库元数据检查工具"""
needs_llm = False
tool_name = "数据库检查器"
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.db_connection = kwargs.get('db_connection')
self.connection_pool = None
async def execute(self, context: TableProcessingContext) -> ProcessingResult:
"""执行数据库元数据检查"""
try:
# 建立数据库连接
if not self.connection_pool:
await self._create_connection_pool()
table_name = context.table_metadata.table_name
schema_name = context.table_metadata.schema_name
# 获取表的基本信息
table_info = await self._get_table_info(schema_name, table_name)
if not table_info:
return ProcessingResult(
success=False,
error_message=f"表 {schema_name}.{table_name} 不存在或无权限访问"
)
# 获取字段信息
fields = await self._get_table_fields(schema_name, table_name)
# 获取表注释
table_comment = await self._get_table_comment(schema_name, table_name)
# 获取表统计信息
stats = await self._get_table_statistics(schema_name, table_name)
# 更新表元数据
context.table_metadata.original_comment = table_comment
context.table_metadata.comment = table_comment
context.table_metadata.fields = fields
context.table_metadata.row_count = stats.get('row_count')
context.table_metadata.table_size = stats.get('table_size')
return ProcessingResult(
success=True,
data={
'fields_count': len(fields),
'table_comment': table_comment,
'row_count': stats.get('row_count'),
'table_size': stats.get('table_size')
},
metadata={'tool': self.tool_name}
)
except Exception as e:
self.logger.exception(f"数据库检查失败")
return ProcessingResult(
success=False,
error_message=f"数据库检查失败: {str(e)}"
)
async def _create_connection_pool(self):
"""创建数据库连接池"""
try:
self.connection_pool = await asyncpg.create_pool(
self.db_connection,
min_size=1,
max_size=5,
command_timeout=30
)
self.logger.info("数据库连接池创建成功")
except Exception as e:
self.logger.error(f"创建数据库连接池失败: {e}")
raise
async def _get_table_info(self, schema_name: str, table_name: str) -> Optional[Dict]:
"""获取表基本信息"""
query = """
SELECT schemaname, tablename, tableowner, tablespace, hasindexes, hasrules, hastriggers
FROM pg_tables
WHERE schemaname = $1 AND tablename = $2
"""
async with self.connection_pool.acquire() as conn:
result = await conn.fetchrow(query, schema_name, table_name)
return dict(result) if result else None
async def _get_table_fields(self, schema_name: str, table_name: str) -> List[FieldInfo]:
"""获取表字段信息"""
query = """
SELECT
c.column_name,
c.data_type,
c.is_nullable,
c.column_default,
c.character_maximum_length,
c.numeric_precision,
c.numeric_scale,
pd.description as column_comment,
CASE WHEN pk.column_name IS NOT NULL THEN true ELSE false END as is_primary_key,
CASE WHEN fk.column_name IS NOT NULL THEN true ELSE false END as is_foreign_key
FROM information_schema.columns c
LEFT JOIN pg_description pd ON pd.objsubid = c.ordinal_position
AND pd.objoid = (
SELECT oid FROM pg_class
WHERE relname = c.table_name
AND relnamespace = (SELECT oid FROM pg_namespace WHERE nspname = c.table_schema)
)
LEFT JOIN (
SELECT ku.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage ku ON tc.constraint_name = ku.constraint_name
WHERE tc.table_schema = $1 AND tc.table_name = $2 AND tc.constraint_type = 'PRIMARY KEY'
) pk ON pk.column_name = c.column_name
LEFT JOIN (
SELECT ku.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage ku ON tc.constraint_name = ku.constraint_name
WHERE tc.table_schema = $1 AND tc.table_name = $2 AND tc.constraint_type = 'FOREIGN KEY'
) fk ON fk.column_name = c.column_name
WHERE c.table_schema = $1 AND c.table_name = $2
ORDER BY c.ordinal_position
"""
fields = []
async with self.connection_pool.acquire() as conn:
rows = await conn.fetch(query, schema_name, table_name)
for row in rows:
field = FieldInfo(
name=row['column_name'],
type=row['data_type'],
nullable=row['is_nullable'] == 'YES',
default_value=row['column_default'],
original_comment=row['column_comment'],
comment=row['column_comment'],
is_primary_key=row['is_primary_key'],
is_foreign_key=row['is_foreign_key'],
max_length=row['character_maximum_length'],
precision=row['numeric_precision'],
scale=row['numeric_scale']
)
fields.append(field)
return fields
async def _get_table_comment(self, schema_name: str, table_name: str) -> Optional[str]:
"""获取表注释"""
query = """
SELECT obj_description(oid) as table_comment
FROM pg_class
WHERE relname = $2
AND relnamespace = (SELECT oid FROM pg_namespace WHERE nspname = $1)
"""
async with self.connection_pool.acquire() as conn:
result = await conn.fetchval(query, schema_name, table_name)
return result
async def _get_table_statistics(self, schema_name: str, table_name: str) -> Dict[str, Any]:
"""获取表统计信息"""
stats_query = """
SELECT
schemaname,
tablename,
attname,
n_distinct,
most_common_vals,
most_common_freqs,
histogram_bounds
FROM pg_stats
WHERE schemaname = $1 AND tablename = $2
"""
size_query = """
SELECT pg_size_pretty(pg_total_relation_size($1)) as table_size,
pg_relation_size($1) as table_size_bytes
"""
count_query = f"SELECT COUNT(*) as row_count FROM {schema_name}.{table_name}"
stats = {}
async with self.connection_pool.acquire() as conn:
try:
# 获取行数
row_count = await conn.fetchval(count_query)
stats['row_count'] = row_count
# 获取表大小
table_oid = await conn.fetchval(
"SELECT oid FROM pg_class WHERE relname = $1 AND relnamespace = (SELECT oid FROM pg_namespace WHERE nspname = $2)",
table_name, schema_name
)
if table_oid:
size_result = await conn.fetchrow(size_query, table_oid)
stats['table_size'] = size_result['table_size']
stats['table_size_bytes'] = size_result['table_size_bytes']
except Exception as e:
self.logger.warning(f"获取表统计信息失败: {e}")
return stats
tools/data_sampler.py
)import random
from typing import List, Dict, Any
from tools.base import BaseTool, ToolRegistry
from utils.data_structures import ProcessingResult, TableProcessingContext
@ToolRegistry.register("data_sampler")
class DataSamplerTool(BaseTool):
"""数据采样工具"""
needs_llm = False
tool_name = "数据采样器"
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.db_connection = kwargs.get('db_connection')
async def execute(self, context: TableProcessingContext) -> ProcessingResult:
"""执行数据采样"""
try:
from config import SCHEMA_TOOLS_CONFIG
table_metadata = context.table_metadata
sample_limit = SCHEMA_TOOLS_CONFIG["sample_data_limit"]
large_table_threshold = SCHEMA_TOOLS_CONFIG["large_table_threshold"]
# 判断是否为大表,使用不同的采样策略
if table_metadata.row_count and table_metadata.row_count > large_table_threshold:
sample_data = await self._smart_sample_large_table(table_metadata, sample_limit)
self.logger.info(f"大表 {table_metadata.full_name} 使用智能采样策略")
else:
sample_data = await self._simple_sample(table_metadata, sample_limit)
# 更新上下文中的采样数据
context.table_metadata.sample_data = sample_data
return ProcessingResult(
success=True,
data={
'sample_count': len(sample_data),
'sampling_strategy': 'smart' if table_metadata.row_count and table_metadata.row_count > large_table_threshold else 'simple'
},
metadata={'tool': self.tool_name}
)
except Exception as e:
self.logger.exception(f"数据采样失败")
return ProcessingResult(
success=False,
error_message=f"数据采样失败: {str(e)}"
)
async def _simple_sample(self, table_metadata: TableMetadata, limit: int) -> List[Dict[str, Any]]:
"""简单采样策略"""
from tools.database_inspector import DatabaseInspectorTool
# 复用数据库检查工具的连接
inspector = ToolRegistry.get_tool("database_inspector")
query = f"SELECT * FROM {table_metadata.full_name} LIMIT {limit}"
async with inspector.connection_pool.acquire() as conn:
rows = await conn.fetch(query)
return [dict(row) for row in rows]
async def _smart_sample_large_table(self, table_metadata: TableMetadata, limit: int) -> List[Dict[str, Any]]:
"""智能采样策略(用于大表)"""
from tools.database_inspector import DatabaseInspectorTool
inspector = ToolRegistry.get_tool("database_inspector")
samples_per_section = max(1, limit // 3)
samples = []
async with inspector.connection_pool.acquire() as conn:
# 1. 前N行采样
front_query = f"SELECT * FROM {table_metadata.full_name} LIMIT {samples_per_section}"
front_rows = await conn.fetch(front_query)
samples.extend([dict(row) for row in front_rows])
# 2. 随机中间采样(使用TABLESAMPLE)
if table_metadata.row_count > samples_per_section * 2:
try:
# 计算采样百分比
sample_percent = min(1.0, (samples_per_section * 100.0) / table_metadata.row_count)
middle_query = f"""
SELECT * FROM {table_metadata.full_name}
TABLESAMPLE SYSTEM({sample_percent})
LIMIT {samples_per_section}
"""
middle_rows = await conn.fetch(middle_query)
samples.extend([dict(row) for row in middle_rows])
except Exception as e:
self.logger.warning(f"TABLESAMPLE采样失败,使用OFFSET采样: {e}")
# 回退到OFFSET采样
offset = random.randint(samples_per_section, table_metadata.row_count - samples_per_section)
offset_query = f"SELECT * FROM {table_metadata.full_name} OFFSET {offset} LIMIT {samples_per_section}"
offset_rows = await conn.fetch(offset_query)
samples.extend([dict(row) for row in offset_rows])
# 3. 后N行采样
remaining = limit - len(samples)
if remaining > 0:
# 使用ORDER BY ... DESC来获取最后的行
tail_query = f"""
SELECT * FROM (
SELECT *, ROW_NUMBER() OVER() as rn
FROM {table_metadata.full_name}
) sub
WHERE sub.rn > (SELECT COUNT(*) FROM {table_metadata.full_name}) - {remaining}
ORDER BY sub.rn
"""
try:
tail_rows = await conn.fetch(tail_query)
# 移除ROW_NUMBER列
for row in tail_rows:
row_dict = dict(row)
row_dict.pop('rn', None)
samples.append(row_dict)
except Exception as e:
self.logger.warning(f"尾部采样失败: {e}")
return samples[:limit] # 确保不超过限制
tools/comment_generator.py
)import asyncio
from typing import List, Dict, Any, Tuple
from tools.base import BaseTool, ToolRegistry
from utils.data_structures import ProcessingResult, TableProcessingContext, FieldInfo
@ToolRegistry.register("comment_generator")
class CommentGeneratorTool(BaseTool):
"""LLM注释生成工具"""
needs_llm = True
tool_name = "注释生成器"
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.business_context = kwargs.get('business_context', '')
self.business_dictionary = self._load_business_dictionary()
async def execute(self, context: TableProcessingContext) -> ProcessingResult:
"""执行注释生成"""
try:
table_metadata = context.table_metadata
# 生成表注释
table_comment_result = await self._generate_table_comment(table_metadata, context.business_context)
# 生成字段注释和枚举建议
field_results = await self._generate_field_comments_and_enums(table_metadata, context.business_context)
# 更新表元数据
if table_comment_result['success']:
table_metadata.generated_comment = table_comment_result['comment']
table_metadata.comment = table_comment_result['comment']
# 更新字段信息
enum_suggestions = []
for i, field in enumerate(table_metadata.fields):
if i < len(field_results) and field_results[i]['success']:
field.generated_comment = field_results[i]['comment']
field.comment = field_results[i]['comment']
# 处理枚举建议
if field_results[i].get('is_enum'):
field.is_enum = True
enum_suggestions.append({
'field_name': field.name,
'suggested_values': field_results[i].get('enum_values', []),
'enum_description': field_results[i].get('enum_description', '')
})
# 验证枚举建议
if enum_suggestions:
validated_enums = await self._validate_enum_suggestions(table_metadata, enum_suggestions)
# 更新验证后的枚举信息
for enum_info in validated_enums:
field_name = enum_info['field_name']
for field in table_metadata.fields:
if field.name == field_name:
field.enum_values = enum_info['actual_values']
field.enum_description = enum_info['description']
break
return ProcessingResult(
success=True,
data={
'table_comment_generated': table_comment_result['success'],
'field_comments_generated': sum(1 for r in field_results if r['success']),
'enum_fields_detected': len([f for f in table_metadata.fields if f.is_enum]),
'enum_suggestions': enum_suggestions
},
metadata={'tool': self.tool_name}
)
except Exception as e:
self.logger.exception(f"注释生成失败")
return ProcessingResult(
success=False,
error_message=f"注释生成失败: {str(e)}"
)
async def _generate_table_comment(self, table_metadata, business_context: str) -> Dict[str, Any]:
"""生成表注释"""
try:
prompt = self._build_table_comment_prompt(table_metadata, business_context)
# 调用LLM
response = await self._call_llm_with_retry(prompt)
# 解析响应
comment = self._extract_table_comment(response)
return {
'success': True,
'comment': comment,
'original_response': response
}
except Exception as e:
self.logger.error(f"表注释生成失败: {e}")
return {
'success': False,
'comment': table_metadata.original_comment or f"{table_metadata.table_name}表",
'error': str(e)
}
async def _generate_field_comments_and_enums(self, table_metadata, business_context: str) -> List[Dict[str, Any]]:
"""批量生成字段注释和枚举建议"""
try:
# 构建批量处理的提示词
prompt = self._build_field_batch_prompt(table_metadata, business_context)
# 调用LLM
response = await self._call_llm_with_retry(prompt)
# 解析批量响应
field_results = self._parse_field_batch_response(response, table_metadata.fields)
return field_results
except Exception as e:
self.logger.error(f"字段注释批量生成失败: {e}")
# 返回默认结果
return [
{
'success': False,
'comment': field.original_comment or field.name,
'is_enum': False,
'error': str(e)
}
for field in table_metadata.fields
]
def _build_table_comment_prompt(self, table_metadata, business_context: str) -> str:
"""构建表注释生成提示词"""
# 准备字段信息摘要
fields_summary = []
for field in table_metadata.fields[:10]: # 只显示前10个字段避免过长
field_desc = f"- {field.name} ({field.type})"
if field.comment:
field_desc += f": {field.comment}"
fields_summary.append(field_desc)
# 准备样例数据摘要
sample_summary = ""
if table_metadata.sample_data:
sample_count = min(3, len(table_metadata.sample_data))
sample_summary = f"\n样例数据({sample_count}条):\n"
for i, sample in enumerate(table_metadata.sample_data[:sample_count]):
sample_str = ", ".join([f"{k}={v}" for k, v in list(sample.items())[:5]])
sample_summary += f"{i+1}. {sample_str}\n"
prompt = f"""你是一个数据库文档专家。请根据以下信息为数据库表生成简洁、准确的中文注释。
业务背景: {business_context}
{self.business_dictionary}
表信息:
- 表名: {table_metadata.table_name}
- Schema: {table_metadata.schema_name}
- 现有注释: {table_metadata.original_comment or "无"}
- 字段数量: {len(table_metadata.fields)}
- 数据行数: {table_metadata.row_count or "未知"}
主要字段:
{chr(10).join(fields_summary)}
{sample_summary}
请生成一个简洁、准确的中文表注释,要求:
1. 如果现有注释是英文,请翻译为中文并改进
2. 根据字段名称和样例数据推断表的业务用途
3. 注释长度控制在50字以内
4. 突出表的核心业务价值
表注释:"""
return prompt
def _build_field_batch_prompt(self, table_metadata, business_context: str) -> str:
"""构建字段批量处理提示词"""
# 准备字段信息
fields_info = []
sample_values = {}
# 收集字段的样例值
for sample in table_metadata.sample_data[:5]:
for field_name, value in sample.items():
if field_name not in sample_values:
sample_values[field_name] = []
if value is not None and len(sample_values[field_name]) < 5:
sample_values[field_name].append(str(value))
# 构建字段信息列表
for field in table_metadata.fields:
field_info = f"{field.name} ({field.type})"
if field.original_comment:
field_info += f" - 原注释: {field.original_comment}"
# 添加样例值
if field.name in sample_values and sample_values[field.name]:
values_str = ", ".join(sample_values[field.name][:3])
field_info += f" - 样例值: {values_str}"
fields_info.append(field_info)
prompt = f"""你是一个数据库文档专家。请为以下表的所有字段生成中文注释,并识别可能的枚举字段。
业务背景: {business_context}
{self.business_dictionary}
表名: {table_metadata.schema_name}.{table_metadata.table_name}
表注释: {table_metadata.comment or "无"}
字段列表:
{chr(10).join([f"{i+1}. {info}" for i, info in enumerate(fields_info)])}
请按以下JSON格式输出每个字段的分析结果:
```json
{{
"fields": [
{{
"name": "字段名",
"comment": "中文注释(简洁明确,15字以内)",
"is_enum": true/false,
"enum_values": ["值1", "值2", "值3"] (如果是枚举),
"enum_description": "枚举含义说明" (如果是枚举)
}}
]
}}
注释生成要求:
请输出JSON格式的分析结果:"""
return prompt
async def _call_llm_with_retry(self, prompt: str, max_retries: int = 3) -> str:
"""带重试的LLM调用"""
from config import SCHEMA_TOOLS_CONFIG
for attempt in range(max_retries):
try:
# 使用vanna实例调用LLM
response = await asyncio.to_thread(self.vn.ask, prompt)
if response and response.strip():
return response.strip()
else:
raise ValueError("LLM返回空响应")
except Exception as e:
self.logger.warning(f"LLM调用失败 (尝试 {attempt + 1}/{max_retries}): {e}")
if attempt == max_retries - 1:
raise
await asyncio.sleep(1) # 等待1秒后重试
raise Exception("LLM调用达到最大重试次数")
def _extract_table_comment(self, llm_response: str) -> str:
"""从LLM响应中提取表注释"""
# 简单的文本清理和提取逻辑
lines = llm_response.strip().split('\n')
# 查找包含实际注释的行
for line in lines:
line = line.strip()
if line and not line.startswith('#') and not line.startswith('*'):
# 移除可能的前缀
prefixes = ['表注释:', '注释:', '说明:', '表说明:']
for prefix in prefixes:
if line.startswith(prefix):
line = line[len(prefix):].strip()
if line:
return line[:200] # 限制长度
return llm_response.strip()[:200]
def _parse_field_batch_response(self, llm_response: str, fields: List[FieldInfo]) -> List[Dict[str, Any]]:
"""解析字段批量处理响应"""
import json
import re
try:
# 尝试提取JSON部分
json_match = re.search(r'```json\s*(.*?)\s*```', llm_response, re.DOTALL)
if json_match:
json_str = json_match.group(1)
else:
# 如果没有代码块,尝试直接解析
json_str = llm_response
# 解析JSON
parsed_data = json.loads(json_str)
field_data = parsed_data.get('fields', [])
# 映射到字段结果
results = []
for i, field in enumerate(fields):
if i < len(field_data):
data = field_data[i]
results.append({
'success': True,
'comment': data.get('comment', field.name),
'is_enum': data.get('is_enum', False),
'enum_values': data.get('enum_values', []),
'enum_description': data.get('enum_description', '')
})
else:
# 默认结果
results.append({
'success': False,
'comment': field.original_comment or field.name,
'is_enum': False
})
return results
except Exception as e:
self.logger.error(f"解析字段批量响应失败: {e}")
# 返回默认结果
return [
{
'success': False,
'comment': field.original_comment or field.name,
'is_enum': False,
'error': str(e)
}
for field in fields
]
async def _validate_enum_suggestions(self, table_metadata, enum_suggestions: List[Dict]) -> List[Dict]:
"""验证枚举建议"""
from tools.database_inspector import DatabaseInspectorTool
from config import SCHEMA_TOOLS_CONFIG
validated_enums = []
inspector = ToolRegistry.get_tool("database_inspector")
sample_limit = SCHEMA_TOOLS_CONFIG["enum_detection_sample_limit"]
for enum_info in enum_suggestions:
field_name = enum_info['field_name']
try:
# 查询字段的所有不同值
query = f"""
SELECT DISTINCT {field_name} as value, COUNT(*) as count
FROM {table_metadata.full_name}
WHERE {field_name} IS NOT NULL
GROUP BY {field_name}
ORDER BY count DESC
LIMIT {sample_limit}
"""
async with inspector.connection_pool.acquire() as conn:
rows = await conn.fetch(query)
actual_values = [str(row['value']) for row in rows]
# 验证是否真的是枚举(不同值数量合理)
max_enum_values = SCHEMA_TOOLS_CONFIG["enum_max_distinct_values"]
if len(actual_values) <= max_enum_values:
validated_enums.append({
'field_name': field_name,
'actual_values': actual_values,
'suggested_values': enum_info['suggested_values'],
'description': enum_info['enum_description'],
'value_counts': [(row['value'], row['count']) for row in rows]
})
self.logger.info(f"确认字段 {field_name} 为枚举类型,包含 {len(actual_values)} 个值")
else:
self.logger.info(f"字段 {field_name} 不同值过多({len(actual_values)}),不认为是枚举")
except Exception as e:
self.logger.warning(f"验证字段 {field_name} 的枚举建议失败: {e}")
return validated_enums
def _load_business_dictionary(self) -> str:
"""加载业务词典"""
try:
import os
dict_file = os.path.join(os.path.dirname(__file__), '..', 'prompts', 'business_dictionary.txt')
if os.path.exists(dict_file):
with open(dict_file, 'r', encoding='utf-8') as f:
content = f.read().strip()
return f"\n业务词典:\n{content}\n" if content else ""
return ""
except Exception as e:
self.logger.warning(f"加载业务词典失败: {e}")
return ""
## 5. 主AI Agent实现
### 5.1 主Agent核心代码 (`training_data_agent.py`)
```python
import asyncio
import time
import logging
import os
from typing import List, Dict, Any, Optional
from pathlib import Path
from tools.base import ToolRegistry, PipelineExecutor
from utils.data_structures import TableMetadata, TableProcessingContext, ProcessingResult
from utils.file_manager import FileNameManager
from utils.system_filter import SystemTableFilter
from utils.permission_checker import DatabasePermissionChecker
from utils.table_parser import TableListParser
from utils.logger import setup_logging
class SchemaTrainingDataAgent:
"""Schema训练数据生成AI Agent"""
def __init__(self,
db_connection: str,
table_list_file: str,
business_context: str = None,
output_dir: str = None,
pipeline: str = "full"):
self.db_connection = db_connection
self.table_list_file = table_list_file
self.business_context = business_context or "数据库管理系统"
self.pipeline = pipeline
# 配置管理
from config import SCHEMA_TOOLS_CONFIG
self.config = SCHEMA_TOOLS_CONFIG
self.output_dir = output_dir or self.config["output_directory"]
# 初始化组件
self.file_manager = FileNameManager(self.output_dir)
self.system_filter = SystemTableFilter()
self.table_parser = TableListParser()
self.pipeline_executor = PipelineExecutor(self.config["available_pipelines"])
# 统计信息
self.stats = {
'total_tables': 0,
'processed_tables': 0,
'failed_tables': 0,
'skipped_tables': 0,
'start_time': None,
'end_time': None
}
self.failed_tables = []
self.logger = logging.getLogger("schema_tools.Agent")
async def generate_training_data(self) -> Dict[str, Any]:
"""主入口:生成训练数据"""
try:
self.stats['start_time'] = time.time()
self.logger.info("🚀 开始生成Schema训练数据")
# 1. 初始化
await self._initialize()
# 2. 检查数据库权限
await self._check_database_permissions()
# 3. 解析表清单
tables = await self._parse_table_list()
# 4. 过滤系统表
user_tables = self._filter_system_tables(tables)
# 5. 并发处理表
results = await self._process_tables_concurrently(user_tables)
# 6. 生成总结报告
report = self._generate_summary_report(results)
self.stats['end_time'] = time.time()
self.logger.info("✅ Schema训练数据生成完成")
return report
except Exception as e:
self.stats['end_time'] = time.time()
self.logger.exception("❌ Schema训练数据生成失败")
raise
async def _initialize(self):
"""初始化Agent"""
# 创建输出目录
os.makedirs(self.output_dir, exist_ok=True)
if self.config["create_subdirectories"]:
os.makedirs(os.path.join(self.output_dir, "ddl"), exist_ok=True)
os.makedirs(os.path.join(self.output_dir, "docs"), exist_ok=True)
os.makedirs(os.path.join(self.output_dir, "logs"), exist_ok=True)
# 初始化数据库工具
database_tool = ToolRegistry.get_tool("database_inspector", db_connection=self.db_connection)
await database_tool._create_connection_pool()
self.logger.info(f"初始化完成,输出目录: {self.output_dir}")
async def _check_database_permissions(self):
"""检查数据库权限"""
if not self.config["check_permissions"]:
return
inspector = ToolRegistry.get_tool("database_inspector")
checker = DatabasePermissionChecker(inspector)
permissions = await checker.check_permissions()
if not permissions['connect']:
raise Exception("无法连接到数据库")
if self.config["require_select_permission"] and not permissions['select_data']:
if not self.config["allow_readonly_database"]:
raise Exception("数据库查询权限不足")
else:
self.logger.warning("数据库为只读或权限受限,部分功能可能受影响")
self.logger.info(f"数据库权限检查完成: {permissions}")
async def _parse_table_list(self) -> List[str]:
"""解析表清单文件"""
tables = self.table_parser.parse_file(self.table_list_file)
self.stats['total_tables'] = len(tables)
self.logger.info(f"📋 从清单文件读取到 {len(tables)} 个表")
return tables
def _filter_system_tables(self, tables: List[str]) -> List[str]:
"""过滤系统表"""
if not self.config["filter_system_tables"]:
return tables
user_tables = self.system_filter.filter_user_tables(tables)
filtered_count = len(tables) - len(user_tables)
if filtered_count > 0:
self.logger.info(f"🔍 过滤了 {filtered_count} 个系统表,保留 {len(user_tables)} 个用户表")
self.stats['skipped_tables'] += filtered_count
return user_tables
async def _process_tables_concurrently(self, tables: List[str]) -> List[Dict[str, Any]]:
"""并发处理表"""
max_concurrent = self.config["max_concurrent_tables"]
semaphore = asyncio.Semaphore(max_concurrent)
self.logger.info(f"🔄 开始并发处理 {len(tables)} 个表 (最大并发: {max_concurrent})")
# 创建任务
tasks = [
self._process_single_table_with_semaphore(semaphore, table_spec)
for table_spec in tables
]
# 并发执行
results = await asyncio.gather(*tasks, return_exceptions=True)
# 统计结果
successful = sum(1 for r in results if isinstance(r, dict) and r.get('success', False))
failed = len(results) - successful
self.stats['processed_tables'] = successful
self.stats['failed_tables'] = failed
self.logger.info(f"📊 处理完成: 成功 {successful} 个,失败 {failed} 个")
return [r for r in results if isinstance(r, dict)]
async def _process_single_table_with_semaphore(self, semaphore: asyncio.Semaphore, table_spec: str) -> Dict[str, Any]:
"""带信号量的单表处理"""
async with semaphore:
return await self._process_single_table(table_spec)
async def _process_single_table(self, table_spec: str) -> Dict[str, Any]:
"""处理单个表"""
start_time = time.time()
try:
# 解析表名
if '.' in table_spec:
schema_name, table_name = table_spec.split('.', 1)
else:
schema_name, table_name = 'public', table_spec
full_name = f"{schema_name}.{table_name}"
self.logger.info(f"🔍 开始处理表: {full_name}")
# 创建表元数据
table_metadata = TableMetadata(
schema_name=schema_name,
table_name=table_name,
full_name=full_name
)
# 创建处理上下文
context = TableProcessingContext(
table_metadata=table_metadata,
business_context=self.business_context,
output_dir=self.output_dir,
pipeline=self.pipeline,
vn=None, # 将在工具中注入
file_manager=self.file_manager,
start_time=start_time
)
# 执行处理链
step_results = await self.pipeline_executor.execute_pipeline(self.pipeline, context)
# 计算总体成功状态
success = all(result.success for result in step_results.values())
execution_time = time.time() - start_time
if success:
self.logger.info(f"✅ 表 {full_name} 处理成功,耗时: {execution_time:.2f}秒")
else:
self.logger.error(f"❌ 表 {full_name} 处理失败,耗时: {execution_time:.2f}秒")
self.failed_tables.append(full_name)
return {
'success': success,
'table_name': full_name,
'execution_time': execution_time,
'step_results': {k: v.to_dict() for k, v in step_results.items()},
'metadata': {
'fields_count': len(table_metadata.fields),
'row_count': table_metadata.row_count,
'enum_fields': len([f for f in table_metadata.fields if f.is_enum])
}
}
except Exception as e:
execution_time = time.time() - start_time
error_msg = f"表 {table_spec} 处理异常: {str(e)}"
self.logger.exception(error_msg)
self.failed_tables.append(table_spec)
return {
'success': False,
'table_name': table_spec,
'execution_time': execution_time,
'error_message': error_msg,
'step_results': {}
}
def _generate_summary_report(self, results: List[Dict[str, Any]]) -> Dict[str, Any]:
"""生成总结报告"""
total_time = self.stats['end_time'] - self.stats['start_time']
# 计算统计信息
successful_results = [r for r in results if r.get('success', False)]
failed_results = [r for r in results if not r.get('success', False)]
total_fields = sum(r.get('metadata', {}).get('fields_count', 0) for r in successful_results)
total_enum_fields = sum(r.get('metadata', {}).get('enum_fields', 0) for r in successful_results)
avg_execution_time = sum(r.get('execution_time', 0) for r in results) / len(results) if results else 0
report = {
'summary': {
'total_tables': self.stats['total_tables'],
'processed_successfully': len(successful_results),
'failed': len(failed_results),
'skipped_system_tables': self.stats['skipped_tables'],
'total_execution_time': total_time,
'average_table_time': avg_execution_time
},
'statistics': {
'total_fields_processed': total_fields,
'enum_fields_detected': total_enum_fields,
'files_generated': len(successful_results) * (2 if self.pipeline == 'full' else 1)
},
'failed_tables': self.failed_tables,
'detailed_results': results,
'configuration': {
'pipeline': self.pipeline,
'business_context': self.business_context,
'output_directory': self.output_dir,
'max_concurrent_tables': self.config['max_concurrent_tables']
}
}
# 输出总结
self.logger.info(f"📊 处理总结:")
self.logger.info(f" ✅ 成功: {report['summary']['processed_successfully']} 个表")
self.logger.info(f" ❌ 失败: {report['summary']['failed']} 个表")
self.logger.info(f" ⏭️ 跳过: {report['summary']['skipped_system_tables']} 个系统表")
self.logger.info(f" 📁 生成文件: {report['statistics']['files_generated']} 个")
self.logger.info(f" 🕐 总耗时: {total_time:.2f} 秒")
if self.failed_tables:
self.logger.warning(f"❌ 失败的表: {', '.join(self.failed_tables)}")
return report
async def check_database_permissions(self) -> Dict[str, bool]:
"""检查数据库权限(供外部调用)"""
inspector = ToolRegistry.get_tool("database_inspector", db_connection=self.db_connection)
checker = DatabasePermissionChecker(inspector)
return await checker.check_permissions()
__main__.py
)import argparse
import asyncio
import sys
import os
import logging
from pathlib import Path
def setup_argument_parser():
"""设置命令行参数解析器"""
parser = argparse.ArgumentParser(
description='Schema Tools - 自动生成数据库训练数据',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
示例用法:
# 基本使用
python -m schema_tools --db-connection "postgresql://user:pass@host:5432/db" --table-list tables.txt
# 指定业务上下文和输出目录
python -m schema_tools --db-connection "..." --table-list tables.txt --business-context "电商系统" --output-dir output
# 仅生成DDL文件
python -m schema_tools --db-connection "..." --table-list tables.txt --pipeline ddl_only
# 权限检查模式
python -m schema_tools --db-connection "..." --check-permissions-only
"""
)
# 必需参数
parser.add_argument(
'--db-connection',
required=True,
help='数据库连接字符串 (例如: postgresql://user:pass@localhost:5432/dbname)'
)
# 可选参数
parser.add_argument(
'--table-list',
help='表清单文件路径'
)
parser.add_argument(
'--business-context',
help='业务上下文描述'
)
parser.add_argument(
'--business-context-file',
help='业务上下文文件路径'
)
parser.add_argument(
'--output-dir',
help='输出目录路径'
)
parser.add_argument(
'--pipeline',
choices=['full', 'ddl_only', 'analysis_only'],
help='处理链类型'
)
parser.add_argument(
'--max-concurrent',
type=int,
help='最大并发表数量'
)
# 功能开关
parser.add_argument(
'--no-filter-system-tables',
action='store_true',
help='禁用系统表过滤'
)
parser.add_argument(
'--check-permissions-only',
action='store_true',
help='仅检查数据库权限,不处理表'
)
parser.add_argument(
'--verbose', '-v',
action='store_true',
help='启用详细日志输出'
)
parser.add_argument(
'--log-file',
help='日志文件路径'
)
return parser
def load_config_with_overrides(args):
"""加载配置并应用命令行覆盖"""
from config import SCHEMA_TOOLS_CONFIG
config = SCHEMA_TOOLS_CONFIG.copy()
# 命令行参数覆盖配置
if args.output_dir:
config["output_directory"] = args.output_dir
if args.pipeline:
config["default_pipeline"] = args.pipeline
if args.max_concurrent:
config["max_concurrent_tables"] = args.max_concurrent
if args.no_filter_system_tables:
config["filter_system_tables"] = False
if args.log_file:
config["log_file"] = args.log_file
return config
def load_business_context(args):
"""加载业务上下文"""
if args.business_context_file:
try:
with open(args.business_context_file, 'r', encoding='utf-8') as f:
return f.read().strip()
except Exception as e:
print(f"警告: 无法读取业务上下文文件 {args.business_context_file}: {e}")
if args.business_context:
return args.business_context
from config import SCHEMA_TOOLS_CONFIG
return SCHEMA_TOOLS_CONFIG.get("default_business_context", "数据库管理系统")
async def check_permissions_only(db_connection: str):
"""仅检查数据库权限"""
from training_data_agent import SchemaTrainingDataAgent
print("🔍 检查数据库权限...")
try:
agent = SchemaTrainingDataAgent(
db_connection=db_connection,
table_list_file="", # 不需要表清单
business_context="" # 不需要业务上下文
)
# 初始化Agent以建立数据库连接
await agent._initialize()
# 检查权限
permissions = await agent.check_database_permissions()
print("\n📋 权限检查结果:")
print(f" ✅ 数据库连接: {'可用' if permissions['connect'] else '不可用'}")
print(f" ✅ 元数据查询: {'可用' if permissions['select_metadata'] else '不可用'}")
print(f" ✅ 数据查询: {'可用' if permissions['select_data'] else '不可用'}")
print(f" ℹ️ 数据库类型: {'只读' if permissions['is_readonly'] else '读写'}")
if all(permissions.values()):
print("\n✅ 数据库权限检查通过,可以开始处理")
return True
else:
print("\n❌ 数据库权限不足,请检查配置")
return False
except Exception as e:
print(f"\n❌ 权限检查失败: {e}")
return False
async def main():
"""主入口函数"""
parser = setup_argument_parser()
args = parser.parse_args()
# 设置日志
from utils.logger import setup_logging
setup_logging(
verbose=args.verbose,
log_file=args.log_file
)
# 仅权限检查模式
if args.check_permissions_only:
success = await check_permissions_only(args.db_connection)
sys.exit(0 if success else 1)
# 验证必需参数
if not args.table_list:
print("错误: 需要指定 --table-list 参数")
parser.print_help()
sys.exit(1)
if not os.path.exists(args.table_list):
print(f"错误: 表清单文件不存在: {args.table_list}")
sys.exit(1)
try:
# 加载配置和业务上下文
config = load_config_with_overrides(args)
business_context = load_business_context(args)
# 创建Agent
from training_data_agent import SchemaTrainingDataAgent
agent = SchemaTrainingDataAgent(
db_connection=args.db_connection,
table_list_file=args.table_list,
business_context=business_context,
output_dir=config["output_directory"],
pipeline=config["default_pipeline"]
)
# 执行生成
print("🚀 开始生成Schema训练数据...")
report = await agent.generate_training_data()
# 输出结果
if report['summary']['failed'] == 0:
print("\n🎉 所有表处理成功!")
else:
print(f"\n⚠️ 处理完成,但有 {report['summary']['failed']} 个表失败")
print(f"📁 输出目录: {config['output_directory']}")
# 如果有失败的表,返回非零退出码
sys.exit(1 if report['summary']['failed'] > 0 else 0)
except KeyboardInterrupt:
print("\n\n⏹️ 用户中断,程序退出")
sys.exit(130)
except Exception as e:
print(f"\n❌ 程序执行失败: {e}")
if args.verbose:
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
asyncio.run(main())
bss_service_area.ddl
)sql
-- 中文名: 服务区基础信息表
-- 描述: 记录高速公路服务区的基础属性,包括服务区编码、名称、方向、公司归属、地理位置、服务类型和状态,是业务分析与服务区定位的核心表。
create table bss_service_area (
id varchar(32) not null, -- 服务区唯一标识(主键,UUID格式)
version integer not null, -- 数据版本号
create_ts timestamp(3), -- 创建时间
created_by varchar(50), -- 创建人
update_ts timestamp(3), -- 更新时间
updated_by varchar(50), -- 更新人
delete_ts timestamp(3), -- 删除时间
deleted_by varchar(50), -- 删除人
service_area_name varchar(255), -- 服务区名称
service_area_no varchar(255), -- 服务区编码(业务唯一标识)
company_id varchar(32), -- 公司ID(外键关联bss_company.id)
service_position varchar(255), -- 经纬度坐标
service_area_type varchar(50), -- 服务区类型(枚举:信息化服务区、智能化服务区)
service_state varchar(50), -- 服务区状态(枚举:开放、关闭、上传数据)
primary key (id)
);
bss_service_area_detail.md
)markdown
## bss_service_area(服务区基础信息表)
bss_service_area 表记录高速公路服务区的基础属性,包括服务区编码、名称、方向、公司归属、地理位置、服务类型和状态,是业务分析与服务区定位的核心表。
字段列表:
- id (varchar(32)) - 服务区唯一标识(主键,UUID格式)[示例: 0271d68ef93de9684b7ad8c7aae600b6]
- version (integer) - 数据版本号 [示例: 3]
- create_ts (timestamp(3)) - 创建时间 [示例: 2021-05-21 13:26:40.589]
- created_by (varchar(50)) - 创建人 [示例: admin]
- update_ts (timestamp(3)) - 更新时间 [示例: 2021-07-10 15:41:28.795]
- updated_by (varchar(50)) - 更新人 [示例: admin]
- delete_ts (timestamp(3)) - 删除时间
- deleted_by (varchar(50)) - 删除人
- service_area_name (varchar(255)) - 服务区名称 [示例: 鄱阳湖服务区]
- service_area_no (varchar(255)) - 服务区编码(业务唯一标识)[示例: H0509]
- company_id (varchar(32)) - 公司ID(外键关联bss_company.id)[示例: b1629f07c8d9ac81494fbc1de61f1ea5]
- service_position (varchar(255)) - 经纬度坐标 [示例: 114.574721,26.825584]
- service_area_type (varchar(50)) - 服务区类型(枚举:信息化服务区、智能化服务区)[示例: 信息化服务区]
- service_state (varchar(50)) - 服务区状态(枚举:开放、关闭、上传数据)[示例: 开放]
字段补充说明:
- id 为主键,使用 UUID 编码,唯一标识每个服务区
- company_id 外键关联服务区管理公司表(bss_company.id)
- service_position 经纬度格式为"经度,纬度"
- service_area_type 为枚举字段,包含两个取值:信息化服务区、智能化服务区
- service_state 为枚举字段,包含三个取值:开放、关闭、上传数据
- 本表是多个表(bss_branch, bss_car_day_count等)的核心关联实体
bss_business_day_data.ddl
)sql
-- 中文名: 档口日营业数据表
-- 描述: 记录每天每个档口的营业情况,包含微信、支付宝、现金、金豆等支付方式的金额与订单数,是核心交易数据表。
create table bss_business_day_data (
id varchar(32) not null, -- 主键ID
version integer not null, -- 数据版本号
create_ts timestamp(3), -- 创建时间
created_by varchar(50), -- 创建人
update_ts timestamp(3), -- 更新时间
updated_by varchar(50), -- 更新人
delete_ts timestamp(3), -- 删除时间
deleted_by varchar(50), -- 删除人
oper_date date, -- 统计日期
service_no varchar(255), -- 服务区编码
service_name varchar(255), -- 服务区名称
branch_no varchar(255), -- 档口编码
branch_name varchar(255), -- 档口名称
wx numeric(19,4), -- 微信支付金额
wx_order integer, -- 微信支付订单数量
zfb numeric(19,4), -- 支付宝支付金额
zf_order integer, -- 支付宝支付订单数量
rmb numeric(19,4), -- 现金支付金额
rmb_order integer, -- 现金支付订单数量
xs numeric(19,4), -- 行吧支付金额
xs_order integer, -- 行吧支付订单数量
jd numeric(19,4), -- 金豆支付金额
jd_order integer, -- 金豆支付订单数量
order_sum integer, -- 订单总数
pay_sum numeric(19,4), -- 支付总金额
source_type integer, -- 数据来源类型ID
primary key (id)
);
config.py
)import os
import sys
# 导入app_config获取数据库等配置
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
try:
import app_config
except ImportError:
app_config = None
# Schema Tools专用配置
SCHEMA_TOOLS_CONFIG = {
# 核心配置
"default_db_connection": None, # 从命令行指定
"default_business_context": "数据库管理系统",
"output_directory": "training/generated_data",
# 处理链配置
"default_pipeline": "full",
"available_pipelines": {
"full": [
"database_inspector",
"data_sampler",
"comment_generator",
"ddl_generator",
"doc_generator"
],
"ddl_only": [
"database_inspector",
"data_sampler",
"comment_generator",
"ddl_generator"
],
"analysis_only": [
"database_inspector",
"data_sampler",
"comment_generator"
]
},
# 数据处理配置
"sample_data_limit": 20, # 用于LLM分析的采样数据量
"enum_detection_sample_limit": 5000, # 枚举检测时的采样限制
"enum_max_distinct_values": 20, # 枚举字段最大不同值数量
"enum_varchar_keywords": [ # VARCHAR枚举关键词
"性别", "gender", "状态", "status", "类型", "type",
"级别", "level", "方向", "direction", "品类", "classify",
"模式", "mode", "格式", "format"
],
"large_table_threshold": 1000000, # 大表阈值(行数)
# 并发配置
"max_concurrent_tables": 1, # 建议保持1,避免LLM并发调用问题 # 最大并发处理表数
# LLM配置
"use_app_config_llm": True, # 是否使用app_config中的LLM配置
"comment_generation_timeout": 30, # LLM调用超时时间(秒)
"max_llm_retries": 3, # LLM调用最大重试次数
# 系统表过滤配置
"filter_system_tables": True, # 是否过滤系统表
"custom_system_prefixes": [], # 用户自定义系统表前缀
"custom_system_schemas": [], # 用户自定义系统schema
# 权限与安全配置
"check_permissions": True, # 是否检查数据库权限
"require_select_permission": True, # 是否要求SELECT权限
"allow_readonly_database": True, # 是否允许只读数据库
# 错误处理配置
"continue_on_error": True, # 遇到错误是否继续
"max_table_failures": 5, # 最大允许失败表数
"skip_large_tables": False, # 是否跳过超大表
"max_table_size": 10000000, # 最大表行数限制
# 文件配置
"ddl_file_suffix": ".ddl",
"doc_file_suffix": "_detail.md",
"log_file": "schema_tools.log",
"create_subdirectories": True, # 是否创建ddl/docs子目录
# 输出格式配置
"include_sample_data_in_comments": True, # 注释中是否包含示例数据
"max_comment_length": 500, # 最大注释长度
"include_field_statistics": True, # 是否包含字段统计信息
# 调试配置
"debug_mode": False, # 调试模式
"save_llm_prompts": False, # 是否保存LLM提示词
"save_llm_responses": False, # 是否保存LLM响应
}
# 从app_config获取相关配置(如果可用)
if app_config:
# 继承数据库配置
if hasattr(app_config, 'PGVECTOR_CONFIG'):
pgvector_config = app_config.PGVECTOR_CONFIG
if not SCHEMA_TOOLS_CONFIG["default_db_connection"]:
SCHEMA_TOOLS_CONFIG["default_db_connection"] = (
f"postgresql://{pgvector_config['user']}:{pgvector_config['password']}"
f"@{pgvector_config['host']}:{pgvector_config['port']}/{pgvector_config['dbname']}"
)
def get_config():
"""获取当前配置"""
return SCHEMA_TOOLS_CONFIG
def update_config(**kwargs):
"""更新配置"""
SCHEMA_TOOLS_CONFIG.update(kwargs)
def validate_config():
"""验证配置有效性"""
errors = []
# 检查必要配置
if SCHEMA_TOOLS_CONFIG["max_concurrent_tables"] <= 0:
errors.append("max_concurrent_tables 必须大于0")
if SCHEMA_TOOLS_CONFIG["sample_data_limit"] <= 0:
errors.append("sample_data_limit 必须大于0")
# 检查处理链配置
default_pipeline = SCHEMA_TOOLS_CONFIG["default_pipeline"]
available_pipelines = SCHEMA_TOOLS_CONFIG["available_pipelines"]
if default_pipeline not in available_pipelines:
errors.append(f"default_pipeline '{default_pipeline}' 不在 available_pipelines 中")
if errors:
raise ValueError("配置验证失败:\n" + "\n".join(f" - {error}" for error in errors))
return True
# 启动时验证配置
try:
validate_config()
except ValueError as e:
print(f"警告: {e}")
这个详细设计文档涵盖了Schema Tools的完整实现,包括:
Question-SQL生成功能是Schema Tools的扩展模块,用于从已生成的DDL和MD文件自动生成高质量的Question-SQL训练数据对。该功能可以独立运行,支持人工检查DDL/MD文件后再执行。
qs_agent.py
)class QuestionSQLGenerationAgent:
"""Question-SQL生成Agent"""
def __init__(self,
output_dir: str,
table_list_file: str,
business_context: str,
db_name: str = None):
"""
初始化Agent
Args:
output_dir: 输出目录(包含DDL和MD文件)
table_list_file: 表清单文件路径
business_context: 业务上下文
db_name: 数据库名称(用于输出文件命名)
"""
self.output_dir = Path(output_dir)
self.table_list_file = table_list_file
self.business_context = business_context
self.db_name = db_name or "db"
# 初始化组件
self.validator = FileCountValidator()
self.md_analyzer = MDFileAnalyzer(output_dir)
self.theme_extractor = None # 延迟初始化
# 中间结果存储
self.intermediate_results = []
self.intermediate_file = None
validators/file_count_validator.py
)@dataclass
class ValidationResult:
"""验证结果"""
is_valid: bool
table_count: int
ddl_count: int
md_count: int
error: str = ""
missing_ddl: List[str] = field(default_factory=list)
missing_md: List[str] = field(default_factory=list)
class FileCountValidator:
"""文件数量验证器"""
def validate(self, table_list_file: str, output_dir: str) -> ValidationResult:
"""
验证生成的文件数量是否与表数量一致
主要验证:
1. 表数量是否超过20个限制
2. DDL文件数量是否与表数量一致
3. MD文件数量是否与表数量一致
"""
# 解析表清单
tables = self.table_parser.parse_file(table_list_file)
table_count = len(tables)
# 检查表数量限制
max_tables = self.config['qs_generation']['max_tables']
if table_count > max_tables:
return ValidationResult(
is_valid=False,
table_count=table_count,
ddl_count=0,
md_count=0,
error=f"表数量({table_count})超过限制({max_tables})"
)
analyzers/md_analyzer.py
)class MDFileAnalyzer:
"""MD文件分析器"""
async def read_all_md_files(self) -> str:
"""
读取所有MD文件的完整内容
Returns:
所有MD文件内容的组合字符串
"""
md_files = sorted(self.output_dir.glob("*_detail.md"))
all_contents = []
all_contents.append(f"# 数据库表结构文档汇总\n")
all_contents.append(f"共包含 {len(md_files)} 个表\n\n")
for md_file in md_files:
content = md_file.read_text(encoding='utf-8')
# 添加分隔符,便于LLM区分不同表
all_contents.append("=" * 80)
all_contents.append(f"# 文件: {md_file.name}")
all_contents.append("=" * 80)
all_contents.append(content)
all_contents.append("\n")
combined_content = "\n".join(all_contents)
# 检查内容大小(预估token数)
estimated_tokens = len(combined_content) / 4
if estimated_tokens > 100000:
self.logger.warning(f"MD内容可能过大,预估tokens: {estimated_tokens:.0f}")
return combined_content
analyzers/theme_extractor.py
)class ThemeExtractor:
"""主题提取器"""
async def extract_themes(self, md_contents: str) -> List[Dict[str, Any]]:
"""
从MD内容中提取分析主题
"""
prompt = f"""你是一位经验丰富的业务数据分析师,正在分析{self.business_context}的数据库。
以下是数据库中所有表的详细结构说明:
{md_contents}
基于对这些表结构的理解,请从业务分析的角度提出 {theme_count} 个数据查询分析主题。
要求:
1. 每个主题应该有明确的业务价值和分析目标
2. 主题之间应该有所区别,覆盖不同的业务领域
3. 你需要自行决定每个主题应该涉及哪些表
4. 主题应该体现实际业务场景的数据分析需求
5. 考虑时间维度、对比分析、排名统计等多种分析角度
请以JSON格式输出:
```json
{{
"themes": [
{{
"name": "经营收入分析",
"description": "分析服务区的营业收入情况,包括日收入趋势、月度对比、服务区排名等",
"focus_areas": ["收入趋势", "服务区对比", "时间维度分析"],
"related_tables": ["bss_business_day_data", "其他相关表名"]
}}
]
}}
```"""
response = await self._call_llm(prompt)
themes = self._parse_theme_response(response)
return themes
async def generate(self) -> Dict[str, Any]:
"""生成Question-SQL对"""
# 1. 验证文件数量
validation_result = self.validator.validate(self.table_list_file, str(self.output_dir))
if not validation_result.is_valid:
raise ValueError(f"文件验证失败: {validation_result.error}")
# 2. 读取所有MD文件内容
md_contents = await self.md_analyzer.read_all_md_files()
# 3. 初始化LLM组件
self._initialize_llm_components()
# 4. 提取分析主题
themes = await self.theme_extractor.extract_themes(md_contents)
# 5. 初始化中间结果文件
self._init_intermediate_file()
# 6. 处理每个主题
if self.config['qs_generation']['max_concurrent_themes'] > 1:
results = await self._process_themes_parallel(themes, md_contents)
else:
results = await self._process_themes_serial(themes, md_contents)
# 7. 保存最终结果
output_file = await self._save_final_results(all_qs_pairs)
return report
async def _process_single_theme(self, theme: Dict, md_contents: str) -> Dict:
"""处理单个主题"""
prompt = f"""你是一位业务数据分析师,正在为{self.business_context}设计数据查询。
当前分析主题:{theme['name']}
主题描述:{theme['description']}
关注领域:{', '.join(theme['focus_areas'])}
相关表:{', '.join(theme['related_tables'])}
数据库表结构信息:
{md_contents}
请为这个主题生成 {questions_count} 个业务问题和对应的SQL查询。
要求:
1. 问题应该从业务角度出发,贴合主题要求,具有实际分析价值
2. SQL必须使用PostgreSQL语法
3. 考虑实际业务逻辑(如软删除使用 delete_ts IS NULL 条件)
4. 使用中文别名提高可读性(使用 AS 指定列别名)
5. 问题应该多样化,覆盖不同的分析角度
6. 包含时间筛选、分组统计、排序、限制等不同类型的查询
7. SQL语句末尾必须以分号结束
输出JSON格式:
```json
[
{{
"question": "具体的业务问题?",
"sql": "SELECT column AS 中文名 FROM table WHERE condition;"
}}
]
```"""
response = await self._call_llm(prompt)
qs_pairs = self._parse_qs_response(response)
validated_pairs = self._validate_qs_pairs(qs_pairs, theme['name'])
# 保存中间结果
await self._save_theme_results(theme['name'], validated_pairs)
return {
'success': True,
'theme_name': theme['name'],
'qs_pairs': validated_pairs
}
def _init_intermediate_file(self):
"""初始化中间结果文件"""
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
self.intermediate_file = self.output_dir / f"qs_intermediate_{timestamp}.json"
self.intermediate_results = []
async def _save_theme_results(self, theme_name: str, qs_pairs: List[Dict]):
"""保存单个主题的结果"""
theme_result = {
"theme": theme_name,
"timestamp": datetime.now().isoformat(),
"questions_count": len(qs_pairs),
"questions": qs_pairs
}
self.intermediate_results.append(theme_result)
# 立即保存到中间文件
if self.config['qs_generation']['save_intermediate']:
with open(self.intermediate_file, 'w', encoding='utf-8') as f:
json.dump(self.intermediate_results, f, ensure_ascii=False, indent=2)
validators/sql_validator.py
)@dataclass
class SQLValidationResult:
"""SQL验证结果"""
sql: str
valid: bool
error_message: str = ""
execution_time: float = 0.0
retry_count: int = 0
# SQL修复相关字段
repair_attempted: bool = False
repair_successful: bool = False
repaired_sql: str = ""
repair_error: str = ""
@dataclass
class ValidationStats:
"""验证统计信息"""
total_sqls: int = 0
valid_sqls: int = 0
invalid_sqls: int = 0
total_time: float = 0.0
avg_time_per_sql: float = 0.0
retry_count: int = 0
# SQL修复统计
repair_attempted: int = 0
repair_successful: int = 0
repair_failed: int = 0
class SQLValidator:
"""SQL验证器核心类"""
def __init__(self, db_connection: str = None):
self.db_connection = db_connection
self.connection_pool = None
self.config = SCHEMA_TOOLS_CONFIG['sql_validation']
async def validate_sql(self, sql: str, retry_count: int = 0) -> SQLValidationResult:
"""验证单个SQL语句"""
start_time = time.time()
try:
if not self.connection_pool:
await self._get_connection_pool()
# 使用EXPLAIN验证SQL语法和表结构
explain_sql = f"EXPLAIN {sql}"
async with self.connection_pool.acquire() as conn:
# 设置只读模式
if self.config['readonly_mode']:
await conn.execute("SET default_transaction_read_only = on")
# 执行EXPLAIN
await asyncio.wait_for(
conn.fetch(explain_sql),
timeout=self.config['validation_timeout']
)
execution_time = time.time() - start_time
return SQLValidationResult(
sql=sql,
valid=True,
execution_time=execution_time,
retry_count=retry_count
)
except asyncio.TimeoutError:
execution_time = time.time() - start_time
error_msg = f"SQL验证超时({self.config['validation_timeout']}秒)"
return SQLValidationResult(
sql=sql,
valid=False,
error_message=error_msg,
execution_time=execution_time,
retry_count=retry_count
)
except Exception as e:
execution_time = time.time() - start_time
error_msg = str(e)
# 检查是否需要重试
if retry_count < self.config['max_retry_count'] and self._should_retry(e):
await asyncio.sleep(0.5) # 短暂延迟
return await self.validate_sql(sql, retry_count + 1)
return SQLValidationResult(
sql=sql,
valid=False,
error_message=error_msg,
execution_time=execution_time,
retry_count=retry_count
)
async def validate_sqls_batch(self, sqls: List[str]) -> List[SQLValidationResult]:
"""批量验证SQL语句"""
max_concurrent = self.config['max_concurrent_validations']
semaphore = asyncio.Semaphore(max_concurrent)
async def validate_with_semaphore(sql):
async with semaphore:
return await self.validate_sql(sql)
# 并发执行验证
tasks = [validate_with_semaphore(sql) for sql in sqls]
results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理异常结果
processed_results = []
for i, result in enumerate(results):
if isinstance(result, Exception):
processed_results.append(SQLValidationResult(
sql=sqls[i],
valid=False,
error_message=f"验证异常: {str(result)}"
))
else:
processed_results.append(result)
return processed_results
def calculate_stats(self, results: List[SQLValidationResult]) -> ValidationStats:
"""计算验证统计信息"""
stats = ValidationStats()
stats.total_sqls = len(results)
stats.valid_sqls = sum(1 for r in results if r.valid)
stats.invalid_sqls = stats.total_sqls - stats.valid_sqls
stats.total_time = sum(r.execution_time for r in results)
stats.avg_time_per_sql = stats.total_time / stats.total_sqls if stats.total_sqls > 0 else 0.0
stats.retry_count = sum(r.retry_count for r in results)
# 修复统计
stats.repair_attempted = sum(1 for r in results if r.repair_attempted)
stats.repair_successful = sum(1 for r in results if r.repair_successful)
stats.repair_failed = stats.repair_attempted - stats.repair_successful
return stats
sql_validation_agent.py
)class SQLValidationAgent:
"""SQL验证Agent - 管理SQL验证的完整流程"""
async def validate(self) -> Dict[str, Any]:
"""执行SQL验证流程"""
# 1. 读取输入文件
questions_sqls = await self._load_questions_sqls()
# 2. 提取SQL语句
sqls = [item['sql'] for item in questions_sqls]
# 3. 执行验证
validation_results = await self._validate_sqls_with_batching(sqls)
# 4. 计算统计信息
stats = self.validator.calculate_stats(validation_results)
# 5. 尝试修复失败的SQL(如果启用LLM修复)
if self.config.get('enable_sql_repair', False) and self.vn:
validation_results = await self._attempt_sql_repair(questions_sqls, validation_results)
stats = self.validator.calculate_stats(validation_results)
# 6. 修改原始JSON文件(如果启用文件修改)
file_modification_stats = {'modified': 0, 'deleted': 0, 'failed_modifications': 0}
if self.config.get('modify_original_file', False):
file_modification_stats = await self._modify_original_json_file(questions_sqls, validation_results)
# 7. 生成详细报告
report = await self._generate_report(questions_sqls, validation_results, stats, file_modification_stats)
# 8. 保存验证报告
if self.config['save_validation_report']:
await self._save_validation_report(report)
return report
async def _attempt_sql_repair(self, questions_sqls: List[Dict], validation_results: List[SQLValidationResult]) -> List[SQLValidationResult]:
"""尝试修复失败的SQL"""
failed_indices = [i for i, result in enumerate(validation_results) if not result.valid]
if not failed_indices:
return validation_results
# 批量修复
batch_size = self.config.get('repair_batch_size', 5)
updated_results = validation_results.copy()
for i in range(0, len(failed_indices), batch_size):
batch_indices = failed_indices[i:i + batch_size]
# 准备批次数据
batch_data = []
for idx in batch_indices:
batch_data.append({
'index': idx,
'question': questions_sqls[idx]['question'],
'sql': validation_results[idx].sql,
'error': validation_results[idx].error_message
})
# 调用LLM修复
repaired_sqls = await self._repair_sqls_with_llm(batch_data)
# 验证修复后的SQL
for j, idx in enumerate(batch_indices):
original_result = updated_results[idx]
original_result.repair_attempted = True
if j < len(repaired_sqls) and repaired_sqls[j]:
repaired_sql = repaired_sqls[j]
# 验证修复后的SQL
repair_result = await self.validator.validate_sql(repaired_sql)
if repair_result.valid:
# 修复成功
original_result.repair_successful = True
original_result.repaired_sql = repaired_sql
original_result.valid = True # 更新为有效
else:
# 修复失败
original_result.repair_successful = False
original_result.repair_error = repair_result.error_message
else:
# LLM修复失败
original_result.repair_successful = False
original_result.repair_error = "LLM修复失败或返回空结果"
return updated_results
async def _modify_original_json_file(self, questions_sqls: List[Dict], validation_results: List[SQLValidationResult]) -> Dict[str, int]:
"""修改原始JSON文件"""
stats = {'modified': 0, 'deleted': 0, 'failed_modifications': 0}
try:
# 读取原始JSON文件
with open(self.input_file, 'r', encoding='utf-8') as f:
original_data = json.load(f)
# 创建备份文件
backup_file = Path(str(self.input_file) + '.backup')
with open(backup_file, 'w', encoding='utf-8') as f:
json.dump(original_data, f, ensure_ascii=False, indent=2)
# 构建修改计划
modifications = []
deletions = []
for i, (qs, result) in enumerate(zip(questions_sqls, validation_results)):
if result.repair_successful and result.repaired_sql:
# 修复成功的SQL
modifications.append({
'index': i,
'original_sql': result.sql,
'repaired_sql': result.repaired_sql,
'question': qs['question']
})
elif not result.valid and not result.repair_successful:
# 无法修复的SQL,标记删除
deletions.append({
'index': i,
'question': qs['question'],
'sql': result.sql,
'error': result.error_message
})
# 执行修改(从后往前,避免索引变化)
new_data = original_data.copy()
# 先删除无效项(从后往前删除)
for deletion in sorted(deletions, key=lambda x: x['index'], reverse=True):
if deletion['index'] < len(new_data):
new_data.pop(deletion['index'])
stats['deleted'] += 1
# 再修改SQL(需要重新计算索引)
for modification in sorted(modifications, key=lambda x: x['index']):
# 计算删除后的新索引
new_index = modification['index']
for deletion in deletions:
if deletion['index'] < modification['index']:
new_index -= 1
if new_index < len(new_data):
new_data[new_index]['sql'] = modification['repaired_sql']
stats['modified'] += 1
# 写入修改后的文件
with open(self.input_file, 'w', encoding='utf-8') as f:
json.dump(new_data, f, ensure_ascii=False, indent=2)
# 记录详细修改信息到日志文件
await self._write_modification_log(modifications, deletions)
except Exception as e:
stats['failed_modifications'] = 1
return stats
class SchemaWorkflowOrchestrator:
"""端到端的Schema处理编排器"""
async def execute_complete_workflow(self) -> Dict[str, Any]:
"""执行完整的Schema处理工作流程"""
# 步骤1: 生成DDL和MD文件
await self._execute_step_1_ddl_md_generation()
# 步骤2: 生成Question-SQL对
await self._execute_step_2_question_sql_generation()
# 步骤3: 验证和修正SQL(可选)
if self.enable_sql_validation:
await self._execute_step_3_sql_validation()
# 生成最终报告
final_report = await self._generate_final_report()
return final_report
这样,文档就与当前代码完全一致了,包含了所有新增的SQL验证、LLM修复、文件修改等功能的详细设计说明。