doc_generator.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. import os
  2. from typing import List, Dict, Any
  3. from schema_tools.tools.base import BaseTool, ToolRegistry
  4. from schema_tools.utils.data_structures import ProcessingResult, TableProcessingContext, FieldInfo, TableMetadata
  5. from schema_tools.config import SCHEMA_TOOLS_CONFIG
  6. @ToolRegistry.register("doc_generator")
  7. class DocGeneratorTool(BaseTool):
  8. """MD文档生成工具"""
  9. needs_llm = False
  10. tool_name = "文档生成器"
  11. def __init__(self, **kwargs):
  12. super().__init__(**kwargs)
  13. async def execute(self, context: TableProcessingContext) -> ProcessingResult:
  14. """执行MD文档生成"""
  15. try:
  16. table_metadata = context.table_metadata
  17. # 获取DDL生成结果(如果有)
  18. ddl_result = context.step_results.get('ddl_generator')
  19. ddl_content = ddl_result.data.get('ddl_content', '') if ddl_result and ddl_result.success else ''
  20. # 生成MD内容
  21. md_content = self._generate_md_content(table_metadata, ddl_content)
  22. # 确定文件名和路径
  23. filename = context.file_manager.get_safe_filename(
  24. table_metadata.schema_name,
  25. table_metadata.table_name,
  26. SCHEMA_TOOLS_CONFIG["doc_file_suffix"]
  27. )
  28. # 确定子目录
  29. subdirectory = "docs" if SCHEMA_TOOLS_CONFIG["create_subdirectories"] else None
  30. filepath = context.file_manager.get_full_path(filename, subdirectory)
  31. # 写入文件
  32. with open(filepath, 'w', encoding='utf-8') as f:
  33. f.write(md_content)
  34. self.logger.info(f"MD文档已生成: {filepath}")
  35. return ProcessingResult(
  36. success=True,
  37. data={
  38. 'filename': filename,
  39. 'filepath': filepath,
  40. 'content_length': len(md_content)
  41. },
  42. metadata={'tool': self.tool_name}
  43. )
  44. except Exception as e:
  45. self.logger.exception(f"MD文档生成失败")
  46. return ProcessingResult(
  47. success=False,
  48. error_message=f"MD文档生成失败: {str(e)}"
  49. )
  50. def _generate_md_content(self, table_metadata: TableMetadata, ddl_content: str) -> str:
  51. """生成MD文档内容"""
  52. lines = []
  53. # 标题 - 只显示表名,不加解释和字数统计
  54. if table_metadata.comment:
  55. # 提取表名部分(去掉解释和字数统计)
  56. comment = table_metadata.comment
  57. # 去掉可能的字数统计 (XX字)
  58. import re
  59. comment = re.sub(r'[((]\d+字[))]', '', comment)
  60. # 只取第一句话或逗号前的部分
  61. if ',' in comment:
  62. table_name_part = comment.split(',')[0]
  63. elif '。' in comment:
  64. table_name_part = comment.split('。')[0]
  65. else:
  66. table_name_part = comment.strip()
  67. lines.append(f"## {table_metadata.table_name}({table_name_part})")
  68. # 表描述
  69. lines.append(f"{table_metadata.table_name} 表{comment}")
  70. else:
  71. lines.append(f"## {table_metadata.table_name}(数据表)")
  72. lines.append(f"{table_metadata.table_name} 表")
  73. # 字段列表(去掉前面的空行)
  74. lines.append("字段列表:")
  75. for field in table_metadata.fields:
  76. field_line = self._generate_field_line(field, table_metadata)
  77. lines.append(field_line)
  78. # 字段补充说明(去掉前面的空行)
  79. supplementary_info = self._generate_supplementary_info(table_metadata)
  80. if supplementary_info:
  81. lines.append("字段补充说明:")
  82. lines.extend(supplementary_info)
  83. # DDL语句(可选)
  84. if ddl_content and SCHEMA_TOOLS_CONFIG.get("include_ddl_in_doc", False):
  85. lines.append("### DDL语句")
  86. lines.append("```sql")
  87. lines.append(ddl_content)
  88. lines.append("```")
  89. lines.append("")
  90. # 删除表统计信息部分
  91. return '\n'.join(lines)
  92. def _generate_field_line(self, field: FieldInfo, table_metadata: TableMetadata) -> str:
  93. """生成字段说明行"""
  94. # 基础信息
  95. parts = [f"- {field.name}"]
  96. # 类型信息
  97. type_info = self._format_field_type_for_doc(field)
  98. parts.append(f"({type_info})")
  99. # 注释
  100. if field.comment:
  101. parts.append(f"- {field.comment}")
  102. # 约束信息
  103. constraints = []
  104. if field.is_primary_key:
  105. constraints.append("主键")
  106. if field.is_foreign_key:
  107. constraints.append("外键")
  108. if not field.nullable:
  109. constraints.append("非空")
  110. if constraints:
  111. parts.append(f"[{', '.join(constraints)}]")
  112. # 示例值(枚举类型显示更多,其他类型只显示2个)
  113. sample_values = self._get_field_sample_values(field.name, table_metadata)
  114. if sample_values:
  115. if field.is_enum:
  116. # 枚举类型最多显示10个
  117. display_values = sample_values[:10]
  118. else:
  119. # 其他类型只显示2个
  120. display_values = sample_values[:2]
  121. sample_str = f"[示例: {', '.join(display_values)}]"
  122. parts.append(sample_str)
  123. return ' '.join(parts)
  124. def _format_field_type_for_doc(self, field: FieldInfo) -> str:
  125. """为文档格式化字段类型"""
  126. if field.type.lower() in ['character varying', 'varchar'] and field.max_length:
  127. return f"varchar({field.max_length})"
  128. elif field.type.lower() == 'numeric' and field.precision:
  129. if field.scale:
  130. return f"numeric({field.precision},{field.scale})"
  131. else:
  132. return f"numeric({field.precision})"
  133. elif 'timestamp' in field.type.lower():
  134. if '(' in field.type:
  135. # 提取精度
  136. precision = field.type.split('(')[1].split(')')[0]
  137. return f"timestamp({precision})"
  138. return "timestamp"
  139. else:
  140. return field.type
  141. def _get_field_sample_values(self, field_name: str, table_metadata: TableMetadata) -> List[str]:
  142. """获取字段的示例值"""
  143. sample_values = []
  144. seen_values = set()
  145. for sample in table_metadata.sample_data:
  146. if field_name in sample:
  147. value = sample[field_name]
  148. if value is not None:
  149. str_value = str(value)
  150. if str_value not in seen_values:
  151. seen_values.add(str_value)
  152. sample_values.append(str_value)
  153. if len(sample_values) >= 3:
  154. break
  155. return sample_values
  156. def _generate_supplementary_info(self, table_metadata: TableMetadata) -> List[str]:
  157. """生成字段补充说明"""
  158. info_lines = []
  159. # 主键信息
  160. primary_keys = [f.name for f in table_metadata.fields if f.is_primary_key]
  161. if primary_keys:
  162. if len(primary_keys) == 1:
  163. info_lines.append(f"- {primary_keys[0]} 为主键")
  164. else:
  165. info_lines.append(f"- 复合主键:{', '.join(primary_keys)}")
  166. # 外键信息
  167. foreign_keys = [(f.name, f.comment) for f in table_metadata.fields if f.is_foreign_key]
  168. for fk_name, fk_comment in foreign_keys:
  169. if fk_comment and '关联' in fk_comment:
  170. info_lines.append(f"- {fk_name} {fk_comment}")
  171. else:
  172. info_lines.append(f"- {fk_name} 为外键")
  173. # 枚举字段信息(包括逻辑枚举类型)
  174. enum_fields = [f for f in table_metadata.fields if f.is_enum and f.enum_values]
  175. for field in enum_fields:
  176. values_str = '、'.join(field.enum_values)
  177. # 不显示取值数量,因为可能不完整
  178. info_lines.append(f"- {field.name} 为枚举字段,包含取值:{values_str}")
  179. # 不显示enum_description,因为它通常是重复的描述
  180. # 检查逻辑枚举(字段名暗示但未被识别为枚举的字段)
  181. logical_enum_keywords = ["状态", "类型", "级别", "方向", "品类", "模式", "格式", "性别"]
  182. for field in table_metadata.fields:
  183. if not field.is_enum: # 只检查未被识别为枚举的字段
  184. field_name_lower = field.name.lower()
  185. if any(keyword in field_name_lower for keyword in logical_enum_keywords):
  186. # 获取该字段的示例值来判断是否可能是逻辑枚举
  187. sample_values = self._get_field_sample_values(field.name, table_metadata)
  188. if sample_values and len(sample_values) <= 10: # 如果样例值数量较少,可能是逻辑枚举
  189. values_str = '、'.join(sample_values[:10])
  190. info_lines.append(f"- {field.name} 疑似枚举字段,当前取值:{values_str}")
  191. # 特殊字段说明
  192. for field in table_metadata.fields:
  193. # UUID字段
  194. if field.type.lower() == 'uuid':
  195. info_lines.append(f"- {field.name} 使用 UUID 编码")
  196. # 时间戳字段
  197. elif 'timestamp' in field.type.lower() and field.default_value:
  198. if 'now()' in field.default_value.lower() or 'current_timestamp' in field.default_value.lower():
  199. info_lines.append(f"- {field.name} 自动记录当前时间")
  200. # JSON字段
  201. elif field.type.lower() in ['json', 'jsonb']:
  202. info_lines.append(f"- {field.name} 存储JSON格式数据")
  203. # 表关联说明
  204. if table_metadata.table_name.endswith('_rel') or table_metadata.table_name.endswith('_relation'):
  205. info_lines.append(f"- 本表是关联表,用于多对多关系映射")
  206. return info_lines
  207. def _generate_statistics_info(self, table_metadata: TableMetadata) -> List[str]:
  208. """生成表统计信息"""
  209. stats_lines = []
  210. if table_metadata.row_count is not None:
  211. stats_lines.append(f"- 数据行数:{table_metadata.row_count:,}")
  212. if table_metadata.table_size:
  213. stats_lines.append(f"- 表大小:{table_metadata.table_size}")
  214. # 字段统计
  215. total_fields = len(table_metadata.fields)
  216. nullable_fields = sum(1 for f in table_metadata.fields if f.nullable)
  217. enum_fields = sum(1 for f in table_metadata.fields if f.is_enum)
  218. stats_lines.append(f"- 字段总数:{total_fields}")
  219. if nullable_fields > 0:
  220. stats_lines.append(f"- 可空字段:{nullable_fields}")
  221. if enum_fields > 0:
  222. stats_lines.append(f"- 枚举字段:{enum_fields}")
  223. return stats_lines