utils.py 42 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035
  1. # utils.py
  2. import psycopg2
  3. from neo4j import GraphDatabase
  4. from config import PG_CONFIG, NEO4J_CONFIG, SCRIPTS_BASE_PATH
  5. import logging
  6. import importlib.util
  7. from pathlib import Path
  8. import networkx as nx
  9. import os
  10. from airflow.exceptions import AirflowFailException
  11. from datetime import datetime, timedelta, date
  12. import functools
  13. import time
  14. import pendulum
  15. # 创建统一的日志记录器
  16. logger = logging.getLogger("airflow.task")
  17. def get_pg_conn():
  18. return psycopg2.connect(**PG_CONFIG)
  19. def execute_script(script_name=None, table_name=None, execution_mode=None, script_path=None, script_exec_mode=None, args=None):
  20. """
  21. 根据脚本名称动态导入并执行对应的脚本
  22. 支持两种调用方式:
  23. 1. execute_script(script_name, table_name, execution_mode) - 原始实现
  24. 2. execute_script(script_path, script_name, script_exec_mode, args={}) - 来自common.py的实现
  25. 返回:
  26. bool: 执行成功返回True,否则返回False
  27. """
  28. # 第一种调用方式 - 原始函数实现
  29. if script_name and table_name and execution_mode is not None and script_path is None and script_exec_mode is None:
  30. if not script_name:
  31. logger.error("未提供脚本名称,无法执行")
  32. return False
  33. try:
  34. # 直接使用配置的部署路径,不考虑本地开发路径
  35. script_path = Path(SCRIPTS_BASE_PATH) / script_name
  36. logger.info(f"使用配置的Airflow部署路径: {script_path}")
  37. # 动态导入模块
  38. spec = importlib.util.spec_from_file_location("dynamic_module", script_path)
  39. module = importlib.util.module_from_spec(spec)
  40. spec.loader.exec_module(module)
  41. # 使用标准入口函数run
  42. if hasattr(module, "run"):
  43. logger.info(f"执行脚本 {script_name} 的标准入口函数 run()")
  44. module.run(table_name=table_name, execution_mode=execution_mode)
  45. return True
  46. else:
  47. logger.warning(f"脚本 {script_name} 未定义标准入口函数 run(),无法执行")
  48. return False
  49. except Exception as e:
  50. logger.error(f"执行脚本 {script_name} 时出错: {str(e)}")
  51. return False
  52. # 第二种调用方式 - 从common.py迁移的实现
  53. else:
  54. # 确定调用方式并统一参数
  55. if script_path and script_name and script_exec_mode is not None:
  56. # 第二种调用方式 - 显式提供所有参数
  57. if args is None:
  58. args = {}
  59. elif script_name and table_name and execution_mode is not None:
  60. # 第二种调用方式 - 但使用第一种调用方式的参数名
  61. script_path = os.path.join(SCRIPTS_BASE_PATH, f"{script_name}.py")
  62. script_exec_mode = execution_mode
  63. args = {"table_name": table_name}
  64. else:
  65. logger.error("参数不正确,无法执行脚本")
  66. return False
  67. try:
  68. # 确保脚本路径存在
  69. if not os.path.exists(script_path):
  70. logger.error(f"脚本路径 {script_path} 不存在")
  71. return False
  72. # 加载脚本模块
  73. spec = importlib.util.spec_from_file_location("script_module", script_path)
  74. module = importlib.util.module_from_spec(spec)
  75. spec.loader.exec_module(module)
  76. # 检查并记录所有可用的函数
  77. module_functions = [f for f in dir(module) if callable(getattr(module, f)) and not f.startswith('_')]
  78. logger.debug(f"模块 {script_name} 中的可用函数: {module_functions}")
  79. # 获取脚本的运行函数
  80. if not hasattr(module, "run"):
  81. logger.error(f"脚本 {script_name} 没有run函数")
  82. return False
  83. # 装饰run函数,确保返回布尔值
  84. original_run = module.run
  85. module.run = ensure_boolean_result(original_run)
  86. logger.info(f"开始执行脚本 {script_name},执行模式: {script_exec_mode}, 参数: {args}")
  87. start_time = time.time()
  88. # 执行脚本
  89. if table_name is not None:
  90. # 使用table_name参数调用
  91. exec_result = module.run(table_name=table_name, execution_mode=script_exec_mode)
  92. else:
  93. # 使用script_exec_mode和args调用
  94. exec_result = module.run(script_exec_mode, args)
  95. end_time = time.time()
  96. duration = end_time - start_time
  97. logger.info(f"脚本 {script_name} 执行完成,结果: {exec_result}, 耗时: {duration:.2f}秒")
  98. return exec_result
  99. except Exception as e:
  100. logger.error(f"执行脚本 {script_name} 时出错: {str(e)}")
  101. import traceback
  102. logger.error(traceback.format_exc())
  103. return False
  104. def get_resource_subscribed_tables(enabled_tables: list) -> list:
  105. result = []
  106. for t in enabled_tables:
  107. if is_data_resource_table(t['table_name']):
  108. result.append(t)
  109. return result
  110. # 根据目标表,递归查找其所有上游依赖的 DataResource 表(不限层级)
  111. def get_dependency_resource_tables(enabled_tables: list) -> list:
  112. uri = NEO4J_CONFIG['uri']
  113. auth = (NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  114. driver = GraphDatabase.driver(uri, auth=auth)
  115. resource_set = set()
  116. try:
  117. with driver.session() as session:
  118. for t in enabled_tables:
  119. query = """
  120. MATCH (target:Table {name: $table_name})
  121. MATCH (res:DataResource)-[:ORIGINATES_FROM]->(:DataSource)
  122. WHERE (target)-[:DERIVED_FROM*1..]->(res)
  123. RETURN DISTINCT res.en_name AS name
  124. """
  125. result = session.run(query, table_name=t['table_name'])
  126. for record in result:
  127. resource_set.add(record['name'])
  128. finally:
  129. driver.close()
  130. output = []
  131. for name in resource_set:
  132. output.append({"table_name": name, "execution_mode": "append"})
  133. return output
  134. # 从 PostgreSQL 获取启用的表,按调度频率 daily/weekly/monthly 过滤
  135. def get_enabled_tables(frequency: str) -> list:
  136. conn = get_pg_conn()
  137. cursor = conn.cursor()
  138. cursor.execute("""
  139. SELECT table_name, execution_mode
  140. FROM table_schedule
  141. WHERE is_enabled = TRUE AND schedule_frequency = %s
  142. """, (frequency,))
  143. result = cursor.fetchall()
  144. cursor.close()
  145. conn.close()
  146. output = []
  147. for r in result:
  148. output.append({"table_name": r[0], "execution_mode": r[1]})
  149. return output
  150. # 判断给定表名是否是 Neo4j 中的 DataResource 类型
  151. def is_data_resource_table(table_name: str) -> bool:
  152. uri = NEO4J_CONFIG['uri']
  153. auth = (NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  154. driver = GraphDatabase.driver(uri, auth=auth)
  155. query = """
  156. MATCH (n:DataResource {en_name: $table_name}) RETURN count(n) > 0 AS exists
  157. """
  158. try:
  159. with driver.session() as session:
  160. result = session.run(query, table_name=table_name)
  161. record = result.single()
  162. return record and record["exists"]
  163. finally:
  164. driver.close()
  165. # 从 Neo4j 查询 DataModel 表的 DERIVED_FROM 关系上的 script_name 属性
  166. def get_script_name_from_neo4j(table_name):
  167. uri = NEO4J_CONFIG['uri']
  168. auth = (NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  169. driver = GraphDatabase.driver(uri, auth=auth)
  170. logger.info(f"从Neo4j查询表 {table_name} 的脚本名称")
  171. # 检查查询的是 DERIVED_FROM 关系的方向
  172. check_query = """
  173. MATCH (a:DataModel {en_name: $table_name})-[r:DERIVED_FROM]->(b)
  174. RETURN b.en_name AS upstream_name LIMIT 5
  175. """
  176. try:
  177. with driver.session() as session:
  178. # 先检查依赖关系
  179. logger.info(f"检查表 {table_name} 的上游依赖方向")
  180. check_result = session.run(check_query, table_name=table_name)
  181. upstreams = [record['upstream_name'] for record in check_result if 'upstream_name' in record]
  182. logger.info(f"表 {table_name} 的上游依赖: {upstreams}")
  183. # 查询脚本名称
  184. query = """
  185. MATCH (target:DataModel {en_name: $table_name})-[r:DERIVED_FROM]->(n)
  186. WHERE n:DataModel OR n:DataResource
  187. RETURN r.script_name AS script_name
  188. """
  189. result = session.run(query, table_name=table_name)
  190. record = result.single()
  191. if record:
  192. try:
  193. script_name = record['script_name']
  194. logger.info(f"找到表 {table_name} 的脚本名称: {script_name}")
  195. return script_name
  196. except (KeyError, TypeError) as e:
  197. logger.warning(f"记录中不包含script_name字段: {e}")
  198. return None
  199. else:
  200. logger.warning(f"没有找到表 {table_name} 的脚本名称")
  201. return None
  202. except Exception as e:
  203. logger.error(f"查询表 {table_name} 的脚本名称时出错: {str(e)}")
  204. return None
  205. finally:
  206. driver.close()
  207. # 判断给定表名是否是 Neo4j 中的 DataModel 类型
  208. def is_data_model_table(table_name):
  209. uri = NEO4J_CONFIG['uri']
  210. auth = (NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  211. driver = GraphDatabase.driver(uri, auth=auth)
  212. query = """
  213. MATCH (n:DataModel {en_name: $table_name}) RETURN count(n) > 0 AS exists
  214. """
  215. try:
  216. with driver.session() as session:
  217. result = session.run(query, table_name=table_name)
  218. record = result.single()
  219. return record and record['exists']
  220. finally:
  221. driver.close()
  222. def check_script_exists(script_name):
  223. """
  224. 检查脚本文件是否存在于配置的脚本目录中
  225. 参数:
  226. script_name (str): 脚本文件名
  227. 返回:
  228. bool: 如果脚本存在返回True,否则返回False
  229. str: 完整的脚本路径
  230. """
  231. if not script_name:
  232. logger.error("脚本名称为空,无法检查")
  233. return False, None
  234. script_path = Path(SCRIPTS_BASE_PATH) / script_name
  235. script_path_str = str(script_path)
  236. logger.info(f"检查脚本路径: {script_path_str}")
  237. if os.path.exists(script_path_str):
  238. logger.info(f"脚本文件已找到: {script_path_str}")
  239. return True, script_path_str
  240. else:
  241. logger.error(f"脚本文件不存在: {script_path_str}")
  242. # 尝试列出目录中的文件
  243. try:
  244. base_dir = Path(SCRIPTS_BASE_PATH)
  245. if base_dir.exists():
  246. files = list(base_dir.glob("*.py"))
  247. logger.info(f"目录 {SCRIPTS_BASE_PATH} 中的Python文件: {[f.name for f in files]}")
  248. else:
  249. logger.error(f"基础目录不存在: {SCRIPTS_BASE_PATH}")
  250. except Exception as e:
  251. logger.error(f"列出目录内容时出错: {str(e)}")
  252. return False, script_path_str
  253. def run_model_script(table_name, execution_mode):
  254. """
  255. 根据表名查找并执行对应的模型脚本
  256. 参数:
  257. table_name (str): 要处理的表名
  258. execution_mode (str): 执行模式 (append/full_refresh)
  259. 返回:
  260. bool: 执行成功返回True,否则返回False
  261. 抛出:
  262. AirflowFailException: 如果脚本不存在或执行失败
  263. """
  264. # 从Neo4j获取脚本名称
  265. script_name = get_script_name_from_neo4j(table_name)
  266. if not script_name:
  267. error_msg = f"未找到表 {table_name} 的脚本名称,任务失败"
  268. logger.error(error_msg)
  269. raise AirflowFailException(error_msg)
  270. logger.info(f"从Neo4j获取到表 {table_name} 的脚本名称: {script_name}")
  271. # 检查脚本文件是否存在
  272. exists, script_path = check_script_exists(script_name)
  273. if not exists:
  274. error_msg = f"表 {table_name} 的脚本文件 {script_name} 不存在,任务失败"
  275. logger.error(error_msg)
  276. raise AirflowFailException(error_msg)
  277. # 执行脚本
  278. logger.info(f"开始执行脚本: {script_path}")
  279. try:
  280. # 动态导入模块
  281. import importlib.util
  282. import sys
  283. spec = importlib.util.spec_from_file_location("dynamic_module", script_path)
  284. module = importlib.util.module_from_spec(spec)
  285. spec.loader.exec_module(module)
  286. # 检查并调用标准入口函数run
  287. if hasattr(module, "run"):
  288. logger.info(f"调用脚本 {script_name} 的标准入口函数 run()")
  289. module.run(table_name=table_name, execution_mode=execution_mode)
  290. logger.info(f"脚本 {script_name} 执行成功")
  291. return True
  292. else:
  293. error_msg = f"脚本 {script_name} 中未定义标准入口函数 run(),任务失败"
  294. logger.error(error_msg)
  295. raise AirflowFailException(error_msg)
  296. except AirflowFailException:
  297. # 直接重新抛出Airflow异常
  298. raise
  299. except Exception as e:
  300. error_msg = f"执行脚本 {script_name} 时出错: {str(e)}"
  301. logger.error(error_msg)
  302. import traceback
  303. logger.error(traceback.format_exc())
  304. raise AirflowFailException(error_msg)
  305. def get_model_dependency_graph(table_names: list) -> dict:
  306. """
  307. 使用networkx从Neo4j获取指定DataModel表之间的依赖关系图
  308. 参数:
  309. table_names: 表名列表
  310. 返回:
  311. dict: 依赖关系字典 {目标表: [上游依赖表1, 上游依赖表2, ...]}
  312. """
  313. logger.info(f"开始构建依赖关系图,表列表: {table_names}")
  314. # 创建有向图
  315. G = nx.DiGraph()
  316. # 添加所有节点
  317. for table_name in table_names:
  318. G.add_node(table_name)
  319. # 从Neo4j获取依赖关系并添加边
  320. uri = NEO4J_CONFIG['uri']
  321. auth = (NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  322. driver = GraphDatabase.driver(uri, auth=auth)
  323. try:
  324. with driver.session() as session:
  325. # 使用一次性查询获取所有表之间的依赖关系
  326. # 注意:这里查询的是 A-[:DERIVED_FROM]->B 关系,表示A依赖B
  327. # 记录原始查询参数用于调试
  328. logger.info(f"查询参数 table_names: {table_names}, 类型: {type(table_names)}")
  329. # 第一层查询 - 更明确的查询形式
  330. query = """
  331. MATCH (source)-[r:DERIVED_FROM]->(target)
  332. WHERE source.en_name IN $table_names AND target.en_name IN $table_names
  333. RETURN source.en_name AS source, target.en_name AS target, r.script_name AS script_name
  334. """
  335. logger.info(f"执行Neo4j查询: 查找所有表之间的依赖关系")
  336. result = session.run(query, table_names=table_names)
  337. # 转换结果为列表,确保结果被消费
  338. result_records = list(result)
  339. logger.info(f"第一层查询返回记录数: {len(result_records)}")
  340. # 处理依赖关系
  341. found_deps = 0
  342. # 初始化依赖字典
  343. dependency_dict = {name: [] for name in table_names}
  344. # 这里是问题所在 - 需要正确处理记录
  345. for record in result_records:
  346. # 直接将记录转换为字典,避免访问问题
  347. record_dict = dict(record)
  348. # 从字典中获取值
  349. source = record_dict.get('source')
  350. target = record_dict.get('target')
  351. script_name = record_dict.get('script_name', 'unknown_script')
  352. # 确保字段存在且有值
  353. if source and target:
  354. logger.info(f"发现依赖关系: {source} -[:DERIVED_FROM]-> {target}, 脚本: {script_name}")
  355. # 添加依赖关系到字典
  356. if source in dependency_dict:
  357. dependency_dict[source].append(target)
  358. found_deps += 1
  359. # 添加边到图 - 把被依赖方指向依赖方,表示执行顺序(被依赖方先执行)
  360. G.add_edge(target, source)
  361. logger.info(f"添加执行顺序边: {target} -> {source} (因为{source}依赖{target})")
  362. logger.info(f"总共发现 {found_deps} 个依赖关系")
  363. # 如果没有找到依赖关系,尝试检查所有可能的表对关系
  364. if found_deps == 0:
  365. logger.warning("仍未找到依赖关系,尝试检查所有表对之间的关系")
  366. logger.info("第三层查询: 开始表对之间的循环检查")
  367. logger.info(f"要检查的表对数量: {len(table_names) * (len(table_names) - 1)}")
  368. pair_count = 0
  369. for source_table in table_names:
  370. for target_table in table_names:
  371. if source_table != target_table:
  372. pair_count += 1
  373. logger.info(f"检查表对[{pair_count}]: {source_table} -> {target_table}")
  374. check_result = check_table_relationship(source_table, target_table)
  375. # 检查forward方向的关系
  376. if 'forward' in check_result and check_result['forward']['exists']:
  377. script_name = check_result['forward'].get('script_name', 'unknown_script')
  378. logger.info(f"表对检查发现关系: {source_table} -[:DERIVED_FROM]-> {target_table}, 脚本: {script_name}")
  379. dependency_dict[source_table].append(target_table)
  380. G.add_edge(target_table, source_table)
  381. found_deps += 1
  382. logger.info(f"表对检查后找到 {found_deps} 个依赖关系")
  383. finally:
  384. driver.close()
  385. # 检测循环依赖
  386. try:
  387. cycles = list(nx.simple_cycles(G))
  388. if cycles:
  389. logger.warning(f"检测到表间循环依赖: {cycles}")
  390. except Exception as e:
  391. logger.error(f"检查循环依赖失败: {str(e)}")
  392. # 将图转换为字典格式
  393. final_dependency_dict = {}
  394. for table_name in table_names:
  395. final_dependency_dict[table_name] = dependency_dict.get(table_name, [])
  396. logger.info(f"最终依赖关系 - 表 {table_name} 依赖于: {final_dependency_dict[table_name]}")
  397. logger.info(f"完整依赖图: {final_dependency_dict}")
  398. return final_dependency_dict
  399. def generate_optimized_execution_order(table_names, dependency_dict=None):
  400. """
  401. 生成优化的执行顺序,处理循环依赖
  402. 参数:
  403. table_names: 表名列表
  404. dependency_dict: 依赖关系字典 {表名: [依赖表1, 依赖表2, ...]}
  405. 如果为None,则通过get_model_dependency_graph获取
  406. 返回:
  407. list: 优化后的执行顺序列表
  408. """
  409. # 创建有向图
  410. G = nx.DiGraph()
  411. # 添加所有节点
  412. for table_name in table_names:
  413. G.add_node(table_name)
  414. # 获取依赖关系
  415. if dependency_dict is None:
  416. # 使用原始utils.py的get_model_dependency_graph获取依赖
  417. dependency_dict = get_model_dependency_graph(table_names)
  418. # 添加依赖边 - 从上游指向目标
  419. for target, upstreams in dependency_dict.items():
  420. for upstream in upstreams:
  421. G.add_edge(upstream, target)
  422. else:
  423. # 使用提供的dependency_dict - 从依赖指向目标
  424. for target, sources in dependency_dict.items():
  425. for source in sources:
  426. if source in table_names: # 确保只考虑目标表集合中的表
  427. G.add_edge(source, target)
  428. # 检测循环依赖
  429. cycles = list(nx.simple_cycles(G))
  430. if cycles:
  431. logger.warning(f"检测到循环依赖,将尝试打破循环: {cycles}")
  432. # 打破循环依赖(简单策略:移除每个循环中的一条边)
  433. for cycle in cycles:
  434. # 移除循环中的最后一条边
  435. G.remove_edge(cycle[-1], cycle[0])
  436. logger.info(f"打破循环依赖: 移除 {cycle[-1]} -> {cycle[0]} 的依赖")
  437. # 生成拓扑排序
  438. try:
  439. execution_order = list(nx.topological_sort(G))
  440. return execution_order
  441. except Exception as e:
  442. logger.error(f"生成执行顺序失败: {str(e)}")
  443. # 返回原始列表作为备选
  444. return table_names
  445. def check_table_relationship(table1, table2):
  446. """
  447. 直接检查Neo4j中两个表之间的关系
  448. 参数:
  449. table1: 第一个表名
  450. table2: 第二个表名
  451. 返回:
  452. 关系信息字典
  453. """
  454. uri = NEO4J_CONFIG['uri']
  455. auth = (NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  456. driver = GraphDatabase.driver(uri, auth=auth)
  457. relationship_info = {}
  458. try:
  459. with driver.session() as session:
  460. # 检查 table1 -> table2 方向
  461. forward_query = """
  462. MATCH (a:DataModel {en_name: $table1})-[r:DERIVED_FROM]->(b:DataModel {en_name: $table2})
  463. RETURN count(r) > 0 AS has_relationship, r.script_name AS script_name
  464. """
  465. forward_result = session.run(forward_query, table1=table1, table2=table2)
  466. forward_record = forward_result.single()
  467. if forward_record and forward_record['has_relationship']:
  468. relationship_info['forward'] = {
  469. 'exists': True,
  470. 'direction': f"{table1} -> {table2}",
  471. 'script_name': forward_record.get('script_name')
  472. }
  473. logger.info(f"发现关系: {table1} -[:DERIVED_FROM]-> {table2}, 脚本: {forward_record.get('script_name')}")
  474. else:
  475. relationship_info['forward'] = {'exists': False}
  476. # 检查 table2 -> table1 方向
  477. backward_query = """
  478. MATCH (a:DataModel {en_name: $table2})-[r:DERIVED_FROM]->(b:DataModel {en_name: $table1})
  479. RETURN count(r) > 0 AS has_relationship, r.script_name AS script_name
  480. """
  481. backward_result = session.run(backward_query, table1=table1, table2=table2)
  482. backward_record = backward_result.single()
  483. if backward_record and backward_record['has_relationship']:
  484. relationship_info['backward'] = {
  485. 'exists': True,
  486. 'direction': f"{table2} -> {table1}",
  487. 'script_name': backward_record.get('script_name')
  488. }
  489. logger.info(f"发现关系: {table2} -[:DERIVED_FROM]-> {table1}, 脚本: {backward_record.get('script_name')}")
  490. else:
  491. relationship_info['backward'] = {'exists': False}
  492. except Exception as e:
  493. logger.error(f"检查表关系时出错: {str(e)}")
  494. relationship_info['error'] = str(e)
  495. finally:
  496. driver.close()
  497. return relationship_info
  498. def build_model_dependency_dag(table_names, model_tables):
  499. """
  500. 基于表名列表构建模型依赖DAG,返回优化后的执行顺序和依赖关系图
  501. 参数:
  502. table_names: 表名列表
  503. model_tables: 表配置列表
  504. 返回:
  505. tuple: (优化后的表执行顺序, 依赖关系图)
  506. """
  507. # 使用优化函数生成执行顺序,可以处理循环依赖
  508. optimized_table_order = generate_optimized_execution_order(table_names)
  509. logger.info(f"生成优化执行顺序, 共 {len(optimized_table_order)} 个表")
  510. # 获取依赖图
  511. dependency_graph = get_model_dependency_graph(table_names)
  512. logger.info(f"构建了 {len(dependency_graph)} 个表的依赖关系图")
  513. return optimized_table_order, dependency_graph
  514. def create_task_dict(optimized_table_order, model_tables, dag, execution_type, **task_options):
  515. """
  516. 根据优化后的表执行顺序创建任务字典
  517. 参数:
  518. optimized_table_order: 优化后的表执行顺序
  519. model_tables: 表配置列表
  520. dag: Airflow DAG对象
  521. execution_type: 执行类型(daily, monthly等)
  522. task_options: 任务创建的额外选项
  523. 返回:
  524. dict: 任务字典 {表名: 任务对象}
  525. """
  526. from airflow.operators.python import PythonOperator
  527. task_dict = {}
  528. for table_name in optimized_table_order:
  529. # 获取表的配置信息
  530. table_config = next((t for t in model_tables if t['table_name'] == table_name), None)
  531. if table_config:
  532. try:
  533. # 构建基础参数
  534. task_params = {
  535. "task_id": f"process_{execution_type}_{table_name}",
  536. "python_callable": run_model_script,
  537. "op_kwargs": {"table_name": table_name, "execution_mode": table_config['execution_mode']},
  538. "dag": dag
  539. }
  540. # 添加额外选项
  541. if task_options:
  542. # 如果有表特定的选项,使用它们
  543. if table_name in task_options:
  544. task_params.update(task_options[table_name])
  545. # 如果有全局选项,使用它们
  546. elif 'default' in task_options:
  547. task_params.update(task_options['default'])
  548. task = PythonOperator(**task_params)
  549. task_dict[table_name] = task
  550. logger.info(f"创建模型处理任务: {task_params['task_id']}")
  551. except Exception as e:
  552. logger.error(f"创建任务 process_{execution_type}_{table_name} 时出错: {str(e)}")
  553. raise
  554. return task_dict
  555. def build_task_dependencies(task_dict, dependency_graph):
  556. """
  557. 根据依赖图设置任务间的依赖关系
  558. 参数:
  559. task_dict: 任务字典
  560. dependency_graph: 依赖关系图
  561. 返回:
  562. tuple: (tasks_with_upstream, tasks_with_downstream, dependency_count)
  563. """
  564. tasks_with_upstream = set() # 用于跟踪已经有上游任务的节点
  565. dependency_count = 0
  566. for target, upstream_list in dependency_graph.items():
  567. if target in task_dict:
  568. for upstream in upstream_list:
  569. if upstream in task_dict:
  570. logger.info(f"建立任务依赖: {upstream} >> {target}")
  571. task_dict[upstream] >> task_dict[target]
  572. tasks_with_upstream.add(target) # 记录此任务已有上游
  573. dependency_count += 1
  574. # 找出有下游任务的节点
  575. tasks_with_downstream = set()
  576. for target, upstream_list in dependency_graph.items():
  577. if target in task_dict: # 目标任务在当前DAG中
  578. for upstream in upstream_list:
  579. if upstream in task_dict: # 上游任务也在当前DAG中
  580. tasks_with_downstream.add(upstream) # 这个上游任务有下游
  581. logger.info(f"总共建立了 {dependency_count} 个任务之间的依赖关系")
  582. logger.info(f"已有上游任务的节点: {tasks_with_upstream}")
  583. return tasks_with_upstream, tasks_with_downstream, dependency_count
  584. def connect_start_and_end_tasks(task_dict, tasks_with_upstream, tasks_with_downstream,
  585. wait_task, completed_task, dag_type):
  586. """
  587. 连接开始节点到等待任务,末端节点到完成标记
  588. 参数:
  589. task_dict: 任务字典
  590. tasks_with_upstream: 有上游任务的节点集合
  591. tasks_with_downstream: 有下游任务的节点集合
  592. wait_task: 等待任务
  593. completed_task: 完成标记任务
  594. dag_type: DAG类型名称(用于日志)
  595. 返回:
  596. tuple: (start_tasks, end_tasks)
  597. """
  598. # 连接开始节点
  599. start_tasks = []
  600. for table_name, task in task_dict.items():
  601. if table_name not in tasks_with_upstream:
  602. start_tasks.append(table_name)
  603. logger.info(f"任务 {table_name} 没有上游任务,应该连接到{dag_type}等待任务")
  604. logger.info(f"需要连接到{dag_type}等待任务的任务: {start_tasks}")
  605. for task_name in start_tasks:
  606. wait_task >> task_dict[task_name]
  607. logger.info(f"连接 {wait_task.task_id} >> {task_name}")
  608. # 连接末端节点
  609. end_tasks = []
  610. for table_name, task in task_dict.items():
  611. if table_name not in tasks_with_downstream:
  612. end_tasks.append(table_name)
  613. logger.info(f"任务 {table_name} 没有下游任务,是末端任务")
  614. logger.info(f"需要连接到{dag_type}完成标记的末端任务: {end_tasks}")
  615. for end_task in end_tasks:
  616. task_dict[end_task] >> completed_task
  617. logger.info(f"连接 {end_task} >> {completed_task.task_id}")
  618. # 处理特殊情况
  619. logger.info("处理特殊情况")
  620. if not start_tasks:
  621. logger.warning(f"没有找到开始任务,将{dag_type}等待任务直接连接到完成标记")
  622. wait_task >> completed_task
  623. if not end_tasks:
  624. logger.warning(f"没有找到末端任务,将所有任务连接到{dag_type}完成标记")
  625. for table_name, task in task_dict.items():
  626. task >> completed_task
  627. logger.info(f"直接连接任务到完成标记: {table_name} >> {completed_task.task_id}")
  628. return start_tasks, end_tasks
  629. def get_neo4j_driver():
  630. """获取Neo4j连接驱动"""
  631. uri = NEO4J_CONFIG['uri']
  632. auth = (NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  633. return GraphDatabase.driver(uri, auth=auth)
  634. def update_task_start_time(exec_date, target_table, script_name, start_time):
  635. """更新任务开始时间"""
  636. logger.info(f"===== 更新任务开始时间 =====")
  637. logger.info(f"参数: exec_date={exec_date} ({type(exec_date).__name__}), target_table={target_table}, script_name={script_name}")
  638. conn = get_pg_conn()
  639. cursor = conn.cursor()
  640. try:
  641. # 首先检查记录是否存在
  642. cursor.execute("""
  643. SELECT COUNT(*)
  644. FROM airflow_dag_schedule
  645. WHERE exec_date = %s AND target_table = %s AND script_name = %s
  646. """, (exec_date, target_table, script_name))
  647. count = cursor.fetchone()[0]
  648. logger.info(f"查询到符合条件的记录数: {count}")
  649. if count == 0:
  650. logger.warning(f"未找到匹配的记录: exec_date={exec_date}, target_table={target_table}, script_name={script_name}")
  651. logger.info("尝试记录在airflow_dag_schedule表中找到的记录:")
  652. cursor.execute("""
  653. SELECT exec_date, target_table, script_name
  654. FROM airflow_dag_schedule
  655. LIMIT 5
  656. """)
  657. sample_records = cursor.fetchall()
  658. for record in sample_records:
  659. logger.info(f"样本记录: exec_date={record[0]} ({type(record[0]).__name__}), target_table={record[1]}, script_name={record[2]}")
  660. # 执行更新
  661. sql = """
  662. UPDATE airflow_dag_schedule
  663. SET exec_start_time = %s
  664. WHERE exec_date = %s AND target_table = %s AND script_name = %s
  665. """
  666. logger.info(f"执行SQL: {sql}")
  667. logger.info(f"参数: start_time={start_time}, exec_date={exec_date}, target_table={target_table}, script_name={script_name}")
  668. cursor.execute(sql, (start_time, exec_date, target_table, script_name))
  669. affected_rows = cursor.rowcount
  670. logger.info(f"更新影响的行数: {affected_rows}")
  671. conn.commit()
  672. logger.info("事务已提交")
  673. except Exception as e:
  674. logger.error(f"更新任务开始时间失败: {str(e)}")
  675. import traceback
  676. logger.error(f"错误堆栈: {traceback.format_exc()}")
  677. conn.rollback()
  678. logger.info("事务已回滚")
  679. raise
  680. finally:
  681. cursor.close()
  682. conn.close()
  683. logger.info("数据库连接已关闭")
  684. logger.info("===== 更新任务开始时间完成 =====")
  685. def update_task_completion(exec_date, target_table, script_name, success, end_time, duration):
  686. """更新任务完成信息"""
  687. logger.info(f"===== 更新任务完成信息 =====")
  688. logger.info(f"参数: exec_date={exec_date} ({type(exec_date).__name__}), target_table={target_table}, script_name={script_name}")
  689. logger.info(f"参数: success={success} ({type(success).__name__}), end_time={end_time}, duration={duration}")
  690. conn = get_pg_conn()
  691. cursor = conn.cursor()
  692. try:
  693. # 首先检查记录是否存在
  694. cursor.execute("""
  695. SELECT COUNT(*)
  696. FROM airflow_dag_schedule
  697. WHERE exec_date = %s AND target_table = %s AND script_name = %s
  698. """, (exec_date, target_table, script_name))
  699. count = cursor.fetchone()[0]
  700. logger.info(f"查询到符合条件的记录数: {count}")
  701. if count == 0:
  702. logger.warning(f"未找到匹配的记录: exec_date={exec_date}, target_table={target_table}, script_name={script_name}")
  703. # 查询表中前几条记录作为参考
  704. cursor.execute("""
  705. SELECT exec_date, target_table, script_name
  706. FROM airflow_dag_schedule
  707. LIMIT 5
  708. """)
  709. sample_records = cursor.fetchall()
  710. logger.info("airflow_dag_schedule表中的样本记录:")
  711. for record in sample_records:
  712. logger.info(f"样本记录: exec_date={record[0]} ({type(record[0]).__name__}), target_table={record[1]}, script_name={record[2]}")
  713. # 确保success是布尔类型
  714. if not isinstance(success, bool):
  715. original_success = success
  716. success = bool(success)
  717. logger.warning(f"success参数不是布尔类型,原始值: {original_success},转换为: {success}")
  718. # 执行更新
  719. sql = """
  720. UPDATE airflow_dag_schedule
  721. SET exec_result = %s, exec_end_time = %s, exec_duration = %s
  722. WHERE exec_date = %s AND target_table = %s AND script_name = %s
  723. """
  724. logger.info(f"执行SQL: {sql}")
  725. logger.info(f"参数: success={success}, end_time={end_time}, duration={duration}, exec_date={exec_date}, target_table={target_table}, script_name={script_name}")
  726. cursor.execute(sql, (success, end_time, duration, exec_date, target_table, script_name))
  727. affected_rows = cursor.rowcount
  728. logger.info(f"更新影响的行数: {affected_rows}")
  729. if affected_rows == 0:
  730. logger.warning("更新操作没有影响任何行,可能是因为条件不匹配")
  731. # 尝试用不同格式的exec_date查询
  732. if isinstance(exec_date, str):
  733. try:
  734. # 尝试解析日期字符串
  735. from datetime import datetime
  736. parsed_date = datetime.strptime(exec_date, "%Y-%m-%d").date()
  737. logger.info(f"尝试使用解析后的日期格式: {parsed_date}")
  738. cursor.execute("""
  739. SELECT COUNT(*)
  740. FROM airflow_dag_schedule
  741. WHERE exec_date = %s AND target_table = %s AND script_name = %s
  742. """, (parsed_date, target_table, script_name))
  743. parsed_count = cursor.fetchone()[0]
  744. logger.info(f"使用解析日期后查询到的记录数: {parsed_count}")
  745. if parsed_count > 0:
  746. # 尝试用解析的日期更新
  747. cursor.execute("""
  748. UPDATE airflow_dag_schedule
  749. SET exec_result = %s, exec_end_time = %s, exec_duration = %s
  750. WHERE exec_date = %s AND target_table = %s AND script_name = %s
  751. """, (success, end_time, duration, parsed_date, target_table, script_name))
  752. new_affected_rows = cursor.rowcount
  753. logger.info(f"使用解析日期后更新影响的行数: {new_affected_rows}")
  754. except Exception as parse_e:
  755. logger.error(f"尝试解析日期格式时出错: {str(parse_e)}")
  756. conn.commit()
  757. logger.info("事务已提交")
  758. except Exception as e:
  759. logger.error(f"更新任务完成信息失败: {str(e)}")
  760. import traceback
  761. logger.error(f"错误堆栈: {traceback.format_exc()}")
  762. conn.rollback()
  763. logger.info("事务已回滚")
  764. raise
  765. finally:
  766. cursor.close()
  767. conn.close()
  768. logger.info("数据库连接已关闭")
  769. logger.info("===== 更新任务完成信息完成 =====")
  770. def execute_with_monitoring(target_table, script_name, script_exec_mode, exec_date, **kwargs):
  771. """执行脚本并监控执行情况"""
  772. # 添加详细日志
  773. logger.info(f"===== 开始监控执行 =====")
  774. logger.info(f"target_table: {target_table}, 类型: {type(target_table)}")
  775. logger.info(f"script_name: {script_name}, 类型: {type(script_name)}")
  776. logger.info(f"script_exec_mode: {script_exec_mode}, 类型: {type(script_exec_mode)}")
  777. logger.info(f"exec_date: {exec_date}, 类型: {type(exec_date)}")
  778. # 检查script_name是否为空
  779. if not script_name:
  780. logger.error(f"表 {target_table} 的script_name为空,无法执行")
  781. # 记录执行失败
  782. now = datetime.now()
  783. update_task_completion(exec_date, target_table, script_name or "", False, now, 0)
  784. return False
  785. # 记录执行开始时间
  786. start_time = datetime.now()
  787. # 尝试更新开始时间并记录结果
  788. try:
  789. update_task_start_time(exec_date, target_table, script_name, start_time)
  790. logger.info(f"成功更新任务开始时间: {start_time}")
  791. except Exception as e:
  792. logger.error(f"更新任务开始时间失败: {str(e)}")
  793. try:
  794. # 执行实际脚本
  795. logger.info(f"开始执行脚本: {script_name}")
  796. result = execute_script(script_name, target_table, script_exec_mode)
  797. logger.info(f"脚本执行完成,原始返回值: {result}, 类型: {type(result)}")
  798. # 确保result是布尔值
  799. if result is None:
  800. logger.warning(f"脚本返回值为None,转换为False")
  801. result = False
  802. elif not isinstance(result, bool):
  803. original_result = result
  804. result = bool(result)
  805. logger.warning(f"脚本返回非布尔值 {original_result},转换为布尔值: {result}")
  806. # 记录结束时间和结果
  807. end_time = datetime.now()
  808. duration = (end_time - start_time).total_seconds()
  809. # 尝试更新完成状态并记录结果
  810. try:
  811. logger.info(f"尝试更新完成状态: result={result}, end_time={end_time}, duration={duration}")
  812. update_task_completion(exec_date, target_table, script_name, result, end_time, duration)
  813. logger.info(f"成功更新任务完成状态,结果: {result}")
  814. except Exception as e:
  815. logger.error(f"更新任务完成状态失败: {str(e)}")
  816. logger.info(f"===== 监控执行完成 =====")
  817. return result
  818. except Exception as e:
  819. # 处理异常
  820. logger.error(f"执行任务出错: {str(e)}")
  821. end_time = datetime.now()
  822. duration = (end_time - start_time).total_seconds()
  823. # 尝试更新失败状态并记录结果
  824. try:
  825. logger.info(f"尝试更新失败状态: end_time={end_time}, duration={duration}")
  826. update_task_completion(exec_date, target_table, script_name, False, end_time, duration)
  827. logger.info(f"成功更新任务失败状态")
  828. except Exception as update_e:
  829. logger.error(f"更新任务失败状态失败: {str(update_e)}")
  830. logger.info(f"===== 监控执行异常结束 =====")
  831. raise e
  832. def ensure_boolean_result(func):
  833. """装饰器:确保函数返回布尔值"""
  834. @functools.wraps(func)
  835. def wrapper(*args, **kwargs):
  836. try:
  837. result = func(*args, **kwargs)
  838. logger.debug(f"脚本原始返回值: {result} (类型: {type(result).__name__})")
  839. # 处理None值
  840. if result is None:
  841. logger.warning(f"脚本函数 {func.__name__} 返回了None,默认设置为False")
  842. return False
  843. # 处理非布尔值
  844. if not isinstance(result, bool):
  845. try:
  846. # 尝试转换为布尔值
  847. bool_result = bool(result)
  848. logger.warning(f"脚本函数 {func.__name__} 返回非布尔值 {result},已转换为布尔值 {bool_result}")
  849. return bool_result
  850. except Exception as e:
  851. logger.error(f"无法将脚本返回值 {result} 转换为布尔值: {str(e)}")
  852. return False
  853. return result
  854. except Exception as e:
  855. logger.error(f"脚本函数 {func.__name__} 执行出错: {str(e)}")
  856. return False
  857. return wrapper
  858. def get_today_date():
  859. """获取今天的日期,返回YYYY-MM-DD格式字符串"""
  860. return datetime.now().strftime("%Y-%m-%d")
  861. def get_cn_exec_date(logical_date):
  862. """
  863. 获取逻辑执行日期
  864. 参数:
  865. logical_date: 逻辑执行日期,UTC时间
  866. 返回:
  867. logical_exec_date: 逻辑执行日期,北京时间
  868. local_logical_date: 北京时区的logical_date
  869. """
  870. # 获取逻辑执行日期
  871. local_logical_date = pendulum.instance(logical_date).in_timezone('Asia/Shanghai')
  872. exec_date = local_logical_date.strftime('%Y-%m-%d')
  873. return exec_date, local_logical_date