import os from typing import List, Dict, Any from schema_tools.tools.base import BaseTool, ToolRegistry from schema_tools.utils.data_structures import ProcessingResult, TableProcessingContext, FieldInfo, TableMetadata from schema_tools.config import SCHEMA_TOOLS_CONFIG @ToolRegistry.register("doc_generator") class DocGeneratorTool(BaseTool): """MD文档生成工具""" needs_llm = False tool_name = "文档生成器" def __init__(self, **kwargs): super().__init__(**kwargs) async def execute(self, context: TableProcessingContext) -> ProcessingResult: """执行MD文档生成""" try: table_metadata = context.table_metadata # 获取DDL生成结果(如果有) ddl_result = context.step_results.get('ddl_generator') ddl_content = ddl_result.data.get('ddl_content', '') if ddl_result and ddl_result.success else '' # 生成MD内容 md_content = self._generate_md_content(table_metadata, ddl_content) # 确定文件名和路径 filename = context.file_manager.get_safe_filename( table_metadata.schema_name, table_metadata.table_name, SCHEMA_TOOLS_CONFIG["doc_file_suffix"] ) # 确定子目录 subdirectory = "docs" if SCHEMA_TOOLS_CONFIG["create_subdirectories"] else None filepath = context.file_manager.get_full_path(filename, subdirectory) # 写入文件 with open(filepath, 'w', encoding='utf-8') as f: f.write(md_content) self.logger.info(f"MD文档已生成: {filepath}") return ProcessingResult( success=True, data={ 'filename': filename, 'filepath': filepath, 'content_length': len(md_content) }, metadata={'tool': self.tool_name} ) except Exception as e: self.logger.exception(f"MD文档生成失败") return ProcessingResult( success=False, error_message=f"MD文档生成失败: {str(e)}" ) def _generate_md_content(self, table_metadata: TableMetadata, ddl_content: str) -> str: """生成MD文档内容""" lines = [] # 标题 - 只显示表名,不加解释和字数统计 if table_metadata.comment: # 提取表名部分(去掉解释和字数统计) comment = table_metadata.comment # 去掉可能的字数统计 (XX字) import re comment = re.sub(r'[((]\d+字[))]', '', comment) # 只取第一句话或逗号前的部分 if ',' in comment: table_name_part = comment.split(',')[0] elif '。' in comment: table_name_part = comment.split('。')[0] else: table_name_part = comment.strip() lines.append(f"## {table_metadata.table_name}({table_name_part})") # 表描述 lines.append(f"{table_metadata.table_name} 表{comment}") else: lines.append(f"## {table_metadata.table_name}(数据表)") lines.append(f"{table_metadata.table_name} 表") # 字段列表(去掉前面的空行) lines.append("字段列表:") for field in table_metadata.fields: field_line = self._generate_field_line(field, table_metadata) lines.append(field_line) # 字段补充说明(去掉前面的空行) supplementary_info = self._generate_supplementary_info(table_metadata) if supplementary_info: lines.append("字段补充说明:") lines.extend(supplementary_info) # DDL语句(可选) if ddl_content and SCHEMA_TOOLS_CONFIG.get("include_ddl_in_doc", False): lines.append("### DDL语句") lines.append("```sql") lines.append(ddl_content) lines.append("```") lines.append("") # 删除表统计信息部分 return '\n'.join(lines) def _generate_field_line(self, field: FieldInfo, table_metadata: TableMetadata) -> str: """生成字段说明行""" # 基础信息 parts = [f"- {field.name}"] # 类型信息 type_info = self._format_field_type_for_doc(field) parts.append(f"({type_info})") # 注释 if field.comment: parts.append(f"- {field.comment}") # 约束信息 constraints = [] if field.is_primary_key: constraints.append("主键") if field.is_foreign_key: constraints.append("外键") if not field.nullable: constraints.append("非空") if constraints: parts.append(f"[{', '.join(constraints)}]") # 示例值(枚举类型显示更多,其他类型只显示2个) sample_values = self._get_field_sample_values(field.name, table_metadata) if sample_values: if field.is_enum: # 枚举类型最多显示10个 display_values = sample_values[:10] else: # 其他类型只显示2个 display_values = sample_values[:2] sample_str = f"[示例: {', '.join(display_values)}]" parts.append(sample_str) return ' '.join(parts) def _format_field_type_for_doc(self, field: FieldInfo) -> str: """为文档格式化字段类型""" if field.type.lower() in ['character varying', 'varchar'] and field.max_length: return f"varchar({field.max_length})" elif field.type.lower() == 'numeric' and field.precision: if field.scale: return f"numeric({field.precision},{field.scale})" else: return f"numeric({field.precision})" elif 'timestamp' in field.type.lower(): if '(' in field.type: # 提取精度 precision = field.type.split('(')[1].split(')')[0] return f"timestamp({precision})" return "timestamp" else: return field.type def _get_field_sample_values(self, field_name: str, table_metadata: TableMetadata) -> List[str]: """获取字段的示例值""" sample_values = [] seen_values = set() for sample in table_metadata.sample_data: if field_name in sample: value = sample[field_name] if value is not None: str_value = str(value) if str_value not in seen_values: seen_values.add(str_value) sample_values.append(str_value) if len(sample_values) >= 3: break return sample_values def _generate_supplementary_info(self, table_metadata: TableMetadata) -> List[str]: """生成字段补充说明""" info_lines = [] # 主键信息 primary_keys = [f.name for f in table_metadata.fields if f.is_primary_key] if primary_keys: if len(primary_keys) == 1: info_lines.append(f"- {primary_keys[0]} 为主键") else: info_lines.append(f"- 复合主键:{', '.join(primary_keys)}") # 外键信息 foreign_keys = [(f.name, f.comment) for f in table_metadata.fields if f.is_foreign_key] for fk_name, fk_comment in foreign_keys: if fk_comment and '关联' in fk_comment: info_lines.append(f"- {fk_name} {fk_comment}") else: info_lines.append(f"- {fk_name} 为外键") # 枚举字段信息(包括逻辑枚举类型) enum_fields = [f for f in table_metadata.fields if f.is_enum and f.enum_values] for field in enum_fields: values_str = '、'.join(field.enum_values) # 不显示取值数量,因为可能不完整 info_lines.append(f"- {field.name} 为枚举字段,包含取值:{values_str}") # 不显示enum_description,因为它通常是重复的描述 # 检查逻辑枚举(字段名暗示但未被识别为枚举的字段) logical_enum_keywords = ["状态", "类型", "级别", "方向", "品类", "模式", "格式", "性别"] for field in table_metadata.fields: if not field.is_enum: # 只检查未被识别为枚举的字段 field_name_lower = field.name.lower() if any(keyword in field_name_lower for keyword in logical_enum_keywords): # 获取该字段的示例值来判断是否可能是逻辑枚举 sample_values = self._get_field_sample_values(field.name, table_metadata) if sample_values and len(sample_values) <= 10: # 如果样例值数量较少,可能是逻辑枚举 values_str = '、'.join(sample_values[:10]) info_lines.append(f"- {field.name} 疑似枚举字段,当前取值:{values_str}") # 特殊字段说明 for field in table_metadata.fields: # UUID字段 if field.type.lower() == 'uuid': info_lines.append(f"- {field.name} 使用 UUID 编码") # 时间戳字段 elif 'timestamp' in field.type.lower() and field.default_value: if 'now()' in field.default_value.lower() or 'current_timestamp' in field.default_value.lower(): info_lines.append(f"- {field.name} 自动记录当前时间") # JSON字段 elif field.type.lower() in ['json', 'jsonb']: info_lines.append(f"- {field.name} 存储JSON格式数据") # 表关联说明 if table_metadata.table_name.endswith('_rel') or table_metadata.table_name.endswith('_relation'): info_lines.append(f"- 本表是关联表,用于多对多关系映射") return info_lines def _generate_statistics_info(self, table_metadata: TableMetadata) -> List[str]: """生成表统计信息""" stats_lines = [] if table_metadata.row_count is not None: stats_lines.append(f"- 数据行数:{table_metadata.row_count:,}") if table_metadata.table_size: stats_lines.append(f"- 表大小:{table_metadata.table_size}") # 字段统计 total_fields = len(table_metadata.fields) nullable_fields = sum(1 for f in table_metadata.fields if f.nullable) enum_fields = sum(1 for f in table_metadata.fields if f.is_enum) stats_lines.append(f"- 字段总数:{total_fields}") if nullable_fields > 0: stats_lines.append(f"- 可空字段:{nullable_fields}") if enum_fields > 0: stats_lines.append(f"- 枚举字段:{enum_fields}") return stats_lines