dag_manual_trigger_chain.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440
  1. # dag_manual_trigger_chain_two_level.py
  2. from airflow import DAG
  3. from airflow.operators.python import PythonOperator
  4. from datetime import datetime, timedelta
  5. import logging
  6. import importlib.util
  7. import os
  8. from pathlib import Path
  9. from neo4j import GraphDatabase
  10. import psycopg2
  11. from config import NEO4J_CONFIG, SCRIPTS_BASE_PATH, PG_CONFIG
  12. # 设置logger
  13. logger = logging.getLogger(__name__)
  14. # DAG参数
  15. default_args = {
  16. 'owner': 'airflow',
  17. 'depends_on_past': False,
  18. 'start_date': datetime(2024, 1, 1),
  19. 'email_on_failure': False,
  20. 'email_on_retry': False,
  21. 'retries': 1,
  22. 'retry_delay': timedelta(minutes=5),
  23. }
  24. def get_pg_conn():
  25. """获取PostgreSQL连接"""
  26. return psycopg2.connect(**PG_CONFIG)
  27. def get_execution_mode(table_name):
  28. """
  29. 从PostgreSQL获取表的执行模式
  30. 参数:
  31. table_name (str): 表名
  32. 注意:
  33. "AND is_enabled = TRUE" 这个条件在这里不适用,因为这是强制执行的。
  34. 即使订阅表中没有这个表名,也会强制执行。
  35. 返回:
  36. str: 执行模式,如果未找到则返回"append"作为默认值
  37. """
  38. try:
  39. conn = get_pg_conn()
  40. cursor = conn.cursor()
  41. cursor.execute("""
  42. SELECT execution_mode
  43. FROM table_schedule
  44. WHERE table_name = %s
  45. """, (table_name,))
  46. result = cursor.fetchone()
  47. cursor.close()
  48. conn.close()
  49. if result:
  50. return result[0]
  51. else:
  52. logger.warning(f"未找到表 {table_name} 的执行模式,使用默认值 'append'")
  53. return "append"
  54. except Exception as e:
  55. logger.error(f"获取表 {table_name} 的执行模式时出错: {str(e)}")
  56. return "append"
  57. def get_dag_params(**context):
  58. """获取DAG运行参数"""
  59. params = context.get('params', {})
  60. table_name = params.get('TABLE_NAME')
  61. upper_level_stop = params.get('UPPER_LEVEL_STOP', 'DataResource') # 默认值为DataResource
  62. if not table_name:
  63. raise ValueError("必须提供TABLE_NAME参数")
  64. logger.info(f"开始处理表: {table_name}, 上游停止级别: {upper_level_stop}")
  65. return table_name, upper_level_stop
  66. def is_data_model_table(table_name):
  67. """判断表是否为DataModel类型"""
  68. driver = GraphDatabase.driver(
  69. NEO4J_CONFIG['uri'],
  70. auth=(NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  71. )
  72. query = """
  73. MATCH (n:DataModel {en_name: $table_name}) RETURN count(n) > 0 AS exists
  74. """
  75. try:
  76. with driver.session() as session:
  77. result = session.run(query, table_name=table_name)
  78. record = result.single()
  79. return record and record["exists"]
  80. finally:
  81. driver.close()
  82. def is_data_resource_table(table_name):
  83. """判断表是否为DataResource类型"""
  84. driver = GraphDatabase.driver(
  85. NEO4J_CONFIG['uri'],
  86. auth=(NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  87. )
  88. query = """
  89. MATCH (n:DataResource {en_name: $table_name}) RETURN count(n) > 0 AS exists
  90. """
  91. try:
  92. with driver.session() as session:
  93. result = session.run(query, table_name=table_name)
  94. record = result.single()
  95. return record and record["exists"]
  96. finally:
  97. driver.close()
  98. def get_upstream_models(table_name):
  99. """获取表的上游DataModel依赖"""
  100. driver = GraphDatabase.driver(
  101. NEO4J_CONFIG['uri'],
  102. auth=(NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  103. )
  104. query = """
  105. MATCH (target:DataModel {en_name: $table_name})-[:DERIVED_FROM]->(up:DataModel)
  106. RETURN up.en_name AS upstream
  107. """
  108. try:
  109. with driver.session() as session:
  110. result = session.run(query, table_name=table_name)
  111. upstream_list = [record["upstream"] for record in result]
  112. logger.info(f"表 {table_name} 的上游DataModel依赖: {upstream_list}")
  113. return upstream_list
  114. finally:
  115. driver.close()
  116. def get_upstream_resources(table_name):
  117. """获取表的上游DataResource依赖"""
  118. driver = GraphDatabase.driver(
  119. NEO4J_CONFIG['uri'],
  120. auth=(NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  121. )
  122. query = """
  123. MATCH (target:DataModel {en_name: $table_name})-[:DERIVED_FROM]->(up:DataResource)
  124. RETURN up.en_name AS upstream
  125. """
  126. try:
  127. with driver.session() as session:
  128. result = session.run(query, table_name=table_name)
  129. upstream_list = [record["upstream"] for record in result]
  130. logger.info(f"表 {table_name} 的上游DataResource依赖: {upstream_list}")
  131. return upstream_list
  132. finally:
  133. driver.close()
  134. def get_data_sources(resource_table_name):
  135. """获取DataResource表的上游DataSource"""
  136. driver = GraphDatabase.driver(
  137. NEO4J_CONFIG['uri'],
  138. auth=(NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  139. )
  140. query = """
  141. MATCH (dr:DataResource {en_name: $table_name})-[:ORIGINATES_FROM]->(ds:DataSource)
  142. RETURN ds.en_name AS source_name
  143. """
  144. try:
  145. with driver.session() as session:
  146. result = session.run(query, table_name=resource_table_name)
  147. return [record["source_name"] for record in result]
  148. finally:
  149. driver.close()
  150. def get_script_name_for_model(table_name):
  151. """获取DataModel表对应的脚本名称"""
  152. driver = GraphDatabase.driver(
  153. NEO4J_CONFIG['uri'],
  154. auth=(NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  155. )
  156. query = """
  157. MATCH (target:DataModel {en_name: $table_name})-[r:DERIVED_FROM]->(n)
  158. WHERE n:DataModel OR n:DataResource
  159. RETURN r.script_name AS script_name
  160. """
  161. try:
  162. with driver.session() as session:
  163. result = session.run(query, table_name=table_name)
  164. record = result.single()
  165. if record:
  166. return record["script_name"]
  167. else:
  168. logger.warning(f"未找到DataModel表 {table_name} 的脚本名称")
  169. return None
  170. except Exception as e:
  171. logger.error(f"查询表 {table_name} 的脚本名称时出错: {str(e)}")
  172. return None
  173. finally:
  174. driver.close()
  175. def get_script_name_for_resource(table_name):
  176. """获取DataResource表对应的脚本名称"""
  177. driver = GraphDatabase.driver(
  178. NEO4J_CONFIG['uri'],
  179. auth=(NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  180. )
  181. query = """
  182. MATCH (dr:DataResource {en_name: $table_name})-[rel:ORIGINATES_FROM]->(ds:DataSource)
  183. RETURN rel.script_name AS script_name
  184. """
  185. try:
  186. with driver.session() as session:
  187. result = session.run(query, table_name=table_name)
  188. record = result.single()
  189. if record:
  190. return record["script_name"]
  191. else:
  192. logger.warning(f"未找到DataResource表 {table_name} 的脚本名称")
  193. return None
  194. except Exception as e:
  195. logger.error(f"查询表 {table_name} 的脚本名称时出错: {str(e)}")
  196. return None
  197. finally:
  198. driver.close()
  199. def build_dependency_chain(start_table, upper_level_stop='DataResource', visited=None):
  200. """
  201. 递归构建依赖链
  202. 参数:
  203. start_table (str): 起始表名
  204. upper_level_stop (str): 上游停止级别
  205. visited (set): 已访问的表集合,避免循环依赖
  206. 返回:
  207. list: 依赖链列表,按执行顺序排序(从上游到下游)
  208. """
  209. if visited is None:
  210. visited = set()
  211. if start_table in visited:
  212. return []
  213. visited.add(start_table)
  214. dependency_chain = []
  215. # 判断表类型
  216. if is_data_model_table(start_table):
  217. # 处理DataModel表
  218. script_name = get_script_name_for_model(start_table)
  219. execution_mode = get_execution_mode(start_table)
  220. # 获取上游DataModel
  221. upstream_models = get_upstream_models(start_table)
  222. for upstream in upstream_models:
  223. # 将上游依赖添加到链条前面
  224. upstream_chain = build_dependency_chain(upstream, upper_level_stop, visited)
  225. dependency_chain.extend(upstream_chain)
  226. # 获取上游DataResource
  227. upstream_resources = get_upstream_resources(start_table)
  228. for upstream in upstream_resources:
  229. # 将上游依赖添加到链条前面
  230. upstream_chain = build_dependency_chain(upstream, upper_level_stop, visited)
  231. dependency_chain.extend(upstream_chain)
  232. # 当前表添加到链条末尾
  233. dependency_chain.append({
  234. 'table_name': start_table,
  235. 'script_name': script_name,
  236. 'table_type': 'DataModel',
  237. 'execution_mode': execution_mode
  238. })
  239. elif is_data_resource_table(start_table):
  240. # 处理DataResource表
  241. script_name = get_script_name_for_resource(start_table)
  242. execution_mode = get_execution_mode(start_table)
  243. # 如果上游停止级别为DataSource,则继续查找DataSource并先添加
  244. if upper_level_stop == 'DataSource':
  245. data_sources = get_data_sources(start_table)
  246. for source in data_sources:
  247. dependency_chain.append({
  248. 'table_name': source,
  249. 'script_name': None, # DataSource没有脚本
  250. 'table_type': 'DataSource',
  251. 'execution_mode': None
  252. })
  253. # 当前DataResource表添加到链条末尾
  254. dependency_chain.append({
  255. 'table_name': start_table,
  256. 'script_name': script_name,
  257. 'table_type': 'DataResource',
  258. 'execution_mode': execution_mode
  259. })
  260. return dependency_chain
  261. def execute_scripts(scripts_list):
  262. """
  263. 执行指定的脚本列表
  264. 参数:
  265. scripts_list (list): 要执行的脚本信息列表,每项包含table_name, script_name, execution_mode
  266. 返回:
  267. bool: 全部执行成功返回True,任一失败返回False
  268. """
  269. if not scripts_list:
  270. logger.info("没有脚本需要执行")
  271. return True
  272. success = True
  273. for item in scripts_list:
  274. script_name = item['script_name']
  275. table_name = item['table_name']
  276. execution_mode = item['execution_mode']
  277. if not script_name:
  278. logger.warning(f"表 {table_name} 没有对应的脚本,跳过执行")
  279. continue
  280. logger.info(f"执行脚本: {script_name}, 表: {table_name}, 模式: {execution_mode}")
  281. try:
  282. script_path = Path(SCRIPTS_BASE_PATH) / script_name
  283. if not os.path.exists(script_path):
  284. logger.error(f"脚本文件不存在: {script_path}")
  285. success = False
  286. break
  287. # 动态导入模块
  288. spec = importlib.util.spec_from_file_location("dynamic_module", script_path)
  289. module = importlib.util.module_from_spec(spec)
  290. spec.loader.exec_module(module)
  291. # 使用标准入口函数run
  292. if hasattr(module, "run"):
  293. logger.info(f"执行脚本 {script_name} 的标准入口函数 run()")
  294. result = module.run(table_name=table_name, execution_mode=execution_mode)
  295. if result:
  296. logger.info(f"脚本 {script_name} 执行成功")
  297. else:
  298. logger.error(f"脚本 {script_name} 执行失败")
  299. success = False
  300. break
  301. else:
  302. logger.warning(f"脚本 {script_name} 未定义标准入口函数 run(),无法执行")
  303. success = False
  304. break
  305. except Exception as e:
  306. logger.error(f"执行脚本 {script_name} 时出错: {str(e)}")
  307. success = False
  308. break
  309. return success
  310. def prepare_dependency_chain(**context):
  311. """准备依赖链并保存到XCom"""
  312. # 获取参数
  313. table_name, upper_level_stop = get_dag_params(**context)
  314. # 获取依赖链
  315. dependency_chain = build_dependency_chain(table_name, upper_level_stop)
  316. if not dependency_chain:
  317. logger.warning(f"没有找到表 {table_name} 的依赖链")
  318. return False
  319. # 记录完整依赖链
  320. logger.info(f"依赖链完整列表: {[item['table_name'] for item in dependency_chain]}")
  321. # 过滤掉DataSource类型(它们没有脚本需要执行)
  322. dependency_chain = [item for item in dependency_chain if item['table_type'] != 'DataSource']
  323. # 保存依赖链到XCom以便后续任务使用
  324. ti = context['ti']
  325. ti.xcom_push(key='dependency_chain', value=dependency_chain)
  326. # 检查是否有各类型的脚本需要执行
  327. has_resource = any(item['table_type'] == 'DataResource' for item in dependency_chain)
  328. has_model = any(item['table_type'] == 'DataModel' for item in dependency_chain)
  329. logger.info(f"是否有DataResource脚本: {has_resource}, 是否有DataModel脚本: {has_model}")
  330. return True
  331. def process_resources(**context):
  332. """处理所有DataResource层的脚本"""
  333. # 获取任务间共享变量
  334. ti = context['ti']
  335. dependency_chain = ti.xcom_pull(task_ids='prepare_dependency_chain', key='dependency_chain')
  336. # 过滤出DataResource类型的表
  337. resource_scripts = [item for item in dependency_chain if item['table_type'] == 'DataResource']
  338. logger.info(f"要执行的DataResource脚本: {[item['table_name'] for item in resource_scripts]}")
  339. # 执行所有DataResource脚本
  340. return execute_scripts(resource_scripts)
  341. def process_models(**context):
  342. """处理所有DataModel层的脚本"""
  343. # 获取任务间共享变量
  344. ti = context['ti']
  345. dependency_chain = ti.xcom_pull(task_ids='prepare_dependency_chain', key='dependency_chain')
  346. # 过滤出DataModel类型的表
  347. model_scripts = [item for item in dependency_chain if item['table_type'] == 'DataModel']
  348. logger.info(f"要执行的DataModel脚本: {[item['table_name'] for item in model_scripts]}")
  349. # 执行所有DataModel脚本
  350. return execute_scripts(model_scripts)
  351. # 创建DAG
  352. with DAG(
  353. 'dag_manual_trigger_chain',
  354. default_args=default_args,
  355. description='手动触发指定表的依赖链执行(两级任务)',
  356. schedule_interval=None, # 设置为None表示只能手动触发
  357. catchup=False,
  358. is_paused_upon_creation=False, # 添加这一行,使DAG创建时不处于暂停状态
  359. ) as dag:
  360. # 第一个任务:准备依赖链
  361. prepare_task = PythonOperator(
  362. task_id='prepare_dependency_chain',
  363. python_callable=prepare_dependency_chain,
  364. provide_context=True,
  365. )
  366. # 第二个任务:执行DataResource脚本
  367. resource_task = PythonOperator(
  368. task_id='process_resources',
  369. python_callable=process_resources,
  370. provide_context=True,
  371. )
  372. # 第三个任务:执行DataModel脚本
  373. model_task = PythonOperator(
  374. task_id='process_models',
  375. python_callable=process_models,
  376. provide_context=True,
  377. )
  378. # 设置任务依赖关系
  379. prepare_task >> resource_task >> model_task