taskmixin.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. # Licensed to the Apache Software Foundation (ASF) under one
  2. # or more contributor license agreements. See the NOTICE file
  3. # distributed with this work for additional information
  4. # regarding copyright ownership. The ASF licenses this file
  5. # to you under the Apache License, Version 2.0 (the
  6. # "License"); you may not use this file except in compliance
  7. # with the License. You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing,
  12. # software distributed under the License is distributed on an
  13. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  14. # KIND, either express or implied. See the License for the
  15. # specific language governing permissions and limitations
  16. # under the License.
  17. from __future__ import annotations
  18. import warnings
  19. from abc import ABCMeta, abstractmethod
  20. from typing import TYPE_CHECKING, Any, Iterable, Sequence
  21. from airflow.exceptions import AirflowException, RemovedInAirflow3Warning
  22. from airflow.utils.types import NOTSET
  23. if TYPE_CHECKING:
  24. from logging import Logger
  25. import pendulum
  26. from airflow.models.baseoperator import BaseOperator
  27. from airflow.models.dag import DAG
  28. from airflow.models.operator import Operator
  29. from airflow.serialization.enums import DagAttributeTypes
  30. from airflow.utils.edgemodifier import EdgeModifier
  31. from airflow.utils.task_group import TaskGroup
  32. from airflow.utils.types import ArgNotSet
  33. class DependencyMixin:
  34. """Mixing implementing common dependency setting methods like >> and <<."""
  35. @property
  36. def roots(self) -> Sequence[DependencyMixin]:
  37. """
  38. List of root nodes -- ones with no upstream dependencies.
  39. a.k.a. the "start" of this sub-graph
  40. """
  41. raise NotImplementedError()
  42. @property
  43. def leaves(self) -> Sequence[DependencyMixin]:
  44. """
  45. List of leaf nodes -- ones with only upstream dependencies.
  46. a.k.a. the "end" of this sub-graph
  47. """
  48. raise NotImplementedError()
  49. @abstractmethod
  50. def set_upstream(
  51. self, other: DependencyMixin | Sequence[DependencyMixin], edge_modifier: EdgeModifier | None = None
  52. ):
  53. """Set a task or a task list to be directly upstream from the current task."""
  54. raise NotImplementedError()
  55. @abstractmethod
  56. def set_downstream(
  57. self, other: DependencyMixin | Sequence[DependencyMixin], edge_modifier: EdgeModifier | None = None
  58. ):
  59. """Set a task or a task list to be directly downstream from the current task."""
  60. raise NotImplementedError()
  61. def as_setup(self) -> DependencyMixin:
  62. """Mark a task as setup task."""
  63. raise NotImplementedError()
  64. def as_teardown(
  65. self,
  66. *,
  67. setups: BaseOperator | Iterable[BaseOperator] | ArgNotSet = NOTSET,
  68. on_failure_fail_dagrun=NOTSET,
  69. ) -> DependencyMixin:
  70. """Mark a task as teardown and set its setups as direct relatives."""
  71. raise NotImplementedError()
  72. def update_relative(
  73. self, other: DependencyMixin, upstream: bool = True, edge_modifier: EdgeModifier | None = None
  74. ) -> None:
  75. """
  76. Update relationship information about another TaskMixin. Default is no-op.
  77. Override if necessary.
  78. """
  79. def __lshift__(self, other: DependencyMixin | Sequence[DependencyMixin]):
  80. """Implement Task << Task."""
  81. self.set_upstream(other)
  82. return other
  83. def __rshift__(self, other: DependencyMixin | Sequence[DependencyMixin]):
  84. """Implement Task >> Task."""
  85. self.set_downstream(other)
  86. return other
  87. def __rrshift__(self, other: DependencyMixin | Sequence[DependencyMixin]):
  88. """Implement Task >> [Task] because list don't have __rshift__ operators."""
  89. self.__lshift__(other)
  90. return self
  91. def __rlshift__(self, other: DependencyMixin | Sequence[DependencyMixin]):
  92. """Implement Task << [Task] because list don't have __lshift__ operators."""
  93. self.__rshift__(other)
  94. return self
  95. @classmethod
  96. def _iter_references(cls, obj: Any) -> Iterable[tuple[DependencyMixin, str]]:
  97. from airflow.models.baseoperator import AbstractOperator
  98. from airflow.utils.mixins import ResolveMixin
  99. if isinstance(obj, AbstractOperator):
  100. yield obj, "operator"
  101. elif isinstance(obj, ResolveMixin):
  102. yield from obj.iter_references()
  103. elif isinstance(obj, Sequence):
  104. for o in obj:
  105. yield from cls._iter_references(o)
  106. class TaskMixin(DependencyMixin):
  107. """
  108. Mixin to provide task-related things.
  109. :meta private:
  110. """
  111. def __init_subclass__(cls) -> None:
  112. warnings.warn(
  113. f"TaskMixin has been renamed to DependencyMixin, please update {cls.__name__}",
  114. category=RemovedInAirflow3Warning,
  115. stacklevel=2,
  116. )
  117. return super().__init_subclass__()
  118. class DAGNode(DependencyMixin, metaclass=ABCMeta):
  119. """
  120. A base class for a node in the graph of a workflow.
  121. A node may be an Operator or a Task Group, either mapped or unmapped.
  122. """
  123. dag: DAG | None = None
  124. task_group: TaskGroup | None = None
  125. """The task_group that contains this node"""
  126. @property
  127. @abstractmethod
  128. def node_id(self) -> str:
  129. raise NotImplementedError()
  130. @property
  131. def label(self) -> str | None:
  132. tg = self.task_group
  133. if tg and tg.node_id and tg.prefix_group_id:
  134. # "task_group_id.task_id" -> "task_id"
  135. return self.node_id[len(tg.node_id) + 1 :]
  136. return self.node_id
  137. start_date: pendulum.DateTime | None
  138. end_date: pendulum.DateTime | None
  139. upstream_task_ids: set[str]
  140. downstream_task_ids: set[str]
  141. def has_dag(self) -> bool:
  142. return self.dag is not None
  143. @property
  144. def dag_id(self) -> str:
  145. """Returns dag id if it has one or an adhoc/meaningless ID."""
  146. if self.dag:
  147. return self.dag.dag_id
  148. return "_in_memory_dag_"
  149. @property
  150. def log(self) -> Logger:
  151. raise NotImplementedError()
  152. @property
  153. @abstractmethod
  154. def roots(self) -> Sequence[DAGNode]:
  155. raise NotImplementedError()
  156. @property
  157. @abstractmethod
  158. def leaves(self) -> Sequence[DAGNode]:
  159. raise NotImplementedError()
  160. def _set_relatives(
  161. self,
  162. task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
  163. upstream: bool = False,
  164. edge_modifier: EdgeModifier | None = None,
  165. ) -> None:
  166. """Set relatives for the task or task list."""
  167. from airflow.models.baseoperator import BaseOperator
  168. from airflow.models.mappedoperator import MappedOperator
  169. if not isinstance(task_or_task_list, Sequence):
  170. task_or_task_list = [task_or_task_list]
  171. task_list: list[Operator] = []
  172. for task_object in task_or_task_list:
  173. task_object.update_relative(self, not upstream, edge_modifier=edge_modifier)
  174. relatives = task_object.leaves if upstream else task_object.roots
  175. for task in relatives:
  176. if not isinstance(task, (BaseOperator, MappedOperator)):
  177. raise AirflowException(
  178. f"Relationships can only be set between Operators; received {task.__class__.__name__}"
  179. )
  180. task_list.append(task)
  181. # relationships can only be set if the tasks share a single DAG. Tasks
  182. # without a DAG are assigned to that DAG.
  183. dags: set[DAG] = {task.dag for task in [*self.roots, *task_list] if task.has_dag() and task.dag}
  184. if len(dags) > 1:
  185. raise AirflowException(f"Tried to set relationships between tasks in more than one DAG: {dags}")
  186. elif len(dags) == 1:
  187. dag = dags.pop()
  188. else:
  189. raise AirflowException(
  190. f"Tried to create relationships between tasks that don't have DAGs yet. "
  191. f"Set the DAG for at least one task and try again: {[self, *task_list]}"
  192. )
  193. if not self.has_dag():
  194. # If this task does not yet have a dag, add it to the same dag as the other task.
  195. self.dag = dag
  196. for task in task_list:
  197. if dag and not task.has_dag():
  198. # If the other task does not yet have a dag, add it to the same dag as this task and
  199. dag.add_task(task)
  200. if upstream:
  201. task.downstream_task_ids.add(self.node_id)
  202. self.upstream_task_ids.add(task.node_id)
  203. if edge_modifier:
  204. edge_modifier.add_edge_info(self.dag, task.node_id, self.node_id)
  205. else:
  206. self.downstream_task_ids.add(task.node_id)
  207. task.upstream_task_ids.add(self.node_id)
  208. if edge_modifier:
  209. edge_modifier.add_edge_info(self.dag, self.node_id, task.node_id)
  210. def set_downstream(
  211. self,
  212. task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
  213. edge_modifier: EdgeModifier | None = None,
  214. ) -> None:
  215. """Set a node (or nodes) to be directly downstream from the current node."""
  216. self._set_relatives(task_or_task_list, upstream=False, edge_modifier=edge_modifier)
  217. def set_upstream(
  218. self,
  219. task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
  220. edge_modifier: EdgeModifier | None = None,
  221. ) -> None:
  222. """Set a node (or nodes) to be directly upstream from the current node."""
  223. self._set_relatives(task_or_task_list, upstream=True, edge_modifier=edge_modifier)
  224. @property
  225. def downstream_list(self) -> Iterable[Operator]:
  226. """List of nodes directly downstream."""
  227. if not self.dag:
  228. raise AirflowException(f"Operator {self} has not been assigned to a DAG yet")
  229. return [self.dag.get_task(tid) for tid in self.downstream_task_ids]
  230. @property
  231. def upstream_list(self) -> Iterable[Operator]:
  232. """List of nodes directly upstream."""
  233. if not self.dag:
  234. raise AirflowException(f"Operator {self} has not been assigned to a DAG yet")
  235. return [self.dag.get_task(tid) for tid in self.upstream_task_ids]
  236. def get_direct_relative_ids(self, upstream: bool = False) -> set[str]:
  237. """Get set of the direct relative ids to the current task, upstream or downstream."""
  238. if upstream:
  239. return self.upstream_task_ids
  240. else:
  241. return self.downstream_task_ids
  242. def get_direct_relatives(self, upstream: bool = False) -> Iterable[DAGNode]:
  243. """Get list of the direct relatives to the current task, upstream or downstream."""
  244. if upstream:
  245. return self.upstream_list
  246. else:
  247. return self.downstream_list
  248. def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
  249. """Serialize a task group's content; used by TaskGroupSerialization."""
  250. raise NotImplementedError()