ddl_generator.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. import os
  2. from typing import List, Dict, Any
  3. from data_pipeline.tools.base import BaseTool, ToolRegistry
  4. from data_pipeline.utils.data_structures import ProcessingResult, TableProcessingContext, FieldInfo, TableMetadata
  5. from data_pipeline.config import SCHEMA_TOOLS_CONFIG
  6. @ToolRegistry.register("ddl_generator")
  7. class DDLGeneratorTool(BaseTool):
  8. """DDL格式生成工具"""
  9. needs_llm = False
  10. tool_name = "DDL生成器"
  11. def __init__(self, **kwargs):
  12. super().__init__(**kwargs)
  13. async def execute(self, context: TableProcessingContext) -> ProcessingResult:
  14. """执行DDL生成"""
  15. try:
  16. table_metadata = context.table_metadata
  17. # 生成DDL内容
  18. ddl_content = self._generate_ddl_content(table_metadata)
  19. # 如果有file_manager,则写入文件(正常的data_pipeline流程)
  20. if context.file_manager:
  21. # 确定文件名和路径
  22. filename = context.file_manager.get_safe_filename(
  23. table_metadata.schema_name,
  24. table_metadata.table_name,
  25. SCHEMA_TOOLS_CONFIG["ddl_file_suffix"]
  26. )
  27. # 确定子目录
  28. subdirectory = "ddl" if SCHEMA_TOOLS_CONFIG["create_subdirectories"] else None
  29. filepath = context.file_manager.get_full_path(filename, subdirectory)
  30. # 写入文件
  31. with open(filepath, 'w', encoding='utf-8') as f:
  32. f.write(ddl_content)
  33. self.logger.info(f"DDL文件已生成: {filepath}")
  34. return ProcessingResult(
  35. success=True,
  36. data={
  37. 'filename': filename,
  38. 'filepath': filepath,
  39. 'content_length': len(ddl_content),
  40. 'ddl_content': ddl_content # 保存内容供后续工具使用
  41. },
  42. metadata={'tool': self.tool_name}
  43. )
  44. else:
  45. # 如果没有file_manager,只返回DDL内容(API调用场景)
  46. self.logger.info("DDL内容已生成(API调用模式,不写入文件)")
  47. return ProcessingResult(
  48. success=True,
  49. data={
  50. 'filename': f"{table_metadata.schema_name}_{table_metadata.table_name}.ddl",
  51. 'filepath': None, # 不写入文件
  52. 'content_length': len(ddl_content),
  53. 'ddl_content': ddl_content # 保存内容供后续工具使用
  54. },
  55. metadata={'tool': self.tool_name}
  56. )
  57. except Exception as e:
  58. self.logger.exception(f"DDL生成失败")
  59. return ProcessingResult(
  60. success=False,
  61. error_message=f"DDL生成失败: {str(e)}"
  62. )
  63. def _generate_ddl_content(self, table_metadata: TableMetadata) -> str:
  64. """生成DDL内容"""
  65. lines = []
  66. # 表头注释 - 只显示表名,不加解释和字数统计
  67. if table_metadata.comment:
  68. # 提取表名部分(去掉解释和字数统计)
  69. comment = table_metadata.comment
  70. # 去掉可能的字数统计 (XX字)
  71. import re
  72. comment = re.sub(r'[((]\d+字[))]', '', comment)
  73. # 只取第一句话或逗号前的部分
  74. if ',' in comment:
  75. table_name_part = comment.split(',')[0]
  76. elif '。' in comment:
  77. table_name_part = comment.split('。')[0]
  78. else:
  79. table_name_part = comment.strip()
  80. lines.append(f"-- 中文名: {table_name_part}")
  81. lines.append(f"-- 描述: {comment}")
  82. else:
  83. lines.append(f"-- 中文名: {table_metadata.table_name}")
  84. # CREATE TABLE语句
  85. lines.append(f"create table {table_metadata.full_name} (")
  86. # 字段定义
  87. field_lines = []
  88. for field in table_metadata.fields:
  89. field_line = self._generate_field_line(field)
  90. field_lines.append(field_line)
  91. # 主键定义
  92. primary_keys = [f.name for f in table_metadata.fields if f.is_primary_key]
  93. if primary_keys:
  94. field_lines.append(f" primary key ({', '.join(primary_keys)})")
  95. # 组合字段行
  96. lines.extend([line if i == len(field_lines) - 1 else line + ","
  97. for i, line in enumerate(field_lines)])
  98. lines.append(");")
  99. return '\n'.join(lines)
  100. def _generate_field_line(self, field: FieldInfo) -> str:
  101. """生成字段定义行"""
  102. parts = [f" {field.name}"]
  103. # 字段类型
  104. field_type = self._format_field_type(field)
  105. parts.append(field_type)
  106. # NOT NULL约束
  107. if not field.nullable:
  108. parts.append("not null")
  109. # 默认值
  110. if field.default_value and not self._should_skip_default(field.default_value):
  111. parts.append(f"default {self._format_default_value(field.default_value)}")
  112. # 组合字段定义
  113. field_def = ' '.join(parts)
  114. # 添加注释
  115. comment = self._format_field_comment(field)
  116. if comment:
  117. # 计算对齐空格(减少到30个字符对齐)
  118. padding = max(1, 30 - len(field_def))
  119. field_line = f"{field_def}{' ' * padding}-- {comment}"
  120. else:
  121. field_line = field_def
  122. return field_line
  123. def _format_field_type(self, field: FieldInfo) -> str:
  124. """格式化字段类型"""
  125. field_type = field.type.lower()
  126. # 处理带长度的类型
  127. if field_type in ['character varying', 'varchar'] and field.max_length:
  128. return f"varchar({field.max_length})"
  129. elif field_type == 'character' and field.max_length:
  130. return f"char({field.max_length})"
  131. elif field_type == 'numeric' and field.precision:
  132. if field.scale:
  133. return f"numeric({field.precision},{field.scale})"
  134. else:
  135. return f"numeric({field.precision})"
  136. elif field_type == 'timestamp without time zone':
  137. return "timestamp"
  138. elif field_type == 'timestamp with time zone':
  139. return "timestamptz"
  140. elif field_type in ['integer', 'int']:
  141. return "integer"
  142. elif field_type in ['bigint', 'int8']:
  143. return "bigint"
  144. elif field_type in ['smallint', 'int2']:
  145. return "smallint"
  146. elif field_type in ['double precision', 'float8']:
  147. return "double precision"
  148. elif field_type in ['real', 'float4']:
  149. return "real"
  150. elif field_type == 'boolean':
  151. return "boolean"
  152. elif field_type == 'text':
  153. return "text"
  154. elif field_type == 'date':
  155. return "date"
  156. elif field_type == 'time without time zone':
  157. return "time"
  158. elif field_type == 'time with time zone':
  159. return "timetz"
  160. elif field_type == 'json':
  161. return "json"
  162. elif field_type == 'jsonb':
  163. return "jsonb"
  164. elif field_type == 'uuid':
  165. return "uuid"
  166. elif field_type.startswith('timestamp(') and 'without time zone' in field_type:
  167. # 处理 timestamp(3) without time zone
  168. precision = field_type.split('(')[1].split(')')[0]
  169. return f"timestamp({precision})"
  170. else:
  171. return field_type
  172. def _format_default_value(self, default_value: str) -> str:
  173. """格式化默认值"""
  174. # 移除可能的类型转换
  175. if '::' in default_value:
  176. default_value = default_value.split('::')[0]
  177. # 处理函数调用
  178. if default_value.lower() in ['now()', 'current_timestamp']:
  179. return 'current_timestamp'
  180. elif default_value.lower() == 'current_date':
  181. return 'current_date'
  182. # 处理字符串值
  183. if not (default_value.startswith("'") and default_value.endswith("'")):
  184. # 检查是否为数字或布尔值
  185. if default_value.lower() in ['true', 'false']:
  186. return default_value.lower()
  187. elif default_value.replace('.', '').replace('-', '').isdigit():
  188. return default_value
  189. else:
  190. # 其他情况加引号
  191. return f"'{default_value}'"
  192. return default_value
  193. def _should_skip_default(self, default_value: str) -> bool:
  194. """判断是否应跳过默认值"""
  195. # 跳过序列默认值
  196. if 'nextval(' in default_value.lower():
  197. return True
  198. # 跳过空字符串
  199. if default_value.strip() in ['', "''", '""']:
  200. return True
  201. return False
  202. def _format_field_comment(self, field: FieldInfo) -> str:
  203. """格式化字段注释"""
  204. comment_parts = []
  205. # 基础注释
  206. if field.comment:
  207. comment_parts.append(field.comment)
  208. # 主键标识
  209. if field.is_primary_key:
  210. comment_parts.append("主键")
  211. # 外键标识
  212. if field.is_foreign_key:
  213. comment_parts.append("外键")
  214. # 去掉小括号,直接返回注释内容
  215. return ','.join(comment_parts) if comment_parts else ""