dagrun.py 69 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721
  1. #
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing,
  13. # software distributed under the License is distributed on an
  14. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. # KIND, either express or implied. See the License for the
  16. # specific language governing permissions and limitations
  17. # under the License.
  18. from __future__ import annotations
  19. import itertools
  20. import os
  21. import warnings
  22. from collections import defaultdict
  23. from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, NamedTuple, Sequence, TypeVar, overload
  24. import re2
  25. from sqlalchemy import (
  26. Boolean,
  27. Column,
  28. ForeignKey,
  29. ForeignKeyConstraint,
  30. Index,
  31. Integer,
  32. PickleType,
  33. PrimaryKeyConstraint,
  34. String,
  35. Text,
  36. UniqueConstraint,
  37. and_,
  38. func,
  39. or_,
  40. text,
  41. update,
  42. )
  43. from sqlalchemy.exc import IntegrityError
  44. from sqlalchemy.ext.associationproxy import association_proxy
  45. from sqlalchemy.orm import declared_attr, joinedload, relationship, synonym, validates
  46. from sqlalchemy.sql.expression import case, false, select, true
  47. from airflow import settings
  48. from airflow.api_internal.internal_api_call import internal_api_call
  49. from airflow.callbacks.callback_requests import DagCallbackRequest
  50. from airflow.configuration import conf as airflow_conf
  51. from airflow.exceptions import AirflowException, RemovedInAirflow3Warning, TaskNotFound
  52. from airflow.listeners.listener import get_listener_manager
  53. from airflow.models import Log
  54. from airflow.models.abstractoperator import NotMapped
  55. from airflow.models.base import Base, StringID
  56. from airflow.models.expandinput import NotFullyPopulated
  57. from airflow.models.taskinstance import TaskInstance as TI
  58. from airflow.models.tasklog import LogTemplate
  59. from airflow.stats import Stats
  60. from airflow.ti_deps.dep_context import DepContext
  61. from airflow.ti_deps.dependencies_states import SCHEDULEABLE_STATES
  62. from airflow.traces.tracer import Trace
  63. from airflow.utils import timezone
  64. from airflow.utils.dates import datetime_to_nano
  65. from airflow.utils.helpers import chunks, is_container, prune_dict
  66. from airflow.utils.log.logging_mixin import LoggingMixin
  67. from airflow.utils.session import NEW_SESSION, provide_session
  68. from airflow.utils.sqlalchemy import UtcDateTime, nulls_first, tuple_in_condition, with_row_locks
  69. from airflow.utils.state import DagRunState, State, TaskInstanceState
  70. from airflow.utils.types import NOTSET, DagRunType
  71. if TYPE_CHECKING:
  72. from datetime import datetime
  73. from sqlalchemy.orm import Query, Session
  74. from airflow.models.dag import DAG
  75. from airflow.models.operator import Operator
  76. from airflow.serialization.pydantic.dag_run import DagRunPydantic
  77. from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
  78. from airflow.serialization.pydantic.tasklog import LogTemplatePydantic
  79. from airflow.typing_compat import Literal
  80. from airflow.utils.types import ArgNotSet
  81. CreatedTasks = TypeVar("CreatedTasks", Iterator["dict[str, Any]"], Iterator[TI])
  82. RUN_ID_REGEX = r"^(?:manual|scheduled|dataset_triggered)__(?:\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\+00:00)$"
  83. class TISchedulingDecision(NamedTuple):
  84. """Type of return for DagRun.task_instance_scheduling_decisions."""
  85. tis: list[TI]
  86. schedulable_tis: list[TI]
  87. changed_tis: bool
  88. unfinished_tis: list[TI]
  89. finished_tis: list[TI]
  90. def _creator_note(val):
  91. """Creator the ``note`` association proxy."""
  92. if isinstance(val, str):
  93. return DagRunNote(content=val)
  94. elif isinstance(val, dict):
  95. return DagRunNote(**val)
  96. else:
  97. return DagRunNote(*val)
  98. class DagRun(Base, LoggingMixin):
  99. """
  100. Invocation instance of a DAG.
  101. A DAG run can be created by the scheduler (i.e. scheduled runs), or by an
  102. external trigger (i.e. manual runs).
  103. """
  104. __tablename__ = "dag_run"
  105. id = Column(Integer, primary_key=True)
  106. dag_id = Column(StringID(), nullable=False)
  107. queued_at = Column(UtcDateTime)
  108. execution_date = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
  109. start_date = Column(UtcDateTime)
  110. end_date = Column(UtcDateTime)
  111. _state = Column("state", String(50), default=DagRunState.QUEUED)
  112. run_id = Column(StringID(), nullable=False)
  113. creating_job_id = Column(Integer)
  114. external_trigger = Column(Boolean, default=True)
  115. run_type = Column(String(50), nullable=False)
  116. conf = Column(PickleType)
  117. # These two must be either both NULL or both datetime.
  118. data_interval_start = Column(UtcDateTime)
  119. data_interval_end = Column(UtcDateTime)
  120. # When a scheduler last attempted to schedule TIs for this DagRun
  121. last_scheduling_decision = Column(UtcDateTime)
  122. dag_hash = Column(String(32))
  123. # Foreign key to LogTemplate. DagRun rows created prior to this column's
  124. # existence have this set to NULL. Later rows automatically populate this on
  125. # insert to point to the latest LogTemplate entry.
  126. log_template_id = Column(
  127. Integer,
  128. ForeignKey("log_template.id", name="task_instance_log_template_id_fkey", ondelete="NO ACTION"),
  129. default=select(func.max(LogTemplate.__table__.c.id)),
  130. )
  131. updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow)
  132. # Keeps track of the number of times the dagrun had been cleared.
  133. # This number is incremented only when the DagRun is re-Queued,
  134. # when the DagRun is cleared.
  135. clear_number = Column(Integer, default=0, nullable=False, server_default="0")
  136. # Remove this `if` after upgrading Sphinx-AutoAPI
  137. if not TYPE_CHECKING and "BUILDING_AIRFLOW_DOCS" in os.environ:
  138. dag: DAG | None
  139. else:
  140. dag: DAG | None = None
  141. __table_args__ = (
  142. Index("dag_id_state", dag_id, _state),
  143. UniqueConstraint("dag_id", "execution_date", name="dag_run_dag_id_execution_date_key"),
  144. UniqueConstraint("dag_id", "run_id", name="dag_run_dag_id_run_id_key"),
  145. Index("idx_dag_run_dag_id", dag_id),
  146. Index(
  147. "idx_dag_run_running_dags",
  148. "state",
  149. "dag_id",
  150. postgresql_where=text("state='running'"),
  151. sqlite_where=text("state='running'"),
  152. ),
  153. # since mysql lacks filtered/partial indices, this creates a
  154. # duplicate index on mysql. Not the end of the world
  155. Index(
  156. "idx_dag_run_queued_dags",
  157. "state",
  158. "dag_id",
  159. postgresql_where=text("state='queued'"),
  160. sqlite_where=text("state='queued'"),
  161. ),
  162. )
  163. task_instances = relationship(
  164. TI, back_populates="dag_run", cascade="save-update, merge, delete, delete-orphan"
  165. )
  166. dag_model = relationship(
  167. "DagModel",
  168. primaryjoin="foreign(DagRun.dag_id) == DagModel.dag_id",
  169. uselist=False,
  170. viewonly=True,
  171. )
  172. dag_run_note = relationship(
  173. "DagRunNote",
  174. back_populates="dag_run",
  175. uselist=False,
  176. cascade="all, delete, delete-orphan",
  177. )
  178. note = association_proxy("dag_run_note", "content", creator=_creator_note)
  179. DEFAULT_DAGRUNS_TO_EXAMINE = airflow_conf.getint(
  180. "scheduler",
  181. "max_dagruns_per_loop_to_schedule",
  182. fallback=20,
  183. )
  184. def __init__(
  185. self,
  186. dag_id: str | None = None,
  187. run_id: str | None = None,
  188. queued_at: datetime | None | ArgNotSet = NOTSET,
  189. execution_date: datetime | None = None,
  190. start_date: datetime | None = None,
  191. external_trigger: bool | None = None,
  192. conf: Any | None = None,
  193. state: DagRunState | None = None,
  194. run_type: str | None = None,
  195. dag_hash: str | None = None,
  196. creating_job_id: int | None = None,
  197. data_interval: tuple[datetime, datetime] | None = None,
  198. ):
  199. if data_interval is None:
  200. # Legacy: Only happen for runs created prior to Airflow 2.2.
  201. self.data_interval_start = self.data_interval_end = None
  202. else:
  203. self.data_interval_start, self.data_interval_end = data_interval
  204. self.dag_id = dag_id
  205. self.run_id = run_id
  206. self.execution_date = execution_date
  207. self.start_date = start_date
  208. self.external_trigger = external_trigger
  209. self.conf = conf or {}
  210. if state is not None:
  211. self.state = state
  212. if queued_at is NOTSET:
  213. self.queued_at = timezone.utcnow() if state == DagRunState.QUEUED else None
  214. else:
  215. self.queued_at = queued_at
  216. self.run_type = run_type
  217. self.dag_hash = dag_hash
  218. self.creating_job_id = creating_job_id
  219. self.clear_number = 0
  220. super().__init__()
  221. def __repr__(self):
  222. return (
  223. f"<DagRun {self.dag_id} @ {self.execution_date}: {self.run_id}, state:{self.state}, "
  224. f"queued_at: {self.queued_at}. externally triggered: {self.external_trigger}>"
  225. )
  226. @validates("run_id")
  227. def validate_run_id(self, key: str, run_id: str) -> str | None:
  228. if not run_id:
  229. return None
  230. regex = airflow_conf.get("scheduler", "allowed_run_id_pattern")
  231. if not re2.match(regex, run_id) and not re2.match(RUN_ID_REGEX, run_id):
  232. raise ValueError(
  233. f"The run_id provided '{run_id}' does not match the pattern '{regex}' or '{RUN_ID_REGEX}'"
  234. )
  235. return run_id
  236. @property
  237. def stats_tags(self) -> dict[str, str]:
  238. return prune_dict({"dag_id": self.dag_id, "run_type": self.run_type})
  239. @property
  240. def logical_date(self) -> datetime:
  241. return self.execution_date
  242. def get_state(self):
  243. return self._state
  244. def set_state(self, state: DagRunState) -> None:
  245. """
  246. Change the state of the DagRan.
  247. Changes to attributes are implemented in accordance with the following table
  248. (rows represent old states, columns represent new states):
  249. .. list-table:: State transition matrix
  250. :header-rows: 1
  251. :stub-columns: 1
  252. * -
  253. - QUEUED
  254. - RUNNING
  255. - SUCCESS
  256. - FAILED
  257. * - None
  258. - queued_at = timezone.utcnow()
  259. - if empty: start_date = timezone.utcnow()
  260. end_date = None
  261. - end_date = timezone.utcnow()
  262. - end_date = timezone.utcnow()
  263. * - QUEUED
  264. - queued_at = timezone.utcnow()
  265. - if empty: start_date = timezone.utcnow()
  266. end_date = None
  267. - end_date = timezone.utcnow()
  268. - end_date = timezone.utcnow()
  269. * - RUNNING
  270. - queued_at = timezone.utcnow()
  271. start_date = None
  272. end_date = None
  273. -
  274. - end_date = timezone.utcnow()
  275. - end_date = timezone.utcnow()
  276. * - SUCCESS
  277. - queued_at = timezone.utcnow()
  278. start_date = None
  279. end_date = None
  280. - start_date = timezone.utcnow()
  281. end_date = None
  282. -
  283. -
  284. * - FAILED
  285. - queued_at = timezone.utcnow()
  286. start_date = None
  287. end_date = None
  288. - start_date = timezone.utcnow()
  289. end_date = None
  290. -
  291. -
  292. """
  293. if state not in State.dag_states:
  294. raise ValueError(f"invalid DagRun state: {state}")
  295. if self._state != state:
  296. if state == DagRunState.QUEUED:
  297. self.queued_at = timezone.utcnow()
  298. self.start_date = None
  299. self.end_date = None
  300. if state == DagRunState.RUNNING:
  301. if self._state in State.finished_dr_states:
  302. self.start_date = timezone.utcnow()
  303. else:
  304. self.start_date = self.start_date or timezone.utcnow()
  305. self.end_date = None
  306. if self._state in State.unfinished_dr_states or self._state is None:
  307. if state in State.finished_dr_states:
  308. self.end_date = timezone.utcnow()
  309. self._state = state
  310. else:
  311. if state == DagRunState.QUEUED:
  312. self.queued_at = timezone.utcnow()
  313. @declared_attr
  314. def state(self):
  315. return synonym("_state", descriptor=property(self.get_state, self.set_state))
  316. @provide_session
  317. def refresh_from_db(self, session: Session = NEW_SESSION) -> None:
  318. """
  319. Reload the current dagrun from the database.
  320. :param session: database session
  321. """
  322. dr = session.scalars(
  323. select(DagRun).where(DagRun.dag_id == self.dag_id, DagRun.run_id == self.run_id)
  324. ).one()
  325. self.id = dr.id
  326. self.state = dr.state
  327. @classmethod
  328. @provide_session
  329. def active_runs_of_dags(
  330. cls,
  331. dag_ids: Iterable[str] | None = None,
  332. only_running: bool = False,
  333. session: Session = NEW_SESSION,
  334. ) -> dict[str, int]:
  335. """Get the number of active dag runs for each dag."""
  336. query = select(cls.dag_id, func.count("*"))
  337. if dag_ids is not None:
  338. # 'set' called to avoid duplicate dag_ids, but converted back to 'list'
  339. # because SQLAlchemy doesn't accept a set here.
  340. query = query.where(cls.dag_id.in_(set(dag_ids)))
  341. if only_running:
  342. query = query.where(cls.state == DagRunState.RUNNING)
  343. else:
  344. query = query.where(cls.state.in_((DagRunState.RUNNING, DagRunState.QUEUED)))
  345. query = query.group_by(cls.dag_id)
  346. return dict(iter(session.execute(query)))
  347. @classmethod
  348. def next_dagruns_to_examine(
  349. cls,
  350. state: DagRunState,
  351. session: Session,
  352. max_number: int | None = None,
  353. ) -> Query:
  354. """
  355. Return the next DagRuns that the scheduler should attempt to schedule.
  356. This will return zero or more DagRun rows that are row-level-locked with a "SELECT ... FOR UPDATE"
  357. query, you should ensure that any scheduling decisions are made in a single transaction -- as soon as
  358. the transaction is committed it will be unlocked.
  359. """
  360. from airflow.models.dag import DagModel
  361. if max_number is None:
  362. max_number = cls.DEFAULT_DAGRUNS_TO_EXAMINE
  363. # TODO: Bake this query, it is run _A lot_
  364. query = (
  365. select(cls)
  366. .with_hint(cls, "USE INDEX (idx_dag_run_running_dags)", dialect_name="mysql")
  367. .where(cls.state == state, cls.run_type != DagRunType.BACKFILL_JOB)
  368. .join(DagModel, DagModel.dag_id == cls.dag_id)
  369. .where(DagModel.is_paused == false(), DagModel.is_active == true())
  370. )
  371. if state == DagRunState.QUEUED:
  372. # For dag runs in the queued state, we check if they have reached the max_active_runs limit
  373. # and if so we drop them
  374. running_drs = (
  375. select(DagRun.dag_id, func.count(DagRun.state).label("num_running"))
  376. .where(DagRun.state == DagRunState.RUNNING)
  377. .group_by(DagRun.dag_id)
  378. .subquery()
  379. )
  380. query = query.outerjoin(running_drs, running_drs.c.dag_id == DagRun.dag_id).where(
  381. func.coalesce(running_drs.c.num_running, 0) < DagModel.max_active_runs
  382. )
  383. query = query.order_by(
  384. nulls_first(cls.last_scheduling_decision, session=session),
  385. cls.execution_date,
  386. )
  387. if not settings.ALLOW_FUTURE_EXEC_DATES:
  388. query = query.where(DagRun.execution_date <= func.now())
  389. return session.scalars(
  390. with_row_locks(query.limit(max_number), of=cls, session=session, skip_locked=True)
  391. )
  392. @classmethod
  393. @provide_session
  394. def find(
  395. cls,
  396. dag_id: str | list[str] | None = None,
  397. run_id: Iterable[str] | None = None,
  398. execution_date: datetime | Iterable[datetime] | None = None,
  399. state: DagRunState | None = None,
  400. external_trigger: bool | None = None,
  401. no_backfills: bool = False,
  402. run_type: DagRunType | None = None,
  403. session: Session = NEW_SESSION,
  404. execution_start_date: datetime | None = None,
  405. execution_end_date: datetime | None = None,
  406. ) -> list[DagRun]:
  407. """
  408. Return a set of dag runs for the given search criteria.
  409. :param dag_id: the dag_id or list of dag_id to find dag runs for
  410. :param run_id: defines the run id for this dag run
  411. :param run_type: type of DagRun
  412. :param execution_date: the execution date
  413. :param state: the state of the dag run
  414. :param external_trigger: whether this dag run is externally triggered
  415. :param no_backfills: return no backfills (True), return all (False).
  416. Defaults to False
  417. :param session: database session
  418. :param execution_start_date: dag run that was executed from this date
  419. :param execution_end_date: dag run that was executed until this date
  420. """
  421. qry = select(cls)
  422. dag_ids = [dag_id] if isinstance(dag_id, str) else dag_id
  423. if dag_ids:
  424. qry = qry.where(cls.dag_id.in_(dag_ids))
  425. if is_container(run_id):
  426. qry = qry.where(cls.run_id.in_(run_id))
  427. elif run_id is not None:
  428. qry = qry.where(cls.run_id == run_id)
  429. if is_container(execution_date):
  430. qry = qry.where(cls.execution_date.in_(execution_date))
  431. elif execution_date is not None:
  432. qry = qry.where(cls.execution_date == execution_date)
  433. if execution_start_date and execution_end_date:
  434. qry = qry.where(cls.execution_date.between(execution_start_date, execution_end_date))
  435. elif execution_start_date:
  436. qry = qry.where(cls.execution_date >= execution_start_date)
  437. elif execution_end_date:
  438. qry = qry.where(cls.execution_date <= execution_end_date)
  439. if state:
  440. qry = qry.where(cls.state == state)
  441. if external_trigger is not None:
  442. qry = qry.where(cls.external_trigger == external_trigger)
  443. if run_type:
  444. qry = qry.where(cls.run_type == run_type)
  445. if no_backfills:
  446. qry = qry.where(cls.run_type != DagRunType.BACKFILL_JOB)
  447. return session.scalars(qry.order_by(cls.execution_date)).all()
  448. @classmethod
  449. @provide_session
  450. def find_duplicate(
  451. cls,
  452. dag_id: str,
  453. run_id: str,
  454. execution_date: datetime,
  455. session: Session = NEW_SESSION,
  456. ) -> DagRun | None:
  457. """
  458. Return an existing run for the DAG with a specific run_id or execution_date.
  459. *None* is returned if no such DAG run is found.
  460. :param dag_id: the dag_id to find duplicates for
  461. :param run_id: defines the run id for this dag run
  462. :param execution_date: the execution date
  463. :param session: database session
  464. """
  465. return session.scalars(
  466. select(cls).where(
  467. cls.dag_id == dag_id,
  468. or_(cls.run_id == run_id, cls.execution_date == execution_date),
  469. )
  470. ).one_or_none()
  471. @staticmethod
  472. def generate_run_id(run_type: DagRunType, execution_date: datetime) -> str:
  473. """Generate Run ID based on Run Type and Execution Date."""
  474. # _Ensure_ run_type is a DagRunType, not just a string from user code
  475. return DagRunType(run_type).generate_run_id(execution_date)
  476. @staticmethod
  477. @internal_api_call
  478. @provide_session
  479. def fetch_task_instances(
  480. dag_id: str | None = None,
  481. run_id: str | None = None,
  482. task_ids: list[str] | None = None,
  483. state: Iterable[TaskInstanceState | None] | None = None,
  484. session: Session = NEW_SESSION,
  485. ) -> list[TI]:
  486. """Return the task instances for this dag run."""
  487. tis = (
  488. select(TI)
  489. .options(joinedload(TI.dag_run))
  490. .where(
  491. TI.dag_id == dag_id,
  492. TI.run_id == run_id,
  493. )
  494. )
  495. if state:
  496. if isinstance(state, str):
  497. tis = tis.where(TI.state == state)
  498. else:
  499. # this is required to deal with NULL values
  500. if None in state:
  501. if all(x is None for x in state):
  502. tis = tis.where(TI.state.is_(None))
  503. else:
  504. not_none_state = (s for s in state if s)
  505. tis = tis.where(or_(TI.state.in_(not_none_state), TI.state.is_(None)))
  506. else:
  507. tis = tis.where(TI.state.in_(state))
  508. if task_ids is not None:
  509. tis = tis.where(TI.task_id.in_(task_ids))
  510. return session.scalars(tis).all()
  511. @internal_api_call
  512. def _check_last_n_dagruns_failed(self, dag_id, max_consecutive_failed_dag_runs, session):
  513. """Check if last N dags failed."""
  514. dag_runs = (
  515. session.query(DagRun)
  516. .filter(DagRun.dag_id == dag_id)
  517. .order_by(DagRun.execution_date.desc())
  518. .limit(max_consecutive_failed_dag_runs)
  519. .all()
  520. )
  521. """ Marking dag as paused, if needed"""
  522. to_be_paused = len(dag_runs) >= max_consecutive_failed_dag_runs and all(
  523. dag_run.state == DagRunState.FAILED for dag_run in dag_runs
  524. )
  525. if to_be_paused:
  526. from airflow.models.dag import DagModel
  527. self.log.info(
  528. "Pausing DAG %s because last %s DAG runs failed.",
  529. self.dag_id,
  530. max_consecutive_failed_dag_runs,
  531. )
  532. filter_query = [
  533. DagModel.dag_id == self.dag_id,
  534. DagModel.root_dag_id == self.dag_id, # for sub-dags
  535. ]
  536. session.execute(
  537. update(DagModel)
  538. .where(or_(*filter_query))
  539. .values(is_paused=True)
  540. .execution_options(synchronize_session="fetch")
  541. )
  542. session.add(
  543. Log(
  544. event="paused",
  545. dag_id=self.dag_id,
  546. owner="scheduler",
  547. owner_display_name="Scheduler",
  548. extra=f"[('dag_id', '{self.dag_id}'), ('is_paused', True)]",
  549. )
  550. )
  551. else:
  552. self.log.debug(
  553. "Limit of consecutive DAG failed dag runs is not reached, DAG %s will not be paused.",
  554. self.dag_id,
  555. )
  556. @provide_session
  557. def get_task_instances(
  558. self,
  559. state: Iterable[TaskInstanceState | None] | None = None,
  560. session: Session = NEW_SESSION,
  561. ) -> list[TI]:
  562. """
  563. Return the task instances for this dag run.
  564. Redirect to DagRun.fetch_task_instances method.
  565. Keep this method because it is widely used across the code.
  566. """
  567. task_ids = DagRun._get_partial_task_ids(self.dag)
  568. return DagRun.fetch_task_instances(
  569. dag_id=self.dag_id, run_id=self.run_id, task_ids=task_ids, state=state, session=session
  570. )
  571. @provide_session
  572. def get_task_instance(
  573. self,
  574. task_id: str,
  575. session: Session = NEW_SESSION,
  576. *,
  577. map_index: int = -1,
  578. ) -> TI | TaskInstancePydantic | None:
  579. """
  580. Return the task instance specified by task_id for this dag run.
  581. :param task_id: the task id
  582. :param session: Sqlalchemy ORM Session
  583. """
  584. return DagRun.fetch_task_instance(
  585. dag_id=self.dag_id,
  586. dag_run_id=self.run_id,
  587. task_id=task_id,
  588. session=session,
  589. map_index=map_index,
  590. )
  591. @staticmethod
  592. @internal_api_call
  593. @provide_session
  594. def fetch_task_instance(
  595. dag_id: str,
  596. dag_run_id: str,
  597. task_id: str,
  598. session: Session = NEW_SESSION,
  599. map_index: int = -1,
  600. ) -> TI | TaskInstancePydantic | None:
  601. """
  602. Return the task instance specified by task_id for this dag run.
  603. :param dag_id: the DAG id
  604. :param dag_run_id: the DAG run id
  605. :param task_id: the task id
  606. :param session: Sqlalchemy ORM Session
  607. """
  608. return session.scalars(
  609. select(TI).filter_by(dag_id=dag_id, run_id=dag_run_id, task_id=task_id, map_index=map_index)
  610. ).one_or_none()
  611. def get_dag(self) -> DAG:
  612. """
  613. Return the Dag associated with this DagRun.
  614. :return: DAG
  615. """
  616. if not self.dag:
  617. raise AirflowException(f"The DAG (.dag) for {self} needs to be set")
  618. return self.dag
  619. @staticmethod
  620. @internal_api_call
  621. @provide_session
  622. def get_previous_dagrun(
  623. dag_run: DagRun | DagRunPydantic, state: DagRunState | None = None, session: Session = NEW_SESSION
  624. ) -> DagRun | None:
  625. """
  626. Return the previous DagRun, if there is one.
  627. :param dag_run: the dag run
  628. :param session: SQLAlchemy ORM Session
  629. :param state: the dag run state
  630. """
  631. filters = [
  632. DagRun.dag_id == dag_run.dag_id,
  633. DagRun.execution_date < dag_run.execution_date,
  634. ]
  635. if state is not None:
  636. filters.append(DagRun.state == state)
  637. return session.scalar(select(DagRun).where(*filters).order_by(DagRun.execution_date.desc()).limit(1))
  638. @staticmethod
  639. @internal_api_call
  640. @provide_session
  641. def get_previous_scheduled_dagrun(
  642. dag_run_id: int,
  643. session: Session = NEW_SESSION,
  644. ) -> DagRun | None:
  645. """
  646. Return the previous SCHEDULED DagRun, if there is one.
  647. :param dag_run_id: the DAG run ID
  648. :param session: SQLAlchemy ORM Session
  649. """
  650. dag_run = session.get(DagRun, dag_run_id)
  651. return session.scalar(
  652. select(DagRun)
  653. .where(
  654. DagRun.dag_id == dag_run.dag_id,
  655. DagRun.execution_date < dag_run.execution_date,
  656. DagRun.run_type != DagRunType.MANUAL,
  657. )
  658. .order_by(DagRun.execution_date.desc())
  659. .limit(1)
  660. )
  661. def _tis_for_dagrun_state(self, *, dag, tis):
  662. """
  663. Return the collection of tasks that should be considered for evaluation of terminal dag run state.
  664. Teardown tasks by default are not considered for the purpose of dag run state. But
  665. users may enable such consideration with on_failure_fail_dagrun.
  666. """
  667. def is_effective_leaf(task):
  668. for down_task_id in task.downstream_task_ids:
  669. down_task = dag.get_task(down_task_id)
  670. if not down_task.is_teardown or down_task.on_failure_fail_dagrun:
  671. # we found a down task that is not ignorable; not a leaf
  672. return False
  673. # we found no ignorable downstreams
  674. # evaluate whether task is itself ignorable
  675. return not task.is_teardown or task.on_failure_fail_dagrun
  676. leaf_task_ids = {x.task_id for x in dag.tasks if is_effective_leaf(x)}
  677. if not leaf_task_ids:
  678. # can happen if dag is exclusively teardown tasks
  679. leaf_task_ids = {x.task_id for x in dag.tasks if not x.downstream_list}
  680. leaf_tis = {ti for ti in tis if ti.task_id in leaf_task_ids if ti.state != TaskInstanceState.REMOVED}
  681. return leaf_tis
  682. @provide_session
  683. def update_state(
  684. self, session: Session = NEW_SESSION, execute_callbacks: bool = True
  685. ) -> tuple[list[TI], DagCallbackRequest | None]:
  686. """
  687. Determine the overall state of the DagRun based on the state of its TaskInstances.
  688. :param session: Sqlalchemy ORM Session
  689. :param execute_callbacks: Should dag callbacks (success/failure, SLA etc.) be invoked
  690. directly (default: true) or recorded as a pending request in the ``returned_callback`` property
  691. :return: Tuple containing tis that can be scheduled in the current loop & `returned_callback` that
  692. needs to be executed
  693. """
  694. # Callback to execute in case of Task Failures
  695. callback: DagCallbackRequest | None = None
  696. class _UnfinishedStates(NamedTuple):
  697. tis: Sequence[TI]
  698. @classmethod
  699. def calculate(cls, unfinished_tis: Sequence[TI]) -> _UnfinishedStates:
  700. return cls(tis=unfinished_tis)
  701. @property
  702. def should_schedule(self) -> bool:
  703. return (
  704. bool(self.tis)
  705. and all(not t.task.depends_on_past for t in self.tis) # type: ignore[union-attr]
  706. and all(t.task.max_active_tis_per_dag is None for t in self.tis) # type: ignore[union-attr]
  707. and all(t.task.max_active_tis_per_dagrun is None for t in self.tis) # type: ignore[union-attr]
  708. and all(t.state != TaskInstanceState.DEFERRED for t in self.tis)
  709. )
  710. def recalculate(self) -> _UnfinishedStates:
  711. return self._replace(tis=[t for t in self.tis if t.state in State.unfinished])
  712. start_dttm = timezone.utcnow()
  713. self.last_scheduling_decision = start_dttm
  714. with Stats.timer(f"dagrun.dependency-check.{self.dag_id}"), Stats.timer(
  715. "dagrun.dependency-check", tags=self.stats_tags
  716. ):
  717. dag = self.get_dag()
  718. info = self.task_instance_scheduling_decisions(session)
  719. tis = info.tis
  720. schedulable_tis = info.schedulable_tis
  721. changed_tis = info.changed_tis
  722. finished_tis = info.finished_tis
  723. unfinished = _UnfinishedStates.calculate(info.unfinished_tis)
  724. if unfinished.should_schedule:
  725. are_runnable_tasks = schedulable_tis or changed_tis
  726. # small speed up
  727. if not are_runnable_tasks:
  728. are_runnable_tasks, changed_by_upstream = self._are_premature_tis(
  729. unfinished.tis, finished_tis, session
  730. )
  731. if changed_by_upstream: # Something changed, we need to recalculate!
  732. unfinished = unfinished.recalculate()
  733. tis_for_dagrun_state = self._tis_for_dagrun_state(dag=dag, tis=tis)
  734. # if all tasks finished and at least one failed, the run failed
  735. if not unfinished.tis and any(x.state in State.failed_states for x in tis_for_dagrun_state):
  736. self.log.error("Marking run %s failed", self)
  737. self.set_state(DagRunState.FAILED)
  738. self.notify_dagrun_state_changed(msg="task_failure")
  739. if execute_callbacks:
  740. dag.handle_callback(self, success=False, reason="task_failure", session=session)
  741. elif dag.has_on_failure_callback:
  742. from airflow.models.dag import DagModel
  743. dag_model = DagModel.get_dagmodel(dag.dag_id, session)
  744. callback = DagCallbackRequest(
  745. full_filepath=dag.fileloc,
  746. dag_id=self.dag_id,
  747. run_id=self.run_id,
  748. is_failure_callback=True,
  749. processor_subdir=None if dag_model is None else dag_model.processor_subdir,
  750. msg="task_failure",
  751. )
  752. # Check if the max_consecutive_failed_dag_runs has been provided and not 0
  753. # and last consecutive failures are more
  754. if dag.max_consecutive_failed_dag_runs > 0:
  755. self.log.debug(
  756. "Checking consecutive failed DAG runs for DAG %s, limit is %s",
  757. self.dag_id,
  758. dag.max_consecutive_failed_dag_runs,
  759. )
  760. self._check_last_n_dagruns_failed(dag.dag_id, dag.max_consecutive_failed_dag_runs, session)
  761. # if all leaves succeeded and no unfinished tasks, the run succeeded
  762. elif not unfinished.tis and all(x.state in State.success_states for x in tis_for_dagrun_state):
  763. self.log.info("Marking run %s successful", self)
  764. self.set_state(DagRunState.SUCCESS)
  765. self.notify_dagrun_state_changed(msg="success")
  766. if execute_callbacks:
  767. dag.handle_callback(self, success=True, reason="success", session=session)
  768. elif dag.has_on_success_callback:
  769. from airflow.models.dag import DagModel
  770. dag_model = DagModel.get_dagmodel(dag.dag_id, session)
  771. callback = DagCallbackRequest(
  772. full_filepath=dag.fileloc,
  773. dag_id=self.dag_id,
  774. run_id=self.run_id,
  775. is_failure_callback=False,
  776. processor_subdir=None if dag_model is None else dag_model.processor_subdir,
  777. msg="success",
  778. )
  779. # if *all tasks* are deadlocked, the run failed
  780. elif unfinished.should_schedule and not are_runnable_tasks:
  781. self.log.error("Task deadlock (no runnable tasks); marking run %s failed", self)
  782. self.set_state(DagRunState.FAILED)
  783. self.notify_dagrun_state_changed(msg="all_tasks_deadlocked")
  784. if execute_callbacks:
  785. dag.handle_callback(self, success=False, reason="all_tasks_deadlocked", session=session)
  786. elif dag.has_on_failure_callback:
  787. from airflow.models.dag import DagModel
  788. dag_model = DagModel.get_dagmodel(dag.dag_id, session)
  789. callback = DagCallbackRequest(
  790. full_filepath=dag.fileloc,
  791. dag_id=self.dag_id,
  792. run_id=self.run_id,
  793. is_failure_callback=True,
  794. processor_subdir=None if dag_model is None else dag_model.processor_subdir,
  795. msg="all_tasks_deadlocked",
  796. )
  797. # finally, if the leaves aren't done, the dag is still running
  798. else:
  799. self.set_state(DagRunState.RUNNING)
  800. if self._state == DagRunState.FAILED or self._state == DagRunState.SUCCESS:
  801. msg = (
  802. "DagRun Finished: dag_id=%s, execution_date=%s, run_id=%s, "
  803. "run_start_date=%s, run_end_date=%s, run_duration=%s, "
  804. "state=%s, external_trigger=%s, run_type=%s, "
  805. "data_interval_start=%s, data_interval_end=%s, dag_hash=%s"
  806. )
  807. self.log.info(
  808. msg,
  809. self.dag_id,
  810. self.execution_date,
  811. self.run_id,
  812. self.start_date,
  813. self.end_date,
  814. (
  815. (self.end_date - self.start_date).total_seconds()
  816. if self.start_date and self.end_date
  817. else None
  818. ),
  819. self._state,
  820. self.external_trigger,
  821. self.run_type,
  822. self.data_interval_start,
  823. self.data_interval_end,
  824. self.dag_hash,
  825. )
  826. with Trace.start_span_from_dagrun(dagrun=self) as span:
  827. if self._state is DagRunState.FAILED:
  828. span.set_attribute("error", True)
  829. attributes = {
  830. "category": "DAG runs",
  831. "dag_id": str(self.dag_id),
  832. "execution_date": str(self.execution_date),
  833. "run_id": str(self.run_id),
  834. "queued_at": str(self.queued_at),
  835. "run_start_date": str(self.start_date),
  836. "run_end_date": str(self.end_date),
  837. "run_duration": str(
  838. (self.end_date - self.start_date).total_seconds()
  839. if self.start_date and self.end_date
  840. else 0
  841. ),
  842. "state": str(self._state),
  843. "external_trigger": str(self.external_trigger),
  844. "run_type": str(self.run_type),
  845. "data_interval_start": str(self.data_interval_start),
  846. "data_interval_end": str(self.data_interval_end),
  847. "dag_hash": str(self.dag_hash),
  848. "conf": str(self.conf),
  849. }
  850. if span.is_recording():
  851. span.add_event(name="queued", timestamp=datetime_to_nano(self.queued_at))
  852. span.add_event(name="started", timestamp=datetime_to_nano(self.start_date))
  853. span.add_event(name="ended", timestamp=datetime_to_nano(self.end_date))
  854. span.set_attributes(attributes)
  855. session.flush()
  856. self._emit_true_scheduling_delay_stats_for_finished_state(finished_tis)
  857. self._emit_duration_stats_for_finished_state()
  858. session.merge(self)
  859. # We do not flush here for performance reasons(It increases queries count by +20)
  860. return schedulable_tis, callback
  861. @provide_session
  862. def task_instance_scheduling_decisions(self, session: Session = NEW_SESSION) -> TISchedulingDecision:
  863. tis = self.get_task_instances(session=session, state=State.task_states)
  864. self.log.debug("number of tis tasks for %s: %s task(s)", self, len(tis))
  865. def _filter_tis_and_exclude_removed(dag: DAG, tis: list[TI]) -> Iterable[TI]:
  866. """Populate ``ti.task`` while excluding those missing one, marking them as REMOVED."""
  867. for ti in tis:
  868. try:
  869. ti.task = dag.get_task(ti.task_id)
  870. except TaskNotFound:
  871. if ti.state != TaskInstanceState.REMOVED:
  872. self.log.error("Failed to get task for ti %s. Marking it as removed.", ti)
  873. ti.state = TaskInstanceState.REMOVED
  874. session.flush()
  875. else:
  876. yield ti
  877. tis = list(_filter_tis_and_exclude_removed(self.get_dag(), tis))
  878. unfinished_tis = [t for t in tis if t.state in State.unfinished]
  879. finished_tis = [t for t in tis if t.state in State.finished]
  880. if unfinished_tis:
  881. schedulable_tis = [ut for ut in unfinished_tis if ut.state in SCHEDULEABLE_STATES]
  882. self.log.debug("number of scheduleable tasks for %s: %s task(s)", self, len(schedulable_tis))
  883. schedulable_tis, changed_tis, expansion_happened = self._get_ready_tis(
  884. schedulable_tis,
  885. finished_tis,
  886. session=session,
  887. )
  888. # During expansion, we may change some tis into non-schedulable
  889. # states, so we need to re-compute.
  890. if expansion_happened:
  891. changed_tis = True
  892. new_unfinished_tis = [t for t in unfinished_tis if t.state in State.unfinished]
  893. finished_tis.extend(t for t in unfinished_tis if t.state in State.finished)
  894. unfinished_tis = new_unfinished_tis
  895. else:
  896. schedulable_tis = []
  897. changed_tis = False
  898. return TISchedulingDecision(
  899. tis=tis,
  900. schedulable_tis=schedulable_tis,
  901. changed_tis=changed_tis,
  902. unfinished_tis=unfinished_tis,
  903. finished_tis=finished_tis,
  904. )
  905. def notify_dagrun_state_changed(self, msg: str = ""):
  906. if self.state == DagRunState.RUNNING:
  907. get_listener_manager().hook.on_dag_run_running(dag_run=self, msg=msg)
  908. elif self.state == DagRunState.SUCCESS:
  909. get_listener_manager().hook.on_dag_run_success(dag_run=self, msg=msg)
  910. elif self.state == DagRunState.FAILED:
  911. get_listener_manager().hook.on_dag_run_failed(dag_run=self, msg=msg)
  912. # deliberately not notifying on QUEUED
  913. # we can't get all the state changes on SchedulerJob, BackfillJob
  914. # or LocalTaskJob, so we don't want to "falsely advertise" we notify about that
  915. def _get_ready_tis(
  916. self,
  917. schedulable_tis: list[TI],
  918. finished_tis: list[TI],
  919. session: Session,
  920. ) -> tuple[list[TI], bool, bool]:
  921. old_states = {}
  922. ready_tis: list[TI] = []
  923. changed_tis = False
  924. if not schedulable_tis:
  925. return ready_tis, changed_tis, False
  926. # If we expand TIs, we need a new list so that we iterate over them too. (We can't alter
  927. # `schedulable_tis` in place and have the `for` loop pick them up
  928. additional_tis: list[TI] = []
  929. dep_context = DepContext(
  930. flag_upstream_failed=True,
  931. ignore_unmapped_tasks=True, # Ignore this Dep, as we will expand it if we can.
  932. finished_tis=finished_tis,
  933. )
  934. def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None:
  935. """
  936. Try to expand the ti, if needed.
  937. If the ti needs expansion, newly created task instances are
  938. returned as well as the original ti.
  939. The original ti is also modified in-place and assigned the
  940. ``map_index`` of 0.
  941. If the ti does not need expansion, either because the task is not
  942. mapped, or has already been expanded, *None* is returned.
  943. """
  944. if TYPE_CHECKING:
  945. assert ti.task
  946. if ti.map_index >= 0: # Already expanded, we're good.
  947. return None
  948. from airflow.models.mappedoperator import MappedOperator
  949. if isinstance(ti.task, MappedOperator):
  950. # If we get here, it could be that we are moving from non-mapped to mapped
  951. # after task instance clearing or this ti is not yet expanded. Safe to clear
  952. # the db references.
  953. ti.clear_db_references(session=session)
  954. try:
  955. expanded_tis, _ = ti.task.expand_mapped_task(self.run_id, session=session)
  956. except NotMapped: # Not a mapped task, nothing needed.
  957. return None
  958. if expanded_tis:
  959. return expanded_tis
  960. return ()
  961. # Check dependencies.
  962. expansion_happened = False
  963. # Set of task ids for which was already done _revise_map_indexes_if_mapped
  964. revised_map_index_task_ids = set()
  965. for schedulable in itertools.chain(schedulable_tis, additional_tis):
  966. if TYPE_CHECKING:
  967. assert schedulable.task
  968. old_state = schedulable.state
  969. if not schedulable.are_dependencies_met(session=session, dep_context=dep_context):
  970. old_states[schedulable.key] = old_state
  971. continue
  972. # If schedulable is not yet expanded, try doing it now. This is
  973. # called in two places: First and ideally in the mini scheduler at
  974. # the end of LocalTaskJob, and then as an "expansion of last resort"
  975. # in the scheduler to ensure that the mapped task is correctly
  976. # expanded before executed. Also see _revise_map_indexes_if_mapped
  977. # docstring for additional information.
  978. new_tis = None
  979. if schedulable.map_index < 0:
  980. new_tis = _expand_mapped_task_if_needed(schedulable)
  981. if new_tis is not None:
  982. additional_tis.extend(new_tis)
  983. expansion_happened = True
  984. if new_tis is None and schedulable.state in SCHEDULEABLE_STATES:
  985. # It's enough to revise map index once per task id,
  986. # checking the map index for each mapped task significantly slows down scheduling
  987. if schedulable.task.task_id not in revised_map_index_task_ids:
  988. ready_tis.extend(self._revise_map_indexes_if_mapped(schedulable.task, session=session))
  989. revised_map_index_task_ids.add(schedulable.task.task_id)
  990. ready_tis.append(schedulable)
  991. # Check if any ti changed state
  992. tis_filter = TI.filter_for_tis(old_states)
  993. if tis_filter is not None:
  994. fresh_tis = session.scalars(select(TI).where(tis_filter)).all()
  995. changed_tis = any(ti.state != old_states[ti.key] for ti in fresh_tis)
  996. return ready_tis, changed_tis, expansion_happened
  997. def _are_premature_tis(
  998. self,
  999. unfinished_tis: Sequence[TI],
  1000. finished_tis: list[TI],
  1001. session: Session,
  1002. ) -> tuple[bool, bool]:
  1003. dep_context = DepContext(
  1004. flag_upstream_failed=True,
  1005. ignore_in_retry_period=True,
  1006. ignore_in_reschedule_period=True,
  1007. finished_tis=finished_tis,
  1008. )
  1009. # there might be runnable tasks that are up for retry and for some reason(retry delay, etc.) are
  1010. # not ready yet, so we set the flags to count them in
  1011. return (
  1012. any(ut.are_dependencies_met(dep_context=dep_context, session=session) for ut in unfinished_tis),
  1013. dep_context.have_changed_ti_states,
  1014. )
  1015. def _emit_true_scheduling_delay_stats_for_finished_state(self, finished_tis: list[TI]) -> None:
  1016. """
  1017. Emit the true scheduling delay stats.
  1018. The true scheduling delay stats is defined as the time when the first
  1019. task in DAG starts minus the expected DAG run datetime.
  1020. This helper method is used in ``update_state`` when the state of the
  1021. DAG run is updated to a completed status (either success or failure).
  1022. It finds the first started task within the DAG, calculates the run's
  1023. expected start time based on the logical date and timetable, and gets
  1024. the delay from the difference of these two values.
  1025. The emitted data may contain outliers (e.g. when the first task was
  1026. cleared, so the second task's start date will be used), but we can get
  1027. rid of the outliers on the stats side through dashboards tooling.
  1028. Note that the stat will only be emitted for scheduler-triggered DAG runs
  1029. (i.e. when ``external_trigger`` is *False* and ``clear_number`` is equal to 0).
  1030. """
  1031. if self.state == TaskInstanceState.RUNNING:
  1032. return
  1033. if self.external_trigger:
  1034. return
  1035. if self.clear_number > 0:
  1036. return
  1037. if not finished_tis:
  1038. return
  1039. try:
  1040. dag = self.get_dag()
  1041. if not dag.timetable.periodic:
  1042. # We can't emit this metric if there is no following schedule to calculate from!
  1043. return
  1044. try:
  1045. first_start_date = min(ti.start_date for ti in finished_tis if ti.start_date)
  1046. except ValueError: # No start dates at all.
  1047. pass
  1048. else:
  1049. # TODO: Logically, this should be DagRunInfo.run_after, but the
  1050. # information is not stored on a DagRun, only before the actual
  1051. # execution on DagModel.next_dagrun_create_after. We should add
  1052. # a field on DagRun for this instead of relying on the run
  1053. # always happening immediately after the data interval.
  1054. data_interval_end = dag.get_run_data_interval(self).end
  1055. true_delay = first_start_date - data_interval_end
  1056. if true_delay.total_seconds() > 0:
  1057. Stats.timing(
  1058. f"dagrun.{dag.dag_id}.first_task_scheduling_delay", true_delay, tags=self.stats_tags
  1059. )
  1060. Stats.timing("dagrun.first_task_scheduling_delay", true_delay, tags=self.stats_tags)
  1061. except Exception:
  1062. self.log.warning("Failed to record first_task_scheduling_delay metric:", exc_info=True)
  1063. def _emit_duration_stats_for_finished_state(self):
  1064. if self.state == DagRunState.RUNNING:
  1065. return
  1066. if self.start_date is None:
  1067. self.log.warning("Failed to record duration of %s: start_date is not set.", self)
  1068. return
  1069. if self.end_date is None:
  1070. self.log.warning("Failed to record duration of %s: end_date is not set.", self)
  1071. return
  1072. duration = self.end_date - self.start_date
  1073. timer_params = {"dt": duration, "tags": self.stats_tags}
  1074. Stats.timing(f"dagrun.duration.{self.state}.{self.dag_id}", **timer_params)
  1075. Stats.timing(f"dagrun.duration.{self.state}", **timer_params)
  1076. @provide_session
  1077. def verify_integrity(self, *, session: Session = NEW_SESSION) -> None:
  1078. """
  1079. Verify the DagRun by checking for removed tasks or tasks that are not in the database yet.
  1080. It will set state to removed or add the task if required.
  1081. :missing_indexes: A dictionary of task vs indexes that are missing.
  1082. :param session: Sqlalchemy ORM Session
  1083. """
  1084. from airflow.settings import task_instance_mutation_hook
  1085. # Set for the empty default in airflow.settings -- if it's not set this means it has been changed
  1086. # Note: Literal[True, False] instead of bool because otherwise it doesn't correctly find the overload.
  1087. hook_is_noop: Literal[True, False] = getattr(task_instance_mutation_hook, "is_noop", False)
  1088. dag = self.get_dag()
  1089. task_ids = self._check_for_removed_or_restored_tasks(
  1090. dag, task_instance_mutation_hook, session=session
  1091. )
  1092. def task_filter(task: Operator) -> bool:
  1093. return task.task_id not in task_ids and (
  1094. self.is_backfill
  1095. or (task.start_date is None or task.start_date <= self.execution_date)
  1096. and (task.end_date is None or self.execution_date <= task.end_date)
  1097. )
  1098. created_counts: dict[str, int] = defaultdict(int)
  1099. task_creator = self._get_task_creator(created_counts, task_instance_mutation_hook, hook_is_noop)
  1100. # Create the missing tasks, including mapped tasks
  1101. tasks_to_create = (task for task in dag.task_dict.values() if task_filter(task))
  1102. tis_to_create = self._create_tasks(tasks_to_create, task_creator, session=session)
  1103. self._create_task_instances(self.dag_id, tis_to_create, created_counts, hook_is_noop, session=session)
  1104. def _check_for_removed_or_restored_tasks(
  1105. self, dag: DAG, ti_mutation_hook, *, session: Session
  1106. ) -> set[str]:
  1107. """
  1108. Check for removed tasks/restored/missing tasks.
  1109. :param dag: DAG object corresponding to the dagrun
  1110. :param ti_mutation_hook: task_instance_mutation_hook function
  1111. :param session: Sqlalchemy ORM Session
  1112. :return: Task IDs in the DAG run
  1113. """
  1114. tis = self.get_task_instances(session=session)
  1115. # check for removed or restored tasks
  1116. task_ids = set()
  1117. for ti in tis:
  1118. ti_mutation_hook(ti)
  1119. task_ids.add(ti.task_id)
  1120. try:
  1121. task = dag.get_task(ti.task_id)
  1122. should_restore_task = (task is not None) and ti.state == TaskInstanceState.REMOVED
  1123. if should_restore_task:
  1124. self.log.info("Restoring task '%s' which was previously removed from DAG '%s'", ti, dag)
  1125. Stats.incr(f"task_restored_to_dag.{dag.dag_id}", tags=self.stats_tags)
  1126. # Same metric with tagging
  1127. Stats.incr("task_restored_to_dag", tags={**self.stats_tags, "dag_id": dag.dag_id})
  1128. ti.state = None
  1129. except AirflowException:
  1130. if ti.state == TaskInstanceState.REMOVED:
  1131. pass # ti has already been removed, just ignore it
  1132. elif self.state != DagRunState.RUNNING and not dag.partial:
  1133. self.log.warning("Failed to get task '%s' for dag '%s'. Marking it as removed.", ti, dag)
  1134. Stats.incr(f"task_removed_from_dag.{dag.dag_id}", tags=self.stats_tags)
  1135. # Same metric with tagging
  1136. Stats.incr("task_removed_from_dag", tags={**self.stats_tags, "dag_id": dag.dag_id})
  1137. ti.state = TaskInstanceState.REMOVED
  1138. continue
  1139. try:
  1140. num_mapped_tis = task.get_parse_time_mapped_ti_count()
  1141. except NotMapped:
  1142. continue
  1143. except NotFullyPopulated:
  1144. # What if it is _now_ dynamically mapped, but wasn't before?
  1145. try:
  1146. total_length = task.get_mapped_ti_count(self.run_id, session=session)
  1147. except NotFullyPopulated:
  1148. # Not all upstreams finished, so we can't tell what should be here. Remove everything.
  1149. if ti.map_index >= 0:
  1150. self.log.debug(
  1151. "Removing the unmapped TI '%s' as the mapping can't be resolved yet", ti
  1152. )
  1153. ti.state = TaskInstanceState.REMOVED
  1154. continue
  1155. # Upstreams finished, check there aren't any extras
  1156. if ti.map_index >= total_length:
  1157. self.log.debug(
  1158. "Removing task '%s' as the map_index is longer than the resolved mapping list (%d)",
  1159. ti,
  1160. total_length,
  1161. )
  1162. ti.state = TaskInstanceState.REMOVED
  1163. else:
  1164. # Check if the number of mapped literals has changed, and we need to mark this TI as removed.
  1165. if ti.map_index >= num_mapped_tis:
  1166. self.log.debug(
  1167. "Removing task '%s' as the map_index is longer than the literal mapping list (%s)",
  1168. ti,
  1169. num_mapped_tis,
  1170. )
  1171. ti.state = TaskInstanceState.REMOVED
  1172. elif ti.map_index < 0:
  1173. self.log.debug("Removing the unmapped TI '%s' as the mapping can now be performed", ti)
  1174. ti.state = TaskInstanceState.REMOVED
  1175. return task_ids
  1176. @overload
  1177. def _get_task_creator(
  1178. self,
  1179. created_counts: dict[str, int],
  1180. ti_mutation_hook: Callable,
  1181. hook_is_noop: Literal[True],
  1182. ) -> Callable[[Operator, Iterable[int]], Iterator[dict[str, Any]]]: ...
  1183. @overload
  1184. def _get_task_creator(
  1185. self,
  1186. created_counts: dict[str, int],
  1187. ti_mutation_hook: Callable,
  1188. hook_is_noop: Literal[False],
  1189. ) -> Callable[[Operator, Iterable[int]], Iterator[TI]]: ...
  1190. def _get_task_creator(
  1191. self,
  1192. created_counts: dict[str, int],
  1193. ti_mutation_hook: Callable,
  1194. hook_is_noop: Literal[True, False],
  1195. ) -> Callable[[Operator, Iterable[int]], Iterator[dict[str, Any]] | Iterator[TI]]:
  1196. """
  1197. Get the task creator function.
  1198. This function also updates the created_counts dictionary with the number of tasks created.
  1199. :param created_counts: Dictionary of task_type -> count of created TIs
  1200. :param ti_mutation_hook: task_instance_mutation_hook function
  1201. :param hook_is_noop: Whether the task_instance_mutation_hook is a noop
  1202. """
  1203. if hook_is_noop:
  1204. def create_ti_mapping(task: Operator, indexes: Iterable[int]) -> Iterator[dict[str, Any]]:
  1205. created_counts[task.task_type] += 1
  1206. for map_index in indexes:
  1207. yield TI.insert_mapping(self.run_id, task, map_index=map_index)
  1208. creator = create_ti_mapping
  1209. else:
  1210. def create_ti(task: Operator, indexes: Iterable[int]) -> Iterator[TI]:
  1211. for map_index in indexes:
  1212. ti = TI(task, run_id=self.run_id, map_index=map_index)
  1213. ti_mutation_hook(ti)
  1214. created_counts[ti.operator] += 1
  1215. yield ti
  1216. creator = create_ti
  1217. return creator
  1218. def _create_tasks(
  1219. self,
  1220. tasks: Iterable[Operator],
  1221. task_creator: Callable[[Operator, Iterable[int]], CreatedTasks],
  1222. *,
  1223. session: Session,
  1224. ) -> CreatedTasks:
  1225. """
  1226. Create missing tasks -- and expand any MappedOperator that _only_ have literals as input.
  1227. :param tasks: Tasks to create jobs for in the DAG run
  1228. :param task_creator: Function to create task instances
  1229. """
  1230. map_indexes: Iterable[int]
  1231. for task in tasks:
  1232. try:
  1233. count = task.get_mapped_ti_count(self.run_id, session=session)
  1234. except (NotMapped, NotFullyPopulated):
  1235. map_indexes = (-1,)
  1236. else:
  1237. if count:
  1238. map_indexes = range(count)
  1239. else:
  1240. # Make sure to always create at least one ti; this will be
  1241. # marked as REMOVED later at runtime.
  1242. map_indexes = (-1,)
  1243. yield from task_creator(task, map_indexes)
  1244. def _create_task_instances(
  1245. self,
  1246. dag_id: str,
  1247. tasks: Iterator[dict[str, Any]] | Iterator[TI],
  1248. created_counts: dict[str, int],
  1249. hook_is_noop: bool,
  1250. *,
  1251. session: Session,
  1252. ) -> None:
  1253. """
  1254. Create the necessary task instances from the given tasks.
  1255. :param dag_id: DAG ID associated with the dagrun
  1256. :param tasks: the tasks to create the task instances from
  1257. :param created_counts: a dictionary of number of tasks -> total ti created by the task creator
  1258. :param hook_is_noop: whether the task_instance_mutation_hook is noop
  1259. :param session: the session to use
  1260. """
  1261. # Fetch the information we need before handling the exception to avoid
  1262. # PendingRollbackError due to the session being invalidated on exception
  1263. # see https://github.com/apache/superset/pull/530
  1264. run_id = self.run_id
  1265. try:
  1266. if hook_is_noop:
  1267. session.bulk_insert_mappings(TI, tasks)
  1268. else:
  1269. session.bulk_save_objects(tasks)
  1270. for task_type, count in created_counts.items():
  1271. Stats.incr(f"task_instance_created_{task_type}", count, tags=self.stats_tags)
  1272. # Same metric with tagging
  1273. Stats.incr("task_instance_created", count, tags={**self.stats_tags, "task_type": task_type})
  1274. session.flush()
  1275. except IntegrityError:
  1276. self.log.info(
  1277. "Hit IntegrityError while creating the TIs for %s- %s",
  1278. dag_id,
  1279. run_id,
  1280. exc_info=True,
  1281. )
  1282. self.log.info("Doing session rollback.")
  1283. # TODO[HA]: We probably need to savepoint this so we can keep the transaction alive.
  1284. session.rollback()
  1285. def _revise_map_indexes_if_mapped(self, task: Operator, *, session: Session) -> Iterator[TI]:
  1286. """
  1287. Check if task increased or reduced in length and handle appropriately.
  1288. Task instances that do not already exist are created and returned if
  1289. possible. Expansion only happens if all upstreams are ready; otherwise
  1290. we delay expansion to the "last resort". See comments at the call site
  1291. for more details.
  1292. """
  1293. from airflow.settings import task_instance_mutation_hook
  1294. try:
  1295. total_length = task.get_mapped_ti_count(self.run_id, session=session)
  1296. except NotMapped:
  1297. return # Not a mapped task, don't need to do anything.
  1298. except NotFullyPopulated:
  1299. return # Upstreams not ready, don't need to revise this yet.
  1300. query = session.scalars(
  1301. select(TI.map_index).where(
  1302. TI.dag_id == self.dag_id,
  1303. TI.task_id == task.task_id,
  1304. TI.run_id == self.run_id,
  1305. )
  1306. )
  1307. existing_indexes = set(query)
  1308. removed_indexes = existing_indexes.difference(range(total_length))
  1309. if removed_indexes:
  1310. session.execute(
  1311. update(TI)
  1312. .where(
  1313. TI.dag_id == self.dag_id,
  1314. TI.task_id == task.task_id,
  1315. TI.run_id == self.run_id,
  1316. TI.map_index.in_(removed_indexes),
  1317. )
  1318. .values(state=TaskInstanceState.REMOVED)
  1319. )
  1320. session.flush()
  1321. for index in range(total_length):
  1322. if index in existing_indexes:
  1323. continue
  1324. ti = TI(task, run_id=self.run_id, map_index=index, state=None)
  1325. self.log.debug("Expanding TIs upserted %s", ti)
  1326. task_instance_mutation_hook(ti)
  1327. ti = session.merge(ti)
  1328. ti.refresh_from_task(task)
  1329. session.flush()
  1330. yield ti
  1331. @staticmethod
  1332. def get_run(session: Session, dag_id: str, execution_date: datetime) -> DagRun | None:
  1333. """
  1334. Get a single DAG Run.
  1335. :meta private:
  1336. :param session: Sqlalchemy ORM Session
  1337. :param dag_id: DAG ID
  1338. :param execution_date: execution date
  1339. :return: DagRun corresponding to the given dag_id and execution date
  1340. if one exists. None otherwise.
  1341. """
  1342. warnings.warn(
  1343. "This method is deprecated. Please use SQLAlchemy directly",
  1344. RemovedInAirflow3Warning,
  1345. stacklevel=2,
  1346. )
  1347. return session.scalar(
  1348. select(DagRun).where(
  1349. DagRun.dag_id == dag_id,
  1350. DagRun.external_trigger == False, # noqa: E712
  1351. DagRun.execution_date == execution_date,
  1352. )
  1353. )
  1354. @property
  1355. def is_backfill(self) -> bool:
  1356. return self.run_type == DagRunType.BACKFILL_JOB
  1357. @classmethod
  1358. @provide_session
  1359. def get_latest_runs(cls, session: Session = NEW_SESSION) -> list[DagRun]:
  1360. """Return the latest DagRun for each DAG."""
  1361. subquery = (
  1362. select(cls.dag_id, func.max(cls.execution_date).label("execution_date"))
  1363. .group_by(cls.dag_id)
  1364. .subquery()
  1365. )
  1366. return session.scalars(
  1367. select(cls).join(
  1368. subquery,
  1369. and_(cls.dag_id == subquery.c.dag_id, cls.execution_date == subquery.c.execution_date),
  1370. )
  1371. ).all()
  1372. @provide_session
  1373. def schedule_tis(
  1374. self,
  1375. schedulable_tis: Iterable[TI],
  1376. session: Session = NEW_SESSION,
  1377. max_tis_per_query: int | None = None,
  1378. ) -> int:
  1379. """
  1380. Set the given task instances in to the scheduled state.
  1381. Each element of ``schedulable_tis`` should have its ``task`` attribute already set.
  1382. Any EmptyOperator without callbacks or outlets is instead set straight to the success state.
  1383. All the TIs should belong to this DagRun, but this code is in the hot-path, this is not checked -- it
  1384. is the caller's responsibility to call this function only with TIs from a single dag run.
  1385. """
  1386. # Get list of TI IDs that do not need to executed, these are
  1387. # tasks using EmptyOperator and without on_execute_callback / on_success_callback
  1388. dummy_ti_ids = []
  1389. schedulable_ti_ids = []
  1390. for ti in schedulable_tis:
  1391. if TYPE_CHECKING:
  1392. assert ti.task
  1393. if (
  1394. ti.task.inherits_from_empty_operator
  1395. and not ti.task.on_execute_callback
  1396. and not ti.task.on_success_callback
  1397. and not ti.task.outlets
  1398. ):
  1399. dummy_ti_ids.append((ti.task_id, ti.map_index))
  1400. # check "start_trigger_args" to see whether the operator supports start execution from triggerer
  1401. # if so, we'll then check "start_from_trigger" to see whether this feature is turned on and defer
  1402. # this task.
  1403. # if not, we'll add this "ti" into "schedulable_ti_ids" and later execute it to run in the worker
  1404. elif ti.task.start_trigger_args is not None:
  1405. context = ti.get_template_context()
  1406. start_from_trigger = ti.task.expand_start_from_trigger(context=context, session=session)
  1407. if start_from_trigger:
  1408. ti.start_date = timezone.utcnow()
  1409. if ti.state != TaskInstanceState.UP_FOR_RESCHEDULE:
  1410. ti.try_number += 1
  1411. ti.defer_task(exception=None, session=session)
  1412. else:
  1413. schedulable_ti_ids.append((ti.task_id, ti.map_index))
  1414. else:
  1415. schedulable_ti_ids.append((ti.task_id, ti.map_index))
  1416. count = 0
  1417. if schedulable_ti_ids:
  1418. schedulable_ti_ids_chunks = chunks(
  1419. schedulable_ti_ids, max_tis_per_query or len(schedulable_ti_ids)
  1420. )
  1421. for schedulable_ti_ids_chunk in schedulable_ti_ids_chunks:
  1422. count += session.execute(
  1423. update(TI)
  1424. .where(
  1425. TI.dag_id == self.dag_id,
  1426. TI.run_id == self.run_id,
  1427. tuple_in_condition((TI.task_id, TI.map_index), schedulable_ti_ids_chunk),
  1428. )
  1429. .values(
  1430. state=TaskInstanceState.SCHEDULED,
  1431. try_number=case(
  1432. (
  1433. or_(TI.state.is_(None), TI.state != TaskInstanceState.UP_FOR_RESCHEDULE),
  1434. TI.try_number + 1,
  1435. ),
  1436. else_=TI.try_number,
  1437. ),
  1438. )
  1439. .execution_options(synchronize_session=False)
  1440. ).rowcount
  1441. # Tasks using EmptyOperator should not be executed, mark them as success
  1442. if dummy_ti_ids:
  1443. dummy_ti_ids_chunks = chunks(dummy_ti_ids, max_tis_per_query or len(dummy_ti_ids))
  1444. for dummy_ti_ids_chunk in dummy_ti_ids_chunks:
  1445. count += session.execute(
  1446. update(TI)
  1447. .where(
  1448. TI.dag_id == self.dag_id,
  1449. TI.run_id == self.run_id,
  1450. tuple_in_condition((TI.task_id, TI.map_index), dummy_ti_ids_chunk),
  1451. )
  1452. .values(
  1453. state=TaskInstanceState.SUCCESS,
  1454. start_date=timezone.utcnow(),
  1455. end_date=timezone.utcnow(),
  1456. duration=0,
  1457. try_number=TI.try_number + 1,
  1458. )
  1459. .execution_options(
  1460. synchronize_session=False,
  1461. )
  1462. ).rowcount
  1463. return count
  1464. @provide_session
  1465. def get_log_template(self, *, session: Session = NEW_SESSION) -> LogTemplate | LogTemplatePydantic:
  1466. return DagRun._get_log_template(log_template_id=self.log_template_id, session=session)
  1467. @staticmethod
  1468. @internal_api_call
  1469. @provide_session
  1470. def _get_log_template(
  1471. log_template_id: int | None, session: Session = NEW_SESSION
  1472. ) -> LogTemplate | LogTemplatePydantic:
  1473. template: LogTemplate | None
  1474. if log_template_id is None: # DagRun created before LogTemplate introduction.
  1475. template = session.scalar(select(LogTemplate).order_by(LogTemplate.id).limit(1))
  1476. else:
  1477. template = session.get(LogTemplate, log_template_id)
  1478. if template is None:
  1479. raise AirflowException(
  1480. f"No log_template entry found for ID {log_template_id!r}. "
  1481. f"Please make sure you set up the metadatabase correctly."
  1482. )
  1483. return template
  1484. @provide_session
  1485. def get_log_filename_template(self, *, session: Session = NEW_SESSION) -> str:
  1486. warnings.warn(
  1487. "This method is deprecated. Please use get_log_template instead.",
  1488. RemovedInAirflow3Warning,
  1489. stacklevel=2,
  1490. )
  1491. return self.get_log_template(session=session).filename
  1492. @staticmethod
  1493. def _get_partial_task_ids(dag: DAG | None) -> list[str] | None:
  1494. return dag.task_ids if dag and dag.partial else None
  1495. class DagRunNote(Base):
  1496. """For storage of arbitrary notes concerning the dagrun instance."""
  1497. __tablename__ = "dag_run_note"
  1498. user_id = Column(
  1499. Integer,
  1500. ForeignKey("ab_user.id", name="dag_run_note_user_fkey"),
  1501. nullable=True,
  1502. )
  1503. dag_run_id = Column(Integer, primary_key=True, nullable=False)
  1504. content = Column(String(1000).with_variant(Text(1000), "mysql"))
  1505. created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
  1506. updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False)
  1507. dag_run = relationship("DagRun", back_populates="dag_run_note")
  1508. __table_args__ = (
  1509. PrimaryKeyConstraint("dag_run_id", name="dag_run_note_pkey"),
  1510. ForeignKeyConstraint(
  1511. (dag_run_id,),
  1512. ["dag_run.id"],
  1513. name="dag_run_note_dr_fkey",
  1514. ondelete="CASCADE",
  1515. ),
  1516. )
  1517. def __init__(self, content, user_id=None):
  1518. self.content = content
  1519. self.user_id = user_id
  1520. def __repr__(self):
  1521. prefix = f"<{self.__class__.__name__}: {self.dag_id}.{self.dagrun_id} {self.run_id}"
  1522. if self.map_index != -1:
  1523. prefix += f" map_index={self.map_index}"
  1524. return prefix + ">"