utils.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. # utils.py
  2. import psycopg2
  3. from neo4j import GraphDatabase
  4. from config import PG_CONFIG, NEO4J_CONFIG, SCRIPTS_BASE_PATH
  5. import logging
  6. import importlib.util
  7. from pathlib import Path
  8. import networkx as nx
  9. import os
  10. # 创建统一的日志记录器
  11. logger = logging.getLogger("airflow.task")
  12. def get_pg_conn():
  13. return psycopg2.connect(**PG_CONFIG)
  14. def get_subscribed_tables(freq: str) -> list[dict]:
  15. """
  16. 根据调度频率获取启用的订阅表列表,附带 execution_mode 参数
  17. 返回结果示例:
  18. [
  19. {'table_name': 'region_sales', 'execution_mode': 'append'},
  20. {'table_name': 'catalog_sales', 'execution_mode': 'full_refresh'}
  21. ]
  22. """
  23. conn = get_pg_conn()
  24. cursor = conn.cursor()
  25. cursor.execute("""
  26. SELECT table_name, execution_mode
  27. FROM table_schedule
  28. WHERE is_enabled = TRUE AND schedule_frequency = %s
  29. """, (freq,))
  30. result = cursor.fetchall()
  31. cursor.close()
  32. conn.close()
  33. return [{"table_name": r[0], "execution_mode": r[1]} for r in result]
  34. def get_neo4j_dependencies(table_name: str) -> list:
  35. """
  36. 查询 Neo4j 中某个模型的 DERIVED_FROM 依赖(上游表名)
  37. """
  38. uri = NEO4J_CONFIG['uri']
  39. auth = (NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  40. driver = GraphDatabase.driver(uri, auth=auth)
  41. query = """
  42. MATCH (a:Table {name: $name})<-[:DERIVED_FROM]-(b:Table)
  43. RETURN b.name
  44. """
  45. with driver.session() as session:
  46. records = session.run(query, name=table_name)
  47. return [record["b.name"] for record in records]
  48. # def get_script_name_from_neo4j(table_name: str) -> str:
  49. # """
  50. # 从Neo4j数据库中查询表对应的脚本名称
  51. # 查询的是 DataResource 和 DataSource 之间的 ORIGINATES_FROM 关系中的 script_name 属性
  52. # 参数:
  53. # table_name (str): 数据资源表名
  54. # 返回:
  55. # str: 脚本名称,如果未找到则返回None
  56. # """
  57. # logger = logging.getLogger("airflow.task")
  58. # driver = GraphDatabase.driver(**NEO4J_CONFIG)
  59. # query = """
  60. # MATCH (dr:DataResource {en_name: $table_name})-[rel:ORIGINATES_FROM]->(ds:DataSource)
  61. # RETURN rel.script_name AS script_name
  62. # """
  63. # try:
  64. # with driver.session() as session:
  65. # result = session.run(query, table_name=table_name)
  66. # record = result.single()
  67. # if record and 'script_name' in record:
  68. # return record['script_name']
  69. # else:
  70. # logger.warning(f"没有找到表 {table_name} 对应的脚本名称")
  71. # return None
  72. # except Exception as e:
  73. # logger.error(f"从Neo4j查询脚本名称时出错: {str(e)}")
  74. # return None
  75. # finally:
  76. # driver.close()
  77. def execute_script(script_name: str, table_name: str, execution_mode: str) -> bool:
  78. """
  79. 根据脚本名称动态导入并执行对应的脚本
  80. 返回:
  81. bool: 执行成功返回True,否则返回False
  82. """
  83. if not script_name:
  84. logger.error("未提供脚本名称,无法执行")
  85. return False
  86. try:
  87. # 直接使用配置的部署路径,不考虑本地开发路径
  88. script_path = Path(SCRIPTS_BASE_PATH) / script_name
  89. logger.info(f"使用配置的Airflow部署路径: {script_path}")
  90. # 动态导入模块
  91. spec = importlib.util.spec_from_file_location("dynamic_module", script_path)
  92. module = importlib.util.module_from_spec(spec)
  93. spec.loader.exec_module(module)
  94. # 使用标准入口函数run
  95. if hasattr(module, "run"):
  96. logger.info(f"执行脚本 {script_name} 的标准入口函数 run()")
  97. module.run(table_name=table_name, execution_mode=execution_mode)
  98. return True
  99. else:
  100. logger.warning(f"脚本 {script_name} 未定义标准入口函数 run(),无法执行")
  101. return False
  102. except Exception as e:
  103. logger.error(f"执行脚本 {script_name} 时出错: {str(e)}")
  104. return False
  105. # def get_enabled_tables(frequency: str) -> list:
  106. # conn = get_pg_conn()
  107. # cursor = conn.cursor()
  108. # cursor.execute("""
  109. # SELECT table_name, execution_mode
  110. # FROM table_schedule
  111. # WHERE is_enabled = TRUE AND schedule_frequency = %s
  112. # """, (frequency,))
  113. # result = cursor.fetchall()
  114. # cursor.close()
  115. # conn.close()
  116. # output = []
  117. # for r in result:
  118. # output.append({"table_name": r[0], "execution_mode": r[1]})
  119. # return output
  120. # def is_data_resource_table(table_name: str) -> bool:
  121. # driver = GraphDatabase.driver(NEO4J_CONFIG['uri'], auth=(NEO4J_CONFIG['user'], NEO4J_CONFIG['password']))
  122. # query = """
  123. # MATCH (n:DataResource {en_name: $table_name}) RETURN count(n) > 0 AS exists
  124. # """
  125. # try:
  126. # with driver.session() as session:
  127. # result = session.run(query, table_name=table_name)
  128. # record = result.single()
  129. # return record and record["exists"]
  130. # finally:
  131. # driver.close()
  132. def get_resource_subscribed_tables(enabled_tables: list) -> list:
  133. result = []
  134. for t in enabled_tables:
  135. if is_data_resource_table(t['table_name']):
  136. result.append(t)
  137. return result
  138. # 根据目标表,递归查找其所有上游依赖的 DataResource 表(不限层级)
  139. def get_dependency_resource_tables(enabled_tables: list) -> list:
  140. uri = NEO4J_CONFIG['uri']
  141. auth = (NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  142. driver = GraphDatabase.driver(uri, auth=auth)
  143. resource_set = set()
  144. try:
  145. with driver.session() as session:
  146. for t in enabled_tables:
  147. query = """
  148. MATCH (target:Table {name: $table_name})
  149. MATCH (res:DataResource)-[:ORIGINATES_FROM]->(:DataSource)
  150. WHERE (target)-[:DERIVED_FROM*1..]->(res)
  151. RETURN DISTINCT res.en_name AS name
  152. """
  153. result = session.run(query, table_name=t['table_name'])
  154. for record in result:
  155. resource_set.add(record['name'])
  156. finally:
  157. driver.close()
  158. output = []
  159. for name in resource_set:
  160. output.append({"table_name": name, "execution_mode": "append"})
  161. return output
  162. # 从 PostgreSQL 获取启用的表,按调度频率 daily/weekly/monthly 过滤
  163. def get_enabled_tables(frequency: str) -> list:
  164. conn = get_pg_conn()
  165. cursor = conn.cursor()
  166. cursor.execute("""
  167. SELECT table_name, execution_mode
  168. FROM table_schedule
  169. WHERE is_enabled = TRUE AND schedule_frequency = %s
  170. """, (frequency,))
  171. result = cursor.fetchall()
  172. cursor.close()
  173. conn.close()
  174. output = []
  175. for r in result:
  176. output.append({"table_name": r[0], "execution_mode": r[1]})
  177. return output
  178. # 判断给定表名是否是 Neo4j 中的 DataResource 类型
  179. def is_data_resource_table(table_name: str) -> bool:
  180. uri = NEO4J_CONFIG['uri']
  181. auth = (NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  182. driver = GraphDatabase.driver(uri, auth=auth)
  183. query = """
  184. MATCH (n:DataResource {en_name: $table_name}) RETURN count(n) > 0 AS exists
  185. """
  186. try:
  187. with driver.session() as session:
  188. result = session.run(query, table_name=table_name)
  189. record = result.single()
  190. return record and record["exists"]
  191. finally:
  192. driver.close()
  193. # 从 Neo4j 查询 DataModel 表的 DERIVED_FROM 关系上的 script_name 属性
  194. def get_script_name_from_neo4j(table_name):
  195. uri = NEO4J_CONFIG['uri']
  196. auth = (NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  197. driver = GraphDatabase.driver(uri, auth=auth)
  198. query = """
  199. MATCH (target:DataModel {en_name: $table_name})-[r:DERIVED_FROM]->(n)
  200. WHERE n:DataModel OR n:DataResource
  201. RETURN r.script_name AS script_name
  202. """
  203. try:
  204. with driver.session() as session:
  205. result = session.run(query, table_name=table_name)
  206. record = result.single()
  207. if record:
  208. try:
  209. script_name = record['script_name']
  210. return script_name
  211. except (KeyError, TypeError) as e:
  212. print(f"[WARN] 记录中不包含script_name字段: {e}")
  213. return None
  214. else:
  215. return None
  216. except Exception as e:
  217. print(f"[ERROR] 查询表 {table_name} 的脚本名称时出错: {str(e)}")
  218. return None
  219. finally:
  220. driver.close()
  221. # 判断给定表名是否是 Neo4j 中的 DataModel 类型
  222. def is_data_model_table(table_name):
  223. uri = NEO4J_CONFIG['uri']
  224. auth = (NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  225. driver = GraphDatabase.driver(uri, auth=auth)
  226. query = """
  227. MATCH (n:DataModel {en_name: $table_name}) RETURN count(n) > 0 AS exists
  228. """
  229. try:
  230. with driver.session() as session:
  231. result = session.run(query, table_name=table_name)
  232. record = result.single()
  233. return record and record['exists']
  234. finally:
  235. driver.close()
  236. def check_script_exists(script_name):
  237. """
  238. 检查脚本文件是否存在于配置的脚本目录中
  239. 参数:
  240. script_name (str): 脚本文件名
  241. 返回:
  242. bool: 如果脚本存在返回True,否则返回False
  243. str: 完整的脚本路径
  244. """
  245. if not script_name:
  246. logger.error("脚本名称为空,无法检查")
  247. return False, None
  248. script_path = Path(SCRIPTS_BASE_PATH) / script_name
  249. script_path_str = str(script_path)
  250. logger.info(f"检查脚本路径: {script_path_str}")
  251. if os.path.exists(script_path_str):
  252. logger.info(f"脚本文件已找到: {script_path_str}")
  253. return True, script_path_str
  254. else:
  255. logger.error(f"脚本文件不存在: {script_path_str}")
  256. # 尝试列出目录中的文件
  257. try:
  258. base_dir = Path(SCRIPTS_BASE_PATH)
  259. if base_dir.exists():
  260. files = list(base_dir.glob("*.py"))
  261. logger.info(f"目录 {SCRIPTS_BASE_PATH} 中的Python文件: {[f.name for f in files]}")
  262. else:
  263. logger.error(f"基础目录不存在: {SCRIPTS_BASE_PATH}")
  264. except Exception as e:
  265. logger.error(f"列出目录内容时出错: {str(e)}")
  266. return False, script_path_str
  267. def run_model_script(table_name, execution_mode):
  268. """
  269. 根据表名查找并执行对应的模型脚本
  270. 参数:
  271. table_name (str): 要处理的表名
  272. execution_mode (str): 执行模式 (append/full_refresh)
  273. 返回:
  274. bool: 执行成功返回True,否则返回False
  275. """
  276. # 从Neo4j获取脚本名称
  277. script_name = get_script_name_from_neo4j(table_name)
  278. if not script_name:
  279. logger.error(f"未找到表 {table_name} 的脚本名称,跳过处理")
  280. return False
  281. logger.info(f"从Neo4j获取到表 {table_name} 的脚本名称: {script_name}")
  282. # 检查脚本文件是否存在
  283. exists, script_path = check_script_exists(script_name)
  284. if not exists:
  285. logger.error(f"表 {table_name} 的脚本文件 {script_name} 不存在,跳过处理")
  286. return False
  287. # 执行脚本
  288. logger.info(f"开始执行脚本: {script_path}")
  289. try:
  290. # 动态导入模块
  291. import importlib.util
  292. import sys
  293. spec = importlib.util.spec_from_file_location("dynamic_module", script_path)
  294. module = importlib.util.module_from_spec(spec)
  295. spec.loader.exec_module(module)
  296. # 检查并调用标准入口函数run
  297. if hasattr(module, "run"):
  298. logger.info(f"调用脚本 {script_name} 的标准入口函数 run()")
  299. module.run(table_name=table_name, execution_mode=execution_mode)
  300. logger.info(f"脚本 {script_name} 执行成功")
  301. return True
  302. else:
  303. logger.error(f"脚本 {script_name} 中未定义标准入口函数 run(),无法执行")
  304. return False
  305. except Exception as e:
  306. logger.error(f"执行脚本 {script_name} 时出错: {str(e)}")
  307. import traceback
  308. logger.error(traceback.format_exc())
  309. return False
  310. # 从 Neo4j 获取指定 DataModel 表之间的依赖关系图
  311. # 返回值为 dict:{目标表: [上游依赖表1, 上游依赖表2, ...]}
  312. # def get_model_dependency_graph(table_names: list) -> dict:
  313. # graph = {}
  314. # uri = NEO4J_CONFIG['uri']
  315. # auth = (NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  316. # driver = GraphDatabase.driver(uri, auth=auth)
  317. # try:
  318. # with driver.session() as session:
  319. # for table_name in table_names:
  320. # query = """
  321. # MATCH (t:DataModel {en_name: $table_name})<-[:DERIVED_FROM]-(up:DataModel)
  322. # RETURN up.en_name AS upstream
  323. # """
  324. # result = session.run(query, table_name=table_name)
  325. # deps = [record['upstream'] for record in result if 'upstream' in record]
  326. # graph[table_name] = deps
  327. # finally:
  328. # driver.close()
  329. # return graph
  330. def get_model_dependency_graph(table_names: list) -> dict:
  331. """
  332. 使用networkx从Neo4j获取指定DataModel表之间的依赖关系图
  333. 参数:
  334. table_names: 表名列表
  335. 返回:
  336. dict: 依赖关系字典 {目标表: [上游依赖表1, 上游依赖表2, ...]}
  337. """
  338. # 创建有向图
  339. G = nx.DiGraph()
  340. # 添加所有节点
  341. for table_name in table_names:
  342. G.add_node(table_name)
  343. # 从Neo4j获取依赖关系并添加边
  344. uri = NEO4J_CONFIG['uri']
  345. auth = (NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  346. driver = GraphDatabase.driver(uri, auth=auth)
  347. try:
  348. with driver.session() as session:
  349. for table_name in table_names:
  350. query = """
  351. MATCH (t:DataModel {en_name: $table_name})<-[:DERIVED_FROM]-(up:DataModel)
  352. WHERE up.en_name IN $all_tables
  353. RETURN up.en_name AS upstream
  354. """
  355. result = session.run(query, table_name=table_name, all_tables=table_names)
  356. deps = [record['upstream'] for record in result if 'upstream' in record]
  357. # 添加依赖边
  358. for dep in deps:
  359. G.add_edge(dep, table_name)
  360. finally:
  361. driver.close()
  362. # 检测循环依赖
  363. try:
  364. cycles = list(nx.simple_cycles(G))
  365. if cycles:
  366. logger.warning(f"检测到表间循环依赖: {cycles}")
  367. except Exception as e:
  368. logger.error(f"检查循环依赖失败: {str(e)}")
  369. # 转换为字典格式返回
  370. dependency_dict = {}
  371. for table_name in table_names:
  372. dependency_dict[table_name] = list(G.predecessors(table_name))
  373. return dependency_dict
  374. def generate_optimized_execution_order(table_names: list) -> list:
  375. """
  376. 生成优化的执行顺序,可处理循环依赖
  377. 参数:
  378. table_names: 表名列表
  379. 返回:
  380. list: 优化后的执行顺序列表
  381. """
  382. # 创建依赖图
  383. G = nx.DiGraph()
  384. # 添加所有节点
  385. for table_name in table_names:
  386. G.add_node(table_name)
  387. # 添加依赖边
  388. dependency_dict = get_model_dependency_graph(table_names)
  389. for target, upstreams in dependency_dict.items():
  390. for upstream in upstreams:
  391. G.add_edge(upstream, target)
  392. # 检测循环依赖
  393. cycles = list(nx.simple_cycles(G))
  394. if cycles:
  395. logger.warning(f"检测到循环依赖,将尝试打破循环: {cycles}")
  396. # 打破循环依赖(简单策略:移除每个循环中的一条边)
  397. for cycle in cycles:
  398. # 移除循环中的最后一条边
  399. G.remove_edge(cycle[-1], cycle[0])
  400. logger.info(f"打破循环依赖: 移除 {cycle[-1]} -> {cycle[0]} 的依赖")
  401. # 生成拓扑排序
  402. try:
  403. execution_order = list(nx.topological_sort(G))
  404. return execution_order
  405. except Exception as e:
  406. logger.error(f"生成执行顺序失败: {str(e)}")
  407. # 返回原始列表作为备选
  408. return table_names