dag_manual_trigger_chain_optimized.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480
  1. # dag_manual_trigger_chain_optimized.py
  2. from airflow import DAG
  3. from airflow.operators.python import PythonOperator
  4. from airflow.operators.empty import EmptyOperator
  5. from airflow.models.param import Param
  6. from datetime import datetime, timedelta
  7. import logging
  8. import os
  9. from pathlib import Path
  10. import networkx as nx
  11. from neo4j import GraphDatabase
  12. from config import NEO4J_CONFIG, SCRIPTS_BASE_PATH
  13. # 导入工具函数
  14. from utils import (
  15. get_pg_conn, is_data_model_table, is_data_resource_table,
  16. get_script_name_from_neo4j, execute_script,
  17. get_model_dependency_graph, generate_optimized_execution_order,
  18. build_model_dependency_dag, create_task_dict, build_task_dependencies,
  19. connect_start_and_end_tasks
  20. )
  21. # 设置logger
  22. logger = logging.getLogger(__name__)
  23. # DAG参数
  24. default_args = {
  25. 'owner': 'airflow',
  26. 'depends_on_past': False,
  27. 'start_date': datetime(2024, 1, 1),
  28. 'email_on_failure': False,
  29. 'email_on_retry': False,
  30. 'retries': 1,
  31. 'retry_delay': timedelta(minutes=5),
  32. }
  33. def get_execution_mode(table_name):
  34. """
  35. 从PostgreSQL获取表的执行模式
  36. 参数:
  37. table_name (str): 表名
  38. 注意:
  39. "AND is_enabled = TRUE" 这个条件在这里不适用,因为这是强制执行的。
  40. 即使订阅表中没有这个表名,也会强制执行。
  41. 返回:
  42. str: 执行模式,如果未找到则返回"append"作为默认值
  43. """
  44. try:
  45. conn = get_pg_conn()
  46. cursor = conn.cursor()
  47. cursor.execute("""
  48. SELECT execution_mode
  49. FROM table_schedule
  50. WHERE table_name = %s
  51. """, (table_name,))
  52. result = cursor.fetchone()
  53. cursor.close()
  54. conn.close()
  55. if result:
  56. return result[0]
  57. else:
  58. logger.warning(f"未找到表 {table_name} 的执行模式,使用默认值 'append'")
  59. return "append"
  60. except Exception as e:
  61. logger.error(f"获取表 {table_name} 的执行模式时出错: {str(e)}")
  62. return "append"
  63. def get_upstream_models(table_name):
  64. """获取表的上游DataModel依赖"""
  65. driver = GraphDatabase.driver(
  66. NEO4J_CONFIG['uri'],
  67. auth=(NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  68. )
  69. query = """
  70. MATCH (target:DataModel {en_name: $table_name})-[:DERIVED_FROM]->(up:DataModel)
  71. RETURN up.en_name AS upstream
  72. """
  73. try:
  74. with driver.session() as session:
  75. result = session.run(query, table_name=table_name)
  76. upstream_list = [record["upstream"] for record in result]
  77. logger.info(f"表 {table_name} 的上游DataModel依赖: {upstream_list}")
  78. return upstream_list
  79. finally:
  80. driver.close()
  81. def get_upstream_resources(table_name):
  82. """获取表的上游DataResource依赖"""
  83. driver = GraphDatabase.driver(
  84. NEO4J_CONFIG['uri'],
  85. auth=(NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  86. )
  87. query = """
  88. MATCH (target:DataModel {en_name: $table_name})-[:DERIVED_FROM]->(up:DataResource)
  89. RETURN up.en_name AS upstream
  90. """
  91. try:
  92. with driver.session() as session:
  93. result = session.run(query, table_name=table_name)
  94. upstream_list = [record["upstream"] for record in result]
  95. logger.info(f"表 {table_name} 的上游DataResource依赖: {upstream_list}")
  96. return upstream_list
  97. finally:
  98. driver.close()
  99. def get_data_sources(resource_table_name):
  100. """获取DataResource表的上游DataSource"""
  101. driver = GraphDatabase.driver(
  102. NEO4J_CONFIG['uri'],
  103. auth=(NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  104. )
  105. query = """
  106. MATCH (dr:DataResource {en_name: $table_name})-[:ORIGINATES_FROM]->(ds:DataSource)
  107. RETURN ds.en_name AS source_name
  108. """
  109. try:
  110. with driver.session() as session:
  111. result = session.run(query, table_name=resource_table_name)
  112. return [record["source_name"] for record in result]
  113. finally:
  114. driver.close()
  115. def build_dependency_chain_with_networkx(start_table, dependency_level='resource'):
  116. """
  117. 使用networkx构建依赖链
  118. 参数:
  119. start_table (str): 起始表名
  120. dependency_level (str): 依赖级别,可选值:
  121. 'self' - 只包含起始表自身
  122. 'resource' - 包含到DataResource层级(默认)
  123. 'source' - 包含到DataSource层级
  124. 返回:
  125. dict: 依赖图 {表名: 表信息字典}
  126. networkx.DiGraph: 表示依赖关系的有向图
  127. """
  128. logger.info(f"使用networkx构建依赖链, 起始表: {start_table}, 依赖级别: {dependency_level}")
  129. # 创建有向图
  130. G = nx.DiGraph()
  131. # 添加起始节点
  132. G.add_node(start_table)
  133. # 记录表类型和脚本信息的字典
  134. table_info = {}
  135. # 只执行起始表自身
  136. if dependency_level == 'self':
  137. # 确定表类型并记录信息
  138. if is_data_model_table(start_table):
  139. script_name = get_script_name_from_neo4j(start_table)
  140. execution_mode = get_execution_mode(start_table)
  141. table_info[start_table] = {
  142. 'table_name': start_table,
  143. 'script_name': script_name,
  144. 'table_type': 'DataModel',
  145. 'execution_mode': execution_mode
  146. }
  147. elif is_data_resource_table(start_table):
  148. script_name = get_script_name_from_neo4j(start_table)
  149. execution_mode = get_execution_mode(start_table)
  150. table_info[start_table] = {
  151. 'table_name': start_table,
  152. 'script_name': script_name,
  153. 'table_type': 'DataResource',
  154. 'execution_mode': execution_mode
  155. }
  156. logger.info(f"依赖级别为'self',只处理起始表: {start_table}")
  157. return table_info, G
  158. # 处理完整依赖链
  159. # 用于检测循环的已访问集合
  160. visited = set()
  161. def add_dependencies(table, level):
  162. """递归添加依赖到图中"""
  163. if table in visited:
  164. return
  165. visited.add(table)
  166. # 确定表类型并记录信息
  167. if is_data_model_table(table):
  168. script_name = get_script_name_from_neo4j(table)
  169. execution_mode = get_execution_mode(start_table)
  170. table_info[table] = {
  171. 'table_name': table,
  172. 'script_name': script_name,
  173. 'table_type': 'DataModel',
  174. 'execution_mode': execution_mode
  175. }
  176. # 添加DataModel上游依赖
  177. upstream_models = get_upstream_models(table)
  178. for upstream in upstream_models:
  179. G.add_node(upstream)
  180. G.add_edge(upstream, table) # 上游指向下游,执行时上游先执行
  181. add_dependencies(upstream, level)
  182. # 添加DataResource上游依赖
  183. upstream_resources = get_upstream_resources(table)
  184. for upstream in upstream_resources:
  185. G.add_node(upstream)
  186. G.add_edge(upstream, table)
  187. add_dependencies(upstream, level)
  188. elif is_data_resource_table(table):
  189. script_name = get_script_name_from_neo4j(table)
  190. execution_mode = get_execution_mode(start_table)
  191. table_info[table] = {
  192. 'table_name': table,
  193. 'script_name': script_name,
  194. 'table_type': 'DataResource',
  195. 'execution_mode': execution_mode
  196. }
  197. # 如果依赖级别为source,则继续查找DataSource
  198. if level == 'source':
  199. data_sources = get_data_sources(table)
  200. for source in data_sources:
  201. G.add_node(source)
  202. G.add_edge(source, table)
  203. table_info[source] = {
  204. 'table_name': source,
  205. 'script_name': None,
  206. 'table_type': 'DataSource',
  207. 'execution_mode': None
  208. }
  209. # 开始递归构建依赖图
  210. add_dependencies(start_table, dependency_level)
  211. # 检测和处理循环依赖
  212. cycles = list(nx.simple_cycles(G))
  213. if cycles:
  214. logger.warning(f"检测到循环依赖: {cycles}")
  215. for cycle in cycles:
  216. # 移除循环中的最后一条边来打破循环
  217. G.remove_edge(cycle[-1], cycle[0])
  218. logger.info(f"打破循环依赖: 移除 {cycle[-1]} -> {cycle[0]} 的依赖")
  219. return table_info, G
  220. def process_table_chain_manual(target_table, dependency_level, **context):
  221. """
  222. 处理指定表及其依赖链
  223. 参数:
  224. target_table: 目标表名
  225. dependency_level: 依赖级别
  226. """
  227. logger.info(f"开始处理表 {target_table} 的依赖链,依赖级别: {dependency_level}")
  228. # 构建依赖链
  229. table_info, dependency_graph = build_dependency_chain_with_networkx(target_table, dependency_level)
  230. if not table_info:
  231. logger.warning(f"没有找到表 {target_table} 的依赖信息")
  232. return
  233. # 使用networkx生成拓扑排序
  234. try:
  235. execution_order = list(nx.topological_sort(dependency_graph))
  236. logger.info(f"生成拓扑排序结果: {execution_order}")
  237. except nx.NetworkXUnfeasible:
  238. logger.error("无法生成拓扑排序,图可能仍然包含循环")
  239. execution_order = [target_table] # 至少包含目标表
  240. # 过滤掉DataSource类型的表
  241. model_tables = []
  242. for table_name in execution_order:
  243. if table_name in table_info and table_info[table_name]['table_type'] != 'DataSource':
  244. model_tables.append({
  245. 'table_name': table_name,
  246. 'script_name': table_info[table_name]['script_name'],
  247. 'execution_mode': table_info[table_name]['execution_mode'],
  248. 'table_type': table_info[table_name]['table_type']
  249. })
  250. # 按顺序处理表
  251. processed_count = 0
  252. failed_tables = []
  253. for table_info in model_tables:
  254. table_name = table_info['table_name']
  255. script_name = table_info['script_name']
  256. execution_mode = table_info['execution_mode']
  257. logger.info(f"处理表: {table_name} ({processed_count + 1}/{len(model_tables)})")
  258. result = execute_script(script_name, table_name, execution_mode)
  259. if result:
  260. processed_count += 1
  261. logger.info(f"表 {table_name} 处理成功")
  262. else:
  263. failed_tables.append(table_name)
  264. logger.error(f"表 {table_name} 处理失败")
  265. # 是否要中断处理?取决于您的要求
  266. # break
  267. # 处理结果
  268. logger.info(f"依赖链处理完成: 成功 {processed_count} 个表, 失败 {len(failed_tables)} 个表")
  269. if failed_tables:
  270. logger.warning(f"失败的表: {', '.join(failed_tables)}")
  271. return False
  272. return True
  273. def run_model_script(table_name, execution_mode):
  274. """
  275. 运行模型脚本的封装函数,用于Airflow任务
  276. 参数:
  277. table_name: 表名
  278. execution_mode: 执行模式
  279. """
  280. script_name = get_script_name_from_neo4j(table_name)
  281. return execute_script(script_name, table_name, execution_mode)
  282. # 创建DAG
  283. with DAG(
  284. 'dag_manual_trigger_chain_optimized',
  285. default_args=default_args,
  286. description='手动触发指定表的依赖链执行(串行任务链)',
  287. schedule_interval=None, # 设置为None表示只能手动触发
  288. catchup=False,
  289. is_paused_upon_creation=False, # DAG创建时不处于暂停状态
  290. params={
  291. 'TABLE_NAME': Param('', type='string', description='目标表名称'),
  292. 'DEPENDENCY_LEVEL': Param('resource', type='string', enum=['self', 'resource', 'source'], description='依赖级别: self-仅本表, resource-到Resource层, source-到Source层')
  293. },
  294. ) as dag:
  295. # 起始任务
  296. start_task = EmptyOperator(
  297. task_id='start',
  298. dag=dag,
  299. )
  300. # 结束任务
  301. end_task = EmptyOperator(
  302. task_id='end',
  303. dag=dag,
  304. )
  305. # 分析依赖的任务
  306. def analyze_dependencies(**context):
  307. """分析依赖链,准备表信息和执行顺序"""
  308. # 获取参数
  309. params = context['params']
  310. target_table = params.get('TABLE_NAME')
  311. dependency_level = params.get('DEPENDENCY_LEVEL', 'resource')
  312. if not target_table:
  313. raise ValueError("必须提供TABLE_NAME参数")
  314. # 验证依赖级别参数
  315. valid_levels = ['self', 'resource', 'source']
  316. if dependency_level not in valid_levels:
  317. logger.warning(f"无效的依赖级别: {dependency_level},使用默认值 'resource'")
  318. dependency_level = 'resource'
  319. logger.info(f"开始分析表 {target_table} 的依赖链, 依赖级别: {dependency_level}")
  320. # 构建依赖链
  321. table_info, dependency_graph = build_dependency_chain_with_networkx(target_table, dependency_level)
  322. if not table_info:
  323. logger.warning(f"没有找到表 {target_table} 的依赖信息")
  324. return []
  325. # 使用networkx生成拓扑排序
  326. try:
  327. execution_order = list(nx.topological_sort(dependency_graph))
  328. logger.info(f"生成拓扑排序结果: {execution_order}")
  329. except nx.NetworkXUnfeasible:
  330. logger.error("无法生成拓扑排序,图可能仍然包含循环")
  331. execution_order = [target_table] # 至少包含目标表
  332. # 准备表信息
  333. model_tables = []
  334. for table_name in execution_order:
  335. if table_name in table_info and table_info[table_name]['table_type'] != 'DataSource':
  336. model_tables.append({
  337. 'table_name': table_name,
  338. 'script_name': table_info[table_name]['script_name'],
  339. 'execution_mode': table_info[table_name]['execution_mode']
  340. })
  341. # 将结果保存到XCom
  342. ti = context['ti']
  343. ti.xcom_push(key='model_tables', value=model_tables)
  344. ti.xcom_push(key='dependency_graph', value={k: list(v) for k, v in dependency_graph.items()})
  345. return model_tables
  346. # 创建分析任务
  347. analyze_task = PythonOperator(
  348. task_id='analyze_dependencies',
  349. python_callable=analyze_dependencies,
  350. provide_context=True,
  351. dag=dag,
  352. )
  353. # 创建构建任务链的任务
  354. def create_dynamic_task_chain(**context):
  355. """创建动态任务链"""
  356. ti = context['ti']
  357. model_tables = ti.xcom_pull(task_ids='analyze_dependencies', key='model_tables')
  358. dependency_data = ti.xcom_pull(task_ids='analyze_dependencies', key='dependency_graph')
  359. if not model_tables:
  360. logger.warning("没有找到需要处理的表")
  361. return 'end'
  362. # 重建依赖图为networkx格式
  363. dependency_graph = nx.DiGraph()
  364. for target, upstreams in dependency_data.items():
  365. for upstream in upstreams:
  366. dependency_graph.add_edge(upstream, target)
  367. # 提取表名列表
  368. table_names = [t['table_name'] for t in model_tables]
  369. # 创建任务字典
  370. task_dict = {}
  371. for table_info in model_tables:
  372. table_name = table_info['table_name']
  373. task_id = f"process_{table_name}"
  374. # 创建处理任务
  375. task = PythonOperator(
  376. task_id=task_id,
  377. python_callable=run_model_script,
  378. op_kwargs={
  379. 'table_name': table_name,
  380. 'execution_mode': table_info['execution_mode']
  381. },
  382. dag=dag,
  383. )
  384. task_dict[table_name] = task
  385. # 设置任务间依赖关系
  386. for i, table_name in enumerate(table_names):
  387. if i > 0:
  388. prev_table = table_names[i-1]
  389. task_dict[prev_table] >> task_dict[table_name]
  390. logger.info(f"设置任务依赖: {prev_table} >> {table_name}")
  391. # 连接第一个和最后一个任务
  392. if table_names:
  393. first_task = task_dict[table_names[0]]
  394. last_task = task_dict[table_names[-1]]
  395. analyze_task >> first_task
  396. last_task >> end_task
  397. else:
  398. # 如果没有表需要处理,直接连接到结束任务
  399. analyze_task >> end_task
  400. return 'end'
  401. # 创建构建任务链的任务
  402. build_task = PythonOperator(
  403. task_id='build_task_chain',
  404. python_callable=create_dynamic_task_chain,
  405. provide_context=True,
  406. dag=dag,
  407. )
  408. # 设置任务链
  409. start_task >> analyze_task >> build_task