ddl_generator.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. import os
  2. from typing import List, Dict, Any
  3. from schema_tools.tools.base import BaseTool, ToolRegistry
  4. from schema_tools.utils.data_structures import ProcessingResult, TableProcessingContext, FieldInfo, TableMetadata
  5. from schema_tools.config import SCHEMA_TOOLS_CONFIG
  6. @ToolRegistry.register("ddl_generator")
  7. class DDLGeneratorTool(BaseTool):
  8. """DDL格式生成工具"""
  9. needs_llm = False
  10. tool_name = "DDL生成器"
  11. def __init__(self, **kwargs):
  12. super().__init__(**kwargs)
  13. async def execute(self, context: TableProcessingContext) -> ProcessingResult:
  14. """执行DDL生成"""
  15. try:
  16. table_metadata = context.table_metadata
  17. # 生成DDL内容
  18. ddl_content = self._generate_ddl_content(table_metadata)
  19. # 确定文件名和路径
  20. filename = context.file_manager.get_safe_filename(
  21. table_metadata.schema_name,
  22. table_metadata.table_name,
  23. SCHEMA_TOOLS_CONFIG["ddl_file_suffix"]
  24. )
  25. # 确定子目录
  26. subdirectory = "ddl" if SCHEMA_TOOLS_CONFIG["create_subdirectories"] else None
  27. filepath = context.file_manager.get_full_path(filename, subdirectory)
  28. # 写入文件
  29. with open(filepath, 'w', encoding='utf-8') as f:
  30. f.write(ddl_content)
  31. self.logger.info(f"DDL文件已生成: {filepath}")
  32. return ProcessingResult(
  33. success=True,
  34. data={
  35. 'filename': filename,
  36. 'filepath': filepath,
  37. 'content_length': len(ddl_content),
  38. 'ddl_content': ddl_content # 保存内容供后续工具使用
  39. },
  40. metadata={'tool': self.tool_name}
  41. )
  42. except Exception as e:
  43. self.logger.exception(f"DDL生成失败")
  44. return ProcessingResult(
  45. success=False,
  46. error_message=f"DDL生成失败: {str(e)}"
  47. )
  48. def _generate_ddl_content(self, table_metadata: TableMetadata) -> str:
  49. """生成DDL内容"""
  50. lines = []
  51. # 表头注释 - 只显示表名,不加解释和字数统计
  52. if table_metadata.comment:
  53. # 提取表名部分(去掉解释和字数统计)
  54. comment = table_metadata.comment
  55. # 去掉可能的字数统计 (XX字)
  56. import re
  57. comment = re.sub(r'[((]\d+字[))]', '', comment)
  58. # 只取第一句话或逗号前的部分
  59. if ',' in comment:
  60. table_name_part = comment.split(',')[0]
  61. elif '。' in comment:
  62. table_name_part = comment.split('。')[0]
  63. else:
  64. table_name_part = comment.strip()
  65. lines.append(f"-- 中文名: {table_name_part}")
  66. lines.append(f"-- 描述: {comment}")
  67. else:
  68. lines.append(f"-- 中文名: {table_metadata.table_name}")
  69. # CREATE TABLE语句
  70. lines.append(f"create table {table_metadata.full_name} (")
  71. # 字段定义
  72. field_lines = []
  73. for field in table_metadata.fields:
  74. field_line = self._generate_field_line(field)
  75. field_lines.append(field_line)
  76. # 主键定义
  77. primary_keys = [f.name for f in table_metadata.fields if f.is_primary_key]
  78. if primary_keys:
  79. field_lines.append(f" primary key ({', '.join(primary_keys)})")
  80. # 组合字段行
  81. lines.extend([line if i == len(field_lines) - 1 else line + ","
  82. for i, line in enumerate(field_lines)])
  83. lines.append(");")
  84. return '\n'.join(lines)
  85. def _generate_field_line(self, field: FieldInfo) -> str:
  86. """生成字段定义行"""
  87. parts = [f" {field.name}"]
  88. # 字段类型
  89. field_type = self._format_field_type(field)
  90. parts.append(field_type)
  91. # NOT NULL约束
  92. if not field.nullable:
  93. parts.append("not null")
  94. # 默认值
  95. if field.default_value and not self._should_skip_default(field.default_value):
  96. parts.append(f"default {self._format_default_value(field.default_value)}")
  97. # 组合字段定义
  98. field_def = ' '.join(parts)
  99. # 添加注释
  100. comment = self._format_field_comment(field)
  101. if comment:
  102. # 计算对齐空格(减少到30个字符对齐)
  103. padding = max(1, 30 - len(field_def))
  104. field_line = f"{field_def}{' ' * padding}-- {comment}"
  105. else:
  106. field_line = field_def
  107. return field_line
  108. def _format_field_type(self, field: FieldInfo) -> str:
  109. """格式化字段类型"""
  110. field_type = field.type.lower()
  111. # 处理带长度的类型
  112. if field_type in ['character varying', 'varchar'] and field.max_length:
  113. return f"varchar({field.max_length})"
  114. elif field_type == 'character' and field.max_length:
  115. return f"char({field.max_length})"
  116. elif field_type == 'numeric' and field.precision:
  117. if field.scale:
  118. return f"numeric({field.precision},{field.scale})"
  119. else:
  120. return f"numeric({field.precision})"
  121. elif field_type == 'timestamp without time zone':
  122. return "timestamp"
  123. elif field_type == 'timestamp with time zone':
  124. return "timestamptz"
  125. elif field_type in ['integer', 'int']:
  126. return "integer"
  127. elif field_type in ['bigint', 'int8']:
  128. return "bigint"
  129. elif field_type in ['smallint', 'int2']:
  130. return "smallint"
  131. elif field_type in ['double precision', 'float8']:
  132. return "double precision"
  133. elif field_type in ['real', 'float4']:
  134. return "real"
  135. elif field_type == 'boolean':
  136. return "boolean"
  137. elif field_type == 'text':
  138. return "text"
  139. elif field_type == 'date':
  140. return "date"
  141. elif field_type == 'time without time zone':
  142. return "time"
  143. elif field_type == 'time with time zone':
  144. return "timetz"
  145. elif field_type == 'json':
  146. return "json"
  147. elif field_type == 'jsonb':
  148. return "jsonb"
  149. elif field_type == 'uuid':
  150. return "uuid"
  151. elif field_type.startswith('timestamp(') and 'without time zone' in field_type:
  152. # 处理 timestamp(3) without time zone
  153. precision = field_type.split('(')[1].split(')')[0]
  154. return f"timestamp({precision})"
  155. else:
  156. return field_type
  157. def _format_default_value(self, default_value: str) -> str:
  158. """格式化默认值"""
  159. # 移除可能的类型转换
  160. if '::' in default_value:
  161. default_value = default_value.split('::')[0]
  162. # 处理函数调用
  163. if default_value.lower() in ['now()', 'current_timestamp']:
  164. return 'current_timestamp'
  165. elif default_value.lower() == 'current_date':
  166. return 'current_date'
  167. # 处理字符串值
  168. if not (default_value.startswith("'") and default_value.endswith("'")):
  169. # 检查是否为数字或布尔值
  170. if default_value.lower() in ['true', 'false']:
  171. return default_value.lower()
  172. elif default_value.replace('.', '').replace('-', '').isdigit():
  173. return default_value
  174. else:
  175. # 其他情况加引号
  176. return f"'{default_value}'"
  177. return default_value
  178. def _should_skip_default(self, default_value: str) -> bool:
  179. """判断是否应跳过默认值"""
  180. # 跳过序列默认值
  181. if 'nextval(' in default_value.lower():
  182. return True
  183. # 跳过空字符串
  184. if default_value.strip() in ['', "''", '""']:
  185. return True
  186. return False
  187. def _format_field_comment(self, field: FieldInfo) -> str:
  188. """格式化字段注释"""
  189. comment_parts = []
  190. # 基础注释
  191. if field.comment:
  192. comment_parts.append(field.comment)
  193. # 主键标识
  194. if field.is_primary_key:
  195. comment_parts.append("主键")
  196. # 外键标识
  197. if field.is_foreign_key:
  198. comment_parts.append("外键")
  199. # 去掉小括号,直接返回注释内容
  200. return ','.join(comment_parts) if comment_parts else ""