dag_manual_unified_dependency_trigger.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. from airflow import DAG
  2. from airflow.operators.empty import EmptyOperator
  3. from airflow.operators.python import PythonOperator
  4. from airflow.decorators import task
  5. from datetime import datetime, timedelta
  6. from config import NEO4J_CONFIG
  7. from utils import execute_script
  8. import logging
  9. from neo4j import GraphDatabase
  10. import networkx as nx
  11. logger = logging.getLogger(__name__)
  12. default_args = {
  13. 'owner': 'airflow',
  14. 'depends_on_past': False,
  15. 'email_on_failure': False,
  16. 'email_on_retry': False,
  17. 'retries': 1,
  18. 'retry_delay': timedelta(minutes=5),
  19. }
  20. def build_dependency_chain_nx(start_table, dependency_level="resource"):
  21. uri = NEO4J_CONFIG['uri']
  22. auth = (NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  23. driver = GraphDatabase.driver(uri, auth=auth)
  24. logger.info(f"构建表 {start_table} 的依赖链(层级: {dependency_level})")
  25. G = nx.DiGraph()
  26. node_info = {}
  27. with driver.session() as session:
  28. query = f"""
  29. MATCH path=(target:Table {{en_name: $start_table}})<-[:DERIVED_FROM*0..]-(source)
  30. WHERE (all(n in nodes(path) WHERE n:DataModel OR n:DataResource))
  31. WITH collect(DISTINCT source.en_name) + '{start_table}' AS tables
  32. UNWIND tables AS tname
  33. MATCH (t:Table {{en_name: tname}})
  34. OPTIONAL MATCH (t)-[r:DERIVED_FROM]->(up)
  35. RETURN t.en_name AS table_name,
  36. labels(t) AS labels,
  37. r.script_name AS script_name,
  38. r.script_exec_mode AS script_exec_mode,
  39. up.en_name AS upstream
  40. """
  41. records = session.run(query, start_table=start_table)
  42. for record in records:
  43. name = record['table_name']
  44. script_name = record.get('script_name')
  45. script_exec_mode = record.get('script_exec_mode') or 'append'
  46. upstream = record.get('upstream')
  47. node_type = None
  48. for label in record.get('labels', []):
  49. if label in ['DataModel', 'DataResource']:
  50. node_type = label
  51. if name not in node_info:
  52. node_info[name] = {
  53. 'table_name': name,
  54. 'script_name': script_name or f"{name}.py",
  55. 'execution_mode': script_exec_mode,
  56. 'table_type': node_type,
  57. 'upstream_tables': []
  58. }
  59. if upstream:
  60. G.add_edge(upstream, name)
  61. node_info[name]['upstream_tables'].append(upstream)
  62. driver.close()
  63. execution_order = list(nx.topological_sort(G))
  64. logger.info(f"拓扑排序执行顺序: {execution_order}")
  65. dependency_chain = []
  66. for table in execution_order:
  67. if table in node_info:
  68. dependency_chain.append(node_info[table])
  69. return dependency_chain
  70. with DAG(
  71. dag_id='dag_manual_dependency_unified_trigger',
  72. start_date=datetime(2024, 1, 1),
  73. schedule_interval=None,
  74. catchup=False,
  75. default_args=default_args,
  76. description='运行时构建任务,支持conf参数,展示拓扑依赖图'
  77. ) as dag:
  78. start = EmptyOperator(task_id='start')
  79. end = EmptyOperator(task_id='end')
  80. @task()
  81. def get_dependency_chain(**context):
  82. conf = context['dag_run'].conf if context.get('dag_run') else {}
  83. table_name = conf.get("TABLE_NAME", "book_sale_amt_2yearly")
  84. dependency_level = conf.get("DEPENDENCY_LEVEL", "resource")
  85. logger.info(f"手动传入参数: TABLE_NAME={table_name}, DEPENDENCY_LEVEL={dependency_level}")
  86. return build_dependency_chain_nx(table_name, dependency_level)
  87. def create_task_callable(table_name, script_name, execution_mode):
  88. def _inner_callable():
  89. logger.info(f"执行任务:{table_name} using {script_name} mode={execution_mode}")
  90. if not execute_script(script_name, table_name, execution_mode):
  91. raise Exception(f"脚本 {script_name} 执行失败")
  92. return _inner_callable
  93. def create_runtime_tasks(chain, dag):
  94. task_dict = {}
  95. for item in chain:
  96. table = item['table_name']
  97. script = item['script_name']
  98. mode = item['execution_mode']
  99. task = PythonOperator(
  100. task_id=f"run_{table}",
  101. python_callable=create_task_callable(table, script, mode),
  102. dag=dag
  103. )
  104. task_dict[table] = task
  105. for item in chain:
  106. downstream = item['table_name']
  107. upstreams = item.get('upstream_tables', [])
  108. if not upstreams:
  109. start >> task_dict[downstream]
  110. else:
  111. for up in upstreams:
  112. if up in task_dict:
  113. task_dict[up] >> task_dict[downstream]
  114. for task in task_dict.values():
  115. task >> end
  116. from airflow.operators.python import PythonOperator
  117. def wrapper(**context):
  118. chain = context['ti'].xcom_pull(task_ids='get_dependency_chain')
  119. create_runtime_tasks(chain, dag)
  120. chain_task = get_dependency_chain()
  121. build_tasks = PythonOperator(
  122. task_id='build_runtime_tasks',
  123. python_callable=wrapper,
  124. provide_context=True
  125. )
  126. start >> chain_task >> build_tasks >> end