script_utils.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # 这是dataops_scripts目录下的文件 - 用于验证路径修改成功
  4. import logging
  5. import sys
  6. import os
  7. import traceback
  8. # 添加父目录到Python路径,以便能导入dags目录下的config模块
  9. parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
  10. if parent_dir not in sys.path:
  11. sys.path.insert(0, parent_dir)
  12. import importlib.util
  13. from datetime import datetime, timedelta
  14. import pytz
  15. import re # 添加re模块以支持正则表达式
  16. # 添加导入SCHEDULE_TABLE_SCHEMA
  17. #from dags.config import SCHEDULE_TABLE_SCHEMA
  18. # 导入Airflow相关包
  19. try:
  20. from airflow.models import Variable
  21. except ImportError:
  22. # 处理在非Airflow环境中运行的情况
  23. class Variable:
  24. @staticmethod
  25. def get(key, default_var=None):
  26. return default_var
  27. # 配置日志记录器
  28. logging.basicConfig(
  29. level=logging.INFO,
  30. format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
  31. handlers=[
  32. logging.StreamHandler(sys.stdout)
  33. ]
  34. )
  35. logger = logging.getLogger("script_utils")
  36. def get_config_path():
  37. """
  38. 从Airflow变量中获取DATAOPS_DAGS_PATH
  39. 返回:
  40. str: config.py的完整路径
  41. """
  42. try:
  43. # 从Airflow变量中获取DATAOPS_DAGS_PATH
  44. dags_path = Variable.get("DATAOPS_DAGS_PATH", "/opt/airflow/dags")
  45. logger.info(f"从Airflow变量获取到DATAOPS_DAGS_PATH: {dags_path}")
  46. # 构建config.py的完整路径
  47. config_path = os.path.join(dags_path, "config.py")
  48. if not os.path.exists(config_path):
  49. logger.warning(f"配置文件路径不存在: {config_path}, 将使用默认路径")
  50. # 尝试使用相对路径
  51. alt_config_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../dags/config.py"))
  52. if os.path.exists(alt_config_path):
  53. logger.info(f"使用替代配置路径: {alt_config_path}")
  54. return alt_config_path
  55. return config_path
  56. except Exception as e:
  57. logger.error(f"获取配置路径时出错: {str(e)}")
  58. # 使用默认路径
  59. return os.path.abspath(os.path.join(os.path.dirname(__file__), "../dags/config.py"))
  60. def load_config_module():
  61. """
  62. 动态加载config.py模块
  63. 返回:
  64. module: 加载的config模块
  65. """
  66. try:
  67. config_path = get_config_path()
  68. logger.info(f"正在加载配置文件: {config_path}")
  69. # 动态加载config.py模块
  70. spec = importlib.util.spec_from_file_location("config", config_path)
  71. config_module = importlib.util.module_from_spec(spec)
  72. spec.loader.exec_module(config_module)
  73. return config_module
  74. except Exception as e:
  75. logger.error(f"加载配置模块时出错: {str(e)}")
  76. raise ImportError(f"无法加载配置模块: {str(e)}")
  77. def get_neo4j_driver():
  78. """获取Neo4j连接驱动"""
  79. try:
  80. # 使用get_config_path获取config路径
  81. config_path = get_config_path()
  82. if not os.path.exists(config_path):
  83. raise FileNotFoundError(f"配置文件不存在: {config_path}")
  84. logger.info(f"使用配置文件路径: {config_path}")
  85. # 动态加载config模块
  86. spec = importlib.util.spec_from_file_location("config", config_path)
  87. config_module = importlib.util.module_from_spec(spec)
  88. spec.loader.exec_module(config_module)
  89. # 从模块中获取NEO4J_CONFIG
  90. NEO4J_CONFIG = getattr(config_module, "NEO4J_CONFIG", None)
  91. if not NEO4J_CONFIG:
  92. raise ValueError(f"配置文件 {config_path} 中未找到NEO4J_CONFIG配置项")
  93. # 验证NEO4J_CONFIG中包含必要的配置项
  94. required_keys = ["uri", "user", "password"]
  95. missing_keys = [key for key in required_keys if key not in NEO4J_CONFIG]
  96. if missing_keys:
  97. raise ValueError(f"NEO4J_CONFIG缺少必要的配置项: {', '.join(missing_keys)}")
  98. # 创建Neo4j驱动
  99. from neo4j import GraphDatabase
  100. logger.info(f"使用配置创建Neo4j驱动: {NEO4J_CONFIG['uri']}")
  101. return GraphDatabase.driver(
  102. NEO4J_CONFIG['uri'],
  103. auth=(NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  104. )
  105. except Exception as e:
  106. logger.error(f"创建Neo4j驱动失败: {str(e)}")
  107. logger.error(traceback.format_exc())
  108. raise
  109. def get_pg_config():
  110. """
  111. 从config.py获取PostgreSQL数据库配置
  112. 返回:
  113. dict: PostgreSQL配置字典
  114. """
  115. try:
  116. config_module = load_config_module()
  117. pg_config = getattr(config_module, "PG_CONFIG", None)
  118. if pg_config is None:
  119. logger.warning("配置模块中未找到PG_CONFIG")
  120. logger.info(f"已获取PostgreSQL配置: {pg_config}")
  121. return pg_config
  122. except Exception as e:
  123. logger.error(f"获取PostgreSQL配置时出错: {str(e)}")
  124. # 返回默认配置
  125. return {
  126. "host": "localhost",
  127. "port": 5432,
  128. "user": "postgres",
  129. "password": "postgres",
  130. "database": "dataops"
  131. }
  132. def get_upload_paths():
  133. """
  134. 从config.py获取文件上传和归档路径
  135. 返回:
  136. tuple: (上传路径, 归档路径)
  137. """
  138. try:
  139. config_module = load_config_module()
  140. upload_path = getattr(config_module, "STRUCTURE_UPLOAD_BASE_PATH")
  141. archive_path = getattr(config_module, "STRUCTURE_UPLOAD_ARCHIVE_BASE_PATH")
  142. logger.info(f"获取上传路径: {upload_path}, 归档路径: {archive_path}")
  143. return upload_path, archive_path
  144. except Exception as e:
  145. logger.error(f"获取上传路径时出错: {str(e)}")
  146. # 返回默认路径
  147. return "/data/upload", "/data/archive"
  148. def get_date_range(exec_date, frequency):
  149. """
  150. 根据执行日期和频率,计算开始日期和结束日期
  151. 参数:
  152. exec_date (str): 执行日期,格式为 YYYY-MM-DD
  153. frequency (str): 频率,可选值为 daily, weekly, monthly, quarterly, yearly
  154. 返回:
  155. tuple: (start_date, end_date) 格式为 YYYY-MM-DD 的字符串
  156. """
  157. logger.info(f"计算日期范围 - 执行日期: {exec_date}, 频率: {frequency}")
  158. # 将输入的日期转换为上海时区的datetime对象
  159. shanghai_tz = pytz.timezone('Asia/Shanghai')
  160. try:
  161. # 解析输入的exec_date
  162. if isinstance(exec_date, str):
  163. date_obj = datetime.strptime(exec_date, '%Y-%m-%d')
  164. elif isinstance(exec_date, datetime):
  165. date_obj = exec_date
  166. else:
  167. raise ValueError(f"不支持的exec_date类型: {type(exec_date)}")
  168. # 转换为上海时区
  169. date_obj = shanghai_tz.localize(date_obj)
  170. logger.info(f"上海时区的执行日期: {date_obj}")
  171. # 根据不同频率计算日期范围
  172. if frequency.lower() == 'daily':
  173. # 每日: start_date = exec_date, end_date = exec_date + 1 day
  174. start_date = date_obj.strftime('%Y-%m-%d')
  175. end_date = (date_obj + timedelta(days=1)).strftime('%Y-%m-%d')
  176. elif frequency.lower() == 'weekly':
  177. # 每周: start_date = 本周一, end_date = 下周一
  178. days_since_monday = date_obj.weekday() # 0=周一, 6=周日
  179. monday = date_obj - timedelta(days=days_since_monday)
  180. next_monday = monday + timedelta(days=7)
  181. start_date = monday.strftime('%Y-%m-%d')
  182. end_date = next_monday.strftime('%Y-%m-%d')
  183. elif frequency.lower() == 'monthly':
  184. # 每月: start_date = 本月第一天, end_date = 下月第一天
  185. first_day = date_obj.replace(day=1)
  186. # 计算下个月的第一天
  187. if first_day.month == 12:
  188. next_month_first_day = first_day.replace(year=first_day.year + 1, month=1)
  189. else:
  190. next_month_first_day = first_day.replace(month=first_day.month + 1)
  191. start_date = first_day.strftime('%Y-%m-%d')
  192. end_date = next_month_first_day.strftime('%Y-%m-%d')
  193. elif frequency.lower() == 'quarterly':
  194. # 每季度: start_date = 本季度第一天, end_date = 下季度第一天
  195. quarter = (date_obj.month - 1) // 3 + 1 # 1-4季度
  196. first_month_of_quarter = (quarter - 1) * 3 + 1 # 季度的第一个月
  197. quarter_first_day = date_obj.replace(month=first_month_of_quarter, day=1)
  198. # 计算下个季度的第一天
  199. if quarter == 4:
  200. next_quarter_first_day = quarter_first_day.replace(year=quarter_first_day.year + 1, month=1)
  201. else:
  202. next_quarter_first_day = quarter_first_day.replace(month=first_month_of_quarter + 3)
  203. start_date = quarter_first_day.strftime('%Y-%m-%d')
  204. end_date = next_quarter_first_day.strftime('%Y-%m-%d')
  205. elif frequency.lower() == 'yearly':
  206. # 每年: start_date = 本年第一天, end_date = 下年第一天
  207. year_first_day = date_obj.replace(month=1, day=1)
  208. next_year_first_day = date_obj.replace(year=date_obj.year + 1, month=1, day=1)
  209. start_date = year_first_day.strftime('%Y-%m-%d')
  210. end_date = next_year_first_day.strftime('%Y-%m-%d')
  211. else:
  212. logger.error(f"不支持的频率: {frequency}")
  213. raise ValueError(f"不支持的频率: {frequency}")
  214. logger.info(f"计算结果 - 开始日期: {start_date}, 结束日期: {end_date}")
  215. return start_date, end_date
  216. except Exception as e:
  217. logger.error(f"计算日期范围时出错: {str(e)}", exc_info=True)
  218. raise
  219. import re
  220. from typing import Dict, List, Optional, Set
  221. def extract_source_fields_linked_to_template(sql: str, jinja_vars: List[str]) -> Set[str]:
  222. """
  223. 从 SQL 中提取和 jinja 模板变量绑定的源字段(支持各种形式)
  224. """
  225. fields = set()
  226. sql = re.sub(r"\s+", " ", sql)
  227. for var in jinja_vars:
  228. # 普通比较、函数包裹
  229. pattern = re.compile(
  230. r"""
  231. (?P<field>
  232. (?:\w+\s*\(\s*)? # 可选函数开始(如 DATE(
  233. [\w\.]+ # 字段名
  234. (?:\s+AS\s+\w+)? # 可选 CAST 形式
  235. \)? # 可选右括号
  236. )
  237. \s*(=|<|>|<=|>=)\s*['"]?\{\{\s*""" + var + r"""\s*\}\}['"]?
  238. """, re.IGNORECASE | re.VERBOSE
  239. )
  240. fields.update(match.group("field").strip() for match in pattern.finditer(sql))
  241. # BETWEEN '{{ start_date }}' AND '{{ end_date }}'
  242. if var == "start_date":
  243. pattern_between = re.compile(
  244. r"""(?P<field>
  245. (?:\w+\s*\(\s*)?[\w\.]+(?:\s+AS\s+\w+)?\)? # 字段(函数包裹可选)
  246. )
  247. \s+BETWEEN\s+['"]?\{\{\s*start_date\s*\}\}['"]?\s+AND\s+['"]?\{\{\s*end_date\s*\}\}
  248. """, re.IGNORECASE | re.VERBOSE
  249. )
  250. fields.update(match.group("field").strip() for match in pattern_between.finditer(sql))
  251. return {extract_core_field(f) for f in fields}
  252. def extract_core_field(expr: str) -> str:
  253. """
  254. 清洗函数包裹的字段表达式:DATE(sd.sale_date) -> sd.sale_date, CAST(...) -> ...
  255. """
  256. expr = re.sub(r"CAST\s*\(\s*([\w\.]+)\s+AS\s+\w+\s*\)", r"\1", expr, flags=re.IGNORECASE)
  257. expr = re.sub(r"\b\w+\s*\(\s*([\w\.]+)\s*\)", r"\1", expr)
  258. return expr.strip()
  259. def parse_select_aliases(sql: str) -> Dict[str, str]:
  260. """
  261. 提取 SELECT 中的字段别名映射:原字段 -> 目标别名
  262. """
  263. sql = re.sub(r"\s+", " ", sql)
  264. select_clause_match = re.search(r"SELECT\s+(.*?)\s+FROM", sql, re.IGNORECASE)
  265. if not select_clause_match:
  266. return {}
  267. select_clause = select_clause_match.group(1)
  268. mappings = {}
  269. for expr in select_clause.split(","):
  270. expr = expr.strip()
  271. alias_match = re.match(r"([\w\.]+)\s+AS\s+([\w]+)", expr, re.IGNORECASE)
  272. if alias_match:
  273. source, alias = alias_match.groups()
  274. mappings[source.strip()] = alias.strip()
  275. return mappings
  276. def find_target_date_field(sql: str, jinja_vars: List[str] = ["start_date", "end_date"]) -> Optional[str]:
  277. """
  278. 从 SQL 中找出与模板时间变量绑定的目标表字段(只返回一个)
  279. """
  280. source_fields = extract_source_fields_linked_to_template(sql, jinja_vars)
  281. alias_map = parse_select_aliases(sql)
  282. # 匹配 SELECT 中的映射字段
  283. for src_field in source_fields:
  284. if src_field in alias_map:
  285. return alias_map[src_field] # 源字段映射的目标字段
  286. # 若未通过 AS 映射,可能直接 SELECT sd.sale_date(裸字段)
  287. for src_field in source_fields:
  288. if '.' not in src_field:
  289. return src_field # 裸字段直接作为目标字段名
  290. return None
  291. def generate_delete_sql(sql_content, target_table=None):
  292. """
  293. 根据SQL脚本内容生成用于清理数据的DELETE语句
  294. 参数:
  295. sql_content (str): 原始SQL脚本内容
  296. target_table (str, optional): 目标表名,如果SQL脚本中无法解析出表名时使用
  297. 返回:
  298. str: DELETE语句,用于清理数据
  299. """
  300. logger.info("生成清理SQL语句,实现ETL作业幂等性")
  301. # 如果提供了目标表名,直接使用
  302. if target_table:
  303. logger.info(f"使用提供的目标表名: {target_table}")
  304. delete_stmt = f"""DELETE FROM {target_table}
  305. WHERE summary_date >= '{{{{ start_date }}}}'
  306. AND summary_date < '{{{{ end_date }}}}';"""
  307. logger.info(f"生成的清理SQL: {delete_stmt}")
  308. return delete_stmt
  309. # 尝试从SQL内容中解析出目标表名
  310. try:
  311. # 简单解析,尝试找出INSERT语句的目标表
  312. # 匹配 INSERT INTO xxx 或 INSERT INTO "xxx" 或 INSERT INTO `xxx` 或 INSERT INTO [xxx]
  313. insert_match = re.search(r'INSERT\s+INTO\s+(?:["\[`])?([a-zA-Z0-9_\.]+)(?:["\]`])?', sql_content, re.IGNORECASE)
  314. if insert_match:
  315. table_name = insert_match.group(1)
  316. logger.info(f"从SQL中解析出目标表名: {table_name}")
  317. delete_stmt = f"""DELETE FROM {table_name}
  318. WHERE summary_date >= '{{{{ start_date }}}}'
  319. AND summary_date < '{{{{ end_date }}}}';"""
  320. logger.info(f"生成的清理SQL: {delete_stmt}")
  321. return delete_stmt
  322. else:
  323. logger.warning("无法从SQL中解析出目标表名,无法生成清理SQL")
  324. return None
  325. except Exception as e:
  326. logger.error(f"解析SQL生成清理语句时出错: {str(e)}", exc_info=True)
  327. return None
  328. def get_one_day_range(exec_date):
  329. """
  330. 根据exec_date返回当天的00:00:00和次日00:00:00,均为datetime对象
  331. 参数:
  332. exec_date (str 或 datetime): 执行日期,格式为YYYY-MM-DD或datetime对象
  333. 返回:
  334. tuple(datetime, datetime): (start_datetime, end_datetime)
  335. """
  336. shanghai_tz = pytz.timezone('Asia/Shanghai')
  337. if isinstance(exec_date, str):
  338. date_obj = datetime.strptime(exec_date, '%Y-%m-%d')
  339. elif isinstance(exec_date, datetime):
  340. date_obj = exec_date
  341. else:
  342. raise ValueError(f"不支持的exec_date类型: {type(exec_date)}")
  343. # 当天00:00:00
  344. start_datetime = shanghai_tz.localize(datetime(date_obj.year, date_obj.month, date_obj.day, 0, 0, 0))
  345. # 次日00:00:00
  346. end_datetime = start_datetime + timedelta(days=1)
  347. return start_datetime, end_datetime
  348. def get_target_dt_column(table_name, script_name=None):
  349. """
  350. 从Neo4j或data_transform_scripts表获取目标日期列
  351. 参数:
  352. table_name (str): 表名
  353. script_name (str, optional): 脚本名称
  354. 返回:
  355. str: 目标日期列名
  356. """
  357. logger.info(f"获取表 {table_name} 的目标日期列")
  358. try:
  359. # 首先从Neo4j获取
  360. driver = get_neo4j_driver()
  361. with driver.session() as session:
  362. # 尝试从DataModel节点的relations关系属性中获取
  363. query = """
  364. MATCH (n {en_name: $table_name})
  365. RETURN n.target_dt_column AS target_dt_column
  366. """
  367. result = session.run(query, table_name=table_name)
  368. record = result.single()
  369. if record and record.get("target_dt_column"):
  370. target_dt_column = record.get("target_dt_column")
  371. logger.info(f"从Neo4j获取到表 {table_name} 的目标日期列: {target_dt_column}")
  372. return target_dt_column
  373. # 导入需要的模块以连接数据库
  374. import sqlalchemy
  375. from sqlalchemy import create_engine, text
  376. # Neo4j中找不到,尝试从data_transform_scripts表获取
  377. # 获取目标数据库连接
  378. pg_config = get_pg_config()
  379. if not pg_config:
  380. logger.error("无法获取PG_CONFIG配置,无法连接数据库查询目标日期列")
  381. return None
  382. # 创建数据库引擎
  383. db_url = f"postgresql://{pg_config['user']}:{pg_config['password']}@{pg_config['host']}:{pg_config['port']}/{pg_config['database']}"
  384. engine = create_engine(db_url)
  385. if not engine:
  386. logger.error("无法创建数据库引擎,无法获取目标日期列")
  387. return None
  388. # 查询data_transform_scripts表
  389. schema = get_config_param("SCHEDULE_TABLE_SCHEMA")
  390. try:
  391. query = f"""
  392. SELECT target_dt_column
  393. FROM {schema}.data_transform_scripts
  394. WHERE target_table = '{table_name}'
  395. """
  396. if script_name:
  397. query += f" AND script_name = '{script_name}'"
  398. query += " LIMIT 1"
  399. with engine.connect() as conn:
  400. result = conn.execute(text(query))
  401. row = result.fetchone()
  402. if row and row[0]:
  403. target_dt_column = row[0]
  404. logger.info(f"从data_transform_scripts表获取到表 {table_name} 的目标日期列: {target_dt_column}")
  405. return target_dt_column
  406. except Exception as db_err:
  407. logger.error(f"从data_transform_scripts表获取目标日期列时出错: {str(db_err)}")
  408. logger.error(traceback.format_exc())
  409. # 都找不到,使用默认值
  410. logger.warning(f"未找到表 {table_name} 的目标日期列,将使用默认值 'data_date'")
  411. return "data_date"
  412. except Exception as e:
  413. logger.error(f"获取目标日期列时出错: {str(e)}")
  414. logger.error(traceback.format_exc())
  415. return None
  416. finally:
  417. if 'driver' in locals() and driver:
  418. driver.close()
  419. def get_config_param(param_name, default_value=None):
  420. """
  421. 从config模块动态获取配置参数
  422. 参数:
  423. param_name (str): 参数名
  424. default_value: 默认值
  425. 返回:
  426. 参数值,如果不存在则返回默认值
  427. """
  428. try:
  429. config_module = load_config_module()
  430. return getattr(config_module, param_name)
  431. except Exception as e:
  432. logger.warning(f"获取配置参数 {param_name} 失败: {str(e)},使用默认值: {default_value}")
  433. return default_value
  434. def create_table_from_neo4j(en_name: str):
  435. """
  436. 根据Neo4j中的表定义创建PostgreSQL表
  437. 参数:
  438. en_name (str): 表的英文名称
  439. 返回:
  440. bool: 成功返回True,失败返回False
  441. """
  442. driver = None
  443. conn = None
  444. cur = None
  445. try:
  446. # 使用script_utils中的方法获取连接
  447. driver = get_neo4j_driver()
  448. pg_config = get_pg_config()
  449. import psycopg2
  450. conn = psycopg2.connect(**pg_config)
  451. cur = conn.cursor()
  452. with driver.session() as session:
  453. # 1. 查找目标表节点(DataResource/DataModel/DataMetric)
  454. result = session.run("""
  455. MATCH (t)
  456. WHERE t.en_name = $en_name AND (t:DataResource OR t:DataModel OR t:DataMetric)
  457. RETURN labels(t) AS labels, t.en_name AS en_name, t.name AS name, id(t) AS node_id
  458. """, en_name=en_name)
  459. record = result.single()
  460. if not record:
  461. logger.error(f"未找到名为 {en_name} 的表节点")
  462. return False
  463. labels = record["labels"]
  464. table_en_name = record["en_name"]
  465. table_cn_name = record["name"]
  466. node_id = record["node_id"]
  467. schema = "ods" if "DataResource" in labels else "ads"
  468. # 2. 查找所有字段(HAS_COLUMN关系)并按Column节点的系统id排序
  469. column_result = session.run("""
  470. MATCH (t)-[:HAS_COLUMN]->(c:Column)
  471. WHERE id(t) = $node_id
  472. RETURN c.en_name AS en_name, c.data_type AS data_type,
  473. c.name AS name, c.is_pk AS is_pk, id(c) AS column_id
  474. ORDER BY id(c) ASC
  475. """, node_id=node_id)
  476. columns = column_result.data()
  477. if not columns:
  478. logger.error(f"未找到表 {en_name} 的字段信息")
  479. return False
  480. # 3. 构造 DDL
  481. ddl_lines = []
  482. pk_fields = []
  483. existing_fields = set()
  484. for col in columns:
  485. col_line = f'{col["en_name"]} {col["data_type"]}'
  486. ddl_lines.append(col_line)
  487. existing_fields.add(col["en_name"].lower())
  488. if col.get("is_pk", False):
  489. pk_fields.append(f'{col["en_name"]}')
  490. # 检查并添加 create_time 和 update_time 字段
  491. if 'create_time' not in existing_fields:
  492. ddl_lines.append('create_time timestamp')
  493. if 'update_time' not in existing_fields:
  494. ddl_lines.append('update_time timestamp')
  495. if pk_fields:
  496. ddl_lines.append(f'PRIMARY KEY ({", ".join(pk_fields)})')
  497. full_table_name = f"{schema}.{table_en_name}"
  498. ddl = f'CREATE SCHEMA IF NOT EXISTS {schema};\n'
  499. ddl += f'CREATE TABLE IF NOT EXISTS {full_table_name} (\n '
  500. ddl += ",\n ".join(ddl_lines)
  501. ddl += "\n);"
  502. # 生成表注释SQL
  503. table_comment_sql = f"COMMENT ON TABLE {full_table_name} IS '{table_cn_name}';"
  504. # 生成字段注释SQL
  505. column_comment_sqls = []
  506. for col in columns:
  507. if col["name"]: # 如果有中文名称
  508. column_comment_sql = f"COMMENT ON COLUMN {full_table_name}.{col['en_name']} IS '{col['name']}';"
  509. column_comment_sqls.append(column_comment_sql)
  510. logger.info(f"DDL: {ddl}")
  511. logger.info(f"表注释SQL: {table_comment_sql}")
  512. if column_comment_sqls:
  513. logger.info("字段注释SQL:")
  514. for comment_sql in column_comment_sqls:
  515. logger.info(f" {comment_sql}")
  516. # 4. 执行 DDL
  517. try:
  518. # 先检查表是否已经存在
  519. check_table_sql = """
  520. SELECT EXISTS (
  521. SELECT FROM information_schema.tables
  522. WHERE table_schema = %s AND table_name = %s
  523. );
  524. """
  525. cur.execute(check_table_sql, (schema, table_en_name))
  526. table_exists = cur.fetchone()[0]
  527. if table_exists:
  528. logger.info(f"表 {full_table_name} 已存在,跳过创建")
  529. return True
  530. else:
  531. # 执行创建表的DDL
  532. cur.execute(ddl)
  533. logger.info(f"成功创建新表: {full_table_name}")
  534. # 执行表注释
  535. cur.execute(table_comment_sql)
  536. logger.info(f"已添加表注释")
  537. # 执行字段注释
  538. for comment_sql in column_comment_sqls:
  539. cur.execute(comment_sql)
  540. logger.info(f"已添加 {len(column_comment_sqls)} 个字段注释")
  541. conn.commit()
  542. return True
  543. except Exception as e:
  544. logger.error(f"执行DDL失败: {e}")
  545. conn.rollback()
  546. return False
  547. except Exception as e:
  548. logger.error(f"创建表 {en_name} 时发生错误: {str(e)}")
  549. if conn:
  550. conn.rollback()
  551. return False
  552. finally:
  553. if cur:
  554. cur.close()
  555. if conn:
  556. conn.close()
  557. if driver:
  558. driver.close()