|
@@ -5,9 +5,12 @@ from config import PG_CONFIG, NEO4J_CONFIG, SCRIPTS_BASE_PATH
|
|
|
import logging
|
|
|
import importlib.util
|
|
|
from pathlib import Path
|
|
|
-import sys
|
|
|
+import networkx as nx
|
|
|
import os
|
|
|
|
|
|
+# 创建统一的日志记录器
|
|
|
+logger = logging.getLogger("airflow.task")
|
|
|
+
|
|
|
def get_pg_conn():
|
|
|
return psycopg2.connect(**PG_CONFIG)
|
|
|
|
|
@@ -87,8 +90,6 @@ def execute_script(script_name: str, table_name: str, execution_mode: str) -> bo
|
|
|
返回:
|
|
|
bool: 执行成功返回True,否则返回False
|
|
|
"""
|
|
|
- logger = logging.getLogger("airflow.task")
|
|
|
-
|
|
|
if not script_name:
|
|
|
logger.error("未提供脚本名称,无法执行")
|
|
|
return False
|
|
@@ -260,7 +261,6 @@ def is_data_model_table(table_name):
|
|
|
finally:
|
|
|
driver.close()
|
|
|
|
|
|
-# 检查脚本文件是否存在于指定路径
|
|
|
def check_script_exists(script_name):
|
|
|
"""
|
|
|
检查脚本文件是否存在于配置的脚本目录中
|
|
@@ -272,12 +272,6 @@ def check_script_exists(script_name):
|
|
|
bool: 如果脚本存在返回True,否则返回False
|
|
|
str: 完整的脚本路径
|
|
|
"""
|
|
|
- from pathlib import Path
|
|
|
- import os
|
|
|
- import logging
|
|
|
-
|
|
|
- logger = logging.getLogger("airflow.task")
|
|
|
-
|
|
|
if not script_name:
|
|
|
logger.error("脚本名称为空,无法检查")
|
|
|
return False, None
|
|
@@ -306,7 +300,6 @@ def check_script_exists(script_name):
|
|
|
|
|
|
return False, script_path_str
|
|
|
|
|
|
-# 更新run_model_script函数以使用上述检查
|
|
|
def run_model_script(table_name, execution_mode):
|
|
|
"""
|
|
|
根据表名查找并执行对应的模型脚本
|
|
@@ -318,9 +311,6 @@ def run_model_script(table_name, execution_mode):
|
|
|
返回:
|
|
|
bool: 执行成功返回True,否则返回False
|
|
|
"""
|
|
|
- import logging
|
|
|
- logger = logging.getLogger("airflow.task")
|
|
|
-
|
|
|
# 从Neo4j获取脚本名称
|
|
|
script_name = get_script_name_from_neo4j(table_name)
|
|
|
if not script_name:
|
|
@@ -363,8 +353,40 @@ def run_model_script(table_name, execution_mode):
|
|
|
|
|
|
# 从 Neo4j 获取指定 DataModel 表之间的依赖关系图
|
|
|
# 返回值为 dict:{目标表: [上游依赖表1, 上游依赖表2, ...]}
|
|
|
+# def get_model_dependency_graph(table_names: list) -> dict:
|
|
|
+# graph = {}
|
|
|
+# uri = NEO4J_CONFIG['uri']
|
|
|
+# auth = (NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
|
|
|
+# driver = GraphDatabase.driver(uri, auth=auth)
|
|
|
+# try:
|
|
|
+# with driver.session() as session:
|
|
|
+# for table_name in table_names:
|
|
|
+# query = """
|
|
|
+# MATCH (t:DataModel {en_name: $table_name})<-[:DERIVED_FROM]-(up:DataModel)
|
|
|
+# RETURN up.en_name AS upstream
|
|
|
+# """
|
|
|
+# result = session.run(query, table_name=table_name)
|
|
|
+# deps = [record['upstream'] for record in result if 'upstream' in record]
|
|
|
+# graph[table_name] = deps
|
|
|
+# finally:
|
|
|
+# driver.close()
|
|
|
+# return graph
|
|
|
def get_model_dependency_graph(table_names: list) -> dict:
|
|
|
- graph = {}
|
|
|
+ """
|
|
|
+ 使用networkx从Neo4j获取指定DataModel表之间的依赖关系图
|
|
|
+ 参数:
|
|
|
+ table_names: 表名列表
|
|
|
+ 返回:
|
|
|
+ dict: 依赖关系字典 {目标表: [上游依赖表1, 上游依赖表2, ...]}
|
|
|
+ """
|
|
|
+ # 创建有向图
|
|
|
+ G = nx.DiGraph()
|
|
|
+
|
|
|
+ # 添加所有节点
|
|
|
+ for table_name in table_names:
|
|
|
+ G.add_node(table_name)
|
|
|
+
|
|
|
+ # 从Neo4j获取依赖关系并添加边
|
|
|
uri = NEO4J_CONFIG['uri']
|
|
|
auth = (NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
|
|
|
driver = GraphDatabase.driver(uri, auth=auth)
|
|
@@ -373,11 +395,75 @@ def get_model_dependency_graph(table_names: list) -> dict:
|
|
|
for table_name in table_names:
|
|
|
query = """
|
|
|
MATCH (t:DataModel {en_name: $table_name})<-[:DERIVED_FROM]-(up:DataModel)
|
|
|
+ WHERE up.en_name IN $all_tables
|
|
|
RETURN up.en_name AS upstream
|
|
|
"""
|
|
|
- result = session.run(query, table_name=table_name)
|
|
|
+ result = session.run(query, table_name=table_name, all_tables=table_names)
|
|
|
deps = [record['upstream'] for record in result if 'upstream' in record]
|
|
|
- graph[table_name] = deps
|
|
|
+
|
|
|
+ # 添加依赖边
|
|
|
+ for dep in deps:
|
|
|
+ G.add_edge(dep, table_name)
|
|
|
finally:
|
|
|
driver.close()
|
|
|
- return graph
|
|
|
+
|
|
|
+ # 检测循环依赖
|
|
|
+ try:
|
|
|
+ cycles = list(nx.simple_cycles(G))
|
|
|
+ if cycles:
|
|
|
+ logger.warning(f"检测到表间循环依赖: {cycles}")
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"检查循环依赖失败: {str(e)}")
|
|
|
+
|
|
|
+ # 转换为字典格式返回
|
|
|
+ dependency_dict = {}
|
|
|
+ for table_name in table_names:
|
|
|
+ dependency_dict[table_name] = list(G.predecessors(table_name))
|
|
|
+
|
|
|
+ return dependency_dict
|
|
|
+
|
|
|
+
|
|
|
+def generate_optimized_execution_order(table_names: list) -> list:
|
|
|
+ """
|
|
|
+ 生成优化的执行顺序,可处理循环依赖
|
|
|
+
|
|
|
+ 参数:
|
|
|
+ table_names: 表名列表
|
|
|
+
|
|
|
+ 返回:
|
|
|
+ list: 优化后的执行顺序列表
|
|
|
+ """
|
|
|
+ # 创建依赖图
|
|
|
+ G = nx.DiGraph()
|
|
|
+
|
|
|
+ # 添加所有节点
|
|
|
+ for table_name in table_names:
|
|
|
+ G.add_node(table_name)
|
|
|
+
|
|
|
+ # 添加依赖边
|
|
|
+ dependency_dict = get_model_dependency_graph(table_names)
|
|
|
+ for target, upstreams in dependency_dict.items():
|
|
|
+ for upstream in upstreams:
|
|
|
+ G.add_edge(upstream, target)
|
|
|
+
|
|
|
+ # 检测循环依赖
|
|
|
+ cycles = list(nx.simple_cycles(G))
|
|
|
+ if cycles:
|
|
|
+ logger.warning(f"检测到循环依赖,将尝试打破循环: {cycles}")
|
|
|
+ # 打破循环依赖(简单策略:移除每个循环中的一条边)
|
|
|
+ for cycle in cycles:
|
|
|
+ # 移除循环中的最后一条边
|
|
|
+ G.remove_edge(cycle[-1], cycle[0])
|
|
|
+ logger.info(f"打破循环依赖: 移除 {cycle[-1]} -> {cycle[0]} 的依赖")
|
|
|
+
|
|
|
+ # 生成拓扑排序
|
|
|
+ try:
|
|
|
+ execution_order = list(nx.topological_sort(G))
|
|
|
+ return execution_order
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"生成执行顺序失败: {str(e)}")
|
|
|
+ # 返回原始列表作为备选
|
|
|
+ return table_names
|
|
|
+
|
|
|
+
|
|
|
+
|