script_utils.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # 这是dataops_scripts目录下的文件 - 用于验证路径修改成功
  4. import logging
  5. import sys
  6. import os
  7. # 添加父目录到Python路径,以便能导入dags目录下的config模块
  8. parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
  9. if parent_dir not in sys.path:
  10. sys.path.insert(0, parent_dir)
  11. import importlib.util
  12. from datetime import datetime, timedelta
  13. import pytz
  14. import re # 添加re模块以支持正则表达式
  15. # 导入Airflow相关包
  16. try:
  17. from airflow.models import Variable
  18. except ImportError:
  19. # 处理在非Airflow环境中运行的情况
  20. class Variable:
  21. @staticmethod
  22. def get(key, default_var=None):
  23. return default_var
  24. # 配置日志记录器
  25. logging.basicConfig(
  26. level=logging.INFO,
  27. format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
  28. handlers=[
  29. logging.StreamHandler(sys.stdout)
  30. ]
  31. )
  32. logger = logging.getLogger("script_utils")
  33. def get_config_path():
  34. """
  35. 从Airflow变量中获取DATAOPS_DAGS_PATH
  36. 返回:
  37. str: config.py的完整路径
  38. """
  39. try:
  40. # 从Airflow变量中获取DATAOPS_DAGS_PATH
  41. dags_path = Variable.get("DATAOPS_DAGS_PATH", "/opt/airflow/dags")
  42. logger.info(f"从Airflow变量获取到DATAOPS_DAGS_PATH: {dags_path}")
  43. # 构建config.py的完整路径
  44. config_path = os.path.join(dags_path, "config.py")
  45. if not os.path.exists(config_path):
  46. logger.warning(f"配置文件路径不存在: {config_path}, 将使用默认路径")
  47. # 尝试使用相对路径
  48. alt_config_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../dags/config.py"))
  49. if os.path.exists(alt_config_path):
  50. logger.info(f"使用替代配置路径: {alt_config_path}")
  51. return alt_config_path
  52. return config_path
  53. except Exception as e:
  54. logger.error(f"获取配置路径时出错: {str(e)}")
  55. # 使用默认路径
  56. return os.path.abspath(os.path.join(os.path.dirname(__file__), "../dags/config.py"))
  57. def load_config_module():
  58. """
  59. 动态加载config.py模块
  60. 返回:
  61. module: 加载的config模块
  62. """
  63. try:
  64. config_path = get_config_path()
  65. logger.info(f"正在加载配置文件: {config_path}")
  66. # 动态加载config.py模块
  67. spec = importlib.util.spec_from_file_location("config", config_path)
  68. config_module = importlib.util.module_from_spec(spec)
  69. spec.loader.exec_module(config_module)
  70. return config_module
  71. except Exception as e:
  72. logger.error(f"加载配置模块时出错: {str(e)}")
  73. raise ImportError(f"无法加载配置模块: {str(e)}")
  74. def get_pg_config():
  75. """
  76. 从config.py获取PostgreSQL数据库配置
  77. 返回:
  78. dict: PostgreSQL配置字典
  79. """
  80. try:
  81. config_module = load_config_module()
  82. pg_config = getattr(config_module, "PG_CONFIG", None)
  83. if pg_config is None:
  84. logger.warning("配置模块中未找到PG_CONFIG")
  85. # 返回默认配置
  86. return {
  87. "host": "localhost",
  88. "port": 5432,
  89. "user": "postgres",
  90. "password": "postgres",
  91. "database": "dataops"
  92. }
  93. logger.info(f"已获取PostgreSQL配置: {pg_config}")
  94. return pg_config
  95. except Exception as e:
  96. logger.error(f"获取PostgreSQL配置时出错: {str(e)}")
  97. # 返回默认配置
  98. return {
  99. "host": "localhost",
  100. "port": 5432,
  101. "user": "postgres",
  102. "password": "postgres",
  103. "database": "dataops"
  104. }
  105. def get_upload_paths():
  106. """
  107. 从config.py获取文件上传和归档路径
  108. 返回:
  109. tuple: (上传路径, 归档路径)
  110. """
  111. try:
  112. config_module = load_config_module()
  113. upload_path = getattr(config_module, "STRUCTURE_UPLOAD_BASE_PATH", "/data/csv")
  114. archive_path = getattr(config_module, "STRUCTURE_UPLOAD_ARCHIVE_BASE_PATH", "/data/archive")
  115. logger.info(f"已获取上传路径: {upload_path}, 归档路径: {archive_path}")
  116. return upload_path, archive_path
  117. except Exception as e:
  118. logger.error(f"获取上传路径时出错: {str(e)}")
  119. # 返回默认路径
  120. return "/data/csv", "/data/archive"
  121. def get_date_range(exec_date, frequency):
  122. """
  123. 根据执行日期和频率,计算开始日期和结束日期
  124. 参数:
  125. exec_date (str): 执行日期,格式为 YYYY-MM-DD
  126. frequency (str): 频率,可选值为 daily, weekly, monthly, quarterly, yearly
  127. 返回:
  128. tuple: (start_date, end_date) 格式为 YYYY-MM-DD 的字符串
  129. """
  130. logger.info(f"计算日期范围 - 执行日期: {exec_date}, 频率: {frequency}")
  131. # 将输入的日期转换为上海时区的datetime对象
  132. shanghai_tz = pytz.timezone('Asia/Shanghai')
  133. try:
  134. # 解析输入的exec_date
  135. if isinstance(exec_date, str):
  136. date_obj = datetime.strptime(exec_date, '%Y-%m-%d')
  137. elif isinstance(exec_date, datetime):
  138. date_obj = exec_date
  139. else:
  140. raise ValueError(f"不支持的exec_date类型: {type(exec_date)}")
  141. # 转换为上海时区
  142. date_obj = shanghai_tz.localize(date_obj)
  143. logger.info(f"上海时区的执行日期: {date_obj}")
  144. # 根据不同频率计算日期范围
  145. if frequency.lower() == 'daily':
  146. # 每日: start_date = exec_date, end_date = exec_date + 1 day
  147. start_date = date_obj.strftime('%Y-%m-%d')
  148. end_date = (date_obj + timedelta(days=1)).strftime('%Y-%m-%d')
  149. elif frequency.lower() == 'weekly':
  150. # 每周: start_date = 本周一, end_date = 下周一
  151. days_since_monday = date_obj.weekday() # 0=周一, 6=周日
  152. monday = date_obj - timedelta(days=days_since_monday)
  153. next_monday = monday + timedelta(days=7)
  154. start_date = monday.strftime('%Y-%m-%d')
  155. end_date = next_monday.strftime('%Y-%m-%d')
  156. elif frequency.lower() == 'monthly':
  157. # 每月: start_date = 本月第一天, end_date = 下月第一天
  158. first_day = date_obj.replace(day=1)
  159. # 计算下个月的第一天
  160. if first_day.month == 12:
  161. next_month_first_day = first_day.replace(year=first_day.year + 1, month=1)
  162. else:
  163. next_month_first_day = first_day.replace(month=first_day.month + 1)
  164. start_date = first_day.strftime('%Y-%m-%d')
  165. end_date = next_month_first_day.strftime('%Y-%m-%d')
  166. elif frequency.lower() == 'quarterly':
  167. # 每季度: start_date = 本季度第一天, end_date = 下季度第一天
  168. quarter = (date_obj.month - 1) // 3 + 1 # 1-4季度
  169. first_month_of_quarter = (quarter - 1) * 3 + 1 # 季度的第一个月
  170. quarter_first_day = date_obj.replace(month=first_month_of_quarter, day=1)
  171. # 计算下个季度的第一天
  172. if quarter == 4:
  173. next_quarter_first_day = quarter_first_day.replace(year=quarter_first_day.year + 1, month=1)
  174. else:
  175. next_quarter_first_day = quarter_first_day.replace(month=first_month_of_quarter + 3)
  176. start_date = quarter_first_day.strftime('%Y-%m-%d')
  177. end_date = next_quarter_first_day.strftime('%Y-%m-%d')
  178. elif frequency.lower() == 'yearly':
  179. # 每年: start_date = 本年第一天, end_date = 下年第一天
  180. year_first_day = date_obj.replace(month=1, day=1)
  181. next_year_first_day = date_obj.replace(year=date_obj.year + 1, month=1, day=1)
  182. start_date = year_first_day.strftime('%Y-%m-%d')
  183. end_date = next_year_first_day.strftime('%Y-%m-%d')
  184. else:
  185. logger.error(f"不支持的频率: {frequency}")
  186. raise ValueError(f"不支持的频率: {frequency}")
  187. logger.info(f"计算结果 - 开始日期: {start_date}, 结束日期: {end_date}")
  188. return start_date, end_date
  189. except Exception as e:
  190. logger.error(f"计算日期范围时出错: {str(e)}", exc_info=True)
  191. raise
  192. import re
  193. from typing import Dict, List, Optional, Set
  194. def extract_source_fields_linked_to_template(sql: str, jinja_vars: List[str]) -> Set[str]:
  195. """
  196. 从 SQL 中提取和 jinja 模板变量绑定的源字段(支持各种形式)
  197. """
  198. fields = set()
  199. sql = re.sub(r"\s+", " ", sql)
  200. for var in jinja_vars:
  201. # 普通比较、函数包裹
  202. pattern = re.compile(
  203. r"""
  204. (?P<field>
  205. (?:\w+\s*\(\s*)? # 可选函数开始(如 DATE(
  206. [\w\.]+ # 字段名
  207. (?:\s+AS\s+\w+)? # 可选 CAST 形式
  208. \)? # 可选右括号
  209. )
  210. \s*(=|<|>|<=|>=)\s*['"]?\{\{\s*""" + var + r"""\s*\}\}['"]?
  211. """, re.IGNORECASE | re.VERBOSE
  212. )
  213. fields.update(match.group("field").strip() for match in pattern.finditer(sql))
  214. # BETWEEN '{{ start_date }}' AND '{{ end_date }}'
  215. if var == "start_date":
  216. pattern_between = re.compile(
  217. r"""(?P<field>
  218. (?:\w+\s*\(\s*)?[\w\.]+(?:\s+AS\s+\w+)?\)? # 字段(函数包裹可选)
  219. )
  220. \s+BETWEEN\s+['"]?\{\{\s*start_date\s*\}\}['"]?\s+AND\s+['"]?\{\{\s*end_date\s*\}\}
  221. """, re.IGNORECASE | re.VERBOSE
  222. )
  223. fields.update(match.group("field").strip() for match in pattern_between.finditer(sql))
  224. return {extract_core_field(f) for f in fields}
  225. def extract_core_field(expr: str) -> str:
  226. """
  227. 清洗函数包裹的字段表达式:DATE(sd.sale_date) -> sd.sale_date, CAST(...) -> ...
  228. """
  229. expr = re.sub(r"CAST\s*\(\s*([\w\.]+)\s+AS\s+\w+\s*\)", r"\1", expr, flags=re.IGNORECASE)
  230. expr = re.sub(r"\b\w+\s*\(\s*([\w\.]+)\s*\)", r"\1", expr)
  231. return expr.strip()
  232. def parse_select_aliases(sql: str) -> Dict[str, str]:
  233. """
  234. 提取 SELECT 中的字段别名映射:原字段 -> 目标别名
  235. """
  236. sql = re.sub(r"\s+", " ", sql)
  237. select_clause_match = re.search(r"SELECT\s+(.*?)\s+FROM", sql, re.IGNORECASE)
  238. if not select_clause_match:
  239. return {}
  240. select_clause = select_clause_match.group(1)
  241. mappings = {}
  242. for expr in select_clause.split(","):
  243. expr = expr.strip()
  244. alias_match = re.match(r"([\w\.]+)\s+AS\s+([\w]+)", expr, re.IGNORECASE)
  245. if alias_match:
  246. source, alias = alias_match.groups()
  247. mappings[source.strip()] = alias.strip()
  248. return mappings
  249. def find_target_date_field(sql: str, jinja_vars: List[str] = ["start_date", "end_date"]) -> Optional[str]:
  250. """
  251. 从 SQL 中找出与模板时间变量绑定的目标表字段(只返回一个)
  252. """
  253. source_fields = extract_source_fields_linked_to_template(sql, jinja_vars)
  254. alias_map = parse_select_aliases(sql)
  255. # 匹配 SELECT 中的映射字段
  256. for src_field in source_fields:
  257. if src_field in alias_map:
  258. return alias_map[src_field] # 源字段映射的目标字段
  259. # 若未通过 AS 映射,可能直接 SELECT sd.sale_date(裸字段)
  260. for src_field in source_fields:
  261. if '.' not in src_field:
  262. return src_field # 裸字段直接作为目标字段名
  263. return None
  264. def generate_delete_sql(sql_content, target_table=None):
  265. """
  266. 根据SQL脚本内容生成用于清理数据的DELETE语句
  267. 参数:
  268. sql_content (str): 原始SQL脚本内容
  269. target_table (str, optional): 目标表名,如果SQL脚本中无法解析出表名时使用
  270. 返回:
  271. str: DELETE语句,用于清理数据
  272. """
  273. logger.info("生成清理SQL语句,实现ETL作业幂等性")
  274. # 如果提供了目标表名,直接使用
  275. if target_table:
  276. logger.info(f"使用提供的目标表名: {target_table}")
  277. delete_stmt = f"""DELETE FROM {target_table}
  278. WHERE summary_date >= '{{{{ start_date }}}}'
  279. AND summary_date < '{{{{ end_date }}}}';"""
  280. logger.info(f"生成的清理SQL: {delete_stmt}")
  281. return delete_stmt
  282. # 尝试从SQL内容中解析出目标表名
  283. try:
  284. # 简单解析,尝试找出INSERT语句的目标表
  285. # 匹配 INSERT INTO xxx 或 INSERT INTO "xxx" 或 INSERT INTO `xxx` 或 INSERT INTO [xxx]
  286. insert_match = re.search(r'INSERT\s+INTO\s+(?:["\[`])?([a-zA-Z0-9_\.]+)(?:["\]`])?', sql_content, re.IGNORECASE)
  287. if insert_match:
  288. table_name = insert_match.group(1)
  289. logger.info(f"从SQL中解析出目标表名: {table_name}")
  290. delete_stmt = f"""DELETE FROM {table_name}
  291. WHERE summary_date >= '{{{{ start_date }}}}'
  292. AND summary_date < '{{{{ end_date }}}}';"""
  293. logger.info(f"生成的清理SQL: {delete_stmt}")
  294. return delete_stmt
  295. else:
  296. logger.warning("无法从SQL中解析出目标表名,无法生成清理SQL")
  297. return None
  298. except Exception as e:
  299. logger.error(f"解析SQL生成清理语句时出错: {str(e)}", exc_info=True)
  300. return None
  301. def get_one_day_range(exec_date):
  302. """
  303. 根据exec_date返回当天的00:00:00和次日00:00:00,均为datetime对象
  304. 参数:
  305. exec_date (str 或 datetime): 执行日期,格式为YYYY-MM-DD或datetime对象
  306. 返回:
  307. tuple(datetime, datetime): (start_datetime, end_datetime)
  308. """
  309. shanghai_tz = pytz.timezone('Asia/Shanghai')
  310. if isinstance(exec_date, str):
  311. date_obj = datetime.strptime(exec_date, '%Y-%m-%d')
  312. elif isinstance(exec_date, datetime):
  313. date_obj = exec_date
  314. else:
  315. raise ValueError(f"不支持的exec_date类型: {type(exec_date)}")
  316. # 当天00:00:00
  317. start_datetime = shanghai_tz.localize(datetime(date_obj.year, date_obj.month, date_obj.day, 0, 0, 0))
  318. # 次日00:00:00
  319. end_datetime = start_datetime + timedelta(days=1)
  320. return start_datetime, end_datetime