123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285 |
- #
- # Licensed to the Apache Software Foundation (ASF) under one
- # or more contributor license agreements. See the NOTICE file
- # distributed with this work for additional information
- # regarding copyright ownership. The ASF licenses this file
- # to you under the Apache License, Version 2.0 (the
- # "License"); you may not use this file except in compliance
- # with the License. You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- # KIND, either express or implied. See the License for the
- # specific language governing permissions and limitations
- # under the License.
- from __future__ import annotations
- import warnings
- from types import GeneratorType
- from typing import TYPE_CHECKING, Iterable, Sequence
- from sqlalchemy import select, update
- from airflow.api_internal.internal_api_call import internal_api_call
- from airflow.exceptions import AirflowException, RemovedInAirflow3Warning
- from airflow.models.taskinstance import TaskInstance
- from airflow.utils import timezone
- from airflow.utils.log.logging_mixin import LoggingMixin
- from airflow.utils.session import NEW_SESSION, provide_session
- from airflow.utils.sqlalchemy import tuple_in_condition
- from airflow.utils.state import TaskInstanceState
- if TYPE_CHECKING:
- from pendulum import DateTime
- from sqlalchemy import Session
- from airflow.models.dagrun import DagRun
- from airflow.models.operator import Operator
- from airflow.models.taskmixin import DAGNode
- from airflow.serialization.pydantic.dag_run import DagRunPydantic
- from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
- # The key used by SkipMixin to store XCom data.
- XCOM_SKIPMIXIN_KEY = "skipmixin_key"
- # The dictionary key used to denote task IDs that are skipped
- XCOM_SKIPMIXIN_SKIPPED = "skipped"
- # The dictionary key used to denote task IDs that are followed
- XCOM_SKIPMIXIN_FOLLOWED = "followed"
- def _ensure_tasks(nodes: Iterable[DAGNode]) -> Sequence[Operator]:
- from airflow.models.baseoperator import BaseOperator
- from airflow.models.mappedoperator import MappedOperator
- return [n for n in nodes if isinstance(n, (BaseOperator, MappedOperator))]
- class SkipMixin(LoggingMixin):
- """A Mixin to skip Tasks Instances."""
- @staticmethod
- def _set_state_to_skipped(
- dag_run: DagRun | DagRunPydantic,
- tasks: Sequence[str] | Sequence[tuple[str, int]],
- session: Session,
- ) -> None:
- """Set state of task instances to skipped from the same dag run."""
- if tasks:
- now = timezone.utcnow()
- if isinstance(tasks[0], tuple):
- session.execute(
- update(TaskInstance)
- .where(
- TaskInstance.dag_id == dag_run.dag_id,
- TaskInstance.run_id == dag_run.run_id,
- tuple_in_condition((TaskInstance.task_id, TaskInstance.map_index), tasks),
- )
- .values(state=TaskInstanceState.SKIPPED, start_date=now, end_date=now)
- .execution_options(synchronize_session=False)
- )
- else:
- session.execute(
- update(TaskInstance)
- .where(
- TaskInstance.dag_id == dag_run.dag_id,
- TaskInstance.run_id == dag_run.run_id,
- TaskInstance.task_id.in_(tasks),
- )
- .values(state=TaskInstanceState.SKIPPED, start_date=now, end_date=now)
- .execution_options(synchronize_session=False)
- )
- def skip(
- self,
- dag_run: DagRun | DagRunPydantic,
- execution_date: DateTime,
- tasks: Iterable[DAGNode],
- map_index: int = -1,
- ):
- """Facade for compatibility for call to internal API."""
- # SkipMixin may not necessarily have a task_id attribute. Only store to XCom if one is available.
- task_id: str | None = getattr(self, "task_id", None)
- SkipMixin._skip(
- dag_run=dag_run, task_id=task_id, execution_date=execution_date, tasks=tasks, map_index=map_index
- )
- @staticmethod
- @internal_api_call
- @provide_session
- def _skip(
- dag_run: DagRun | DagRunPydantic,
- task_id: str | None,
- execution_date: DateTime,
- tasks: Iterable[DAGNode],
- session: Session = NEW_SESSION,
- map_index: int = -1,
- ):
- """
- Set tasks instances to skipped from the same dag run.
- If this instance has a `task_id` attribute, store the list of skipped task IDs to XCom
- so that NotPreviouslySkippedDep knows these tasks should be skipped when they
- are cleared.
- :param dag_run: the DagRun for which to set the tasks to skipped
- :param execution_date: execution_date
- :param tasks: tasks to skip (not task_ids)
- :param session: db session to use
- :param map_index: map_index of the current task instance
- """
- task_list = _ensure_tasks(tasks)
- if not task_list:
- return
- if execution_date and not dag_run:
- from airflow.models.dagrun import DagRun
- warnings.warn(
- "Passing an execution_date to `skip()` is deprecated in favour of passing a dag_run",
- RemovedInAirflow3Warning,
- stacklevel=2,
- )
- dag_run = session.scalars(
- select(DagRun).where(
- DagRun.dag_id == task_list[0].dag_id, DagRun.execution_date == execution_date
- )
- ).one()
- elif execution_date and dag_run and execution_date != dag_run.execution_date:
- raise ValueError(
- "execution_date has a different value to dag_run.execution_date -- please only pass dag_run"
- )
- if dag_run is None:
- raise ValueError("dag_run is required")
- task_ids_list = [d.task_id for d in task_list]
- # The following could be applied only for non-mapped tasks
- if map_index == -1:
- SkipMixin._set_state_to_skipped(dag_run, task_ids_list, session)
- session.commit()
- if task_id is not None:
- from airflow.models.xcom import XCom
- XCom.set(
- key=XCOM_SKIPMIXIN_KEY,
- value={XCOM_SKIPMIXIN_SKIPPED: task_ids_list},
- task_id=task_id,
- dag_id=dag_run.dag_id,
- run_id=dag_run.run_id,
- map_index=map_index,
- session=session,
- )
- @staticmethod
- def skip_all_except(
- ti: TaskInstance | TaskInstancePydantic,
- branch_task_ids: None | str | Iterable[str],
- ):
- """Facade for compatibility for call to internal API."""
- # Ensure we don't serialize a generator object
- if branch_task_ids and isinstance(branch_task_ids, GeneratorType):
- branch_task_ids = list(branch_task_ids)
- SkipMixin._skip_all_except(ti=ti, branch_task_ids=branch_task_ids)
- @classmethod
- @internal_api_call
- @provide_session
- def _skip_all_except(
- cls,
- ti: TaskInstance | TaskInstancePydantic,
- branch_task_ids: None | str | Iterable[str],
- session: Session = NEW_SESSION,
- ):
- """
- Implement the logic for a branching operator.
- Given a single task ID or list of task IDs to follow, this skips all other tasks
- immediately downstream of this operator.
- branch_task_ids is stored to XCom so that NotPreviouslySkippedDep knows skipped tasks or
- newly added tasks should be skipped when they are cleared.
- """
- log = cls().log # Note: need to catch logger form instance, static logger breaks pytest
- if isinstance(branch_task_ids, str):
- branch_task_id_set = {branch_task_ids}
- elif isinstance(branch_task_ids, Iterable):
- branch_task_id_set = set(branch_task_ids)
- invalid_task_ids_type = {
- (bti, type(bti).__name__) for bti in branch_task_id_set if not isinstance(bti, str)
- }
- if invalid_task_ids_type:
- raise AirflowException(
- f"'branch_task_ids' expected all task IDs are strings. "
- f"Invalid tasks found: {invalid_task_ids_type}."
- )
- elif branch_task_ids is None:
- branch_task_id_set = set()
- else:
- raise AirflowException(
- "'branch_task_ids' must be either None, a task ID, or an Iterable of IDs, "
- f"but got {type(branch_task_ids).__name__!r}."
- )
- log.info("Following branch %s", branch_task_id_set)
- dag_run = ti.get_dagrun(session=session)
- if TYPE_CHECKING:
- assert isinstance(dag_run, DagRun)
- assert ti.task
- task = ti.task
- dag = TaskInstance.ensure_dag(ti, session=session)
- valid_task_ids = set(dag.task_ids)
- invalid_task_ids = branch_task_id_set - valid_task_ids
- if invalid_task_ids:
- raise AirflowException(
- "'branch_task_ids' must contain only valid task_ids. "
- f"Invalid tasks found: {invalid_task_ids}."
- )
- downstream_tasks = _ensure_tasks(task.downstream_list)
- if downstream_tasks:
- # For a branching workflow that looks like this, when "branch" does skip_all_except("task1"),
- # we intuitively expect both "task1" and "join" to execute even though strictly speaking,
- # "join" is also immediately downstream of "branch" and should have been skipped. Therefore,
- # we need a special case here for such empty branches: Check downstream tasks of branch_task_ids.
- # In case the task to skip is also downstream of branch_task_ids, we add it to branch_task_ids and
- # exclude it from skipping.
- #
- # branch -----> join
- # \ ^
- # v /
- # task1
- #
- for branch_task_id in list(branch_task_id_set):
- branch_task_id_set.update(dag.get_task(branch_task_id).get_flat_relative_ids(upstream=False))
- skip_tasks = [
- (t.task_id, downstream_ti.map_index)
- for t in downstream_tasks
- if (
- downstream_ti := dag_run.get_task_instance(
- t.task_id, map_index=ti.map_index, session=session
- )
- )
- and t.task_id not in branch_task_id_set
- ]
- follow_task_ids = [t.task_id for t in downstream_tasks if t.task_id in branch_task_id_set]
- log.info("Skipping tasks %s", skip_tasks)
- SkipMixin._set_state_to_skipped(dag_run, skip_tasks, session=session)
- ti.xcom_push(
- key=XCOM_SKIPMIXIN_KEY, value={XCOM_SKIPMIXIN_FOLLOWED: follow_task_ids}, session=session
- )
|