dag_dependency_analysis.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408
  1. # dag_dependency_analysis.py
  2. from airflow import DAG
  3. from airflow.operators.python import PythonOperator
  4. from airflow.models.param import Param
  5. from datetime import datetime, timedelta
  6. import logging
  7. import os
  8. from pathlib import Path
  9. import networkx as nx
  10. from neo4j import GraphDatabase
  11. from config import NEO4J_CONFIG, SCRIPTS_BASE_PATH
  12. # 导入工具函数
  13. from utils import (
  14. get_pg_conn, is_data_model_table, is_data_resource_table,
  15. get_script_name_from_neo4j, execute_script,
  16. check_script_exists, run_model_script
  17. )
  18. # 设置logger
  19. logger = logging.getLogger(__name__)
  20. # DAG参数
  21. default_args = {
  22. 'owner': 'airflow',
  23. 'depends_on_past': False,
  24. 'start_date': datetime(2024, 1, 1),
  25. 'email_on_failure': False,
  26. 'email_on_retry': False,
  27. 'retries': 1,
  28. 'retry_delay': timedelta(minutes=5),
  29. }
  30. def get_execution_mode(table_name):
  31. """
  32. 从PostgreSQL获取表的执行模式
  33. 参数:
  34. table_name (str): 表名
  35. 返回:
  36. str: 执行模式,如果未找到则返回"append"作为默认值
  37. """
  38. try:
  39. conn = get_pg_conn()
  40. cursor = conn.cursor()
  41. cursor.execute("""
  42. SELECT execution_mode
  43. FROM table_schedule
  44. WHERE table_name = %s
  45. """, (table_name,))
  46. result = cursor.fetchone()
  47. cursor.close()
  48. conn.close()
  49. if result:
  50. return result[0]
  51. else:
  52. logger.warning(f"未找到表 {table_name} 的执行模式,使用默认值 'append'")
  53. return "append"
  54. except Exception as e:
  55. logger.error(f"获取表 {table_name} 的执行模式时出错: {str(e)}")
  56. return "append"
  57. def get_table_metadata(table_name):
  58. """
  59. 获取表的元数据信息
  60. 参数:
  61. table_name (str): 表名
  62. 返回:
  63. dict: 表的元数据
  64. """
  65. driver = GraphDatabase.driver(
  66. NEO4J_CONFIG['uri'],
  67. auth=(NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  68. )
  69. metadata = {
  70. 'table_name': table_name,
  71. 'type': None,
  72. 'script_name': None,
  73. 'execution_mode': get_execution_mode(table_name)
  74. }
  75. try:
  76. # 判断表类型
  77. if is_data_model_table(table_name):
  78. metadata['type'] = 'DataModel'
  79. elif is_data_resource_table(table_name):
  80. metadata['type'] = 'DataResource'
  81. else:
  82. # 尝试查询是否为DataSource类型
  83. with driver.session() as session:
  84. query = """
  85. MATCH (ds:DataSource {en_name: $table_name})
  86. RETURN count(ds) > 0 AS exists
  87. """
  88. result = session.run(query, table_name=table_name)
  89. record = result.single()
  90. if record and record['exists']:
  91. metadata['type'] = 'DataSource'
  92. # 查询脚本名称
  93. if metadata['type'] in ['DataModel', 'DataResource']:
  94. metadata['script_name'] = get_script_name_from_neo4j(table_name)
  95. return metadata
  96. finally:
  97. driver.close()
  98. def get_upstream_tables(table_name, dependency_level):
  99. """
  100. 获取表的上游依赖
  101. 参数:
  102. table_name (str): 表名
  103. dependency_level (str): 依赖级别 (self/resource/source)
  104. 返回:
  105. list: 上游表名列表
  106. """
  107. # 如果只需要自身,返回空列表
  108. if dependency_level == 'self':
  109. return []
  110. driver = GraphDatabase.driver(
  111. NEO4J_CONFIG['uri'],
  112. auth=(NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  113. )
  114. upstream_tables = []
  115. try:
  116. with driver.session() as session:
  117. # 根据依赖级别构建不同的查询
  118. if dependency_level == 'resource':
  119. # 查询上游DataModel和DataResource表
  120. query = """
  121. MATCH (target {en_name: $table_name})-[:DERIVED_FROM]->(up)
  122. WHERE up:DataModel OR up:DataResource
  123. RETURN up.en_name AS upstream, labels(up) AS types
  124. """
  125. else: # source级别
  126. # 查询所有上游表,包括DataSource
  127. query = """
  128. MATCH (target {en_name: $table_name})-[:DERIVED_FROM]->(up)
  129. RETURN up.en_name AS upstream, labels(up) AS types
  130. """
  131. result = session.run(query, table_name=table_name)
  132. for record in result:
  133. upstream_tables.append({
  134. 'table_name': record['upstream'],
  135. 'type': record['types'][0] if record['types'] else 'Unknown'
  136. })
  137. return upstream_tables
  138. finally:
  139. driver.close()
  140. def build_dependency_graph(start_table, dependency_level):
  141. """
  142. 构建依赖图
  143. 参数:
  144. start_table (str): 起始表名
  145. dependency_level (str): 依赖级别 (self/resource/source)
  146. 返回:
  147. tuple: (表信息字典, 依赖图)
  148. """
  149. logger.info(f"开始构建 {start_table} 的依赖图,依赖级别: {dependency_level}")
  150. # 创建有向图
  151. G = nx.DiGraph()
  152. # 添加起始节点
  153. G.add_node(start_table)
  154. # 记录表信息的字典
  155. table_info = {}
  156. # 获取起始表的元数据
  157. table_metadata = get_table_metadata(start_table)
  158. table_info[start_table] = table_metadata
  159. # 如果依赖级别为self,只返回起始表的信息
  160. if dependency_level == 'self':
  161. logger.info(f"依赖级别为'self',只包含起始表: {start_table}")
  162. return table_info, G
  163. # 记录已访问的表,避免循环
  164. visited = set()
  165. def add_dependencies(table_name):
  166. """递归添加依赖到图中"""
  167. if table_name in visited:
  168. return
  169. visited.add(table_name)
  170. # 获取上游依赖
  171. upstream_tables = get_upstream_tables(table_name, dependency_level)
  172. for upstream in upstream_tables:
  173. up_table_name = upstream['table_name']
  174. # 添加节点和边
  175. G.add_node(up_table_name)
  176. G.add_edge(up_table_name, table_name) # 上游指向下游,执行时上游先执行
  177. # 递归处理上游依赖
  178. if up_table_name not in table_info:
  179. up_metadata = get_table_metadata(up_table_name)
  180. table_info[up_table_name] = up_metadata
  181. # 如果是resource级别,不继续处理DataSource节点
  182. if dependency_level == 'resource' and up_metadata['type'] == 'DataSource':
  183. continue
  184. add_dependencies(up_table_name)
  185. # 开始递归构建依赖图
  186. add_dependencies(start_table)
  187. # 检测和处理循环依赖
  188. cycles = list(nx.simple_cycles(G))
  189. if cycles:
  190. logger.warning(f"检测到循环依赖: {cycles}")
  191. for cycle in cycles:
  192. # 移除循环中的最后一条边来打破循环
  193. G.remove_edge(cycle[-1], cycle[0])
  194. logger.info(f"打破循环依赖: 移除 {cycle[-1]} -> {cycle[0]} 的依赖")
  195. return table_info, G
  196. def optimize_execution_order(dependency_graph):
  197. """
  198. 优化执行顺序
  199. 参数:
  200. dependency_graph: NetworkX依赖图
  201. 返回:
  202. list: 优化后的执行顺序
  203. """
  204. # 使用拓扑排序生成执行顺序
  205. try:
  206. execution_order = list(nx.topological_sort(dependency_graph))
  207. logger.info(f"生成拓扑排序: {execution_order}")
  208. return execution_order
  209. except nx.NetworkXUnfeasible:
  210. logger.error("无法生成拓扑排序,图可能仍然包含循环")
  211. # 返回图中的所有节点作为备选
  212. return list(dependency_graph.nodes())
  213. def analyze_and_prepare_dag(**context):
  214. """
  215. 分析依赖关系并准备DAG结构,但不执行任何脚本
  216. """
  217. # 获取参数
  218. params = context['params']
  219. target_table = params.get('TABLE_NAME')
  220. dependency_level = params.get('DEPENDENCY_LEVEL', 'resource')
  221. if not target_table:
  222. raise ValueError("必须提供TABLE_NAME参数")
  223. logger.info(f"开始分析表 {target_table} 的依赖,依赖级别: {dependency_level}")
  224. # 构建依赖图
  225. table_info, dependency_graph = build_dependency_graph(target_table, dependency_level)
  226. if not table_info:
  227. logger.warning(f"没有找到表 {target_table} 的依赖信息")
  228. return {}
  229. # 优化执行顺序
  230. execution_order = optimize_execution_order(dependency_graph)
  231. # 过滤掉没有脚本的表
  232. executable_tables = [
  233. table_name for table_name in execution_order
  234. if table_name in table_info and table_info[table_name]['script_name']
  235. ]
  236. logger.info(f"需要执行的表: {executable_tables}")
  237. # 返回执行计划,包含每个表的信息和执行顺序
  238. execution_plan = {
  239. 'executable_tables': executable_tables,
  240. 'table_info': {k: v for k, v in table_info.items() if k in executable_tables},
  241. 'dependencies': {
  242. k: list(dependency_graph.predecessors(k))
  243. for k in executable_tables
  244. }
  245. }
  246. return execution_plan
  247. # 创建DAG
  248. with DAG(
  249. 'dag_dependency_analysis',
  250. default_args=default_args,
  251. description='分析表依赖路径并执行相关脚本',
  252. schedule_interval=None, # 设置为None表示只能手动触发
  253. catchup=False,
  254. is_paused_upon_creation=False,
  255. params={
  256. 'TABLE_NAME': Param('', type='string', description='目标表名称'),
  257. 'DEPENDENCY_LEVEL': Param('resource', type='string', enum=['self', 'resource', 'source'], description='依赖级别: self-仅本表, resource-到Resource层, source-到Source层')
  258. },
  259. ) as dag:
  260. # 创建分析依赖的任务
  261. analyze_task = PythonOperator(
  262. task_id='analyze_dependencies',
  263. python_callable=analyze_and_prepare_dag,
  264. provide_context=True,
  265. dag=dag,
  266. )
  267. # 动态确定要执行的任务列表
  268. def determine_and_create_tasks(**context):
  269. """
  270. 根据分析结果确定要执行的任务,并动态创建任务
  271. """
  272. # 获取analyze_dependencies任务的输出
  273. ti = context['ti']
  274. execution_plan = ti.xcom_pull(task_ids='analyze_dependencies')
  275. if not execution_plan or 'executable_tables' not in execution_plan:
  276. logger.warning("未获取到执行计划,无法创建任务")
  277. return None
  278. executable_tables = execution_plan.get('executable_tables', [])
  279. table_info = execution_plan.get('table_info', {})
  280. dependencies = execution_plan.get('dependencies', {})
  281. if not executable_tables:
  282. logger.warning("没有表需要执行")
  283. return None
  284. # 记录执行计划
  285. logger.info(f"要执行的表: {executable_tables}")
  286. for table_name in executable_tables:
  287. logger.info(f"表 {table_name} 的信息: {table_info.get(table_name, {})}")
  288. logger.info(f"表 {table_name} 的依赖: {dependencies.get(table_name, [])}")
  289. # 为每个需要执行的表创建任务
  290. for table_name in executable_tables:
  291. table_data = table_info.get(table_name, {})
  292. execution_mode = table_data.get('execution_mode', 'append')
  293. # 创建处理任务
  294. task = PythonOperator(
  295. task_id=f'process_{table_name}',
  296. python_callable=run_model_script,
  297. op_kwargs={
  298. 'table_name': table_name,
  299. 'execution_mode': execution_mode
  300. },
  301. dag=dag,
  302. )
  303. # 设置依赖关系
  304. # 当前表依赖的上游表
  305. upstream_tables = dependencies.get(table_name, [])
  306. # 过滤出在executable_tables中的上游表
  307. upstream_tables = [t for t in upstream_tables if t in executable_tables]
  308. for upstream in upstream_tables:
  309. # 获取上游任务(假设已经创建)
  310. upstream_task = dag.get_task(f'process_{upstream}')
  311. if upstream_task:
  312. # 设置依赖: 上游任务 >> 当前任务
  313. upstream_task >> task
  314. logger.info(f"设置任务依赖: process_{upstream} >> process_{table_name}")
  315. # 如果没有上游任务,直接依赖于分析任务
  316. if not upstream_tables:
  317. analyze_task >> task
  318. logger.info(f"设置任务依赖: analyze_dependencies >> process_{table_name}")
  319. # 找到没有依赖的第一个表(入口任务)
  320. entry_tables = [
  321. table for table in executable_tables
  322. if not dependencies.get(table, [])
  323. ]
  324. # 返回入口任务的ID,如果有的话
  325. if entry_tables:
  326. return f'process_{entry_tables[0]}'
  327. else:
  328. # 如果没有明确的入口任务,使用第一个表
  329. return f'process_{executable_tables[0]}'
  330. # 使用BranchPythonOperator
  331. branch_task = PythonOperator(
  332. task_id='branch_and_create_tasks',
  333. python_callable=determine_and_create_tasks,
  334. provide_context=True,
  335. dag=dag,
  336. )
  337. # 设置基本任务流
  338. analyze_task >> branch_task