utils.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889
  1. # utils.py
  2. import psycopg2
  3. from neo4j import GraphDatabase
  4. from config import PG_CONFIG, NEO4J_CONFIG, SCRIPTS_BASE_PATH
  5. import logging
  6. import importlib.util
  7. from pathlib import Path
  8. import networkx as nx
  9. import os
  10. from airflow.exceptions import AirflowFailException
  11. # 创建统一的日志记录器
  12. logger = logging.getLogger("airflow.task")
  13. def get_pg_conn():
  14. return psycopg2.connect(**PG_CONFIG)
  15. def get_subscribed_tables(freq: str) -> list[dict]:
  16. """
  17. 根据调度频率获取启用的订阅表列表,附带 execution_mode 参数
  18. 返回结果示例:
  19. [
  20. {'table_name': 'region_sales', 'execution_mode': 'append'},
  21. {'table_name': 'catalog_sales', 'execution_mode': 'full_refresh'}
  22. ]
  23. """
  24. conn = get_pg_conn()
  25. cursor = conn.cursor()
  26. cursor.execute("""
  27. SELECT table_name, execution_mode
  28. FROM table_schedule
  29. WHERE is_enabled = TRUE AND schedule_frequency = %s
  30. """, (freq,))
  31. result = cursor.fetchall()
  32. cursor.close()
  33. conn.close()
  34. return [{"table_name": r[0], "execution_mode": r[1]} for r in result]
  35. def get_neo4j_dependencies(table_name: str) -> list:
  36. """
  37. 查询 Neo4j 中某个模型的 DERIVED_FROM 依赖(上游表名)
  38. """
  39. uri = NEO4J_CONFIG['uri']
  40. auth = (NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  41. driver = GraphDatabase.driver(uri, auth=auth)
  42. query = """
  43. MATCH (a:Table {name: $name})<-[:DERIVED_FROM]-(b:Table)
  44. RETURN b.name
  45. """
  46. with driver.session() as session:
  47. records = session.run(query, name=table_name)
  48. return [record["b.name"] for record in records]
  49. # def get_script_name_from_neo4j(table_name: str) -> str:
  50. # """
  51. # 从Neo4j数据库中查询表对应的脚本名称
  52. # 查询的是 DataResource 和 DataSource 之间的 ORIGINATES_FROM 关系中的 script_name 属性
  53. # 参数:
  54. # table_name (str): 数据资源表名
  55. # 返回:
  56. # str: 脚本名称,如果未找到则返回None
  57. # """
  58. # logger = logging.getLogger("airflow.task")
  59. # driver = GraphDatabase.driver(**NEO4J_CONFIG)
  60. # query = """
  61. # MATCH (dr:DataResource {en_name: $table_name})-[rel:ORIGINATES_FROM]->(ds:DataSource)
  62. # RETURN rel.script_name AS script_name
  63. # """
  64. # try:
  65. # with driver.session() as session:
  66. # result = session.run(query, table_name=table_name)
  67. # record = result.single()
  68. # if record and 'script_name' in record:
  69. # return record['script_name']
  70. # else:
  71. # logger.warning(f"没有找到表 {table_name} 对应的脚本名称")
  72. # return None
  73. # except Exception as e:
  74. # logger.error(f"从Neo4j查询脚本名称时出错: {str(e)}")
  75. # return None
  76. # finally:
  77. # driver.close()
  78. def execute_script(script_name: str, table_name: str, execution_mode: str) -> bool:
  79. """
  80. 根据脚本名称动态导入并执行对应的脚本
  81. 返回:
  82. bool: 执行成功返回True,否则返回False
  83. """
  84. if not script_name:
  85. logger.error("未提供脚本名称,无法执行")
  86. return False
  87. try:
  88. # 直接使用配置的部署路径,不考虑本地开发路径
  89. script_path = Path(SCRIPTS_BASE_PATH) / script_name
  90. logger.info(f"使用配置的Airflow部署路径: {script_path}")
  91. # 动态导入模块
  92. spec = importlib.util.spec_from_file_location("dynamic_module", script_path)
  93. module = importlib.util.module_from_spec(spec)
  94. spec.loader.exec_module(module)
  95. # 使用标准入口函数run
  96. if hasattr(module, "run"):
  97. logger.info(f"执行脚本 {script_name} 的标准入口函数 run()")
  98. module.run(table_name=table_name, execution_mode=execution_mode)
  99. return True
  100. else:
  101. logger.warning(f"脚本 {script_name} 未定义标准入口函数 run(),无法执行")
  102. return False
  103. except Exception as e:
  104. logger.error(f"执行脚本 {script_name} 时出错: {str(e)}")
  105. return False
  106. # def get_enabled_tables(frequency: str) -> list:
  107. # conn = get_pg_conn()
  108. # cursor = conn.cursor()
  109. # cursor.execute("""
  110. # SELECT table_name, execution_mode
  111. # FROM table_schedule
  112. # WHERE is_enabled = TRUE AND schedule_frequency = %s
  113. # """, (frequency,))
  114. # result = cursor.fetchall()
  115. # cursor.close()
  116. # conn.close()
  117. # output = []
  118. # for r in result:
  119. # output.append({"table_name": r[0], "execution_mode": r[1]})
  120. # return output
  121. # def is_data_resource_table(table_name: str) -> bool:
  122. # driver = GraphDatabase.driver(NEO4J_CONFIG['uri'], auth=(NEO4J_CONFIG['user'], NEO4J_CONFIG['password']))
  123. # query = """
  124. # MATCH (n:DataResource {en_name: $table_name}) RETURN count(n) > 0 AS exists
  125. # """
  126. # try:
  127. # with driver.session() as session:
  128. # result = session.run(query, table_name=table_name)
  129. # record = result.single()
  130. # return record and record["exists"]
  131. # finally:
  132. # driver.close()
  133. def get_resource_subscribed_tables(enabled_tables: list) -> list:
  134. result = []
  135. for t in enabled_tables:
  136. if is_data_resource_table(t['table_name']):
  137. result.append(t)
  138. return result
  139. # 根据目标表,递归查找其所有上游依赖的 DataResource 表(不限层级)
  140. def get_dependency_resource_tables(enabled_tables: list) -> list:
  141. uri = NEO4J_CONFIG['uri']
  142. auth = (NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  143. driver = GraphDatabase.driver(uri, auth=auth)
  144. resource_set = set()
  145. try:
  146. with driver.session() as session:
  147. for t in enabled_tables:
  148. query = """
  149. MATCH (target:Table {name: $table_name})
  150. MATCH (res:DataResource)-[:ORIGINATES_FROM]->(:DataSource)
  151. WHERE (target)-[:DERIVED_FROM*1..]->(res)
  152. RETURN DISTINCT res.en_name AS name
  153. """
  154. result = session.run(query, table_name=t['table_name'])
  155. for record in result:
  156. resource_set.add(record['name'])
  157. finally:
  158. driver.close()
  159. output = []
  160. for name in resource_set:
  161. output.append({"table_name": name, "execution_mode": "append"})
  162. return output
  163. # 从 PostgreSQL 获取启用的表,按调度频率 daily/weekly/monthly 过滤
  164. def get_enabled_tables(frequency: str) -> list:
  165. conn = get_pg_conn()
  166. cursor = conn.cursor()
  167. cursor.execute("""
  168. SELECT table_name, execution_mode
  169. FROM table_schedule
  170. WHERE is_enabled = TRUE AND schedule_frequency = %s
  171. """, (frequency,))
  172. result = cursor.fetchall()
  173. cursor.close()
  174. conn.close()
  175. output = []
  176. for r in result:
  177. output.append({"table_name": r[0], "execution_mode": r[1]})
  178. return output
  179. # 判断给定表名是否是 Neo4j 中的 DataResource 类型
  180. def is_data_resource_table(table_name: str) -> bool:
  181. uri = NEO4J_CONFIG['uri']
  182. auth = (NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  183. driver = GraphDatabase.driver(uri, auth=auth)
  184. query = """
  185. MATCH (n:DataResource {en_name: $table_name}) RETURN count(n) > 0 AS exists
  186. """
  187. try:
  188. with driver.session() as session:
  189. result = session.run(query, table_name=table_name)
  190. record = result.single()
  191. return record and record["exists"]
  192. finally:
  193. driver.close()
  194. # 从 Neo4j 查询 DataModel 表的 DERIVED_FROM 关系上的 script_name 属性
  195. def get_script_name_from_neo4j(table_name):
  196. uri = NEO4J_CONFIG['uri']
  197. auth = (NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  198. driver = GraphDatabase.driver(uri, auth=auth)
  199. logger.info(f"从Neo4j查询表 {table_name} 的脚本名称")
  200. # 检查查询的是 DERIVED_FROM 关系的方向
  201. check_query = """
  202. MATCH (a:DataModel {en_name: $table_name})-[r:DERIVED_FROM]->(b)
  203. RETURN b.en_name AS upstream_name LIMIT 5
  204. """
  205. try:
  206. with driver.session() as session:
  207. # 先检查依赖关系
  208. logger.info(f"检查表 {table_name} 的上游依赖方向")
  209. check_result = session.run(check_query, table_name=table_name)
  210. upstreams = [record['upstream_name'] for record in check_result if 'upstream_name' in record]
  211. logger.info(f"表 {table_name} 的上游依赖: {upstreams}")
  212. # 查询脚本名称
  213. query = """
  214. MATCH (target:DataModel {en_name: $table_name})-[r:DERIVED_FROM]->(n)
  215. WHERE n:DataModel OR n:DataResource
  216. RETURN r.script_name AS script_name
  217. """
  218. result = session.run(query, table_name=table_name)
  219. record = result.single()
  220. if record:
  221. try:
  222. script_name = record['script_name']
  223. logger.info(f"找到表 {table_name} 的脚本名称: {script_name}")
  224. return script_name
  225. except (KeyError, TypeError) as e:
  226. logger.warning(f"记录中不包含script_name字段: {e}")
  227. return None
  228. else:
  229. logger.warning(f"没有找到表 {table_name} 的脚本名称")
  230. return None
  231. except Exception as e:
  232. logger.error(f"查询表 {table_name} 的脚本名称时出错: {str(e)}")
  233. return None
  234. finally:
  235. driver.close()
  236. # 判断给定表名是否是 Neo4j 中的 DataModel 类型
  237. def is_data_model_table(table_name):
  238. uri = NEO4J_CONFIG['uri']
  239. auth = (NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  240. driver = GraphDatabase.driver(uri, auth=auth)
  241. query = """
  242. MATCH (n:DataModel {en_name: $table_name}) RETURN count(n) > 0 AS exists
  243. """
  244. try:
  245. with driver.session() as session:
  246. result = session.run(query, table_name=table_name)
  247. record = result.single()
  248. return record and record['exists']
  249. finally:
  250. driver.close()
  251. def check_script_exists(script_name):
  252. """
  253. 检查脚本文件是否存在于配置的脚本目录中
  254. 参数:
  255. script_name (str): 脚本文件名
  256. 返回:
  257. bool: 如果脚本存在返回True,否则返回False
  258. str: 完整的脚本路径
  259. """
  260. if not script_name:
  261. logger.error("脚本名称为空,无法检查")
  262. return False, None
  263. script_path = Path(SCRIPTS_BASE_PATH) / script_name
  264. script_path_str = str(script_path)
  265. logger.info(f"检查脚本路径: {script_path_str}")
  266. if os.path.exists(script_path_str):
  267. logger.info(f"脚本文件已找到: {script_path_str}")
  268. return True, script_path_str
  269. else:
  270. logger.error(f"脚本文件不存在: {script_path_str}")
  271. # 尝试列出目录中的文件
  272. try:
  273. base_dir = Path(SCRIPTS_BASE_PATH)
  274. if base_dir.exists():
  275. files = list(base_dir.glob("*.py"))
  276. logger.info(f"目录 {SCRIPTS_BASE_PATH} 中的Python文件: {[f.name for f in files]}")
  277. else:
  278. logger.error(f"基础目录不存在: {SCRIPTS_BASE_PATH}")
  279. except Exception as e:
  280. logger.error(f"列出目录内容时出错: {str(e)}")
  281. return False, script_path_str
  282. def run_model_script(table_name, execution_mode):
  283. """
  284. 根据表名查找并执行对应的模型脚本
  285. 参数:
  286. table_name (str): 要处理的表名
  287. execution_mode (str): 执行模式 (append/full_refresh)
  288. 返回:
  289. bool: 执行成功返回True,否则返回False
  290. 抛出:
  291. AirflowFailException: 如果脚本不存在或执行失败
  292. """
  293. # 从Neo4j获取脚本名称
  294. script_name = get_script_name_from_neo4j(table_name)
  295. if not script_name:
  296. error_msg = f"未找到表 {table_name} 的脚本名称,任务失败"
  297. logger.error(error_msg)
  298. raise AirflowFailException(error_msg)
  299. logger.info(f"从Neo4j获取到表 {table_name} 的脚本名称: {script_name}")
  300. # 检查脚本文件是否存在
  301. exists, script_path = check_script_exists(script_name)
  302. if not exists:
  303. error_msg = f"表 {table_name} 的脚本文件 {script_name} 不存在,任务失败"
  304. logger.error(error_msg)
  305. raise AirflowFailException(error_msg)
  306. # 执行脚本
  307. logger.info(f"开始执行脚本: {script_path}")
  308. try:
  309. # 动态导入模块
  310. import importlib.util
  311. import sys
  312. spec = importlib.util.spec_from_file_location("dynamic_module", script_path)
  313. module = importlib.util.module_from_spec(spec)
  314. spec.loader.exec_module(module)
  315. # 检查并调用标准入口函数run
  316. if hasattr(module, "run"):
  317. logger.info(f"调用脚本 {script_name} 的标准入口函数 run()")
  318. module.run(table_name=table_name, execution_mode=execution_mode)
  319. logger.info(f"脚本 {script_name} 执行成功")
  320. return True
  321. else:
  322. error_msg = f"脚本 {script_name} 中未定义标准入口函数 run(),任务失败"
  323. logger.error(error_msg)
  324. raise AirflowFailException(error_msg)
  325. except AirflowFailException:
  326. # 直接重新抛出Airflow异常
  327. raise
  328. except Exception as e:
  329. error_msg = f"执行脚本 {script_name} 时出错: {str(e)}"
  330. logger.error(error_msg)
  331. import traceback
  332. logger.error(traceback.format_exc())
  333. raise AirflowFailException(error_msg)
  334. # 从 Neo4j 获取指定 DataModel 表之间的依赖关系图
  335. # 返回值为 dict:{目标表: [上游依赖表1, 上游依赖表2, ...]}
  336. # def get_model_dependency_graph(table_names: list) -> dict:
  337. # graph = {}
  338. # uri = NEO4J_CONFIG['uri']
  339. # auth = (NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  340. # driver = GraphDatabase.driver(uri, auth=auth)
  341. # try:
  342. # with driver.session() as session:
  343. # for table_name in table_names:
  344. # query = """
  345. # MATCH (t:DataModel {en_name: $table_name})<-[:DERIVED_FROM]-(up:DataModel)
  346. # RETURN up.en_name AS upstream
  347. # """
  348. # result = session.run(query, table_name=table_name)
  349. # deps = [record['upstream'] for record in result if 'upstream' in record]
  350. # graph[table_name] = deps
  351. # finally:
  352. # driver.close()
  353. # return graph
  354. def get_model_dependency_graph(table_names: list) -> dict:
  355. """
  356. 使用networkx从Neo4j获取指定DataModel表之间的依赖关系图
  357. 参数:
  358. table_names: 表名列表
  359. 返回:
  360. dict: 依赖关系字典 {目标表: [上游依赖表1, 上游依赖表2, ...]}
  361. """
  362. logger.info(f"开始构建依赖关系图,表列表: {table_names}")
  363. # 创建有向图
  364. G = nx.DiGraph()
  365. # 添加所有节点
  366. for table_name in table_names:
  367. G.add_node(table_name)
  368. # 从Neo4j获取依赖关系并添加边
  369. uri = NEO4J_CONFIG['uri']
  370. auth = (NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  371. driver = GraphDatabase.driver(uri, auth=auth)
  372. try:
  373. with driver.session() as session:
  374. # 使用一次性查询获取所有表之间的依赖关系
  375. # 注意:这里查询的是 A-[:DERIVED_FROM]->B 关系,表示A依赖B
  376. # 记录原始查询参数用于调试
  377. logger.info(f"查询参数 table_names: {table_names}, 类型: {type(table_names)}")
  378. # 第一层查询 - 更明确的查询形式
  379. query = """
  380. MATCH (source)-[r:DERIVED_FROM]->(target)
  381. WHERE source.en_name IN $table_names AND target.en_name IN $table_names
  382. RETURN source.en_name AS source, target.en_name AS target, r.script_name AS script_name
  383. """
  384. logger.info(f"执行Neo4j查询: 查找所有表之间的依赖关系")
  385. result = session.run(query, table_names=table_names)
  386. # 转换结果为列表,确保结果被消费
  387. result_records = list(result)
  388. logger.info(f"第一层查询返回记录数: {len(result_records)}")
  389. # 处理依赖关系
  390. found_deps = 0
  391. # 初始化依赖字典
  392. dependency_dict = {name: [] for name in table_names}
  393. # 这里是问题所在 - 需要正确处理记录
  394. for record in result_records:
  395. # 直接将记录转换为字典,避免访问问题
  396. record_dict = dict(record)
  397. # 从字典中获取值
  398. source = record_dict.get('source')
  399. target = record_dict.get('target')
  400. script_name = record_dict.get('script_name', 'unknown_script')
  401. # 确保字段存在且有值
  402. if source and target:
  403. logger.info(f"发现依赖关系: {source} -[:DERIVED_FROM]-> {target}, 脚本: {script_name}")
  404. # 添加依赖关系到字典
  405. if source in dependency_dict:
  406. dependency_dict[source].append(target)
  407. found_deps += 1
  408. # 添加边到图 - 把被依赖方指向依赖方,表示执行顺序(被依赖方先执行)
  409. G.add_edge(target, source)
  410. logger.info(f"添加执行顺序边: {target} -> {source} (因为{source}依赖{target})")
  411. logger.info(f"总共发现 {found_deps} 个依赖关系")
  412. # 如果没有找到依赖关系,尝试检查所有可能的表对关系
  413. if found_deps == 0:
  414. logger.warning("仍未找到依赖关系,尝试检查所有表对之间的关系")
  415. logger.info("第三层查询: 开始表对之间的循环检查")
  416. logger.info(f"要检查的表对数量: {len(table_names) * (len(table_names) - 1)}")
  417. pair_count = 0
  418. for source_table in table_names:
  419. for target_table in table_names:
  420. if source_table != target_table:
  421. pair_count += 1
  422. logger.info(f"检查表对[{pair_count}]: {source_table} -> {target_table}")
  423. check_result = check_table_relationship(source_table, target_table)
  424. # 检查forward方向的关系
  425. if 'forward' in check_result and check_result['forward']['exists']:
  426. script_name = check_result['forward'].get('script_name', 'unknown_script')
  427. logger.info(f"表对检查发现关系: {source_table} -[:DERIVED_FROM]-> {target_table}, 脚本: {script_name}")
  428. dependency_dict[source_table].append(target_table)
  429. G.add_edge(target_table, source_table)
  430. found_deps += 1
  431. logger.info(f"表对检查后找到 {found_deps} 个依赖关系")
  432. finally:
  433. driver.close()
  434. # 检测循环依赖
  435. try:
  436. cycles = list(nx.simple_cycles(G))
  437. if cycles:
  438. logger.warning(f"检测到表间循环依赖: {cycles}")
  439. except Exception as e:
  440. logger.error(f"检查循环依赖失败: {str(e)}")
  441. # 将图转换为字典格式
  442. final_dependency_dict = {}
  443. for table_name in table_names:
  444. final_dependency_dict[table_name] = dependency_dict.get(table_name, [])
  445. logger.info(f"最终依赖关系 - 表 {table_name} 依赖于: {final_dependency_dict[table_name]}")
  446. logger.info(f"完整依赖图: {final_dependency_dict}")
  447. return final_dependency_dict
  448. def generate_optimized_execution_order(table_names: list) -> list:
  449. """
  450. 生成优化的执行顺序,可处理循环依赖
  451. 参数:
  452. table_names: 表名列表
  453. 返回:
  454. list: 优化后的执行顺序列表
  455. """
  456. # 创建依赖图
  457. G = nx.DiGraph()
  458. # 添加所有节点
  459. for table_name in table_names:
  460. G.add_node(table_name)
  461. # 添加依赖边
  462. dependency_dict = get_model_dependency_graph(table_names)
  463. for target, upstreams in dependency_dict.items():
  464. for upstream in upstreams:
  465. G.add_edge(upstream, target)
  466. # 检测循环依赖
  467. cycles = list(nx.simple_cycles(G))
  468. if cycles:
  469. logger.warning(f"检测到循环依赖,将尝试打破循环: {cycles}")
  470. # 打破循环依赖(简单策略:移除每个循环中的一条边)
  471. for cycle in cycles:
  472. # 移除循环中的最后一条边
  473. G.remove_edge(cycle[-1], cycle[0])
  474. logger.info(f"打破循环依赖: 移除 {cycle[-1]} -> {cycle[0]} 的依赖")
  475. # 生成拓扑排序
  476. try:
  477. execution_order = list(nx.topological_sort(G))
  478. return execution_order
  479. except Exception as e:
  480. logger.error(f"生成执行顺序失败: {str(e)}")
  481. # 返回原始列表作为备选
  482. return table_names
  483. def identify_common_paths(table_names: list) -> dict:
  484. """
  485. 识别多个表之间的公共执行路径
  486. 参数:
  487. table_names: 表名列表
  488. 返回:
  489. dict: 公共路径信息 {(path_tuple): 使用次数}
  490. """
  491. # 创建依赖图
  492. G = nx.DiGraph()
  493. # 添加所有节点和直接依赖边
  494. dependency_dict = get_model_dependency_graph(table_names)
  495. for target, upstreams in dependency_dict.items():
  496. G.add_node(target)
  497. for upstream in upstreams:
  498. G.add_node(upstream)
  499. G.add_edge(upstream, target)
  500. # 找出所有路径
  501. all_paths = []
  502. # 找出所有源节点(没有入边的节点)和终节点(没有出边的节点)
  503. sources = [n for n in G.nodes() if G.in_degree(n) == 0]
  504. targets = [n for n in G.nodes() if G.out_degree(n) == 0]
  505. # 获取所有源到目标的路径
  506. for source in sources:
  507. for target in targets:
  508. try:
  509. # 限制路径长度,避免组合爆炸
  510. paths = list(nx.all_simple_paths(G, source, target, cutoff=10))
  511. all_paths.extend(paths)
  512. except nx.NetworkXNoPath:
  513. continue
  514. # 统计路径段使用频率
  515. path_segments = {}
  516. for path in all_paths:
  517. # 只考虑长度>=2的路径段(至少有一条边)
  518. for i in range(len(path)-1):
  519. for j in range(i+2, min(i+6, len(path)+1)): # 限制段长,避免组合爆炸
  520. segment = tuple(path[i:j])
  521. if segment not in path_segments:
  522. path_segments[segment] = 0
  523. path_segments[segment] += 1
  524. # 过滤出重复使用的路径段
  525. common_paths = {seg: count for seg, count in path_segments.items()
  526. if count > 1 and len(seg) >= 3} # 至少3个节点,2条边
  527. # 按使用次数排序
  528. common_paths = dict(sorted(common_paths.items(), key=lambda x: x[1], reverse=True))
  529. return common_paths
  530. def check_table_relationship(table1, table2):
  531. """
  532. 直接检查Neo4j中两个表之间的关系
  533. 参数:
  534. table1: 第一个表名
  535. table2: 第二个表名
  536. 返回:
  537. 关系信息字典
  538. """
  539. uri = NEO4J_CONFIG['uri']
  540. auth = (NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  541. driver = GraphDatabase.driver(uri, auth=auth)
  542. relationship_info = {}
  543. try:
  544. with driver.session() as session:
  545. # 检查 table1 -> table2 方向
  546. forward_query = """
  547. MATCH (a:DataModel {en_name: $table1})-[r:DERIVED_FROM]->(b:DataModel {en_name: $table2})
  548. RETURN count(r) > 0 AS has_relationship, r.script_name AS script_name
  549. """
  550. forward_result = session.run(forward_query, table1=table1, table2=table2)
  551. forward_record = forward_result.single()
  552. if forward_record and forward_record['has_relationship']:
  553. relationship_info['forward'] = {
  554. 'exists': True,
  555. 'direction': f"{table1} -> {table2}",
  556. 'script_name': forward_record.get('script_name')
  557. }
  558. logger.info(f"发现关系: {table1} -[:DERIVED_FROM]-> {table2}, 脚本: {forward_record.get('script_name')}")
  559. else:
  560. relationship_info['forward'] = {'exists': False}
  561. # 检查 table2 -> table1 方向
  562. backward_query = """
  563. MATCH (a:DataModel {en_name: $table2})-[r:DERIVED_FROM]->(b:DataModel {en_name: $table1})
  564. RETURN count(r) > 0 AS has_relationship, r.script_name AS script_name
  565. """
  566. backward_result = session.run(backward_query, table1=table1, table2=table2)
  567. backward_record = backward_result.single()
  568. if backward_record and backward_record['has_relationship']:
  569. relationship_info['backward'] = {
  570. 'exists': True,
  571. 'direction': f"{table2} -> {table1}",
  572. 'script_name': backward_record.get('script_name')
  573. }
  574. logger.info(f"发现关系: {table2} -[:DERIVED_FROM]-> {table1}, 脚本: {backward_record.get('script_name')}")
  575. else:
  576. relationship_info['backward'] = {'exists': False}
  577. except Exception as e:
  578. logger.error(f"检查表关系时出错: {str(e)}")
  579. relationship_info['error'] = str(e)
  580. finally:
  581. driver.close()
  582. return relationship_info
  583. def build_model_dependency_dag(table_names, model_tables):
  584. """
  585. 基于表名列表构建模型依赖DAG,返回优化后的执行顺序和依赖关系图
  586. 参数:
  587. table_names: 表名列表
  588. model_tables: 表配置列表
  589. 返回:
  590. tuple: (优化后的表执行顺序, 依赖关系图)
  591. """
  592. # 使用优化函数生成执行顺序,可以处理循环依赖
  593. optimized_table_order = generate_optimized_execution_order(table_names)
  594. logger.info(f"生成优化执行顺序, 共 {len(optimized_table_order)} 个表")
  595. # 获取依赖图
  596. dependency_graph = get_model_dependency_graph(table_names)
  597. logger.info(f"构建了 {len(dependency_graph)} 个表的依赖关系图")
  598. return optimized_table_order, dependency_graph
  599. def create_task_dict(optimized_table_order, model_tables, dag, execution_type, **task_options):
  600. """
  601. 根据优化后的表执行顺序创建任务字典
  602. 参数:
  603. optimized_table_order: 优化后的表执行顺序
  604. model_tables: 表配置列表
  605. dag: Airflow DAG对象
  606. execution_type: 执行类型(daily, monthly等)
  607. task_options: 任务创建的额外选项
  608. 返回:
  609. dict: 任务字典 {表名: 任务对象}
  610. """
  611. from airflow.operators.python import PythonOperator
  612. task_dict = {}
  613. for table_name in optimized_table_order:
  614. # 获取表的配置信息
  615. table_config = next((t for t in model_tables if t['table_name'] == table_name), None)
  616. if table_config:
  617. try:
  618. # 构建基础参数
  619. task_params = {
  620. "task_id": f"process_{execution_type}_{table_name}",
  621. "python_callable": run_model_script,
  622. "op_kwargs": {"table_name": table_name, "execution_mode": table_config['execution_mode']},
  623. "dag": dag
  624. }
  625. # 添加额外选项
  626. if task_options:
  627. # 如果有表特定的选项,使用它们
  628. if table_name in task_options:
  629. task_params.update(task_options[table_name])
  630. # 如果有全局选项,使用它们
  631. elif 'default' in task_options:
  632. task_params.update(task_options['default'])
  633. task = PythonOperator(**task_params)
  634. task_dict[table_name] = task
  635. logger.info(f"创建模型处理任务: {task_params['task_id']}")
  636. except Exception as e:
  637. logger.error(f"创建任务 process_{execution_type}_{table_name} 时出错: {str(e)}")
  638. raise
  639. return task_dict
  640. def build_task_dependencies(task_dict, dependency_graph):
  641. """
  642. 根据依赖图设置任务间的依赖关系
  643. 参数:
  644. task_dict: 任务字典
  645. dependency_graph: 依赖关系图
  646. 返回:
  647. tuple: (tasks_with_upstream, tasks_with_downstream, dependency_count)
  648. """
  649. tasks_with_upstream = set() # 用于跟踪已经有上游任务的节点
  650. dependency_count = 0
  651. for target, upstream_list in dependency_graph.items():
  652. if target in task_dict:
  653. for upstream in upstream_list:
  654. if upstream in task_dict:
  655. logger.info(f"建立任务依赖: {upstream} >> {target}")
  656. task_dict[upstream] >> task_dict[target]
  657. tasks_with_upstream.add(target) # 记录此任务已有上游
  658. dependency_count += 1
  659. # 找出有下游任务的节点
  660. tasks_with_downstream = set()
  661. for target, upstream_list in dependency_graph.items():
  662. if target in task_dict: # 目标任务在当前DAG中
  663. for upstream in upstream_list:
  664. if upstream in task_dict: # 上游任务也在当前DAG中
  665. tasks_with_downstream.add(upstream) # 这个上游任务有下游
  666. logger.info(f"总共建立了 {dependency_count} 个任务之间的依赖关系")
  667. logger.info(f"已有上游任务的节点: {tasks_with_upstream}")
  668. return tasks_with_upstream, tasks_with_downstream, dependency_count
  669. def connect_start_and_end_tasks(task_dict, tasks_with_upstream, tasks_with_downstream,
  670. wait_task, completed_task, dag_type):
  671. """
  672. 连接开始节点到等待任务,末端节点到完成标记
  673. 参数:
  674. task_dict: 任务字典
  675. tasks_with_upstream: 有上游任务的节点集合
  676. tasks_with_downstream: 有下游任务的节点集合
  677. wait_task: 等待任务
  678. completed_task: 完成标记任务
  679. dag_type: DAG类型名称(用于日志)
  680. 返回:
  681. tuple: (start_tasks, end_tasks)
  682. """
  683. # 连接开始节点
  684. start_tasks = []
  685. for table_name, task in task_dict.items():
  686. if table_name not in tasks_with_upstream:
  687. start_tasks.append(table_name)
  688. logger.info(f"任务 {table_name} 没有上游任务,应该连接到{dag_type}等待任务")
  689. logger.info(f"需要连接到{dag_type}等待任务的任务: {start_tasks}")
  690. for task_name in start_tasks:
  691. wait_task >> task_dict[task_name]
  692. logger.info(f"连接 {wait_task.task_id} >> {task_name}")
  693. # 连接末端节点
  694. end_tasks = []
  695. for table_name, task in task_dict.items():
  696. if table_name not in tasks_with_downstream:
  697. end_tasks.append(table_name)
  698. logger.info(f"任务 {table_name} 没有下游任务,是末端任务")
  699. logger.info(f"需要连接到{dag_type}完成标记的末端任务: {end_tasks}")
  700. for end_task in end_tasks:
  701. task_dict[end_task] >> completed_task
  702. logger.info(f"连接 {end_task} >> {completed_task.task_id}")
  703. # 处理特殊情况
  704. logger.info("处理特殊情况")
  705. if not start_tasks:
  706. logger.warning(f"没有找到开始任务,将{dag_type}等待任务直接连接到完成标记")
  707. wait_task >> completed_task
  708. if not end_tasks:
  709. logger.warning(f"没有找到末端任务,将所有任务连接到{dag_type}完成标记")
  710. for table_name, task in task_dict.items():
  711. task >> completed_task
  712. logger.info(f"直接连接任务到完成标记: {table_name} >> {completed_task.task_id}")
  713. return start_tasks, end_tasks
  714. def process_model_tables(enabled_tables, dag_type, wait_task, completed_task, dag, **task_options):
  715. """
  716. 处理模型表并构建DAG
  717. 参数:
  718. enabled_tables: 已启用的表列表
  719. dag_type: DAG类型 (daily, monthly等)
  720. wait_task: 等待任务
  721. completed_task: 完成标记任务
  722. dag: Airflow DAG对象
  723. task_options: 创建任务的额外选项
  724. """
  725. model_tables = [t for t in enabled_tables if is_data_model_table(t['table_name'])]
  726. logger.info(f"获取到 {len(model_tables)} 个启用的 {dag_type} 模型表")
  727. if not model_tables:
  728. # 如果没有模型表需要处理,直接将等待任务与完成标记相连接
  729. logger.info(f"没有找到需要处理的{dag_type}模型表,DAG将直接标记为完成")
  730. wait_task >> completed_task
  731. return
  732. # 获取表名列表
  733. table_names = [t['table_name'] for t in model_tables]
  734. try:
  735. # 构建模型依赖DAG
  736. optimized_table_order, dependency_graph = build_model_dependency_dag(table_names, model_tables)
  737. # 创建任务字典
  738. task_dict = create_task_dict(optimized_table_order, model_tables, dag, dag_type, **task_options)
  739. # 建立任务依赖关系
  740. tasks_with_upstream, tasks_with_downstream, _ = build_task_dependencies(task_dict, dependency_graph)
  741. # 连接开始节点和末端节点
  742. connect_start_and_end_tasks(task_dict, tasks_with_upstream, tasks_with_downstream,
  743. wait_task, completed_task, dag_type)
  744. except Exception as e:
  745. logger.error(f"处理{dag_type}模型表时出错: {str(e)}")
  746. # 出错时也要确保完成标记被触发
  747. wait_task >> completed_task
  748. raise