Pārlūkot izejas kodu

完成对utils.py/dag_data_resource.py/daily的优化,resource添加了并行度控制

wangxq 1 mēnesi atpakaļ
vecāks
revīzija
721fbed27d
5 mainītis faili ar 241 papildinājumiem un 62 dzēšanām
  1. 3 0
      dags/config.py
  2. 69 17
      dags/dag_data_model_daily.py
  3. 63 26
      dags/dag_data_resource.py
  4. 104 18
      dags/utils.py
  5. 2 1
      requirements.txt

+ 3 - 0
dags/config.py

@@ -29,3 +29,6 @@ SCRIPTS_BASE_PATH = "/opt/airflow/dataops/scripts"
 
 # 本地开发环境脚本路径(如果需要区分环境)
 # LOCAL_SCRIPTS_BASE_PATH = "/path/to/local/scripts"
+
+# 资源表加载并行度
+RESOURCE_LOADING_PARALLEL_DEGREE = 4  # 可根据环境调整

+ 69 - 17
dags/dag_data_model_daily.py

@@ -8,10 +8,52 @@ from utils import get_enabled_tables, is_data_model_table, run_model_script, get
 from config import NEO4J_CONFIG
 import pendulum
 import logging
+import networkx as nx
 
 # 创建日志记录器
 logger = logging.getLogger(__name__)
 
+def generate_optimized_execution_order(table_names: list) -> list:
+    """
+    生成优化的执行顺序,可处理循环依赖    
+    参数:
+        table_names: 表名列表    
+    返回:
+        list: 优化后的执行顺序列表
+    """
+    # 创建依赖图
+    G = nx.DiGraph()
+    
+    # 添加所有节点
+    for table_name in table_names:
+        G.add_node(table_name)
+    
+    # 添加依赖边
+    dependency_dict = get_model_dependency_graph(table_names)
+    for target, upstreams in dependency_dict.items():
+        for upstream in upstreams:
+            if upstream in table_names:  # 确保只考虑目标表集合中的表
+                G.add_edge(upstream, target)
+    
+    # 检测循环依赖
+    cycles = list(nx.simple_cycles(G))
+    if cycles:
+        logger.warning(f"检测到循环依赖,将尝试打破循环: {cycles}")
+        # 打破循环依赖(简单策略:移除每个循环中的一条边)
+        for cycle in cycles:
+            # 移除循环中的最后一条边
+            G.remove_edge(cycle[-1], cycle[0])
+            logger.info(f"打破循环依赖: 移除 {cycle[-1]} -> {cycle[0]} 的依赖")
+    
+    # 生成拓扑排序
+    try:
+        execution_order = list(nx.topological_sort(G))
+        return execution_order
+    except Exception as e:
+        logger.error(f"生成执行顺序失败: {str(e)}")
+        # 返回原始列表作为备选
+        return table_names
+
 with DAG("dag_data_model_daily", start_date=datetime(2024, 1, 1), schedule_interval="@daily", catchup=False) as dag:
     logger.info("初始化 dag_data_model_daily DAG")
     
