utils.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. # utils.py
  2. import psycopg2
  3. from neo4j import GraphDatabase
  4. from config import PG_CONFIG, NEO4J_CONFIG, SCRIPTS_BASE_PATH
  5. import logging
  6. import importlib.util
  7. from pathlib import Path
  8. import sys
  9. import os
  10. def get_pg_conn():
  11. return psycopg2.connect(**PG_CONFIG)
  12. def get_subscribed_tables(freq: str) -> list[dict]:
  13. """
  14. 根据调度频率获取启用的订阅表列表,附带 execution_mode 参数
  15. 返回结果示例:
  16. [
  17. {'table_name': 'region_sales', 'execution_mode': 'append'},
  18. {'table_name': 'catalog_sales', 'execution_mode': 'full_refresh'}
  19. ]
  20. """
  21. conn = get_pg_conn()
  22. cursor = conn.cursor()
  23. cursor.execute("""
  24. SELECT table_name, execution_mode
  25. FROM table_schedule
  26. WHERE is_enabled = TRUE AND schedule_frequency = %s
  27. """, (freq,))
  28. result = cursor.fetchall()
  29. cursor.close()
  30. conn.close()
  31. return [{"table_name": r[0], "execution_mode": r[1]} for r in result]
  32. def get_neo4j_dependencies(table_name: str) -> list:
  33. """
  34. 查询 Neo4j 中某个模型的 DERIVED_FROM 依赖(上游表名)
  35. """
  36. uri = NEO4J_CONFIG['uri']
  37. auth = (NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  38. driver = GraphDatabase.driver(uri, auth=auth)
  39. query = """
  40. MATCH (a:Table {name: $name})<-[:DERIVED_FROM]-(b:Table)
  41. RETURN b.name
  42. """
  43. with driver.session() as session:
  44. records = session.run(query, name=table_name)
  45. return [record["b.name"] for record in records]
  46. # def get_script_name_from_neo4j(table_name: str) -> str:
  47. # """
  48. # 从Neo4j数据库中查询表对应的脚本名称
  49. # 查询的是 DataResource 和 DataSource 之间的 ORIGINATES_FROM 关系中的 script_name 属性
  50. # 参数:
  51. # table_name (str): 数据资源表名
  52. # 返回:
  53. # str: 脚本名称,如果未找到则返回None
  54. # """
  55. # logger = logging.getLogger("airflow.task")
  56. # driver = GraphDatabase.driver(**NEO4J_CONFIG)
  57. # query = """
  58. # MATCH (dr:DataResource {en_name: $table_name})-[rel:ORIGINATES_FROM]->(ds:DataSource)
  59. # RETURN rel.script_name AS script_name
  60. # """
  61. # try:
  62. # with driver.session() as session:
  63. # result = session.run(query, table_name=table_name)
  64. # record = result.single()
  65. # if record and 'script_name' in record:
  66. # return record['script_name']
  67. # else:
  68. # logger.warning(f"没有找到表 {table_name} 对应的脚本名称")
  69. # return None
  70. # except Exception as e:
  71. # logger.error(f"从Neo4j查询脚本名称时出错: {str(e)}")
  72. # return None
  73. # finally:
  74. # driver.close()
  75. def execute_script(script_name: str, table_name: str, execution_mode: str) -> bool:
  76. """
  77. 根据脚本名称动态导入并执行对应的脚本
  78. 返回:
  79. bool: 执行成功返回True,否则返回False
  80. """
  81. logger = logging.getLogger("airflow.task")
  82. if not script_name:
  83. logger.error("未提供脚本名称,无法执行")
  84. return False
  85. try:
  86. # 直接使用配置的部署路径,不考虑本地开发路径
  87. script_path = Path(SCRIPTS_BASE_PATH) / script_name
  88. logger.info(f"使用配置的Airflow部署路径: {script_path}")
  89. # 动态导入模块
  90. spec = importlib.util.spec_from_file_location("dynamic_module", script_path)
  91. module = importlib.util.module_from_spec(spec)
  92. spec.loader.exec_module(module)
  93. # 使用标准入口函数run
  94. if hasattr(module, "run"):
  95. logger.info(f"执行脚本 {script_name} 的标准入口函数 run()")
  96. module.run(table_name=table_name, execution_mode=execution_mode)
  97. return True
  98. else:
  99. logger.warning(f"脚本 {script_name} 未定义标准入口函数 run(),无法执行")
  100. return False
  101. except Exception as e:
  102. logger.error(f"执行脚本 {script_name} 时出错: {str(e)}")
  103. return False
  104. # def get_enabled_tables(frequency: str) -> list:
  105. # conn = get_pg_conn()
  106. # cursor = conn.cursor()
  107. # cursor.execute("""
  108. # SELECT table_name, execution_mode
  109. # FROM table_schedule
  110. # WHERE is_enabled = TRUE AND schedule_frequency = %s
  111. # """, (frequency,))
  112. # result = cursor.fetchall()
  113. # cursor.close()
  114. # conn.close()
  115. # output = []
  116. # for r in result:
  117. # output.append({"table_name": r[0], "execution_mode": r[1]})
  118. # return output
  119. # def is_data_resource_table(table_name: str) -> bool:
  120. # driver = GraphDatabase.driver(NEO4J_CONFIG['uri'], auth=(NEO4J_CONFIG['user'], NEO4J_CONFIG['password']))
  121. # query = """
  122. # MATCH (n:DataResource {en_name: $table_name}) RETURN count(n) > 0 AS exists
  123. # """
  124. # try:
  125. # with driver.session() as session:
  126. # result = session.run(query, table_name=table_name)
  127. # record = result.single()
  128. # return record and record["exists"]
  129. # finally:
  130. # driver.close()
  131. def get_resource_subscribed_tables(enabled_tables: list) -> list:
  132. result = []
  133. for t in enabled_tables:
  134. if is_data_resource_table(t['table_name']):
  135. result.append(t)
  136. return result
  137. # 根据目标表,递归查找其所有上游依赖的 DataResource 表(不限层级)
  138. def get_dependency_resource_tables(enabled_tables: list) -> list:
  139. uri = NEO4J_CONFIG['uri']
  140. auth = (NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  141. driver = GraphDatabase.driver(uri, auth=auth)
  142. resource_set = set()
  143. try:
  144. with driver.session() as session:
  145. for t in enabled_tables:
  146. query = """
  147. MATCH (target:Table {name: $table_name})
  148. MATCH (res:DataResource)-[:ORIGINATES_FROM]->(:DataSource)
  149. WHERE (target)-[:DERIVED_FROM*1..]->(res)
  150. RETURN DISTINCT res.en_name AS name
  151. """
  152. result = session.run(query, table_name=t['table_name'])
  153. for record in result:
  154. resource_set.add(record['name'])
  155. finally:
  156. driver.close()
  157. output = []
  158. for name in resource_set:
  159. output.append({"table_name": name, "execution_mode": "append"})
  160. return output
  161. # 从 PostgreSQL 获取启用的表,按调度频率 daily/weekly/monthly 过滤
  162. def get_enabled_tables(frequency: str) -> list:
  163. conn = get_pg_conn()
  164. cursor = conn.cursor()
  165. cursor.execute("""
  166. SELECT table_name, execution_mode
  167. FROM table_schedule
  168. WHERE is_enabled = TRUE AND schedule_frequency = %s
  169. """, (frequency,))
  170. result = cursor.fetchall()
  171. cursor.close()
  172. conn.close()
  173. output = []
  174. for r in result:
  175. output.append({"table_name": r[0], "execution_mode": r[1]})
  176. return output
  177. # 判断给定表名是否是 Neo4j 中的 DataResource 类型
  178. def is_data_resource_table(table_name: str) -> bool:
  179. uri = NEO4J_CONFIG['uri']
  180. auth = (NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  181. driver = GraphDatabase.driver(uri, auth=auth)
  182. query = """
  183. MATCH (n:DataResource {en_name: $table_name}) RETURN count(n) > 0 AS exists
  184. """
  185. try:
  186. with driver.session() as session:
  187. result = session.run(query, table_name=table_name)
  188. record = result.single()
  189. return record and record["exists"]
  190. finally:
  191. driver.close()
  192. # 从 Neo4j 查询 DataModel 表的 DERIVED_FROM 关系上的 script_name 属性
  193. def get_script_name_from_neo4j(table_name):
  194. uri = NEO4J_CONFIG['uri']
  195. auth = (NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  196. driver = GraphDatabase.driver(uri, auth=auth)
  197. query = """
  198. MATCH (target:DataModel {en_name: $table_name})-[r:DERIVED_FROM]->(n)
  199. WHERE n:DataModel OR n:DataResource
  200. RETURN r.script_name AS script_name
  201. """
  202. try:
  203. with driver.session() as session:
  204. result = session.run(query, table_name=table_name)
  205. record = result.single()
  206. if record:
  207. try:
  208. script_name = record['script_name']
  209. return script_name
  210. except (KeyError, TypeError) as e:
  211. print(f"[WARN] 记录中不包含script_name字段: {e}")
  212. return None
  213. else:
  214. return None
  215. except Exception as e:
  216. print(f"[ERROR] 查询表 {table_name} 的脚本名称时出错: {str(e)}")
  217. return None
  218. finally:
  219. driver.close()
  220. # 判断给定表名是否是 Neo4j 中的 DataModel 类型
  221. def is_data_model_table(table_name):
  222. uri = NEO4J_CONFIG['uri']
  223. auth = (NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  224. driver = GraphDatabase.driver(uri, auth=auth)
  225. query = """
  226. MATCH (n:DataModel {en_name: $table_name}) RETURN count(n) > 0 AS exists
  227. """
  228. try:
  229. with driver.session() as session:
  230. result = session.run(query, table_name=table_name)
  231. record = result.single()
  232. return record and record['exists']
  233. finally:
  234. driver.close()
  235. # 检查脚本文件是否存在于指定路径
  236. def check_script_exists(script_name):
  237. """
  238. 检查脚本文件是否存在于配置的脚本目录中
  239. 参数:
  240. script_name (str): 脚本文件名
  241. 返回:
  242. bool: 如果脚本存在返回True,否则返回False
  243. str: 完整的脚本路径
  244. """
  245. from pathlib import Path
  246. import os
  247. import logging
  248. logger = logging.getLogger("airflow.task")
  249. if not script_name:
  250. logger.error("脚本名称为空,无法检查")
  251. return False, None
  252. script_path = Path(SCRIPTS_BASE_PATH) / script_name
  253. script_path_str = str(script_path)
  254. logger.info(f"检查脚本路径: {script_path_str}")
  255. if os.path.exists(script_path_str):
  256. logger.info(f"脚本文件已找到: {script_path_str}")
  257. return True, script_path_str
  258. else:
  259. logger.error(f"脚本文件不存在: {script_path_str}")
  260. # 尝试列出目录中的文件
  261. try:
  262. base_dir = Path(SCRIPTS_BASE_PATH)
  263. if base_dir.exists():
  264. files = list(base_dir.glob("*.py"))
  265. logger.info(f"目录 {SCRIPTS_BASE_PATH} 中的Python文件: {[f.name for f in files]}")
  266. else:
  267. logger.error(f"基础目录不存在: {SCRIPTS_BASE_PATH}")
  268. except Exception as e:
  269. logger.error(f"列出目录内容时出错: {str(e)}")
  270. return False, script_path_str
  271. # 更新run_model_script函数以使用上述检查
  272. def run_model_script(table_name, execution_mode):
  273. """
  274. 根据表名查找并执行对应的模型脚本
  275. 参数:
  276. table_name (str): 要处理的表名
  277. execution_mode (str): 执行模式 (append/full_refresh)
  278. 返回:
  279. bool: 执行成功返回True,否则返回False
  280. """
  281. import logging
  282. logger = logging.getLogger("airflow.task")
  283. # 从Neo4j获取脚本名称
  284. script_name = get_script_name_from_neo4j(table_name)
  285. if not script_name:
  286. logger.error(f"未找到表 {table_name} 的脚本名称,跳过处理")
  287. return False
  288. logger.info(f"从Neo4j获取到表 {table_name} 的脚本名称: {script_name}")
  289. # 检查脚本文件是否存在
  290. exists, script_path = check_script_exists(script_name)
  291. if not exists:
  292. logger.error(f"表 {table_name} 的脚本文件 {script_name} 不存在,跳过处理")
  293. return False
  294. # 执行脚本
  295. logger.info(f"开始执行脚本: {script_path}")
  296. try:
  297. # 动态导入模块
  298. import importlib.util
  299. import sys
  300. spec = importlib.util.spec_from_file_location("dynamic_module", script_path)
  301. module = importlib.util.module_from_spec(spec)
  302. spec.loader.exec_module(module)
  303. # 检查并调用标准入口函数run
  304. if hasattr(module, "run"):
  305. logger.info(f"调用脚本 {script_name} 的标准入口函数 run()")
  306. module.run(table_name=table_name, execution_mode=execution_mode)
  307. logger.info(f"脚本 {script_name} 执行成功")
  308. return True
  309. else:
  310. logger.error(f"脚本 {script_name} 中未定义标准入口函数 run(),无法执行")
  311. return False
  312. except Exception as e:
  313. logger.error(f"执行脚本 {script_name} 时出错: {str(e)}")
  314. import traceback
  315. logger.error(traceback.format_exc())
  316. return False
  317. # 从 Neo4j 获取指定 DataModel 表之间的依赖关系图
  318. # 返回值为 dict:{目标表: [上游依赖表1, 上游依赖表2, ...]}
  319. def get_model_dependency_graph(table_names: list) -> dict:
  320. graph = {}
  321. uri = NEO4J_CONFIG['uri']
  322. auth = (NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
  323. driver = GraphDatabase.driver(uri, auth=auth)
  324. try:
  325. with driver.session() as session:
  326. for table_name in table_names:
  327. query = """
  328. MATCH (t:DataModel {en_name: $table_name})<-[:DERIVED_FROM]-(up:DataModel)
  329. RETURN up.en_name AS upstream
  330. """
  331. result = session.run(query, table_name=table_name)
  332. deps = [record['upstream'] for record in result if 'upstream' in record]
  333. graph[table_name] = deps
  334. finally:
  335. driver.close()
  336. return graph