edgemodifier.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. # Licensed to the Apache Software Foundation (ASF) under one
  2. # or more contributor license agreements. See the NOTICE file
  3. # distributed with this work for additional information
  4. # regarding copyright ownership. The ASF licenses this file
  5. # to you under the Apache License, Version 2.0 (the
  6. # "License"); you may not use this file except in compliance
  7. # with the License. You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing,
  12. # software distributed under the License is distributed on an
  13. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  14. # KIND, either express or implied. See the License for the
  15. # specific language governing permissions and limitations
  16. # under the License.
  17. from __future__ import annotations
  18. from typing import Sequence
  19. from airflow.models.taskmixin import DAGNode, DependencyMixin
  20. from airflow.utils.task_group import TaskGroup
  21. class EdgeModifier(DependencyMixin):
  22. """
  23. Class that represents edge information to be added between two tasks/operators.
  24. Has shorthand factory functions, like Label("hooray").
  25. Current implementation supports
  26. t1 >> Label("Success route") >> t2
  27. t2 << Label("Success route") << t2
  28. Note that due to the potential for use in either direction, this waits
  29. to make the actual connection between both sides until both are declared,
  30. and will do so progressively if multiple ups/downs are added.
  31. This and EdgeInfo are related - an EdgeModifier is the Python object you
  32. use to add information to (potentially multiple) edges, and EdgeInfo
  33. is the representation of the information for one specific edge.
  34. """
  35. def __init__(self, label: str | None = None):
  36. self.label = label
  37. self._upstream: list[DependencyMixin] = []
  38. self._downstream: list[DependencyMixin] = []
  39. @property
  40. def roots(self):
  41. return self._downstream
  42. @property
  43. def leaves(self):
  44. return self._upstream
  45. @staticmethod
  46. def _make_list(item_or_list: DependencyMixin | Sequence[DependencyMixin]) -> Sequence[DependencyMixin]:
  47. if not isinstance(item_or_list, Sequence):
  48. return [item_or_list]
  49. return item_or_list
  50. def _save_nodes(
  51. self,
  52. nodes: DependencyMixin | Sequence[DependencyMixin],
  53. stream: list[DependencyMixin],
  54. ):
  55. from airflow.models.xcom_arg import XComArg
  56. for node in self._make_list(nodes):
  57. if isinstance(node, (TaskGroup, XComArg, DAGNode)):
  58. stream.append(node)
  59. else:
  60. raise TypeError(
  61. f"Cannot use edge labels with {type(node).__name__}, "
  62. f"only tasks, XComArg or TaskGroups"
  63. )
  64. def _convert_streams_to_task_groups(self):
  65. """
  66. Convert a node to a TaskGroup or leave it as a DAGNode.
  67. Requires both self._upstream and self._downstream.
  68. To do this, we keep a set of group_ids seen among the streams. If we find that
  69. the nodes are from the same TaskGroup, we will leave them as DAGNodes and not
  70. convert them to TaskGroups
  71. """
  72. from airflow.models.xcom_arg import XComArg
  73. group_ids = set()
  74. for node in [*self._upstream, *self._downstream]:
  75. if isinstance(node, DAGNode) and node.task_group:
  76. if node.task_group.is_root:
  77. group_ids.add("root")
  78. else:
  79. group_ids.add(node.task_group.group_id)
  80. elif isinstance(node, TaskGroup):
  81. group_ids.add(node.group_id)
  82. elif isinstance(node, XComArg):
  83. if isinstance(node.operator, DAGNode) and node.operator.task_group:
  84. if node.operator.task_group.is_root:
  85. group_ids.add("root")
  86. else:
  87. group_ids.add(node.operator.task_group.group_id)
  88. # If all nodes originate from the same TaskGroup, we will not convert them
  89. if len(group_ids) != 1:
  90. self._upstream = self._convert_stream_to_task_groups(self._upstream)
  91. self._downstream = self._convert_stream_to_task_groups(self._downstream)
  92. def _convert_stream_to_task_groups(self, stream: Sequence[DependencyMixin]) -> Sequence[DependencyMixin]:
  93. return [
  94. node.task_group
  95. if isinstance(node, DAGNode) and node.task_group and not node.task_group.is_root
  96. else node
  97. for node in stream
  98. ]
  99. def set_upstream(
  100. self,
  101. other: DependencyMixin | Sequence[DependencyMixin],
  102. edge_modifier: EdgeModifier | None = None,
  103. ):
  104. """
  105. Set the given task/list onto the upstream attribute, then attempt to resolve the relationship.
  106. Providing this also provides << via DependencyMixin.
  107. """
  108. self._save_nodes(other, self._upstream)
  109. if self._upstream and self._downstream:
  110. # Convert _upstream and _downstream to task_groups only after both are set
  111. self._convert_streams_to_task_groups()
  112. for node in self._downstream:
  113. node.set_upstream(other, edge_modifier=self)
  114. def set_downstream(
  115. self,
  116. other: DependencyMixin | Sequence[DependencyMixin],
  117. edge_modifier: EdgeModifier | None = None,
  118. ):
  119. """
  120. Set the given task/list onto the downstream attribute, then attempt to resolve the relationship.
  121. Providing this also provides >> via DependencyMixin.
  122. """
  123. self._save_nodes(other, self._downstream)
  124. if self._upstream and self._downstream:
  125. # Convert _upstream and _downstream to task_groups only after both are set
  126. self._convert_streams_to_task_groups()
  127. for node in self._upstream:
  128. node.set_downstream(other, edge_modifier=self)
  129. def update_relative(
  130. self, other: DependencyMixin, upstream: bool = True, edge_modifier: EdgeModifier | None = None
  131. ) -> None:
  132. """Update relative if we're not the "main" side of a relationship; still run the same logic."""
  133. if upstream:
  134. self.set_upstream(other)
  135. else:
  136. self.set_downstream(other)
  137. def add_edge_info(self, dag, upstream_id: str, downstream_id: str):
  138. """
  139. Add or update task info on the DAG for this specific pair of tasks.
  140. Called either from our relationship trigger methods above, or directly
  141. by set_upstream/set_downstream in operators.
  142. """
  143. dag.set_edge_info(upstream_id, downstream_id, {"label": self.label})
  144. # Factory functions
  145. def Label(label: str):
  146. """Create an EdgeModifier that sets a human-readable label on the edge."""
  147. return EdgeModifier(label=label)