123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134 |
- # Licensed to the Apache Software Foundation (ASF) under one
- # or more contributor license agreements. See the NOTICE file
- # distributed with this work for additional information
- # regarding copyright ownership. The ASF licenses this file
- # to you under the Apache License, Version 2.0 (the
- # "License"); you may not use this file except in compliance
- # with the License. You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- # KIND, either express or implied. See the License for the
- # specific language governing permissions and limitations
- # under the License.
- from __future__ import annotations
- from typing import TYPE_CHECKING
- from airflow.models.abstractoperator import AbstractOperator
- if TYPE_CHECKING:
- from airflow.models import Operator
- from airflow.models.dag import DAG
- def dag_edges(dag: DAG):
- """
- Create the list of edges needed to construct the Graph view.
- A special case is made if a TaskGroup is immediately upstream/downstream of another
- TaskGroup or task. Two proxy nodes named upstream_join_id and downstream_join_id are
- created for the TaskGroup. Instead of drawing an edge onto every task in the TaskGroup,
- all edges are directed onto the proxy nodes. This is to cut down the number of edges on
- the graph.
- For example: A DAG with TaskGroups group1 and group2:
- group1: task1, task2, task3
- group2: task4, task5, task6
- group2 is downstream of group1:
- group1 >> group2
- Edges to add (This avoids having to create edges between every task in group1 and group2):
- task1 >> downstream_join_id
- task2 >> downstream_join_id
- task3 >> downstream_join_id
- downstream_join_id >> upstream_join_id
- upstream_join_id >> task4
- upstream_join_id >> task5
- upstream_join_id >> task6
- """
- # Edges to add between TaskGroup
- edges_to_add = set()
- # Edges to remove between individual tasks that are replaced by edges_to_add.
- edges_to_skip = set()
- task_group_map = dag.task_group.get_task_group_dict()
- def collect_edges(task_group):
- """Update edges_to_add and edges_to_skip according to TaskGroups."""
- if isinstance(task_group, AbstractOperator):
- return
- for target_id in task_group.downstream_group_ids:
- # For every TaskGroup immediately downstream, add edges between downstream_join_id
- # and upstream_join_id. Skip edges between individual tasks of the TaskGroups.
- target_group = task_group_map[target_id]
- edges_to_add.add((task_group.downstream_join_id, target_group.upstream_join_id))
- for child in task_group.get_leaves():
- edges_to_add.add((child.task_id, task_group.downstream_join_id))
- for target in target_group.get_roots():
- edges_to_skip.add((child.task_id, target.task_id))
- edges_to_skip.add((child.task_id, target_group.upstream_join_id))
- for child in target_group.get_roots():
- edges_to_add.add((target_group.upstream_join_id, child.task_id))
- edges_to_skip.add((task_group.downstream_join_id, child.task_id))
- # For every individual task immediately downstream, add edges between downstream_join_id and
- # the downstream task. Skip edges between individual tasks of the TaskGroup and the
- # downstream task.
- for target_id in task_group.downstream_task_ids:
- edges_to_add.add((task_group.downstream_join_id, target_id))
- for child in task_group.get_leaves():
- edges_to_add.add((child.task_id, task_group.downstream_join_id))
- edges_to_skip.add((child.task_id, target_id))
- # For every individual task immediately upstream, add edges between the upstream task
- # and upstream_join_id. Skip edges between the upstream task and individual tasks
- # of the TaskGroup.
- for source_id in task_group.upstream_task_ids:
- edges_to_add.add((source_id, task_group.upstream_join_id))
- for child in task_group.get_roots():
- edges_to_add.add((task_group.upstream_join_id, child.task_id))
- edges_to_skip.add((source_id, child.task_id))
- for child in task_group.children.values():
- collect_edges(child)
- collect_edges(dag.task_group)
- # Collect all the edges between individual tasks
- edges = set()
- setup_teardown_edges = set()
- tasks_to_trace: list[Operator] = dag.roots
- while tasks_to_trace:
- tasks_to_trace_next: list[Operator] = []
- for task in tasks_to_trace:
- for child in task.downstream_list:
- edge = (task.task_id, child.task_id)
- if task.is_setup and child.is_teardown:
- setup_teardown_edges.add(edge)
- if edge not in edges:
- edges.add(edge)
- tasks_to_trace_next.append(child)
- tasks_to_trace = tasks_to_trace_next
- result = []
- # Build result dicts with the two ends of the edge, plus any extra metadata
- # if we have it.
- for source_id, target_id in sorted(edges.union(edges_to_add) - edges_to_skip):
- record = {"source_id": source_id, "target_id": target_id}
- label = dag.get_edge_info(source_id, target_id).get("label")
- if (source_id, target_id) in setup_teardown_edges:
- record["is_setup_teardown"] = True
- if label:
- record["label"] = label
- result.append(record)
- return result
|