123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623 |
- #
- # Licensed to the Apache Software Foundation (ASF) under one
- # or more contributor license agreements. See the NOTICE file
- # distributed with this work for additional information
- # regarding copyright ownership. The ASF licenses this file
- # to you under the Apache License, Version 2.0 (the
- # "License"); you may not use this file except in compliance
- # with the License. You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- # KIND, either express or implied. See the License for the
- # specific language governing permissions and limitations
- # under the License.
- """Marks tasks APIs."""
- from __future__ import annotations
- from typing import TYPE_CHECKING, Collection, Iterable, Iterator, NamedTuple
- from sqlalchemy import and_, or_, select
- from sqlalchemy.orm import lazyload
- from airflow.models.dagrun import DagRun
- from airflow.models.taskinstance import TaskInstance
- from airflow.operators.subdag import SubDagOperator
- from airflow.utils import timezone
- from airflow.utils.helpers import exactly_one
- from airflow.utils.session import NEW_SESSION, provide_session
- from airflow.utils.state import DagRunState, State, TaskInstanceState
- from airflow.utils.types import DagRunType
- if TYPE_CHECKING:
- from datetime import datetime
- from sqlalchemy.orm import Session as SASession
- from airflow.models.dag import DAG
- from airflow.models.operator import Operator
- class _DagRunInfo(NamedTuple):
- logical_date: datetime
- data_interval: tuple[datetime, datetime]
- def _create_dagruns(
- dag: DAG,
- infos: Iterable[_DagRunInfo],
- state: DagRunState,
- run_type: DagRunType,
- ) -> Iterable[DagRun]:
- """
- Infers from data intervals which DAG runs need to be created and does so.
- :param dag: The DAG to create runs for.
- :param infos: List of logical dates and data intervals to evaluate.
- :param state: The state to set the dag run to
- :param run_type: The prefix will be used to construct dag run id: ``{run_id_prefix}__{execution_date}``.
- :return: Newly created and existing dag runs for the execution dates supplied.
- """
- # Find out existing DAG runs that we don't need to create.
- dag_runs = {
- run.logical_date: run
- for run in DagRun.find(dag_id=dag.dag_id, execution_date=[info.logical_date for info in infos])
- }
- for info in infos:
- if info.logical_date not in dag_runs:
- dag_runs[info.logical_date] = dag.create_dagrun(
- execution_date=info.logical_date,
- data_interval=info.data_interval,
- start_date=timezone.utcnow(),
- external_trigger=False,
- state=state,
- run_type=run_type,
- )
- return dag_runs.values()
- @provide_session
- def set_state(
- *,
- tasks: Collection[Operator | tuple[Operator, int]],
- run_id: str | None = None,
- execution_date: datetime | None = None,
- upstream: bool = False,
- downstream: bool = False,
- future: bool = False,
- past: bool = False,
- state: TaskInstanceState = TaskInstanceState.SUCCESS,
- commit: bool = False,
- session: SASession = NEW_SESSION,
- ) -> list[TaskInstance]:
- """
- Set the state of a task instance and if needed its relatives.
- Can set state for future tasks (calculated from run_id) and retroactively
- for past tasks. Will verify integrity of past dag runs in order to create
- tasks that did not exist. It will not create dag runs that are missing
- on the schedule (but it will, as for subdag, dag runs if needed).
- :param tasks: the iterable of tasks or (task, map_index) tuples from which to work.
- ``task.dag`` needs to be set
- :param run_id: the run_id of the dagrun to start looking from
- :param execution_date: the execution date from which to start looking (deprecated)
- :param upstream: Mark all parents (upstream tasks)
- :param downstream: Mark all siblings (downstream tasks) of task_id, including SubDags
- :param future: Mark all future tasks on the interval of the dag up until
- last execution date.
- :param past: Retroactively mark all tasks starting from start_date of the DAG
- :param state: State to which the tasks need to be set
- :param commit: Commit tasks to be altered to the database
- :param session: database session
- :return: list of tasks that have been created and updated
- """
- if not tasks:
- return []
- if not exactly_one(execution_date, run_id):
- raise ValueError("Exactly one of dag_run_id and execution_date must be set")
- if execution_date and not timezone.is_localized(execution_date):
- raise ValueError(f"Received non-localized date {execution_date}")
- task_dags = {task[0].dag if isinstance(task, tuple) else task.dag for task in tasks}
- if len(task_dags) > 1:
- raise ValueError(f"Received tasks from multiple DAGs: {task_dags}")
- dag = next(iter(task_dags))
- if dag is None:
- raise ValueError("Received tasks with no DAG")
- if execution_date:
- run_id = dag.get_dagrun(execution_date=execution_date, session=session).run_id
- if not run_id:
- raise ValueError("Received tasks with no run_id")
- dag_run_ids = get_run_ids(dag, run_id, future, past, session=session)
- task_id_map_index_list = list(find_task_relatives(tasks, downstream, upstream))
- task_ids = [task_id if isinstance(task_id, str) else task_id[0] for task_id in task_id_map_index_list]
- confirmed_infos = list(_iter_existing_dag_run_infos(dag, dag_run_ids, session=session))
- confirmed_dates = [info.logical_date for info in confirmed_infos]
- sub_dag_run_ids = (
- list(
- _iter_subdag_run_ids(dag, session, DagRunState(state), task_ids, commit, confirmed_infos),
- )
- if not state == TaskInstanceState.SKIPPED
- else []
- )
- # now look for the task instances that are affected
- qry_dag = get_all_dag_task_query(dag, session, state, task_id_map_index_list, dag_run_ids)
- if commit:
- tis_altered = session.scalars(qry_dag.with_for_update()).all()
- if sub_dag_run_ids:
- qry_sub_dag = all_subdag_tasks_query(sub_dag_run_ids, session, state, confirmed_dates)
- tis_altered += session.scalars(qry_sub_dag.with_for_update()).all()
- for task_instance in tis_altered:
- task_instance.set_state(state, session=session)
- session.flush()
- else:
- tis_altered = session.scalars(qry_dag).all()
- if sub_dag_run_ids:
- qry_sub_dag = all_subdag_tasks_query(sub_dag_run_ids, session, state, confirmed_dates)
- tis_altered += session.scalars(qry_sub_dag).all()
- return tis_altered
- def all_subdag_tasks_query(
- sub_dag_run_ids: list[str],
- session: SASession,
- state: TaskInstanceState,
- confirmed_dates: Iterable[datetime],
- ):
- """Get *all* tasks of the sub dags."""
- qry_sub_dag = (
- select(TaskInstance)
- .where(TaskInstance.dag_id.in_(sub_dag_run_ids), TaskInstance.execution_date.in_(confirmed_dates))
- .where(or_(TaskInstance.state.is_(None), TaskInstance.state != state))
- )
- return qry_sub_dag
- def get_all_dag_task_query(
- dag: DAG,
- session: SASession,
- state: TaskInstanceState,
- task_ids: list[str | tuple[str, int]],
- run_ids: Iterable[str],
- ):
- """Get all tasks of the main dag that will be affected by a state change."""
- qry_dag = select(TaskInstance).where(
- TaskInstance.dag_id == dag.dag_id,
- TaskInstance.run_id.in_(run_ids),
- TaskInstance.ti_selector_condition(task_ids),
- )
- qry_dag = qry_dag.where(or_(TaskInstance.state.is_(None), TaskInstance.state != state)).options(
- lazyload(TaskInstance.dag_run)
- )
- return qry_dag
- def _iter_subdag_run_ids(
- dag: DAG,
- session: SASession,
- state: DagRunState,
- task_ids: list[str],
- commit: bool,
- confirmed_infos: Iterable[_DagRunInfo],
- ) -> Iterator[str]:
- """
- Go through subdag operators and create dag runs.
- We only work within the scope of the subdag. A subdag does not propagate to
- its parent DAG, but parent propagates to subdags.
- """
- dags = [dag]
- while dags:
- current_dag = dags.pop()
- for task_id in task_ids:
- if not current_dag.has_task(task_id):
- continue
- current_task = current_dag.get_task(task_id)
- if isinstance(current_task, SubDagOperator) or current_task.task_type == "SubDagOperator":
- # this works as a kind of integrity check
- # it creates missing dag runs for subdag operators,
- # maybe this should be moved to dagrun.verify_integrity
- if TYPE_CHECKING:
- assert current_task.subdag
- dag_runs = _create_dagruns(
- current_task.subdag,
- infos=confirmed_infos,
- state=DagRunState.RUNNING,
- run_type=DagRunType.BACKFILL_JOB,
- )
- verify_dagruns(dag_runs, commit, state, session, current_task)
- dags.append(current_task.subdag)
- yield current_task.subdag.dag_id
- def verify_dagruns(
- dag_runs: Iterable[DagRun],
- commit: bool,
- state: DagRunState,
- session: SASession,
- current_task: Operator,
- ):
- """
- Verify integrity of dag_runs.
- :param dag_runs: dag runs to verify
- :param commit: whether dag runs state should be updated
- :param state: state of the dag_run to set if commit is True
- :param session: session to use
- :param current_task: current task
- """
- for dag_run in dag_runs:
- dag_run.dag = current_task.subdag
- dag_run.verify_integrity()
- if commit:
- dag_run.state = state
- session.merge(dag_run)
- def _iter_existing_dag_run_infos(dag: DAG, run_ids: list[str], session: SASession) -> Iterator[_DagRunInfo]:
- for dag_run in DagRun.find(dag_id=dag.dag_id, run_id=run_ids, session=session):
- dag_run.dag = dag
- dag_run.verify_integrity(session=session)
- yield _DagRunInfo(dag_run.logical_date, dag.get_run_data_interval(dag_run))
- def find_task_relatives(tasks, downstream, upstream):
- """Yield task ids and optionally ancestor and descendant ids."""
- for item in tasks:
- if isinstance(item, tuple):
- task, map_index = item
- yield task.task_id, map_index
- else:
- task = item
- yield task.task_id
- if downstream:
- for relative in task.get_flat_relatives(upstream=False):
- yield relative.task_id
- if upstream:
- for relative in task.get_flat_relatives(upstream=True):
- yield relative.task_id
- @provide_session
- def get_execution_dates(
- dag: DAG, execution_date: datetime, future: bool, past: bool, *, session: SASession = NEW_SESSION
- ) -> list[datetime]:
- """Return DAG execution dates."""
- latest_execution_date = dag.get_latest_execution_date(session=session)
- if latest_execution_date is None:
- raise ValueError(f"Received non-localized date {execution_date}")
- execution_date = timezone.coerce_datetime(execution_date)
- # determine date range of dag runs and tasks to consider
- end_date = latest_execution_date if future else execution_date
- if dag.start_date:
- start_date = dag.start_date
- else:
- start_date = execution_date
- start_date = execution_date if not past else start_date
- if not dag.timetable.can_be_scheduled:
- # If the DAG never schedules, need to look at existing DagRun if the user wants future or
- # past runs.
- dag_runs = dag.get_dagruns_between(start_date=start_date, end_date=end_date)
- dates = sorted({d.execution_date for d in dag_runs})
- elif not dag.timetable.periodic:
- dates = [start_date]
- else:
- dates = [
- info.logical_date for info in dag.iter_dagrun_infos_between(start_date, end_date, align=False)
- ]
- return dates
- @provide_session
- def get_run_ids(dag: DAG, run_id: str, future: bool, past: bool, session: SASession = NEW_SESSION):
- """Return DAG executions' run_ids."""
- last_dagrun = dag.get_last_dagrun(include_externally_triggered=True, session=session)
- current_dagrun = dag.get_dagrun(run_id=run_id, session=session)
- first_dagrun = session.scalar(
- select(DagRun).filter(DagRun.dag_id == dag.dag_id).order_by(DagRun.execution_date.asc()).limit(1)
- )
- if last_dagrun is None:
- raise ValueError(f"DagRun for {dag.dag_id} not found")
- # determine run_id range of dag runs and tasks to consider
- end_date = last_dagrun.logical_date if future else current_dagrun.logical_date
- start_date = current_dagrun.logical_date if not past else first_dagrun.logical_date
- if not dag.timetable.can_be_scheduled:
- # If the DAG never schedules, need to look at existing DagRun if the user wants future or
- # past runs.
- dag_runs = dag.get_dagruns_between(start_date=start_date, end_date=end_date, session=session)
- run_ids = sorted({d.run_id for d in dag_runs})
- elif not dag.timetable.periodic:
- run_ids = [run_id]
- else:
- dates = [
- info.logical_date for info in dag.iter_dagrun_infos_between(start_date, end_date, align=False)
- ]
- run_ids = [dr.run_id for dr in DagRun.find(dag_id=dag.dag_id, execution_date=dates, session=session)]
- return run_ids
- def _set_dag_run_state(dag_id: str, run_id: str, state: DagRunState, session: SASession):
- """
- Set dag run state in the DB.
- :param dag_id: dag_id of target dag run
- :param run_id: run id of target dag run
- :param state: target state
- :param session: database session
- """
- dag_run = session.execute(
- select(DagRun).where(DagRun.dag_id == dag_id, DagRun.run_id == run_id)
- ).scalar_one()
- dag_run.state = state
- session.merge(dag_run)
- @provide_session
- def set_dag_run_state_to_success(
- *,
- dag: DAG,
- execution_date: datetime | None = None,
- run_id: str | None = None,
- commit: bool = False,
- session: SASession = NEW_SESSION,
- ) -> list[TaskInstance]:
- """
- Set the dag run's state to success.
- Set for a specific execution date and its task instances to success.
- :param dag: the DAG of which to alter state
- :param execution_date: the execution date from which to start looking(deprecated)
- :param run_id: the run_id to start looking from
- :param commit: commit DAG and tasks to be altered to the database
- :param session: database session
- :return: If commit is true, list of tasks that have been updated,
- otherwise list of tasks that will be updated
- :raises: ValueError if dag or execution_date is invalid
- """
- if not exactly_one(execution_date, run_id):
- return []
- if not dag:
- return []
- if execution_date:
- if not timezone.is_localized(execution_date):
- raise ValueError(f"Received non-localized date {execution_date}")
- dag_run = dag.get_dagrun(execution_date=execution_date)
- if not dag_run:
- raise ValueError(f"DagRun with execution_date: {execution_date} not found")
- run_id = dag_run.run_id
- if not run_id:
- raise ValueError(f"Invalid dag_run_id: {run_id}")
- # Mark all task instances of the dag run to success - except for teardown as they need to complete work.
- normal_tasks = [task for task in dag.tasks if not task.is_teardown]
- # Mark the dag run to success.
- if commit and len(normal_tasks) == len(dag.tasks):
- _set_dag_run_state(dag.dag_id, run_id, DagRunState.SUCCESS, session)
- for task in normal_tasks:
- task.dag = dag
- return set_state(
- tasks=normal_tasks,
- run_id=run_id,
- state=TaskInstanceState.SUCCESS,
- commit=commit,
- session=session,
- )
- @provide_session
- def set_dag_run_state_to_failed(
- *,
- dag: DAG,
- execution_date: datetime | None = None,
- run_id: str | None = None,
- commit: bool = False,
- session: SASession = NEW_SESSION,
- ) -> list[TaskInstance]:
- """
- Set the dag run's state to failed.
- Set for a specific execution date and its task instances to failed.
- :param dag: the DAG of which to alter state
- :param execution_date: the execution date from which to start looking(deprecated)
- :param run_id: the DAG run_id to start looking from
- :param commit: commit DAG and tasks to be altered to the database
- :param session: database session
- :return: If commit is true, list of tasks that have been updated,
- otherwise list of tasks that will be updated
- :raises: AssertionError if dag or execution_date is invalid
- """
- if not exactly_one(execution_date, run_id):
- return []
- if not dag:
- return []
- if execution_date:
- if not timezone.is_localized(execution_date):
- raise ValueError(f"Received non-localized date {execution_date}")
- dag_run = dag.get_dagrun(execution_date=execution_date)
- if not dag_run:
- raise ValueError(f"DagRun with execution_date: {execution_date} not found")
- run_id = dag_run.run_id
- if not run_id:
- raise ValueError(f"Invalid dag_run_id: {run_id}")
- running_states = (
- TaskInstanceState.RUNNING,
- TaskInstanceState.DEFERRED,
- TaskInstanceState.UP_FOR_RESCHEDULE,
- )
- # Mark only RUNNING task instances.
- task_ids = [task.task_id for task in dag.tasks]
- running_tis: list[TaskInstance] = session.scalars(
- select(TaskInstance).where(
- TaskInstance.dag_id == dag.dag_id,
- TaskInstance.run_id == run_id,
- TaskInstance.task_id.in_(task_ids),
- TaskInstance.state.in_(running_states),
- )
- ).all()
- # Do not kill teardown tasks
- task_ids_of_running_tis = [ti.task_id for ti in running_tis if not dag.task_dict[ti.task_id].is_teardown]
- running_tasks = []
- for task in dag.tasks:
- if task.task_id in task_ids_of_running_tis:
- task.dag = dag
- running_tasks.append(task)
- # Mark non-finished tasks as SKIPPED.
- pending_tis: list[TaskInstance] = session.scalars(
- select(TaskInstance).filter(
- TaskInstance.dag_id == dag.dag_id,
- TaskInstance.run_id == run_id,
- or_(
- TaskInstance.state.is_(None),
- and_(
- TaskInstance.state.not_in(State.finished),
- TaskInstance.state.not_in(running_states),
- ),
- ),
- )
- ).all()
- # Do not skip teardown tasks
- pending_normal_tis = [ti for ti in pending_tis if not dag.task_dict[ti.task_id].is_teardown]
- if commit:
- for ti in pending_normal_tis:
- ti.set_state(TaskInstanceState.SKIPPED)
- # Mark the dag run to failed if there is no pending teardown (else this would not be scheduled later).
- if not any(dag.task_dict[ti.task_id].is_teardown for ti in (running_tis + pending_tis)):
- _set_dag_run_state(dag.dag_id, run_id, DagRunState.FAILED, session)
- return pending_normal_tis + set_state(
- tasks=running_tasks,
- run_id=run_id,
- state=TaskInstanceState.FAILED,
- commit=commit,
- session=session,
- )
- def __set_dag_run_state_to_running_or_queued(
- *,
- new_state: DagRunState,
- dag: DAG,
- execution_date: datetime | None = None,
- run_id: str | None = None,
- commit: bool = False,
- session: SASession,
- ) -> list[TaskInstance]:
- """
- Set the dag run for a specific execution date to running.
- :param dag: the DAG of which to alter state
- :param execution_date: the execution date from which to start looking
- :param run_id: the id of the DagRun
- :param commit: commit DAG and tasks to be altered to the database
- :param session: database session
- :return: If commit is true, list of tasks that have been updated,
- otherwise list of tasks that will be updated
- """
- res: list[TaskInstance] = []
- if not exactly_one(execution_date, run_id):
- return res
- if not dag:
- return res
- if execution_date:
- if not timezone.is_localized(execution_date):
- raise ValueError(f"Received non-localized date {execution_date}")
- dag_run = dag.get_dagrun(execution_date=execution_date)
- if not dag_run:
- raise ValueError(f"DagRun with execution_date: {execution_date} not found")
- run_id = dag_run.run_id
- if not run_id:
- raise ValueError(f"DagRun with run_id: {run_id} not found")
- # Mark the dag run to running.
- if commit:
- _set_dag_run_state(dag.dag_id, run_id, new_state, session)
- # To keep the return type consistent with the other similar functions.
- return res
- @provide_session
- def set_dag_run_state_to_running(
- *,
- dag: DAG,
- execution_date: datetime | None = None,
- run_id: str | None = None,
- commit: bool = False,
- session: SASession = NEW_SESSION,
- ) -> list[TaskInstance]:
- """
- Set the dag run's state to running.
- Set for a specific execution date and its task instances to running.
- """
- return __set_dag_run_state_to_running_or_queued(
- new_state=DagRunState.RUNNING,
- dag=dag,
- execution_date=execution_date,
- run_id=run_id,
- commit=commit,
- session=session,
- )
- @provide_session
- def set_dag_run_state_to_queued(
- *,
- dag: DAG,
- execution_date: datetime | None = None,
- run_id: str | None = None,
- commit: bool = False,
- session: SASession = NEW_SESSION,
- ) -> list[TaskInstance]:
- """
- Set the dag run's state to queued.
- Set for a specific execution date and its task instances to queued.
- """
- return __set_dag_run_state_to_running_or_queued(
- new_state=DagRunState.QUEUED,
- dag=dag,
- execution_date=execution_date,
- run_id=run_id,
- commit=commit,
- session=session,
- )
|