# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. from __future__ import annotations from typing import TYPE_CHECKING from airflow.models.abstractoperator import AbstractOperator if TYPE_CHECKING: from airflow.models import Operator from airflow.models.dag import DAG def dag_edges(dag: DAG): """ Create the list of edges needed to construct the Graph view. A special case is made if a TaskGroup is immediately upstream/downstream of another TaskGroup or task. Two proxy nodes named upstream_join_id and downstream_join_id are created for the TaskGroup. Instead of drawing an edge onto every task in the TaskGroup, all edges are directed onto the proxy nodes. This is to cut down the number of edges on the graph. For example: A DAG with TaskGroups group1 and group2: group1: task1, task2, task3 group2: task4, task5, task6 group2 is downstream of group1: group1 >> group2 Edges to add (This avoids having to create edges between every task in group1 and group2): task1 >> downstream_join_id task2 >> downstream_join_id task3 >> downstream_join_id downstream_join_id >> upstream_join_id upstream_join_id >> task4 upstream_join_id >> task5 upstream_join_id >> task6 """ # Edges to add between TaskGroup edges_to_add = set() # Edges to remove between individual tasks that are replaced by edges_to_add. edges_to_skip = set() task_group_map = dag.task_group.get_task_group_dict() def collect_edges(task_group): """Update edges_to_add and edges_to_skip according to TaskGroups.""" if isinstance(task_group, AbstractOperator): return for target_id in task_group.downstream_group_ids: # For every TaskGroup immediately downstream, add edges between downstream_join_id # and upstream_join_id. Skip edges between individual tasks of the TaskGroups. target_group = task_group_map[target_id] edges_to_add.add((task_group.downstream_join_id, target_group.upstream_join_id)) for child in task_group.get_leaves(): edges_to_add.add((child.task_id, task_group.downstream_join_id)) for target in target_group.get_roots(): edges_to_skip.add((child.task_id, target.task_id)) edges_to_skip.add((child.task_id, target_group.upstream_join_id)) for child in target_group.get_roots(): edges_to_add.add((target_group.upstream_join_id, child.task_id)) edges_to_skip.add((task_group.downstream_join_id, child.task_id)) # For every individual task immediately downstream, add edges between downstream_join_id and # the downstream task. Skip edges between individual tasks of the TaskGroup and the # downstream task. for target_id in task_group.downstream_task_ids: edges_to_add.add((task_group.downstream_join_id, target_id)) for child in task_group.get_leaves(): edges_to_add.add((child.task_id, task_group.downstream_join_id)) edges_to_skip.add((child.task_id, target_id)) # For every individual task immediately upstream, add edges between the upstream task # and upstream_join_id. Skip edges between the upstream task and individual tasks # of the TaskGroup. for source_id in task_group.upstream_task_ids: edges_to_add.add((source_id, task_group.upstream_join_id)) for child in task_group.get_roots(): edges_to_add.add((task_group.upstream_join_id, child.task_id)) edges_to_skip.add((source_id, child.task_id)) for child in task_group.children.values(): collect_edges(child) collect_edges(dag.task_group) # Collect all the edges between individual tasks edges = set() setup_teardown_edges = set() tasks_to_trace: list[Operator] = dag.roots while tasks_to_trace: tasks_to_trace_next: list[Operator] = [] for task in tasks_to_trace: for child in task.downstream_list: edge = (task.task_id, child.task_id) if task.is_setup and child.is_teardown: setup_teardown_edges.add(edge) if edge not in edges: edges.add(edge) tasks_to_trace_next.append(child) tasks_to_trace = tasks_to_trace_next result = [] # Build result dicts with the two ends of the edge, plus any extra metadata # if we have it. for source_id, target_id in sorted(edges.union(edges_to_add) - edges_to_skip): record = {"source_id": source_id, "target_id": target_id} label = dag.get_edge_info(source_id, target_id).get("label") if (source_id, target_id) in setup_teardown_edges: record["is_setup_teardown"] = True if label: record["label"] = label result.append(record) return result