|
@@ -7,6 +7,7 @@ import importlib.util
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
import networkx as nx
|
|
import networkx as nx
|
|
import os
|
|
import os
|
|
|
|
+from airflow.exceptions import AirflowFailException
|
|
|
|
|
|
# 创建统一的日志记录器
|
|
# 创建统一的日志记录器
|
|
logger = logging.getLogger("airflow.task")
|
|
logger = logging.getLogger("airflow.task")
|
|
@@ -310,20 +311,25 @@ def run_model_script(table_name, execution_mode):
|
|
|
|
|
|
返回:
|
|
返回:
|
|
bool: 执行成功返回True,否则返回False
|
|
bool: 执行成功返回True,否则返回False
|
|
|
|
+
|
|
|
|
+ 抛出:
|
|
|
|
+ AirflowFailException: 如果脚本不存在或执行失败
|
|
"""
|
|
"""
|
|
# 从Neo4j获取脚本名称
|
|
# 从Neo4j获取脚本名称
|
|
script_name = get_script_name_from_neo4j(table_name)
|
|
script_name = get_script_name_from_neo4j(table_name)
|
|
if not script_name:
|
|
if not script_name:
|
|
- logger.error(f"未找到表 {table_name} 的脚本名称,跳过处理")
|
|
|
|
- return False
|
|
|
|
|
|
+ error_msg = f"未找到表 {table_name} 的脚本名称,任务失败"
|
|
|
|
+ logger.error(error_msg)
|
|
|
|
+ raise AirflowFailException(error_msg)
|
|
|
|
|
|
logger.info(f"从Neo4j获取到表 {table_name} 的脚本名称: {script_name}")
|
|
logger.info(f"从Neo4j获取到表 {table_name} 的脚本名称: {script_name}")
|
|
|
|
|
|
# 检查脚本文件是否存在
|
|
# 检查脚本文件是否存在
|
|
exists, script_path = check_script_exists(script_name)
|
|
exists, script_path = check_script_exists(script_name)
|
|
if not exists:
|
|
if not exists:
|
|
- logger.error(f"表 {table_name} 的脚本文件 {script_name} 不存在,跳过处理")
|
|
|
|
- return False
|
|
|
|
|
|
+ error_msg = f"表 {table_name} 的脚本文件 {script_name} 不存在,任务失败"
|
|
|
|
+ logger.error(error_msg)
|
|
|
|
+ raise AirflowFailException(error_msg)
|
|
|
|
|
|
# 执行脚本
|
|
# 执行脚本
|
|
logger.info(f"开始执行脚本: {script_path}")
|
|
logger.info(f"开始执行脚本: {script_path}")
|
|
@@ -343,13 +349,18 @@ def run_model_script(table_name, execution_mode):
|
|
logger.info(f"脚本 {script_name} 执行成功")
|
|
logger.info(f"脚本 {script_name} 执行成功")
|
|
return True
|
|
return True
|
|
else:
|
|
else:
|
|
- logger.error(f"脚本 {script_name} 中未定义标准入口函数 run(),无法执行")
|
|
|
|
- return False
|
|
|
|
|
|
+ error_msg = f"脚本 {script_name} 中未定义标准入口函数 run(),任务失败"
|
|
|
|
+ logger.error(error_msg)
|
|
|
|
+ raise AirflowFailException(error_msg)
|
|
|
|
+ except AirflowFailException:
|
|
|
|
+ # 直接重新抛出Airflow异常
|
|
|
|
+ raise
|
|
except Exception as e:
|
|
except Exception as e:
|
|
- logger.error(f"执行脚本 {script_name} 时出错: {str(e)}")
|
|
|
|
|
|
+ error_msg = f"执行脚本 {script_name} 时出错: {str(e)}"
|
|
|
|
+ logger.error(error_msg)
|
|
import traceback
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
logger.error(traceback.format_exc())
|
|
- return False
|
|
|
|
|
|
+ raise AirflowFailException(error_msg)
|
|
|
|
|
|
# 从 Neo4j 获取指定 DataModel 表之间的依赖关系图
|
|
# 从 Neo4j 获取指定 DataModel 表之间的依赖关系图
|
|
# 返回值为 dict:{目标表: [上游依赖表1, 上游依赖表2, ...]}
|
|
# 返回值为 dict:{目标表: [上游依赖表1, 上游依赖表2, ...]}
|
|
@@ -467,3 +478,59 @@ def generate_optimized_execution_order(table_names: list) -> list:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+def identify_common_paths(table_names: list) -> dict:
|
|
|
|
+ """
|
|
|
|
+ 识别多个表之间的公共执行路径
|
|
|
|
+
|
|
|
|
+ 参数:
|
|
|
|
+ table_names: 表名列表
|
|
|
|
+
|
|
|
|
+ 返回:
|
|
|
|
+ dict: 公共路径信息 {(path_tuple): 使用次数}
|
|
|
|
+ """
|
|
|
|
+ # 创建依赖图
|
|
|
|
+ G = nx.DiGraph()
|
|
|
|
+
|
|
|
|
+ # 添加所有节点和直接依赖边
|
|
|
|
+ dependency_dict = get_model_dependency_graph(table_names)
|
|
|
|
+ for target, upstreams in dependency_dict.items():
|
|
|
|
+ G.add_node(target)
|
|
|
|
+ for upstream in upstreams:
|
|
|
|
+ G.add_node(upstream)
|
|
|
|
+ G.add_edge(upstream, target)
|
|
|
|
+
|
|
|
|
+ # 找出所有路径
|
|
|
|
+ all_paths = []
|
|
|
|
+ # 找出所有源节点(没有入边的节点)和终节点(没有出边的节点)
|
|
|
|
+ sources = [n for n in G.nodes() if G.in_degree(n) == 0]
|
|
|
|
+ targets = [n for n in G.nodes() if G.out_degree(n) == 0]
|
|
|
|
+
|
|
|
|
+ # 获取所有源到目标的路径
|
|
|
|
+ for source in sources:
|
|
|
|
+ for target in targets:
|
|
|
|
+ try:
|
|
|
|
+ # 限制路径长度,避免组合爆炸
|
|
|
|
+ paths = list(nx.all_simple_paths(G, source, target, cutoff=10))
|
|
|
|
+ all_paths.extend(paths)
|
|
|
|
+ except nx.NetworkXNoPath:
|
|
|
|
+ continue
|
|
|
|
+
|
|
|
|
+ # 统计路径段使用频率
|
|
|
|
+ path_segments = {}
|
|
|
|
+ for path in all_paths:
|
|
|
|
+ # 只考虑长度>=2的路径段(至少有一条边)
|
|
|
|
+ for i in range(len(path)-1):
|
|
|
|
+ for j in range(i+2, min(i+6, len(path)+1)): # 限制段长,避免组合爆炸
|
|
|
|
+ segment = tuple(path[i:j])
|
|
|
|
+ if segment not in path_segments:
|
|
|
|
+ path_segments[segment] = 0
|
|
|
|
+ path_segments[segment] += 1
|
|
|
|
+
|
|
|
|
+ # 过滤出重复使用的路径段
|
|
|
|
+ common_paths = {seg: count for seg, count in path_segments.items()
|
|
|
|
+ if count > 1 and len(seg) >= 3} # 至少3个节点,2条边
|
|
|
|
+
|
|
|
|
+ # 按使用次数排序
|
|
|
|
+ common_paths = dict(sorted(common_paths.items(), key=lambda x: x[1], reverse=True))
|
|
|
|
+
|
|
|
|
+ return common_paths
|