skipmixin.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. #
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing,
  13. # software distributed under the License is distributed on an
  14. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. # KIND, either express or implied. See the License for the
  16. # specific language governing permissions and limitations
  17. # under the License.
  18. from __future__ import annotations
  19. import warnings
  20. from types import GeneratorType
  21. from typing import TYPE_CHECKING, Iterable, Sequence
  22. from sqlalchemy import select, update
  23. from airflow.api_internal.internal_api_call import internal_api_call
  24. from airflow.exceptions import AirflowException, RemovedInAirflow3Warning
  25. from airflow.models.taskinstance import TaskInstance
  26. from airflow.utils import timezone
  27. from airflow.utils.log.logging_mixin import LoggingMixin
  28. from airflow.utils.session import NEW_SESSION, provide_session
  29. from airflow.utils.sqlalchemy import tuple_in_condition
  30. from airflow.utils.state import TaskInstanceState
  31. if TYPE_CHECKING:
  32. from pendulum import DateTime
  33. from sqlalchemy import Session
  34. from airflow.models.dagrun import DagRun
  35. from airflow.models.operator import Operator
  36. from airflow.models.taskmixin import DAGNode
  37. from airflow.serialization.pydantic.dag_run import DagRunPydantic
  38. from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
  39. # The key used by SkipMixin to store XCom data.
  40. XCOM_SKIPMIXIN_KEY = "skipmixin_key"
  41. # The dictionary key used to denote task IDs that are skipped
  42. XCOM_SKIPMIXIN_SKIPPED = "skipped"
  43. # The dictionary key used to denote task IDs that are followed
  44. XCOM_SKIPMIXIN_FOLLOWED = "followed"
  45. def _ensure_tasks(nodes: Iterable[DAGNode]) -> Sequence[Operator]:
  46. from airflow.models.baseoperator import BaseOperator
  47. from airflow.models.mappedoperator import MappedOperator
  48. return [n for n in nodes if isinstance(n, (BaseOperator, MappedOperator))]
  49. class SkipMixin(LoggingMixin):
  50. """A Mixin to skip Tasks Instances."""
  51. @staticmethod
  52. def _set_state_to_skipped(
  53. dag_run: DagRun | DagRunPydantic,
  54. tasks: Sequence[str] | Sequence[tuple[str, int]],
  55. session: Session,
  56. ) -> None:
  57. """Set state of task instances to skipped from the same dag run."""
  58. if tasks:
  59. now = timezone.utcnow()
  60. if isinstance(tasks[0], tuple):
  61. session.execute(
  62. update(TaskInstance)
  63. .where(
  64. TaskInstance.dag_id == dag_run.dag_id,
  65. TaskInstance.run_id == dag_run.run_id,
  66. tuple_in_condition((TaskInstance.task_id, TaskInstance.map_index), tasks),
  67. )
  68. .values(state=TaskInstanceState.SKIPPED, start_date=now, end_date=now)
  69. .execution_options(synchronize_session=False)
  70. )
  71. else:
  72. session.execute(
  73. update(TaskInstance)
  74. .where(
  75. TaskInstance.dag_id == dag_run.dag_id,
  76. TaskInstance.run_id == dag_run.run_id,
  77. TaskInstance.task_id.in_(tasks),
  78. )
  79. .values(state=TaskInstanceState.SKIPPED, start_date=now, end_date=now)
  80. .execution_options(synchronize_session=False)
  81. )
  82. def skip(
  83. self,
  84. dag_run: DagRun | DagRunPydantic,
  85. execution_date: DateTime,
  86. tasks: Iterable[DAGNode],
  87. map_index: int = -1,
  88. ):
  89. """Facade for compatibility for call to internal API."""
  90. # SkipMixin may not necessarily have a task_id attribute. Only store to XCom if one is available.
  91. task_id: str | None = getattr(self, "task_id", None)
  92. SkipMixin._skip(
  93. dag_run=dag_run, task_id=task_id, execution_date=execution_date, tasks=tasks, map_index=map_index
  94. )
  95. @staticmethod
  96. @internal_api_call
  97. @provide_session
  98. def _skip(
  99. dag_run: DagRun | DagRunPydantic,
  100. task_id: str | None,
  101. execution_date: DateTime,
  102. tasks: Iterable[DAGNode],
  103. session: Session = NEW_SESSION,
  104. map_index: int = -1,
  105. ):
  106. """
  107. Set tasks instances to skipped from the same dag run.
  108. If this instance has a `task_id` attribute, store the list of skipped task IDs to XCom
  109. so that NotPreviouslySkippedDep knows these tasks should be skipped when they
  110. are cleared.
  111. :param dag_run: the DagRun for which to set the tasks to skipped
  112. :param execution_date: execution_date
  113. :param tasks: tasks to skip (not task_ids)
  114. :param session: db session to use
  115. :param map_index: map_index of the current task instance
  116. """
  117. task_list = _ensure_tasks(tasks)
  118. if not task_list:
  119. return
  120. if execution_date and not dag_run:
  121. from airflow.models.dagrun import DagRun
  122. warnings.warn(
  123. "Passing an execution_date to `skip()` is deprecated in favour of passing a dag_run",
  124. RemovedInAirflow3Warning,
  125. stacklevel=2,
  126. )
  127. dag_run = session.scalars(
  128. select(DagRun).where(
  129. DagRun.dag_id == task_list[0].dag_id, DagRun.execution_date == execution_date
  130. )
  131. ).one()
  132. elif execution_date and dag_run and execution_date != dag_run.execution_date:
  133. raise ValueError(
  134. "execution_date has a different value to dag_run.execution_date -- please only pass dag_run"
  135. )
  136. if dag_run is None:
  137. raise ValueError("dag_run is required")
  138. task_ids_list = [d.task_id for d in task_list]
  139. # The following could be applied only for non-mapped tasks
  140. if map_index == -1:
  141. SkipMixin._set_state_to_skipped(dag_run, task_ids_list, session)
  142. session.commit()
  143. if task_id is not None:
  144. from airflow.models.xcom import XCom
  145. XCom.set(
  146. key=XCOM_SKIPMIXIN_KEY,
  147. value={XCOM_SKIPMIXIN_SKIPPED: task_ids_list},
  148. task_id=task_id,
  149. dag_id=dag_run.dag_id,
  150. run_id=dag_run.run_id,
  151. map_index=map_index,
  152. session=session,
  153. )
  154. @staticmethod
  155. def skip_all_except(
  156. ti: TaskInstance | TaskInstancePydantic,
  157. branch_task_ids: None | str | Iterable[str],
  158. ):
  159. """Facade for compatibility for call to internal API."""
  160. # Ensure we don't serialize a generator object
  161. if branch_task_ids and isinstance(branch_task_ids, GeneratorType):
  162. branch_task_ids = list(branch_task_ids)
  163. SkipMixin._skip_all_except(ti=ti, branch_task_ids=branch_task_ids)
  164. @classmethod
  165. @internal_api_call
  166. @provide_session
  167. def _skip_all_except(
  168. cls,
  169. ti: TaskInstance | TaskInstancePydantic,
  170. branch_task_ids: None | str | Iterable[str],
  171. session: Session = NEW_SESSION,
  172. ):
  173. """
  174. Implement the logic for a branching operator.
  175. Given a single task ID or list of task IDs to follow, this skips all other tasks
  176. immediately downstream of this operator.
  177. branch_task_ids is stored to XCom so that NotPreviouslySkippedDep knows skipped tasks or
  178. newly added tasks should be skipped when they are cleared.
  179. """
  180. log = cls().log # Note: need to catch logger form instance, static logger breaks pytest
  181. if isinstance(branch_task_ids, str):
  182. branch_task_id_set = {branch_task_ids}
  183. elif isinstance(branch_task_ids, Iterable):
  184. branch_task_id_set = set(branch_task_ids)
  185. invalid_task_ids_type = {
  186. (bti, type(bti).__name__) for bti in branch_task_id_set if not isinstance(bti, str)
  187. }
  188. if invalid_task_ids_type:
  189. raise AirflowException(
  190. f"'branch_task_ids' expected all task IDs are strings. "
  191. f"Invalid tasks found: {invalid_task_ids_type}."
  192. )
  193. elif branch_task_ids is None:
  194. branch_task_id_set = set()
  195. else:
  196. raise AirflowException(
  197. "'branch_task_ids' must be either None, a task ID, or an Iterable of IDs, "
  198. f"but got {type(branch_task_ids).__name__!r}."
  199. )
  200. log.info("Following branch %s", branch_task_id_set)
  201. dag_run = ti.get_dagrun(session=session)
  202. if TYPE_CHECKING:
  203. assert isinstance(dag_run, DagRun)
  204. assert ti.task
  205. task = ti.task
  206. dag = TaskInstance.ensure_dag(ti, session=session)
  207. valid_task_ids = set(dag.task_ids)
  208. invalid_task_ids = branch_task_id_set - valid_task_ids
  209. if invalid_task_ids:
  210. raise AirflowException(
  211. "'branch_task_ids' must contain only valid task_ids. "
  212. f"Invalid tasks found: {invalid_task_ids}."
  213. )
  214. downstream_tasks = _ensure_tasks(task.downstream_list)
  215. if downstream_tasks:
  216. # For a branching workflow that looks like this, when "branch" does skip_all_except("task1"),
  217. # we intuitively expect both "task1" and "join" to execute even though strictly speaking,
  218. # "join" is also immediately downstream of "branch" and should have been skipped. Therefore,
  219. # we need a special case here for such empty branches: Check downstream tasks of branch_task_ids.
  220. # In case the task to skip is also downstream of branch_task_ids, we add it to branch_task_ids and
  221. # exclude it from skipping.
  222. #
  223. # branch -----> join
  224. # \ ^
  225. # v /
  226. # task1
  227. #
  228. for branch_task_id in list(branch_task_id_set):
  229. branch_task_id_set.update(dag.get_task(branch_task_id).get_flat_relative_ids(upstream=False))
  230. skip_tasks = [
  231. (t.task_id, downstream_ti.map_index)
  232. for t in downstream_tasks
  233. if (
  234. downstream_ti := dag_run.get_task_instance(
  235. t.task_id, map_index=ti.map_index, session=session
  236. )
  237. )
  238. and t.task_id not in branch_task_id_set
  239. ]
  240. follow_task_ids = [t.task_id for t in downstream_tasks if t.task_id in branch_task_id_set]
  241. log.info("Skipping tasks %s", skip_tasks)
  242. SkipMixin._set_state_to_skipped(dag_run, skip_tasks, session=session)
  243. ti.xcom_push(
  244. key=XCOM_SKIPMIXIN_KEY, value={XCOM_SKIPMIXIN_FOLLOWED: follow_task_ids}, session=session
  245. )