utils.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536
  1. # utils.py
  2. import psycopg2
  3. from neo4j import GraphDatabase
  4. from config import PG_CONFIG, NEO4J_CONFIG, SCRIPTS_BASE_PATH
  5. import logging
  6. import importlib.util
  7. from pathlib import Path
  8. import networkx as nx
  9. import os
  10. from airflow.exceptions import AirflowFailException
  11. # 创建统一的日志记录器
  12. logger = logging.getLogger("airflow.task")
  13. def get_pg_conn():
  14. return psycopg2.connect(**PG_CONFIG)
  15. def get_subscribed_tables(freq: str) -> list[dict]:
  16. """
  17. 根据调度频率获取启用的订阅表列表,附带 execution_mode 参数
  18. 返回结果示例:
  19. [
  20. {'table_name': 'region_sales', 'execution_mode': 'append'},
  21. {'table_name': 'catalog_sales', 'execution_mode': 'full_refresh'}
  22. ]
  23. """
  24. conn = get_pg_conn()
  25. cursor = conn.cursor()
  26. cursor.execute("""
  27. SELECT table_name, execution_mode
  28. FROM table_schedule
  29. WHERE is_enabled = TRUE AND schedule_frequency = %s
  30. """, (freq,))
  31. result = cursor.fetchall()
  32. cursor.close()
  33. conn.close()
  34. return [{"table_name": r[0], "execution_mode": r[1]} for r in result]
  35. def get_neo4j_dependencies(table_name: str) -> list:
  36. """
  37. 查询 Neo4j 中某个模型的 DERIVED_FROM 依赖(上游表名)
  38. """
  39. uri = NEO4J_CONFIG['uri']
  40. auth = (NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  41. driver = GraphDatabase.driver(uri, auth=auth)
  42. query = """
  43. MATCH (a:Table {name: $name})<-[:DERIVED_FROM]-(b:Table)
  44. RETURN b.name
  45. """
  46. with driver.session() as session:
  47. records = session.run(query, name=table_name)
  48. return [record["b.name"] for record in records]
  49. # def get_script_name_from_neo4j(table_name: str) -> str:
  50. # """
  51. # 从Neo4j数据库中查询表对应的脚本名称
  52. # 查询的是 DataResource 和 DataSource 之间的 ORIGINATES_FROM 关系中的 script_name 属性
  53. # 参数:
  54. # table_name (str): 数据资源表名
  55. # 返回:
  56. # str: 脚本名称,如果未找到则返回None
  57. # """
  58. # logger = logging.getLogger("airflow.task")
  59. # driver = GraphDatabase.driver(**NEO4J_CONFIG)
  60. # query = """
  61. # MATCH (dr:DataResource {en_name: $table_name})-[rel:ORIGINATES_FROM]->(ds:DataSource)
  62. # RETURN rel.script_name AS script_name
  63. # """
  64. # try:
  65. # with driver.session() as session:
  66. # result = session.run(query, table_name=table_name)
  67. # record = result.single()
  68. # if record and 'script_name' in record:
  69. # return record['script_name']
  70. # else:
  71. # logger.warning(f"没有找到表 {table_name} 对应的脚本名称")
  72. # return None
  73. # except Exception as e:
  74. # logger.error(f"从Neo4j查询脚本名称时出错: {str(e)}")
  75. # return None
  76. # finally:
  77. # driver.close()
  78. def execute_script(script_name: str, table_name: str, execution_mode: str) -> bool:
  79. """
  80. 根据脚本名称动态导入并执行对应的脚本
  81. 返回:
  82. bool: 执行成功返回True,否则返回False
  83. """
  84. if not script_name:
  85. logger.error("未提供脚本名称,无法执行")
  86. return False
  87. try:
  88. # 直接使用配置的部署路径,不考虑本地开发路径
  89. script_path = Path(SCRIPTS_BASE_PATH) / script_name
  90. logger.info(f"使用配置的Airflow部署路径: {script_path}")
  91. # 动态导入模块
  92. spec = importlib.util.spec_from_file_location("dynamic_module", script_path)
  93. module = importlib.util.module_from_spec(spec)
  94. spec.loader.exec_module(module)
  95. # 使用标准入口函数run
  96. if hasattr(module, "run"):
  97. logger.info(f"执行脚本 {script_name} 的标准入口函数 run()")
  98. module.run(table_name=table_name, execution_mode=execution_mode)
  99. return True
  100. else:
  101. logger.warning(f"脚本 {script_name} 未定义标准入口函数 run(),无法执行")
  102. return False
  103. except Exception as e:
  104. logger.error(f"执行脚本 {script_name} 时出错: {str(e)}")
  105. return False
  106. # def get_enabled_tables(frequency: str) -> list:
  107. # conn = get_pg_conn()
  108. # cursor = conn.cursor()
  109. # cursor.execute("""
  110. # SELECT table_name, execution_mode
  111. # FROM table_schedule
  112. # WHERE is_enabled = TRUE AND schedule_frequency = %s
  113. # """, (frequency,))
  114. # result = cursor.fetchall()
  115. # cursor.close()
  116. # conn.close()
  117. # output = []
  118. # for r in result:
  119. # output.append({"table_name": r[0], "execution_mode": r[1]})
  120. # return output
  121. # def is_data_resource_table(table_name: str) -> bool:
  122. # driver = GraphDatabase.driver(NEO4J_CONFIG['uri'], auth=(NEO4J_CONFIG['user'], NEO4J_CONFIG['password']))
  123. # query = """
  124. # MATCH (n:DataResource {en_name: $table_name}) RETURN count(n) > 0 AS exists
  125. # """
  126. # try:
  127. # with driver.session() as session:
  128. # result = session.run(query, table_name=table_name)
  129. # record = result.single()
  130. # return record and record["exists"]
  131. # finally:
  132. # driver.close()
  133. def get_resource_subscribed_tables(enabled_tables: list) -> list:
  134. result = []
  135. for t in enabled_tables:
  136. if is_data_resource_table(t['table_name']):
  137. result.append(t)
  138. return result
  139. # 根据目标表,递归查找其所有上游依赖的 DataResource 表(不限层级)
  140. def get_dependency_resource_tables(enabled_tables: list) -> list:
  141. uri = NEO4J_CONFIG['uri']
  142. auth = (NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  143. driver = GraphDatabase.driver(uri, auth=auth)
  144. resource_set = set()
  145. try:
  146. with driver.session() as session:
  147. for t in enabled_tables:
  148. query = """
  149. MATCH (target:Table {name: $table_name})
  150. MATCH (res:DataResource)-[:ORIGINATES_FROM]->(:DataSource)
  151. WHERE (target)-[:DERIVED_FROM*1..]->(res)
  152. RETURN DISTINCT res.en_name AS name
  153. """
  154. result = session.run(query, table_name=t['table_name'])
  155. for record in result:
  156. resource_set.add(record['name'])
  157. finally:
  158. driver.close()
  159. output = []
  160. for name in resource_set:
  161. output.append({"table_name": name, "execution_mode": "append"})
  162. return output
  163. # 从 PostgreSQL 获取启用的表,按调度频率 daily/weekly/monthly 过滤
  164. def get_enabled_tables(frequency: str) -> list:
  165. conn = get_pg_conn()
  166. cursor = conn.cursor()
  167. cursor.execute("""
  168. SELECT table_name, execution_mode
  169. FROM table_schedule
  170. WHERE is_enabled = TRUE AND schedule_frequency = %s
  171. """, (frequency,))
  172. result = cursor.fetchall()
  173. cursor.close()
  174. conn.close()
  175. output = []
  176. for r in result:
  177. output.append({"table_name": r[0], "execution_mode": r[1]})
  178. return output
  179. # 判断给定表名是否是 Neo4j 中的 DataResource 类型
  180. def is_data_resource_table(table_name: str) -> bool:
  181. uri = NEO4J_CONFIG['uri']
  182. auth = (NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  183. driver = GraphDatabase.driver(uri, auth=auth)
  184. query = """
  185. MATCH (n:DataResource {en_name: $table_name}) RETURN count(n) > 0 AS exists
  186. """
  187. try:
  188. with driver.session() as session:
  189. result = session.run(query, table_name=table_name)
  190. record = result.single()
  191. return record and record["exists"]
  192. finally:
  193. driver.close()
  194. # 从 Neo4j 查询 DataModel 表的 DERIVED_FROM 关系上的 script_name 属性
  195. def get_script_name_from_neo4j(table_name):
  196. uri = NEO4J_CONFIG['uri']
  197. auth = (NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  198. driver = GraphDatabase.driver(uri, auth=auth)
  199. query = """
  200. MATCH (target:DataModel {en_name: $table_name})-[r:DERIVED_FROM]->(n)
  201. WHERE n:DataModel OR n:DataResource
  202. RETURN r.script_name AS script_name
  203. """
  204. try:
  205. with driver.session() as session:
  206. result = session.run(query, table_name=table_name)
  207. record = result.single()
  208. if record:
  209. try:
  210. script_name = record['script_name']
  211. return script_name
  212. except (KeyError, TypeError) as e:
  213. print(f"[WARN] 记录中不包含script_name字段: {e}")
  214. return None
  215. else:
  216. return None
  217. except Exception as e:
  218. print(f"[ERROR] 查询表 {table_name} 的脚本名称时出错: {str(e)}")
  219. return None
  220. finally:
  221. driver.close()
  222. # 判断给定表名是否是 Neo4j 中的 DataModel 类型
  223. def is_data_model_table(table_name):
  224. uri = NEO4J_CONFIG['uri']
  225. auth = (NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  226. driver = GraphDatabase.driver(uri, auth=auth)
  227. query = """
  228. MATCH (n:DataModel {en_name: $table_name}) RETURN count(n) > 0 AS exists
  229. """
  230. try:
  231. with driver.session() as session:
  232. result = session.run(query, table_name=table_name)
  233. record = result.single()
  234. return record and record['exists']
  235. finally:
  236. driver.close()
  237. def check_script_exists(script_name):
  238. """
  239. 检查脚本文件是否存在于配置的脚本目录中
  240. 参数:
  241. script_name (str): 脚本文件名
  242. 返回:
  243. bool: 如果脚本存在返回True,否则返回False
  244. str: 完整的脚本路径
  245. """
  246. if not script_name:
  247. logger.error("脚本名称为空,无法检查")
  248. return False, None
  249. script_path = Path(SCRIPTS_BASE_PATH) / script_name
  250. script_path_str = str(script_path)
  251. logger.info(f"检查脚本路径: {script_path_str}")
  252. if os.path.exists(script_path_str):
  253. logger.info(f"脚本文件已找到: {script_path_str}")
  254. return True, script_path_str
  255. else:
  256. logger.error(f"脚本文件不存在: {script_path_str}")
  257. # 尝试列出目录中的文件
  258. try:
  259. base_dir = Path(SCRIPTS_BASE_PATH)
  260. if base_dir.exists():
  261. files = list(base_dir.glob("*.py"))
  262. logger.info(f"目录 {SCRIPTS_BASE_PATH} 中的Python文件: {[f.name for f in files]}")
  263. else:
  264. logger.error(f"基础目录不存在: {SCRIPTS_BASE_PATH}")
  265. except Exception as e:
  266. logger.error(f"列出目录内容时出错: {str(e)}")
  267. return False, script_path_str
  268. def run_model_script(table_name, execution_mode):
  269. """
  270. 根据表名查找并执行对应的模型脚本
  271. 参数:
  272. table_name (str): 要处理的表名
  273. execution_mode (str): 执行模式 (append/full_refresh)
  274. 返回:
  275. bool: 执行成功返回True,否则返回False
  276. 抛出:
  277. AirflowFailException: 如果脚本不存在或执行失败
  278. """
  279. # 从Neo4j获取脚本名称
  280. script_name = get_script_name_from_neo4j(table_name)
  281. if not script_name:
  282. error_msg = f"未找到表 {table_name} 的脚本名称,任务失败"
  283. logger.error(error_msg)
  284. raise AirflowFailException(error_msg)
  285. logger.info(f"从Neo4j获取到表 {table_name} 的脚本名称: {script_name}")
  286. # 检查脚本文件是否存在
  287. exists, script_path = check_script_exists(script_name)
  288. if not exists:
  289. error_msg = f"表 {table_name} 的脚本文件 {script_name} 不存在,任务失败"
  290. logger.error(error_msg)
  291. raise AirflowFailException(error_msg)
  292. # 执行脚本
  293. logger.info(f"开始执行脚本: {script_path}")
  294. try:
  295. # 动态导入模块
  296. import importlib.util
  297. import sys
  298. spec = importlib.util.spec_from_file_location("dynamic_module", script_path)
  299. module = importlib.util.module_from_spec(spec)
  300. spec.loader.exec_module(module)
  301. # 检查并调用标准入口函数run
  302. if hasattr(module, "run"):
  303. logger.info(f"调用脚本 {script_name} 的标准入口函数 run()")
  304. module.run(table_name=table_name, execution_mode=execution_mode)
  305. logger.info(f"脚本 {script_name} 执行成功")
  306. return True
  307. else:
  308. error_msg = f"脚本 {script_name} 中未定义标准入口函数 run(),任务失败"
  309. logger.error(error_msg)
  310. raise AirflowFailException(error_msg)
  311. except AirflowFailException:
  312. # 直接重新抛出Airflow异常
  313. raise
  314. except Exception as e:
  315. error_msg = f"执行脚本 {script_name} 时出错: {str(e)}"
  316. logger.error(error_msg)
  317. import traceback
  318. logger.error(traceback.format_exc())
  319. raise AirflowFailException(error_msg)
  320. # 从 Neo4j 获取指定 DataModel 表之间的依赖关系图
  321. # 返回值为 dict:{目标表: [上游依赖表1, 上游依赖表2, ...]}
  322. # def get_model_dependency_graph(table_names: list) -> dict:
  323. # graph = {}
  324. # uri = NEO4J_CONFIG['uri']
  325. # auth = (NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  326. # driver = GraphDatabase.driver(uri, auth=auth)
  327. # try:
  328. # with driver.session() as session:
  329. # for table_name in table_names:
  330. # query = """
  331. # MATCH (t:DataModel {en_name: $table_name})<-[:DERIVED_FROM]-(up:DataModel)
  332. # RETURN up.en_name AS upstream
  333. # """
  334. # result = session.run(query, table_name=table_name)
  335. # deps = [record['upstream'] for record in result if 'upstream' in record]
  336. # graph[table_name] = deps
  337. # finally:
  338. # driver.close()
  339. # return graph
  340. def get_model_dependency_graph(table_names: list) -> dict:
  341. """
  342. 使用networkx从Neo4j获取指定DataModel表之间的依赖关系图
  343. 参数:
  344. table_names: 表名列表
  345. 返回:
  346. dict: 依赖关系字典 {目标表: [上游依赖表1, 上游依赖表2, ...]}
  347. """
  348. # 创建有向图
  349. G = nx.DiGraph()
  350. # 添加所有节点
  351. for table_name in table_names:
  352. G.add_node(table_name)
  353. # 从Neo4j获取依赖关系并添加边
  354. uri = NEO4J_CONFIG['uri']
  355. auth = (NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  356. driver = GraphDatabase.driver(uri, auth=auth)
  357. try:
  358. with driver.session() as session:
  359. for table_name in table_names:
  360. query = """
  361. MATCH (t:DataModel {en_name: $table_name})<-[:DERIVED_FROM]-(up:DataModel)
  362. WHERE up.en_name IN $all_tables
  363. RETURN up.en_name AS upstream
  364. """
  365. result = session.run(query, table_name=table_name, all_tables=table_names)
  366. deps = [record['upstream'] for record in result if 'upstream' in record]
  367. # 添加依赖边
  368. for dep in deps:
  369. G.add_edge(dep, table_name)
  370. finally:
  371. driver.close()
  372. # 检测循环依赖
  373. try:
  374. cycles = list(nx.simple_cycles(G))
  375. if cycles:
  376. logger.warning(f"检测到表间循环依赖: {cycles}")
  377. except Exception as e:
  378. logger.error(f"检查循环依赖失败: {str(e)}")
  379. # 转换为字典格式返回
  380. dependency_dict = {}
  381. for table_name in table_names:
  382. dependency_dict[table_name] = list(G.predecessors(table_name))
  383. return dependency_dict
  384. def generate_optimized_execution_order(table_names: list) -> list:
  385. """
  386. 生成优化的执行顺序,可处理循环依赖
  387. 参数:
  388. table_names: 表名列表
  389. 返回:
  390. list: 优化后的执行顺序列表
  391. """
  392. # 创建依赖图
  393. G = nx.DiGraph()
  394. # 添加所有节点
  395. for table_name in table_names:
  396. G.add_node(table_name)
  397. # 添加依赖边
  398. dependency_dict = get_model_dependency_graph(table_names)
  399. for target, upstreams in dependency_dict.items():
  400. for upstream in upstreams:
  401. G.add_edge(upstream, target)
  402. # 检测循环依赖
  403. cycles = list(nx.simple_cycles(G))
  404. if cycles:
  405. logger.warning(f"检测到循环依赖,将尝试打破循环: {cycles}")
  406. # 打破循环依赖(简单策略:移除每个循环中的一条边)
  407. for cycle in cycles:
  408. # 移除循环中的最后一条边
  409. G.remove_edge(cycle[-1], cycle[0])
  410. logger.info(f"打破循环依赖: 移除 {cycle[-1]} -> {cycle[0]} 的依赖")
  411. # 生成拓扑排序
  412. try:
  413. execution_order = list(nx.topological_sort(G))
  414. return execution_order
  415. except Exception as e:
  416. logger.error(f"生成执行顺序失败: {str(e)}")
  417. # 返回原始列表作为备选
  418. return table_names
  419. def identify_common_paths(table_names: list) -> dict:
  420. """
  421. 识别多个表之间的公共执行路径
  422. 参数:
  423. table_names: 表名列表
  424. 返回:
  425. dict: 公共路径信息 {(path_tuple): 使用次数}
  426. """
  427. # 创建依赖图
  428. G = nx.DiGraph()
  429. # 添加所有节点和直接依赖边
  430. dependency_dict = get_model_dependency_graph(table_names)
  431. for target, upstreams in dependency_dict.items():
  432. G.add_node(target)
  433. for upstream in upstreams:
  434. G.add_node(upstream)
  435. G.add_edge(upstream, target)
  436. # 找出所有路径
  437. all_paths = []
  438. # 找出所有源节点(没有入边的节点)和终节点(没有出边的节点)
  439. sources = [n for n in G.nodes() if G.in_degree(n) == 0]
  440. targets = [n for n in G.nodes() if G.out_degree(n) == 0]
  441. # 获取所有源到目标的路径
  442. for source in sources:
  443. for target in targets:
  444. try:
  445. # 限制路径长度,避免组合爆炸
  446. paths = list(nx.all_simple_paths(G, source, target, cutoff=10))
  447. all_paths.extend(paths)
  448. except nx.NetworkXNoPath:
  449. continue
  450. # 统计路径段使用频率
  451. path_segments = {}
  452. for path in all_paths:
  453. # 只考虑长度>=2的路径段(至少有一条边)
  454. for i in range(len(path)-1):
  455. for j in range(i+2, min(i+6, len(path)+1)): # 限制段长,避免组合爆炸
  456. segment = tuple(path[i:j])
  457. if segment not in path_segments:
  458. path_segments[segment] = 0
  459. path_segments[segment] += 1
  460. # 过滤出重复使用的路径段
  461. common_paths = {seg: count for seg, count in path_segments.items()
  462. if count > 1 and len(seg) >= 3} # 至少3个节点,2条边
  463. # 按使用次数排序
  464. common_paths = dict(sorted(common_paths.items(), key=lambda x: x[1], reverse=True))
  465. return common_paths