dag_edges.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. # Licensed to the Apache Software Foundation (ASF) under one
  2. # or more contributor license agreements. See the NOTICE file
  3. # distributed with this work for additional information
  4. # regarding copyright ownership. The ASF licenses this file
  5. # to you under the Apache License, Version 2.0 (the
  6. # "License"); you may not use this file except in compliance
  7. # with the License. You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing,
  12. # software distributed under the License is distributed on an
  13. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  14. # KIND, either express or implied. See the License for the
  15. # specific language governing permissions and limitations
  16. # under the License.
  17. from __future__ import annotations
  18. from typing import TYPE_CHECKING
  19. from airflow.models.abstractoperator import AbstractOperator
  20. if TYPE_CHECKING:
  21. from airflow.models import Operator
  22. from airflow.models.dag import DAG
  23. def dag_edges(dag: DAG):
  24. """
  25. Create the list of edges needed to construct the Graph view.
  26. A special case is made if a TaskGroup is immediately upstream/downstream of another
  27. TaskGroup or task. Two proxy nodes named upstream_join_id and downstream_join_id are
  28. created for the TaskGroup. Instead of drawing an edge onto every task in the TaskGroup,
  29. all edges are directed onto the proxy nodes. This is to cut down the number of edges on
  30. the graph.
  31. For example: A DAG with TaskGroups group1 and group2:
  32. group1: task1, task2, task3
  33. group2: task4, task5, task6
  34. group2 is downstream of group1:
  35. group1 >> group2
  36. Edges to add (This avoids having to create edges between every task in group1 and group2):
  37. task1 >> downstream_join_id
  38. task2 >> downstream_join_id
  39. task3 >> downstream_join_id
  40. downstream_join_id >> upstream_join_id
  41. upstream_join_id >> task4
  42. upstream_join_id >> task5
  43. upstream_join_id >> task6
  44. """
  45. # Edges to add between TaskGroup
  46. edges_to_add = set()
  47. # Edges to remove between individual tasks that are replaced by edges_to_add.
  48. edges_to_skip = set()
  49. task_group_map = dag.task_group.get_task_group_dict()
  50. def collect_edges(task_group):
  51. """Update edges_to_add and edges_to_skip according to TaskGroups."""
  52. if isinstance(task_group, AbstractOperator):
  53. return
  54. for target_id in task_group.downstream_group_ids:
  55. # For every TaskGroup immediately downstream, add edges between downstream_join_id
  56. # and upstream_join_id. Skip edges between individual tasks of the TaskGroups.
  57. target_group = task_group_map[target_id]
  58. edges_to_add.add((task_group.downstream_join_id, target_group.upstream_join_id))
  59. for child in task_group.get_leaves():
  60. edges_to_add.add((child.task_id, task_group.downstream_join_id))
  61. for target in target_group.get_roots():
  62. edges_to_skip.add((child.task_id, target.task_id))
  63. edges_to_skip.add((child.task_id, target_group.upstream_join_id))
  64. for child in target_group.get_roots():
  65. edges_to_add.add((target_group.upstream_join_id, child.task_id))
  66. edges_to_skip.add((task_group.downstream_join_id, child.task_id))
  67. # For every individual task immediately downstream, add edges between downstream_join_id and
  68. # the downstream task. Skip edges between individual tasks of the TaskGroup and the
  69. # downstream task.
  70. for target_id in task_group.downstream_task_ids:
  71. edges_to_add.add((task_group.downstream_join_id, target_id))
  72. for child in task_group.get_leaves():
  73. edges_to_add.add((child.task_id, task_group.downstream_join_id))
  74. edges_to_skip.add((child.task_id, target_id))
  75. # For every individual task immediately upstream, add edges between the upstream task
  76. # and upstream_join_id. Skip edges between the upstream task and individual tasks
  77. # of the TaskGroup.
  78. for source_id in task_group.upstream_task_ids:
  79. edges_to_add.add((source_id, task_group.upstream_join_id))
  80. for child in task_group.get_roots():
  81. edges_to_add.add((task_group.upstream_join_id, child.task_id))
  82. edges_to_skip.add((source_id, child.task_id))
  83. for child in task_group.children.values():
  84. collect_edges(child)
  85. collect_edges(dag.task_group)
  86. # Collect all the edges between individual tasks
  87. edges = set()
  88. setup_teardown_edges = set()
  89. tasks_to_trace: list[Operator] = dag.roots
  90. while tasks_to_trace:
  91. tasks_to_trace_next: list[Operator] = []
  92. for task in tasks_to_trace:
  93. for child in task.downstream_list:
  94. edge = (task.task_id, child.task_id)
  95. if task.is_setup and child.is_teardown:
  96. setup_teardown_edges.add(edge)
  97. if edge not in edges:
  98. edges.add(edge)
  99. tasks_to_trace_next.append(child)
  100. tasks_to_trace = tasks_to_trace_next
  101. result = []
  102. # Build result dicts with the two ends of the edge, plus any extra metadata
  103. # if we have it.
  104. for source_id, target_id in sorted(edges.union(edges_to_add) - edges_to_skip):
  105. record = {"source_id": source_id, "target_id": target_id}
  106. label = dag.get_edge_info(source_id, target_id).get("label")
  107. if (source_id, target_id) in setup_teardown_edges:
  108. record["is_setup_teardown"] = True
  109. if label:
  110. record["label"] = label
  111. result.append(record)
  112. return result