@@ -44,9 +86,15 @@ with DAG("dag_data_model_daily", start_date=datetime(2024, 1, 1), schedule_inter
             logger.info("没有找到需要处理的模型表,DAG将直接标记为完成")
             wait_for_resource >> daily_completed
         else:
-            # 获取依赖图
+            # 获取表名列表
+            table_names = [t['table_name'] for t in model_tables]
+            
+            # 使用优化函数生成执行顺序,可以处理循环依赖
+            optimized_table_order = generate_optimized_execution_order(table_names)
+            logger.info(f"生成优化执行顺序, 共 {len(optimized_table_order)} 个表")
+            
+            # 获取依赖图 (仍然需要用于设置任务依赖关系)
             try:
-                table_names = [t['table_name'] for t in model_tables]
                 dependency_graph = get_model_dependency_graph(table_names)
                 logger.info(f"构建了 {len(dependency_graph)} 个表的依赖关系图")
             except Exception as e:
@@ -57,20 +105,23 @@ with DAG("dag_data_model_daily", start_date=datetime(2024, 1, 1), schedule_inter
 
             # 构建 task 对象
             task_dict = {}
-            for item in model_tables:
-                try:
-                    task = PythonOperator(
-                        task_id=f"process_model_{item['table_name']}",
-                        python_callable=run_model_script,
-                        op_kwargs={"table_name": item['table_name'], "execution_mode": item['execution_mode']},
-                    )
-                    task_dict[item['table_name']] = task
-                    logger.info(f"创建模型处理任务: process_model_{item['table_name']}")
-                except Exception as e:
-                    logger.error(f"创建任务 process_model_{item['table_name']} 时出错: {str(e)}")
-                    # 出错时也要确保完成标记被触发
-                    wait_for_resource >> daily_completed
-                    raise
+            for table_name in optimized_table_order:
+                # 获取表的配置信息
+                table_config = next((t for t in model_tables if t['table_name'] == table_name), None)
+                if table_config:
+                    try:
+                        task = PythonOperator(
+                            task_id=f"process_model_{table_name}",
+                            python_callable=run_model_script,
+                            op_kwargs={"table_name": table_name, "execution_mode": table_config['execution_mode']},
+                        )
+                        task_dict[table_name] = task
+                        logger.info(f"创建模型处理任务: process_model_{table_name}")
+                    except Exception as e:
+                        logger.error(f"创建任务 process_model_{table_name} 时出错: {str(e)}")
+                        # 出错时也要确保完成标记被触发
+                        wait_for_resource >> daily_completed
+                        raise
 
             # 建立任务依赖(基于 DERIVED_FROM 图)
             dependency_count = 0
@@ -94,7 +145,8 @@ with DAG("dag_data_model_daily", start_date=datetime(2024, 1, 1), schedule_inter
             if top_level_tasks:
                 logger.info(f"发现 {len(top_level_tasks)} 个顶层任务: {', '.join(top_level_tasks)}")
                 for name in top_level_tasks:
-                    wait_for_resource >> task_dict[name]
+                    if name in task_dict:
+                        wait_for_resource >> task_dict[name]
             else:
                 logger.warning("没有找到顶层任务,请检查依赖关系图是否正确")
                 # 如果没有顶层任务,直接将等待任务与完成标记相连接

+ 63 - 26
dags/dag_data_resource.py

@@ -4,16 +4,20 @@ from datetime import datetime
 from utils import (
     get_enabled_tables,
     get_resource_subscribed_tables,
-    get_dependency_resource_tables
+    get_dependency_resource_tables,
+    check_script_exists
 )
 import pendulum
 import logging
 import sys
+from airflow.operators.empty import EmptyOperator
+from config import NEO4J_CONFIG, SCRIPTS_BASE_PATH, RESOURCE_LOADING_PARALLEL_DEGREE
+from neo4j import GraphDatabase
 
 # 创建日志记录器
 logger = logging.getLogger(__name__)
 
-def get_script_name_from_neo4j(table_name):
+def get_resource_script_name_from_neo4j(table_name):
     from neo4j import GraphDatabase
     from config import NEO4J_CONFIG
     
@@ -52,34 +56,44 @@ def load_table_data(table_name, execution_mode):
     import os
     import importlib.util
 
-    script_name = get_script_name_from_neo4j(table_name)
+    script_name = get_resource_script_name_from_neo4j(table_name)
     if not script_name:
         logger.warning(f"未找到表 {table_name} 的 script_name,跳过")
         return
-
-    # scripts_base_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "dataops", "scripts")
-    # script_path = os.path.join(scripts_base_path, script_name)
-    # 使用配置文件中的绝对路径
-    from pathlib import Path
-    from config import SCRIPTS_BASE_PATH
-    script_path = Path(SCRIPTS_BASE_PATH) / script_name
-
-    if not os.path.exists(script_path):
-        logger.error(f"脚本文件不存在: {script_path}")
-        return
-
-    logger.info(f"执行脚本: {script_path}")
+    
+    logger.info(f"从Neo4j获取到表 {table_name} 的脚本名称: {script_name}")
+    
+    # 检查脚本文件是否存在
+    exists, script_path = check_script_exists(script_name)
+    if not exists:
+        logger.error(f"表 {table_name} 的脚本文件 {script_name} 不存在,跳过处理")
+        return False
+    
+    # 执行脚本
+    logger.info(f"开始执行脚本: {script_path}")
     try:
-        spec = importlib.util.spec_from_file_location("dynamic_script", script_path)
+        # 动态导入模块
+        import importlib.util
+        import sys
+        
+        spec = importlib.util.spec_from_file_location("dynamic_module", script_path)
         module = importlib.util.module_from_spec(spec)
         spec.loader.exec_module(module)
-
+        
+        # 检查并调用标准入口函数run
         if hasattr(module, "run"):
+            logger.info(f"调用脚本 {script_name} 的标准入口函数 run()")
             module.run(table_name=table_name, execution_mode=execution_mode)
+            logger.info(f"脚本 {script_name} 执行成功")
+            return True
         else:
-            logger.warning(f"脚本 {script_name} 中未定义 run(...) 方法,跳过")
+            logger.error(f"脚本 {script_name} 中未定义标准入口函数 run(),无法执行")
+            return False
     except Exception as e:
         logger.error(f"执行脚本 {script_name} 时出错: {str(e)}")
+        import traceback
+        logger.error(traceback.format_exc())
+        return False
 
 with DAG("dag_data_resource", start_date=datetime(2024, 1, 1), schedule_interval="@daily", catchup=False) as dag:
     today = pendulum.today()
@@ -113,10 +127,33 @@ with DAG("dag_data_resource", start_date=datetime(2024, 1, 1), schedule_interval
             unique_resources[name] = item
 
     resource_tables = list(unique_resources.values())
-
-    for item in resource_tables:
-        PythonOperator(
-            task_id=f"load_{item['table_name']}",
-            python_callable=load_table_data,
-            op_kwargs={"table_name": item['table_name'], "execution_mode": item['execution_mode']},
-        )
+    
+    # 创建开始任务
+    start_loading = EmptyOperator(task_id="start_resource_loading")
+    
+    # 创建结束任务
+    end_loading = EmptyOperator(task_id="finish_resource_loading")
+    
+    # 按批次分组进行并行处理
+    batch_size = RESOURCE_LOADING_PARALLEL_DEGREE
+    batched_tables = [resource_tables[i:i+batch_size] for i in range(0, len(resource_tables), batch_size)]
+    
+    logger.info(f"将 {len(resource_tables)} 个资源表分为 {len(batched_tables)} 批处理,每批最多 {batch_size} 个表")
+    
+    for batch_idx, batch in enumerate(batched_tables):
+        batch_tasks = []
+        for item in batch:
+            task = PythonOperator(
+                task_id=f"load_{item['table_name']}",
+                python_callable=load_table_data,
+                op_kwargs={"table_name": item['table_name'], "execution_mode": item['execution_mode']},
+            )
+            batch_tasks.append(task)
+            
+            # 设置起始依赖
+            start_loading >> task
+            
+            # 设置结束依赖
+            task >> end_loading
+        
+        logger.info(f"批次 {batch_idx+1}: 创建了 {len(batch_tasks)} 个表加载任务")

+ 104 - 18
dags/utils.py

@@ -5,9 +5,12 @@ from config import PG_CONFIG, NEO4J_CONFIG, SCRIPTS_BASE_PATH
 import logging
 import importlib.util
 from pathlib import Path
-import sys
+import networkx as nx
 import os
 
+# 创建统一的日志记录器
+logger = logging.getLogger("airflow.task")
+
 def get_pg_conn():
     return psycopg2.connect(**PG_CONFIG)
 
@@ -87,8 +90,6 @@ def execute_script(script_name: str, table_name: str, execution_mode: str) -> bo
     返回:
         bool: 执行成功返回True,否则返回False
     """
-    logger = logging.getLogger("airflow.task")
-    
     if not script_name:
         logger.error("未提供脚本名称,无法执行")
         return False
@@ -260,7 +261,6 @@ def is_data_model_table(table_name):
     finally:
         driver.close()
 
-# 检查脚本文件是否存在于指定路径
 def check_script_exists(script_name):
     """
     检查脚本文件是否存在于配置的脚本目录中
@@ -272,12 +272,6 @@ def check_script_exists(script_name):
         bool: 如果脚本存在返回True,否则返回False
         str: 完整的脚本路径
     """
-    from pathlib import Path
-    import os
-    import logging
-    
-    logger = logging.getLogger("airflow.task")
-    
     if not script_name:
         logger.error("脚本名称为空,无法检查")
         return False, None
@@ -306,7 +300,6 @@ def check_script_exists(script_name):
             
         return False, script_path_str
 
-# 更新run_model_script函数以使用上述检查
 def run_model_script(table_name, execution_mode):
     """
     根据表名查找并执行对应的模型脚本
@@ -318,9 +311,6 @@ def run_model_script(table_name, execution_mode):
     返回:
         bool: 执行成功返回True,否则返回False
     """
-    import logging
-    logger = logging.getLogger("airflow.task")
-    
     # 从Neo4j获取脚本名称
     script_name = get_script_name_from_neo4j(table_name)
     if not script_name:
@@ -363,8 +353,40 @@ def run_model_script(table_name, execution_mode):
 
 # 从 Neo4j 获取指定 DataModel 表之间的依赖关系图
 # 返回值为 dict:{目标表: [上游依赖表1, 上游依赖表2, ...]}
+# def get_model_dependency_graph(table_names: list) -> dict:
+#     graph = {}
+#     uri = NEO4J_CONFIG['uri']
+#     auth = (NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
+#     driver = GraphDatabase.driver(uri, auth=auth)
+#     try:
+#         with driver.session() as session:
+#             for table_name in table_names:
+#                 query = """
+#                     MATCH (t:DataModel {en_name: $table_name})<-[:DERIVED_FROM]-(up:DataModel)
+#                     RETURN up.en_name AS upstream
+#                 """
+#                 result = session.run(query, table_name=table_name)
+#                 deps = [record['upstream'] for record in result if 'upstream' in record]
+#                 graph[table_name] = deps
+#     finally:
+#         driver.close()
+#     return graph
 def get_model_dependency_graph(table_names: list) -> dict:
-    graph = {}
+    """
+    使用networkx从Neo4j获取指定DataModel表之间的依赖关系图    
+    参数:
+        table_names: 表名列表    
+    返回:
+        dict: 依赖关系字典 {目标表: [上游依赖表1, 上游依赖表2, ...]}
+    """
+    # 创建有向图
+    G = nx.DiGraph()
+    
+    # 添加所有节点
+    for table_name in table_names:
+        G.add_node(table_name)
+    
+    # 从Neo4j获取依赖关系并添加边
     uri = NEO4J_CONFIG['uri']
     auth = (NEO4J_CONFIG['user'], NEO4J_CONFIG['password'])
     driver = GraphDatabase.driver(uri, auth=auth)
@@ -373,11 +395,75 @@ def get_model_dependency_graph(table_names: list) -> dict:
             for table_name in table_names:
                 query = """
                     MATCH (t:DataModel {en_name: $table_name})<-[:DERIVED_FROM]-(up:DataModel)
+                    WHERE up.en_name IN $all_tables
                     RETURN up.en_name AS upstream
                 """
-                result = session.run(query, table_name=table_name)
+                result = session.run(query, table_name=table_name, all_tables=table_names)
                 deps = [record['upstream'] for record in result if 'upstream' in record]
-                graph[table_name] = deps
+                
+                # 添加依赖边
+                for dep in deps:
+                    G.add_edge(dep, table_name)
     finally:
         driver.close()
-    return graph
+    
+    # 检测循环依赖
+    try:
+        cycles = list(nx.simple_cycles(G))
+        if cycles:
+            logger.warning(f"检测到表间循环依赖: {cycles}")
+    except Exception as e:
+        logger.error(f"检查循环依赖失败: {str(e)}")
+    
+    # 转换为字典格式返回
+    dependency_dict = {}
+    for table_name in table_names:
+        dependency_dict[table_name] = list(G.predecessors(table_name))
+    
+    return dependency_dict
+
+
+def generate_optimized_execution_order(table_names: list) -> list:
+    """
+    生成优化的执行顺序,可处理循环依赖
+    
+    参数:
+        table_names: 表名列表
+    
+    返回:
+        list: 优化后的执行顺序列表
+    """
+    # 创建依赖图
+    G = nx.DiGraph()
+    
+    # 添加所有节点
+    for table_name in table_names:
+        G.add_node(table_name)
+    
+    # 添加依赖边
+    dependency_dict = get_model_dependency_graph(table_names)
+    for target, upstreams in dependency_dict.items():
+        for upstream in upstreams:
+            G.add_edge(upstream, target)
+    
+    # 检测循环依赖
+    cycles = list(nx.simple_cycles(G))
+    if cycles:
+        logger.warning(f"检测到循环依赖,将尝试打破循环: {cycles}")
+        # 打破循环依赖(简单策略:移除每个循环中的一条边)
+        for cycle in cycles:
+            # 移除循环中的最后一条边
+            G.remove_edge(cycle[-1], cycle[0])
+            logger.info(f"打破循环依赖: 移除 {cycle[-1]} -> {cycle[0]} 的依赖")
+    
+    # 生成拓扑排序
+    try:
+        execution_order = list(nx.topological_sort(G))
+        return execution_order
+    except Exception as e:
+        logger.error(f"生成执行顺序失败: {str(e)}")
+        # 返回原始列表作为备选
+        return table_names
+
+
+

+ 2 - 1
requirements.txt

@@ -1,4 +1,5 @@
 apache-airflow==2.10.5
 psycopg2-binary>=2.9.9
 neo4j>=5.19.0
-pendulum>=3.0.0
+pendulum>=3.0.0
+networkx>=3.4.2