mark_tasks.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623
  1. #
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing,
  13. # software distributed under the License is distributed on an
  14. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. # KIND, either express or implied. See the License for the
  16. # specific language governing permissions and limitations
  17. # under the License.
  18. """Marks tasks APIs."""
  19. from __future__ import annotations
  20. from typing import TYPE_CHECKING, Collection, Iterable, Iterator, NamedTuple
  21. from sqlalchemy import and_, or_, select
  22. from sqlalchemy.orm import lazyload
  23. from airflow.models.dagrun import DagRun
  24. from airflow.models.taskinstance import TaskInstance
  25. from airflow.operators.subdag import SubDagOperator
  26. from airflow.utils import timezone
  27. from airflow.utils.helpers import exactly_one
  28. from airflow.utils.session import NEW_SESSION, provide_session
  29. from airflow.utils.state import DagRunState, State, TaskInstanceState
  30. from airflow.utils.types import DagRunType
  31. if TYPE_CHECKING:
  32. from datetime import datetime
  33. from sqlalchemy.orm import Session as SASession
  34. from airflow.models.dag import DAG
  35. from airflow.models.operator import Operator
  36. class _DagRunInfo(NamedTuple):
  37. logical_date: datetime
  38. data_interval: tuple[datetime, datetime]
  39. def _create_dagruns(
  40. dag: DAG,
  41. infos: Iterable[_DagRunInfo],
  42. state: DagRunState,
  43. run_type: DagRunType,
  44. ) -> Iterable[DagRun]:
  45. """
  46. Infers from data intervals which DAG runs need to be created and does so.
  47. :param dag: The DAG to create runs for.
  48. :param infos: List of logical dates and data intervals to evaluate.
  49. :param state: The state to set the dag run to
  50. :param run_type: The prefix will be used to construct dag run id: ``{run_id_prefix}__{execution_date}``.
  51. :return: Newly created and existing dag runs for the execution dates supplied.
  52. """
  53. # Find out existing DAG runs that we don't need to create.
  54. dag_runs = {
  55. run.logical_date: run
  56. for run in DagRun.find(dag_id=dag.dag_id, execution_date=[info.logical_date for info in infos])
  57. }
  58. for info in infos:
  59. if info.logical_date not in dag_runs:
  60. dag_runs[info.logical_date] = dag.create_dagrun(
  61. execution_date=info.logical_date,
  62. data_interval=info.data_interval,
  63. start_date=timezone.utcnow(),
  64. external_trigger=False,
  65. state=state,
  66. run_type=run_type,
  67. )
  68. return dag_runs.values()
  69. @provide_session
  70. def set_state(
  71. *,
  72. tasks: Collection[Operator | tuple[Operator, int]],
  73. run_id: str | None = None,
  74. execution_date: datetime | None = None,
  75. upstream: bool = False,
  76. downstream: bool = False,
  77. future: bool = False,
  78. past: bool = False,
  79. state: TaskInstanceState = TaskInstanceState.SUCCESS,
  80. commit: bool = False,
  81. session: SASession = NEW_SESSION,
  82. ) -> list[TaskInstance]:
  83. """
  84. Set the state of a task instance and if needed its relatives.
  85. Can set state for future tasks (calculated from run_id) and retroactively
  86. for past tasks. Will verify integrity of past dag runs in order to create
  87. tasks that did not exist. It will not create dag runs that are missing
  88. on the schedule (but it will, as for subdag, dag runs if needed).
  89. :param tasks: the iterable of tasks or (task, map_index) tuples from which to work.
  90. ``task.dag`` needs to be set
  91. :param run_id: the run_id of the dagrun to start looking from
  92. :param execution_date: the execution date from which to start looking (deprecated)
  93. :param upstream: Mark all parents (upstream tasks)
  94. :param downstream: Mark all siblings (downstream tasks) of task_id, including SubDags
  95. :param future: Mark all future tasks on the interval of the dag up until
  96. last execution date.
  97. :param past: Retroactively mark all tasks starting from start_date of the DAG
  98. :param state: State to which the tasks need to be set
  99. :param commit: Commit tasks to be altered to the database
  100. :param session: database session
  101. :return: list of tasks that have been created and updated
  102. """
  103. if not tasks:
  104. return []
  105. if not exactly_one(execution_date, run_id):
  106. raise ValueError("Exactly one of dag_run_id and execution_date must be set")
  107. if execution_date and not timezone.is_localized(execution_date):
  108. raise ValueError(f"Received non-localized date {execution_date}")
  109. task_dags = {task[0].dag if isinstance(task, tuple) else task.dag for task in tasks}
  110. if len(task_dags) > 1:
  111. raise ValueError(f"Received tasks from multiple DAGs: {task_dags}")
  112. dag = next(iter(task_dags))
  113. if dag is None:
  114. raise ValueError("Received tasks with no DAG")
  115. if execution_date:
  116. run_id = dag.get_dagrun(execution_date=execution_date, session=session).run_id
  117. if not run_id:
  118. raise ValueError("Received tasks with no run_id")
  119. dag_run_ids = get_run_ids(dag, run_id, future, past, session=session)
  120. task_id_map_index_list = list(find_task_relatives(tasks, downstream, upstream))
  121. task_ids = [task_id if isinstance(task_id, str) else task_id[0] for task_id in task_id_map_index_list]
  122. confirmed_infos = list(_iter_existing_dag_run_infos(dag, dag_run_ids, session=session))
  123. confirmed_dates = [info.logical_date for info in confirmed_infos]
  124. sub_dag_run_ids = (
  125. list(
  126. _iter_subdag_run_ids(dag, session, DagRunState(state), task_ids, commit, confirmed_infos),
  127. )
  128. if not state == TaskInstanceState.SKIPPED
  129. else []
  130. )
  131. # now look for the task instances that are affected
  132. qry_dag = get_all_dag_task_query(dag, session, state, task_id_map_index_list, dag_run_ids)
  133. if commit:
  134. tis_altered = session.scalars(qry_dag.with_for_update()).all()
  135. if sub_dag_run_ids:
  136. qry_sub_dag = all_subdag_tasks_query(sub_dag_run_ids, session, state, confirmed_dates)
  137. tis_altered += session.scalars(qry_sub_dag.with_for_update()).all()
  138. for task_instance in tis_altered:
  139. task_instance.set_state(state, session=session)
  140. session.flush()
  141. else:
  142. tis_altered = session.scalars(qry_dag).all()
  143. if sub_dag_run_ids:
  144. qry_sub_dag = all_subdag_tasks_query(sub_dag_run_ids, session, state, confirmed_dates)
  145. tis_altered += session.scalars(qry_sub_dag).all()
  146. return tis_altered
  147. def all_subdag_tasks_query(
  148. sub_dag_run_ids: list[str],
  149. session: SASession,
  150. state: TaskInstanceState,
  151. confirmed_dates: Iterable[datetime],
  152. ):
  153. """Get *all* tasks of the sub dags."""
  154. qry_sub_dag = (
  155. select(TaskInstance)
  156. .where(TaskInstance.dag_id.in_(sub_dag_run_ids), TaskInstance.execution_date.in_(confirmed_dates))
  157. .where(or_(TaskInstance.state.is_(None), TaskInstance.state != state))
  158. )
  159. return qry_sub_dag
  160. def get_all_dag_task_query(
  161. dag: DAG,
  162. session: SASession,
  163. state: TaskInstanceState,
  164. task_ids: list[str | tuple[str, int]],
  165. run_ids: Iterable[str],
  166. ):
  167. """Get all tasks of the main dag that will be affected by a state change."""
  168. qry_dag = select(TaskInstance).where(
  169. TaskInstance.dag_id == dag.dag_id,
  170. TaskInstance.run_id.in_(run_ids),
  171. TaskInstance.ti_selector_condition(task_ids),
  172. )
  173. qry_dag = qry_dag.where(or_(TaskInstance.state.is_(None), TaskInstance.state != state)).options(
  174. lazyload(TaskInstance.dag_run)
  175. )
  176. return qry_dag
  177. def _iter_subdag_run_ids(
  178. dag: DAG,
  179. session: SASession,
  180. state: DagRunState,
  181. task_ids: list[str],
  182. commit: bool,
  183. confirmed_infos: Iterable[_DagRunInfo],
  184. ) -> Iterator[str]:
  185. """
  186. Go through subdag operators and create dag runs.
  187. We only work within the scope of the subdag. A subdag does not propagate to
  188. its parent DAG, but parent propagates to subdags.
  189. """
  190. dags = [dag]
  191. while dags:
  192. current_dag = dags.pop()
  193. for task_id in task_ids:
  194. if not current_dag.has_task(task_id):
  195. continue
  196. current_task = current_dag.get_task(task_id)
  197. if isinstance(current_task, SubDagOperator) or current_task.task_type == "SubDagOperator":
  198. # this works as a kind of integrity check
  199. # it creates missing dag runs for subdag operators,
  200. # maybe this should be moved to dagrun.verify_integrity
  201. if TYPE_CHECKING:
  202. assert current_task.subdag
  203. dag_runs = _create_dagruns(
  204. current_task.subdag,
  205. infos=confirmed_infos,
  206. state=DagRunState.RUNNING,
  207. run_type=DagRunType.BACKFILL_JOB,
  208. )
  209. verify_dagruns(dag_runs, commit, state, session, current_task)
  210. dags.append(current_task.subdag)
  211. yield current_task.subdag.dag_id
  212. def verify_dagruns(
  213. dag_runs: Iterable[DagRun],
  214. commit: bool,
  215. state: DagRunState,
  216. session: SASession,
  217. current_task: Operator,
  218. ):
  219. """
  220. Verify integrity of dag_runs.
  221. :param dag_runs: dag runs to verify
  222. :param commit: whether dag runs state should be updated
  223. :param state: state of the dag_run to set if commit is True
  224. :param session: session to use
  225. :param current_task: current task
  226. """
  227. for dag_run in dag_runs:
  228. dag_run.dag = current_task.subdag
  229. dag_run.verify_integrity()
  230. if commit:
  231. dag_run.state = state
  232. session.merge(dag_run)
  233. def _iter_existing_dag_run_infos(dag: DAG, run_ids: list[str], session: SASession) -> Iterator[_DagRunInfo]:
  234. for dag_run in DagRun.find(dag_id=dag.dag_id, run_id=run_ids, session=session):
  235. dag_run.dag = dag
  236. dag_run.verify_integrity(session=session)
  237. yield _DagRunInfo(dag_run.logical_date, dag.get_run_data_interval(dag_run))
  238. def find_task_relatives(tasks, downstream, upstream):
  239. """Yield task ids and optionally ancestor and descendant ids."""
  240. for item in tasks:
  241. if isinstance(item, tuple):
  242. task, map_index = item
  243. yield task.task_id, map_index
  244. else:
  245. task = item
  246. yield task.task_id
  247. if downstream:
  248. for relative in task.get_flat_relatives(upstream=False):
  249. yield relative.task_id
  250. if upstream:
  251. for relative in task.get_flat_relatives(upstream=True):
  252. yield relative.task_id
  253. @provide_session
  254. def get_execution_dates(
  255. dag: DAG, execution_date: datetime, future: bool, past: bool, *, session: SASession = NEW_SESSION
  256. ) -> list[datetime]:
  257. """Return DAG execution dates."""
  258. latest_execution_date = dag.get_latest_execution_date(session=session)
  259. if latest_execution_date is None:
  260. raise ValueError(f"Received non-localized date {execution_date}")
  261. execution_date = timezone.coerce_datetime(execution_date)
  262. # determine date range of dag runs and tasks to consider
  263. end_date = latest_execution_date if future else execution_date
  264. if dag.start_date:
  265. start_date = dag.start_date
  266. else:
  267. start_date = execution_date
  268. start_date = execution_date if not past else start_date
  269. if not dag.timetable.can_be_scheduled:
  270. # If the DAG never schedules, need to look at existing DagRun if the user wants future or
  271. # past runs.
  272. dag_runs = dag.get_dagruns_between(start_date=start_date, end_date=end_date)
  273. dates = sorted({d.execution_date for d in dag_runs})
  274. elif not dag.timetable.periodic:
  275. dates = [start_date]
  276. else:
  277. dates = [
  278. info.logical_date for info in dag.iter_dagrun_infos_between(start_date, end_date, align=False)
  279. ]
  280. return dates
  281. @provide_session
  282. def get_run_ids(dag: DAG, run_id: str, future: bool, past: bool, session: SASession = NEW_SESSION):
  283. """Return DAG executions' run_ids."""
  284. last_dagrun = dag.get_last_dagrun(include_externally_triggered=True, session=session)
  285. current_dagrun = dag.get_dagrun(run_id=run_id, session=session)
  286. first_dagrun = session.scalar(
  287. select(DagRun).filter(DagRun.dag_id == dag.dag_id).order_by(DagRun.execution_date.asc()).limit(1)
  288. )
  289. if last_dagrun is None:
  290. raise ValueError(f"DagRun for {dag.dag_id} not found")
  291. # determine run_id range of dag runs and tasks to consider
  292. end_date = last_dagrun.logical_date if future else current_dagrun.logical_date
  293. start_date = current_dagrun.logical_date if not past else first_dagrun.logical_date
  294. if not dag.timetable.can_be_scheduled:
  295. # If the DAG never schedules, need to look at existing DagRun if the user wants future or
  296. # past runs.
  297. dag_runs = dag.get_dagruns_between(start_date=start_date, end_date=end_date, session=session)
  298. run_ids = sorted({d.run_id for d in dag_runs})
  299. elif not dag.timetable.periodic:
  300. run_ids = [run_id]
  301. else:
  302. dates = [
  303. info.logical_date for info in dag.iter_dagrun_infos_between(start_date, end_date, align=False)
  304. ]
  305. run_ids = [dr.run_id for dr in DagRun.find(dag_id=dag.dag_id, execution_date=dates, session=session)]
  306. return run_ids
  307. def _set_dag_run_state(dag_id: str, run_id: str, state: DagRunState, session: SASession):
  308. """
  309. Set dag run state in the DB.
  310. :param dag_id: dag_id of target dag run
  311. :param run_id: run id of target dag run
  312. :param state: target state
  313. :param session: database session
  314. """
  315. dag_run = session.execute(
  316. select(DagRun).where(DagRun.dag_id == dag_id, DagRun.run_id == run_id)
  317. ).scalar_one()
  318. dag_run.state = state
  319. session.merge(dag_run)
  320. @provide_session
  321. def set_dag_run_state_to_success(
  322. *,
  323. dag: DAG,
  324. execution_date: datetime | None = None,
  325. run_id: str | None = None,
  326. commit: bool = False,
  327. session: SASession = NEW_SESSION,
  328. ) -> list[TaskInstance]:
  329. """
  330. Set the dag run's state to success.
  331. Set for a specific execution date and its task instances to success.
  332. :param dag: the DAG of which to alter state
  333. :param execution_date: the execution date from which to start looking(deprecated)
  334. :param run_id: the run_id to start looking from
  335. :param commit: commit DAG and tasks to be altered to the database
  336. :param session: database session
  337. :return: If commit is true, list of tasks that have been updated,
  338. otherwise list of tasks that will be updated
  339. :raises: ValueError if dag or execution_date is invalid
  340. """
  341. if not exactly_one(execution_date, run_id):
  342. return []
  343. if not dag:
  344. return []
  345. if execution_date:
  346. if not timezone.is_localized(execution_date):
  347. raise ValueError(f"Received non-localized date {execution_date}")
  348. dag_run = dag.get_dagrun(execution_date=execution_date)
  349. if not dag_run:
  350. raise ValueError(f"DagRun with execution_date: {execution_date} not found")
  351. run_id = dag_run.run_id
  352. if not run_id:
  353. raise ValueError(f"Invalid dag_run_id: {run_id}")
  354. # Mark all task instances of the dag run to success - except for teardown as they need to complete work.
  355. normal_tasks = [task for task in dag.tasks if not task.is_teardown]
  356. # Mark the dag run to success.
  357. if commit and len(normal_tasks) == len(dag.tasks):
  358. _set_dag_run_state(dag.dag_id, run_id, DagRunState.SUCCESS, session)
  359. for task in normal_tasks:
  360. task.dag = dag
  361. return set_state(
  362. tasks=normal_tasks,
  363. run_id=run_id,
  364. state=TaskInstanceState.SUCCESS,
  365. commit=commit,
  366. session=session,
  367. )
  368. @provide_session
  369. def set_dag_run_state_to_failed(
  370. *,
  371. dag: DAG,
  372. execution_date: datetime | None = None,
  373. run_id: str | None = None,
  374. commit: bool = False,
  375. session: SASession = NEW_SESSION,
  376. ) -> list[TaskInstance]:
  377. """
  378. Set the dag run's state to failed.
  379. Set for a specific execution date and its task instances to failed.
  380. :param dag: the DAG of which to alter state
  381. :param execution_date: the execution date from which to start looking(deprecated)
  382. :param run_id: the DAG run_id to start looking from
  383. :param commit: commit DAG and tasks to be altered to the database
  384. :param session: database session
  385. :return: If commit is true, list of tasks that have been updated,
  386. otherwise list of tasks that will be updated
  387. :raises: AssertionError if dag or execution_date is invalid
  388. """
  389. if not exactly_one(execution_date, run_id):
  390. return []
  391. if not dag:
  392. return []
  393. if execution_date:
  394. if not timezone.is_localized(execution_date):
  395. raise ValueError(f"Received non-localized date {execution_date}")
  396. dag_run = dag.get_dagrun(execution_date=execution_date)
  397. if not dag_run:
  398. raise ValueError(f"DagRun with execution_date: {execution_date} not found")
  399. run_id = dag_run.run_id
  400. if not run_id:
  401. raise ValueError(f"Invalid dag_run_id: {run_id}")
  402. running_states = (
  403. TaskInstanceState.RUNNING,
  404. TaskInstanceState.DEFERRED,
  405. TaskInstanceState.UP_FOR_RESCHEDULE,
  406. )
  407. # Mark only RUNNING task instances.
  408. task_ids = [task.task_id for task in dag.tasks]
  409. running_tis: list[TaskInstance] = session.scalars(
  410. select(TaskInstance).where(
  411. TaskInstance.dag_id == dag.dag_id,
  412. TaskInstance.run_id == run_id,
  413. TaskInstance.task_id.in_(task_ids),
  414. TaskInstance.state.in_(running_states),
  415. )
  416. ).all()
  417. # Do not kill teardown tasks
  418. task_ids_of_running_tis = [ti.task_id for ti in running_tis if not dag.task_dict[ti.task_id].is_teardown]
  419. running_tasks = []
  420. for task in dag.tasks:
  421. if task.task_id in task_ids_of_running_tis:
  422. task.dag = dag
  423. running_tasks.append(task)
  424. # Mark non-finished tasks as SKIPPED.
  425. pending_tis: list[TaskInstance] = session.scalars(
  426. select(TaskInstance).filter(
  427. TaskInstance.dag_id == dag.dag_id,
  428. TaskInstance.run_id == run_id,
  429. or_(
  430. TaskInstance.state.is_(None),
  431. and_(
  432. TaskInstance.state.not_in(State.finished),
  433. TaskInstance.state.not_in(running_states),
  434. ),
  435. ),
  436. )
  437. ).all()
  438. # Do not skip teardown tasks
  439. pending_normal_tis = [ti for ti in pending_tis if not dag.task_dict[ti.task_id].is_teardown]
  440. if commit:
  441. for ti in pending_normal_tis:
  442. ti.set_state(TaskInstanceState.SKIPPED)
  443. # Mark the dag run to failed if there is no pending teardown (else this would not be scheduled later).
  444. if not any(dag.task_dict[ti.task_id].is_teardown for ti in (running_tis + pending_tis)):
  445. _set_dag_run_state(dag.dag_id, run_id, DagRunState.FAILED, session)
  446. return pending_normal_tis + set_state(
  447. tasks=running_tasks,
  448. run_id=run_id,
  449. state=TaskInstanceState.FAILED,
  450. commit=commit,
  451. session=session,
  452. )
  453. def __set_dag_run_state_to_running_or_queued(
  454. *,
  455. new_state: DagRunState,
  456. dag: DAG,
  457. execution_date: datetime | None = None,
  458. run_id: str | None = None,
  459. commit: bool = False,
  460. session: SASession,
  461. ) -> list[TaskInstance]:
  462. """
  463. Set the dag run for a specific execution date to running.
  464. :param dag: the DAG of which to alter state
  465. :param execution_date: the execution date from which to start looking
  466. :param run_id: the id of the DagRun
  467. :param commit: commit DAG and tasks to be altered to the database
  468. :param session: database session
  469. :return: If commit is true, list of tasks that have been updated,
  470. otherwise list of tasks that will be updated
  471. """
  472. res: list[TaskInstance] = []
  473. if not exactly_one(execution_date, run_id):
  474. return res
  475. if not dag:
  476. return res
  477. if execution_date:
  478. if not timezone.is_localized(execution_date):
  479. raise ValueError(f"Received non-localized date {execution_date}")
  480. dag_run = dag.get_dagrun(execution_date=execution_date)
  481. if not dag_run:
  482. raise ValueError(f"DagRun with execution_date: {execution_date} not found")
  483. run_id = dag_run.run_id
  484. if not run_id:
  485. raise ValueError(f"DagRun with run_id: {run_id} not found")
  486. # Mark the dag run to running.
  487. if commit:
  488. _set_dag_run_state(dag.dag_id, run_id, new_state, session)
  489. # To keep the return type consistent with the other similar functions.
  490. return res
  491. @provide_session
  492. def set_dag_run_state_to_running(
  493. *,
  494. dag: DAG,
  495. execution_date: datetime | None = None,
  496. run_id: str | None = None,
  497. commit: bool = False,
  498. session: SASession = NEW_SESSION,
  499. ) -> list[TaskInstance]:
  500. """
  501. Set the dag run's state to running.
  502. Set for a specific execution date and its task instances to running.
  503. """
  504. return __set_dag_run_state_to_running_or_queued(
  505. new_state=DagRunState.RUNNING,
  506. dag=dag,
  507. execution_date=execution_date,
  508. run_id=run_id,
  509. commit=commit,
  510. session=session,
  511. )
  512. @provide_session
  513. def set_dag_run_state_to_queued(
  514. *,
  515. dag: DAG,
  516. execution_date: datetime | None = None,
  517. run_id: str | None = None,
  518. commit: bool = False,
  519. session: SASession = NEW_SESSION,
  520. ) -> list[TaskInstance]:
  521. """
  522. Set the dag run's state to queued.
  523. Set for a specific execution date and its task instances to queued.
  524. """
  525. return __set_dag_run_state_to_running_or_queued(
  526. new_state=DagRunState.QUEUED,
  527. dag=dag,
  528. execution_date=execution_date,
  529. run_id=run_id,
  530. commit=commit,
  531. session=session,
  532. )