taskinstance.py 161 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668366936703671367236733674367536763677367836793680368136823683368436853686368736883689369036913692369336943695369636973698369937003701370237033704370537063707370837093710371137123713371437153716371737183719372037213722372337243725372637273728372937303731373237333734373537363737373837393740374137423743374437453746374737483749375037513752375337543755375637573758375937603761376237633764376537663767376837693770377137723773377437753776377737783779378037813782378337843785378637873788378937903791379237933794379537963797379837993800380138023803380438053806380738083809381038113812381338143815381638173818381938203821382238233824382538263827382838293830383138323833383438353836383738383839384038413842384338443845384638473848384938503851385238533854385538563857385838593860386138623863386438653866386738683869387038713872387338743875387638773878387938803881388238833884388538863887388838893890389138923893389438953896389738983899390039013902390339043905390639073908390939103911391239133914391539163917391839193920392139223923392439253926392739283929393039313932393339343935393639373938393939403941394239433944394539463947394839493950395139523953395439553956395739583959396039613962396339643965396639673968396939703971397239733974397539763977397839793980398139823983398439853986398739883989399039913992399339943995399639973998399940004001400240034004400540064007400840094010401140124013401440154016401740184019402040214022402340244025402640274028402940304031403240334034403540364037403840394040404140424043404440454046404740484049405040514052405340544055405640574058405940604061406240634064406540664067406840694070407140724073407440754076407740784079408040814082408340844085408640874088408940904091409240934094409540964097409840994100410141024103410441054106410741084109411041114112411341144115411641174118411941204121412241234124412541264127412841294130413141324133413441354136413741384139414041414142414341444145414641474148414941504151415241534154415541564157415841594160416141624163416441654166416741684169417041714172417341744175417641774178417941804181418241834184418541864187418841894190419141924193419441954196419741984199420042014202420342044205420642074208420942104211421242134214421542164217421842194220422142224223422442254226422742284229
  1. #
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing,
  13. # software distributed under the License is distributed on an
  14. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. # KIND, either express or implied. See the License for the
  16. # specific language governing permissions and limitations
  17. # under the License.
  18. from __future__ import annotations
  19. import collections.abc
  20. import contextlib
  21. import hashlib
  22. import itertools
  23. import logging
  24. import math
  25. import operator
  26. import os
  27. import signal
  28. import traceback
  29. import warnings
  30. from collections import defaultdict
  31. from contextlib import nullcontext
  32. from datetime import timedelta
  33. from enum import Enum
  34. from typing import TYPE_CHECKING, Any, Callable, Collection, Generator, Iterable, Mapping, Tuple
  35. from urllib.parse import quote
  36. import dill
  37. import jinja2
  38. import lazy_object_proxy
  39. import pendulum
  40. from deprecated import deprecated
  41. from jinja2 import TemplateAssertionError, UndefinedError
  42. from sqlalchemy import (
  43. Column,
  44. DateTime,
  45. Float,
  46. ForeignKey,
  47. ForeignKeyConstraint,
  48. Index,
  49. Integer,
  50. PrimaryKeyConstraint,
  51. String,
  52. Text,
  53. and_,
  54. delete,
  55. false,
  56. func,
  57. inspect,
  58. or_,
  59. text,
  60. update,
  61. )
  62. from sqlalchemy.ext.associationproxy import association_proxy
  63. from sqlalchemy.ext.hybrid import hybrid_property
  64. from sqlalchemy.ext.mutable import MutableDict
  65. from sqlalchemy.orm import lazyload, reconstructor, relationship
  66. from sqlalchemy.orm.attributes import NO_VALUE, set_committed_value
  67. from sqlalchemy.sql.expression import case, select
  68. from airflow import settings
  69. from airflow.api_internal.internal_api_call import InternalApiConfig, internal_api_call
  70. from airflow.compat.functools import cache
  71. from airflow.configuration import conf
  72. from airflow.datasets import Dataset, DatasetAlias
  73. from airflow.datasets.manager import dataset_manager
  74. from airflow.exceptions import (
  75. AirflowException,
  76. AirflowFailException,
  77. AirflowRescheduleException,
  78. AirflowSensorTimeout,
  79. AirflowSkipException,
  80. AirflowTaskTerminated,
  81. AirflowTaskTimeout,
  82. DagRunNotFound,
  83. RemovedInAirflow3Warning,
  84. TaskDeferralError,
  85. TaskDeferred,
  86. UnmappableXComLengthPushed,
  87. UnmappableXComTypePushed,
  88. XComForMappingNotPushed,
  89. )
  90. from airflow.listeners.listener import get_listener_manager
  91. from airflow.models.base import Base, StringID, TaskInstanceDependencies, _sentinel
  92. from airflow.models.dagbag import DagBag
  93. from airflow.models.dataset import DatasetAliasModel, DatasetModel
  94. from airflow.models.log import Log
  95. from airflow.models.param import process_params
  96. from airflow.models.renderedtifields import get_serialized_template_fields
  97. from airflow.models.taskfail import TaskFail
  98. from airflow.models.taskinstancekey import TaskInstanceKey
  99. from airflow.models.taskmap import TaskMap
  100. from airflow.models.taskreschedule import TaskReschedule
  101. from airflow.models.xcom import LazyXComSelectSequence, XCom
  102. from airflow.plugins_manager import integrate_macros_plugins
  103. from airflow.sentry import Sentry
  104. from airflow.settings import task_instance_mutation_hook
  105. from airflow.stats import Stats
  106. from airflow.templates import SandboxedEnvironment
  107. from airflow.ti_deps.dep_context import DepContext
  108. from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS
  109. from airflow.traces.tracer import Trace
  110. from airflow.utils import timezone
  111. from airflow.utils.context import (
  112. ConnectionAccessor,
  113. Context,
  114. InletEventsAccessors,
  115. OutletEventAccessors,
  116. VariableAccessor,
  117. context_get_outlet_events,
  118. context_merge,
  119. )
  120. from airflow.utils.email import send_email
  121. from airflow.utils.helpers import prune_dict, render_template_to_string
  122. from airflow.utils.log.logging_mixin import LoggingMixin
  123. from airflow.utils.net import get_hostname
  124. from airflow.utils.operator_helpers import ExecutionCallableRunner, context_to_airflow_vars
  125. from airflow.utils.platform import getuser
  126. from airflow.utils.retries import run_with_db_retries
  127. from airflow.utils.session import NEW_SESSION, create_session, provide_session
  128. from airflow.utils.sqlalchemy import (
  129. ExecutorConfigType,
  130. ExtendedJSON,
  131. UtcDateTime,
  132. tuple_in_condition,
  133. with_row_locks,
  134. )
  135. from airflow.utils.state import DagRunState, JobState, State, TaskInstanceState
  136. from airflow.utils.task_group import MappedTaskGroup
  137. from airflow.utils.task_instance_session import set_current_task_instance_session
  138. from airflow.utils.timeout import timeout
  139. from airflow.utils.types import AttributeRemoved
  140. from airflow.utils.xcom import XCOM_RETURN_KEY
  141. TR = TaskReschedule
  142. _CURRENT_CONTEXT: list[Context] = []
  143. log = logging.getLogger(__name__)
  144. if TYPE_CHECKING:
  145. from datetime import datetime
  146. from pathlib import PurePath
  147. from types import TracebackType
  148. from sqlalchemy.orm.session import Session
  149. from sqlalchemy.sql.elements import BooleanClauseList
  150. from sqlalchemy.sql.expression import ColumnOperators
  151. from airflow.models.abstractoperator import TaskStateChangeCallback
  152. from airflow.models.baseoperator import BaseOperator
  153. from airflow.models.dag import DAG, DagModel
  154. from airflow.models.dagrun import DagRun
  155. from airflow.models.dataset import DatasetEvent
  156. from airflow.models.operator import Operator
  157. from airflow.serialization.pydantic.dag import DagModelPydantic
  158. from airflow.serialization.pydantic.dataset import DatasetEventPydantic
  159. from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
  160. from airflow.timetables.base import DataInterval
  161. from airflow.typing_compat import Literal, TypeGuard
  162. from airflow.utils.task_group import TaskGroup
  163. PAST_DEPENDS_MET = "past_depends_met"
  164. class TaskReturnCode(Enum):
  165. """
  166. Enum to signal manner of exit for task run command.
  167. :meta private:
  168. """
  169. DEFERRED = 100
  170. """When task exits with deferral to trigger."""
  171. @internal_api_call
  172. @provide_session
  173. def _merge_ti(ti, session: Session = NEW_SESSION):
  174. session.merge(ti)
  175. session.commit()
  176. @internal_api_call
  177. @provide_session
  178. def _add_log(
  179. event,
  180. task_instance=None,
  181. owner=None,
  182. owner_display_name=None,
  183. extra=None,
  184. session: Session = NEW_SESSION,
  185. **kwargs,
  186. ):
  187. session.add(
  188. Log(
  189. event,
  190. task_instance,
  191. owner,
  192. owner_display_name,
  193. extra,
  194. **kwargs,
  195. )
  196. )
  197. def _run_raw_task(
  198. ti: TaskInstance | TaskInstancePydantic,
  199. mark_success: bool = False,
  200. test_mode: bool = False,
  201. job_id: str | None = None,
  202. pool: str | None = None,
  203. raise_on_defer: bool = False,
  204. session: Session | None = None,
  205. ) -> TaskReturnCode | None:
  206. """
  207. Run a task, update the state upon completion, and run any appropriate callbacks.
  208. Immediately runs the task (without checking or changing db state
  209. before execution) and then sets the appropriate final state after
  210. completion and runs any post-execute callbacks. Meant to be called
  211. only after another function changes the state to running.
  212. :param mark_success: Don't run the task, mark its state as success
  213. :param test_mode: Doesn't record success or failure in the DB
  214. :param pool: specifies the pool to use to run the task instance
  215. :param session: SQLAlchemy ORM Session
  216. """
  217. if TYPE_CHECKING:
  218. assert ti.task
  219. ti.test_mode = test_mode
  220. ti.refresh_from_task(ti.task, pool_override=pool)
  221. ti.refresh_from_db(session=session)
  222. ti.job_id = job_id
  223. ti.hostname = get_hostname()
  224. ti.pid = os.getpid()
  225. if not test_mode:
  226. TaskInstance.save_to_db(ti=ti, session=session, refresh_dag=False)
  227. actual_start_date = timezone.utcnow()
  228. Stats.incr(f"ti.start.{ti.task.dag_id}.{ti.task.task_id}", tags=ti.stats_tags)
  229. # Same metric with tagging
  230. Stats.incr("ti.start", tags=ti.stats_tags)
  231. # Initialize final state counters at zero
  232. for state in State.task_states:
  233. Stats.incr(
  234. f"ti.finish.{ti.task.dag_id}.{ti.task.task_id}.{state}",
  235. count=0,
  236. tags=ti.stats_tags,
  237. )
  238. # Same metric with tagging
  239. Stats.incr(
  240. "ti.finish",
  241. count=0,
  242. tags={**ti.stats_tags, "state": str(state)},
  243. )
  244. with set_current_task_instance_session(session=session):
  245. ti.task = ti.task.prepare_for_execution()
  246. context = ti.get_template_context(ignore_param_exceptions=False, session=session)
  247. try:
  248. if not mark_success:
  249. TaskInstance._execute_task_with_callbacks(
  250. self=ti, # type: ignore[arg-type]
  251. context=context,
  252. test_mode=test_mode,
  253. session=session,
  254. )
  255. if not test_mode:
  256. ti.refresh_from_db(lock_for_update=True, session=session)
  257. ti.state = TaskInstanceState.SUCCESS
  258. except TaskDeferred as defer:
  259. # The task has signalled it wants to defer execution based on
  260. # a trigger.
  261. if raise_on_defer:
  262. raise
  263. ti.defer_task(exception=defer, session=session)
  264. ti.log.info(
  265. "Pausing task as DEFERRED. dag_id=%s, task_id=%s, run_id=%s, execution_date=%s, start_date=%s",
  266. ti.dag_id,
  267. ti.task_id,
  268. ti.run_id,
  269. _date_or_empty(task_instance=ti, attr="execution_date"),
  270. _date_or_empty(task_instance=ti, attr="start_date"),
  271. )
  272. return TaskReturnCode.DEFERRED
  273. except AirflowSkipException as e:
  274. # Recording SKIP
  275. # log only if exception has any arguments to prevent log flooding
  276. if e.args:
  277. ti.log.info(e)
  278. if not test_mode:
  279. ti.refresh_from_db(lock_for_update=True, session=session)
  280. ti.state = TaskInstanceState.SKIPPED
  281. _run_finished_callback(callbacks=ti.task.on_skipped_callback, context=context)
  282. TaskInstance.save_to_db(ti=ti, session=session)
  283. except AirflowRescheduleException as reschedule_exception:
  284. ti._handle_reschedule(actual_start_date, reschedule_exception, test_mode, session=session)
  285. ti.log.info("Rescheduling task, marking task as UP_FOR_RESCHEDULE")
  286. return None
  287. except (AirflowFailException, AirflowSensorTimeout) as e:
  288. # If AirflowFailException is raised, task should not retry.
  289. # If a sensor in reschedule mode reaches timeout, task should not retry.
  290. ti.handle_failure(e, test_mode, context, force_fail=True, session=session) # already saves to db
  291. raise
  292. except (AirflowTaskTimeout, AirflowException, AirflowTaskTerminated) as e:
  293. if not test_mode:
  294. ti.refresh_from_db(lock_for_update=True, session=session)
  295. # for case when task is marked as success/failed externally
  296. # or dagrun timed out and task is marked as skipped
  297. # current behavior doesn't hit the callbacks
  298. if ti.state in State.finished:
  299. ti.clear_next_method_args()
  300. TaskInstance.save_to_db(ti=ti, session=session)
  301. return None
  302. else:
  303. ti.handle_failure(e, test_mode, context, session=session)
  304. raise
  305. except SystemExit as e:
  306. # We have already handled SystemExit with success codes (0 and None) in the `_execute_task`.
  307. # Therefore, here we must handle only error codes.
  308. msg = f"Task failed due to SystemExit({e.code})"
  309. ti.handle_failure(msg, test_mode, context, session=session)
  310. raise AirflowException(msg)
  311. except BaseException as e:
  312. ti.handle_failure(e, test_mode, context, session=session)
  313. raise
  314. finally:
  315. # Print a marker post execution for internals of post task processing
  316. log.info("::group::Post task execution logs")
  317. Stats.incr(
  318. f"ti.finish.{ti.dag_id}.{ti.task_id}.{ti.state}",
  319. tags=ti.stats_tags,
  320. )
  321. # Same metric with tagging
  322. Stats.incr("ti.finish", tags={**ti.stats_tags, "state": str(ti.state)})
  323. # Recording SKIPPED or SUCCESS
  324. ti.clear_next_method_args()
  325. ti.end_date = timezone.utcnow()
  326. _log_state(task_instance=ti)
  327. ti.set_duration()
  328. # run on_success_callback before db committing
  329. # otherwise, the LocalTaskJob sees the state is changed to `success`,
  330. # but the task_runner is still running, LocalTaskJob then treats the state is set externally!
  331. _run_finished_callback(callbacks=ti.task.on_success_callback, context=context)
  332. if not test_mode:
  333. _add_log(event=ti.state, task_instance=ti, session=session)
  334. if ti.state == TaskInstanceState.SUCCESS:
  335. ti._register_dataset_changes(events=context["outlet_events"], session=session)
  336. TaskInstance.save_to_db(ti=ti, session=session)
  337. if ti.state == TaskInstanceState.SUCCESS:
  338. get_listener_manager().hook.on_task_instance_success(
  339. previous_state=TaskInstanceState.RUNNING, task_instance=ti, session=session
  340. )
  341. return None
  342. @contextlib.contextmanager
  343. def set_current_context(context: Context) -> Generator[Context, None, None]:
  344. """
  345. Set the current execution context to the provided context object.
  346. This method should be called once per Task execution, before calling operator.execute.
  347. """
  348. _CURRENT_CONTEXT.append(context)
  349. try:
  350. yield context
  351. finally:
  352. expected_state = _CURRENT_CONTEXT.pop()
  353. if expected_state != context:
  354. log.warning(
  355. "Current context is not equal to the state at context stack. Expected=%s, got=%s",
  356. context,
  357. expected_state,
  358. )
  359. def _stop_remaining_tasks(*, task_instance: TaskInstance | TaskInstancePydantic, session: Session):
  360. """
  361. Stop non-teardown tasks in dag.
  362. :meta private:
  363. """
  364. if not task_instance.dag_run:
  365. raise ValueError("``task_instance`` must have ``dag_run`` set")
  366. tis = task_instance.dag_run.get_task_instances(session=session)
  367. if TYPE_CHECKING:
  368. assert task_instance.task
  369. assert isinstance(task_instance.task.dag, DAG)
  370. for ti in tis:
  371. if ti.task_id == task_instance.task_id or ti.state in (
  372. TaskInstanceState.SUCCESS,
  373. TaskInstanceState.FAILED,
  374. ):
  375. continue
  376. task = task_instance.task.dag.task_dict[ti.task_id]
  377. if not task.is_teardown:
  378. if ti.state == TaskInstanceState.RUNNING:
  379. log.info("Forcing task %s to fail due to dag's `fail_stop` setting", ti.task_id)
  380. ti.error(session)
  381. else:
  382. log.info("Setting task %s to SKIPPED due to dag's `fail_stop` setting.", ti.task_id)
  383. ti.set_state(state=TaskInstanceState.SKIPPED, session=session)
  384. else:
  385. log.info("Not skipping teardown task '%s'", ti.task_id)
  386. def clear_task_instances(
  387. tis: list[TaskInstance],
  388. session: Session,
  389. activate_dag_runs: None = None,
  390. dag: DAG | None = None,
  391. dag_run_state: DagRunState | Literal[False] = DagRunState.QUEUED,
  392. ) -> None:
  393. """
  394. Clear a set of task instances, but make sure the running ones get killed.
  395. Also sets Dagrun's `state` to QUEUED and `start_date` to the time of execution.
  396. But only for finished DRs (SUCCESS and FAILED).
  397. Doesn't clear DR's `state` and `start_date`for running
  398. DRs (QUEUED and RUNNING) because clearing the state for already
  399. running DR is redundant and clearing `start_date` affects DR's duration.
  400. :param tis: a list of task instances
  401. :param session: current session
  402. :param dag_run_state: state to set finished DagRuns to.
  403. If set to False, DagRuns state will not be changed.
  404. :param dag: DAG object
  405. :param activate_dag_runs: Deprecated parameter, do not pass
  406. """
  407. job_ids = []
  408. # Keys: dag_id -> run_id -> map_indexes -> try_numbers -> task_id
  409. task_id_by_key: dict[str, dict[str, dict[int, dict[int, set[str]]]]] = defaultdict(
  410. lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(set)))
  411. )
  412. dag_bag = DagBag(read_dags_from_db=True)
  413. from airflow.models.taskinstancehistory import TaskInstanceHistory
  414. for ti in tis:
  415. TaskInstanceHistory.record_ti(ti, session)
  416. if ti.state == TaskInstanceState.RUNNING:
  417. if ti.job_id:
  418. # If a task is cleared when running, set its state to RESTARTING so that
  419. # the task is terminated and becomes eligible for retry.
  420. ti.state = TaskInstanceState.RESTARTING
  421. job_ids.append(ti.job_id)
  422. else:
  423. ti_dag = dag if dag and dag.dag_id == ti.dag_id else dag_bag.get_dag(ti.dag_id, session=session)
  424. task_id = ti.task_id
  425. if ti_dag and ti_dag.has_task(task_id):
  426. task = ti_dag.get_task(task_id)
  427. ti.refresh_from_task(task)
  428. if TYPE_CHECKING:
  429. assert ti.task
  430. ti.max_tries = ti.try_number + task.retries
  431. else:
  432. # Ignore errors when updating max_tries if the DAG or
  433. # task are not found since database records could be
  434. # outdated. We make max_tries the maximum value of its
  435. # original max_tries or the last attempted try number.
  436. ti.max_tries = max(ti.max_tries, ti.try_number)
  437. ti.state = None
  438. ti.external_executor_id = None
  439. ti.clear_next_method_args()
  440. session.merge(ti)
  441. task_id_by_key[ti.dag_id][ti.run_id][ti.map_index][ti.try_number].add(ti.task_id)
  442. if task_id_by_key:
  443. # Clear all reschedules related to the ti to clear
  444. # This is an optimization for the common case where all tis are for a small number
  445. # of dag_id, run_id, try_number, and map_index. Use a nested dict of dag_id,
  446. # run_id, try_number, map_index, and task_id to construct the where clause in a
  447. # hierarchical manner. This speeds up the delete statement by more than 40x for
  448. # large number of tis (50k+).
  449. conditions = or_(
  450. and_(
  451. TR.dag_id == dag_id,
  452. or_(
  453. and_(
  454. TR.run_id == run_id,
  455. or_(
  456. and_(
  457. TR.map_index == map_index,
  458. or_(
  459. and_(TR.try_number == try_number, TR.task_id.in_(task_ids))
  460. for try_number, task_ids in task_tries.items()
  461. ),
  462. )
  463. for map_index, task_tries in map_indexes.items()
  464. ),
  465. )
  466. for run_id, map_indexes in run_ids.items()
  467. ),
  468. )
  469. for dag_id, run_ids in task_id_by_key.items()
  470. )
  471. delete_qry = TR.__table__.delete().where(conditions)
  472. session.execute(delete_qry)
  473. if job_ids:
  474. from airflow.jobs.job import Job
  475. session.execute(update(Job).where(Job.id.in_(job_ids)).values(state=JobState.RESTARTING))
  476. if activate_dag_runs is not None:
  477. warnings.warn(
  478. "`activate_dag_runs` parameter to clear_task_instances function is deprecated. "
  479. "Please use `dag_run_state`",
  480. RemovedInAirflow3Warning,
  481. stacklevel=2,
  482. )
  483. if not activate_dag_runs:
  484. dag_run_state = False
  485. if dag_run_state is not False and tis:
  486. from airflow.models.dagrun import DagRun # Avoid circular import
  487. run_ids_by_dag_id = defaultdict(set)
  488. for instance in tis:
  489. run_ids_by_dag_id[instance.dag_id].add(instance.run_id)
  490. drs = (
  491. session.query(DagRun)
  492. .filter(
  493. or_(
  494. and_(DagRun.dag_id == dag_id, DagRun.run_id.in_(run_ids))
  495. for dag_id, run_ids in run_ids_by_dag_id.items()
  496. )
  497. )
  498. .all()
  499. )
  500. dag_run_state = DagRunState(dag_run_state) # Validate the state value.
  501. for dr in drs:
  502. if dr.state in State.finished_dr_states:
  503. dr.state = dag_run_state
  504. dr.start_date = timezone.utcnow()
  505. if dag_run_state == DagRunState.QUEUED:
  506. dr.last_scheduling_decision = None
  507. dr.start_date = None
  508. dr.clear_number += 1
  509. session.flush()
  510. @internal_api_call
  511. @provide_session
  512. def _xcom_pull(
  513. *,
  514. ti,
  515. task_ids: str | Iterable[str] | None = None,
  516. dag_id: str | None = None,
  517. key: str = XCOM_RETURN_KEY,
  518. include_prior_dates: bool = False,
  519. session: Session = NEW_SESSION,
  520. map_indexes: int | Iterable[int] | None = None,
  521. default: Any = None,
  522. ) -> Any:
  523. """
  524. Pull XComs that optionally meet certain criteria.
  525. :param key: A key for the XCom. If provided, only XComs with matching
  526. keys will be returned. The default key is ``'return_value'``, also
  527. available as constant ``XCOM_RETURN_KEY``. This key is automatically
  528. given to XComs returned by tasks (as opposed to being pushed
  529. manually). To remove the filter, pass *None*.
  530. :param task_ids: Only XComs from tasks with matching ids will be
  531. pulled. Pass *None* to remove the filter.
  532. :param dag_id: If provided, only pulls XComs from this DAG. If *None*
  533. (default), the DAG of the calling task is used.
  534. :param map_indexes: If provided, only pull XComs with matching indexes.
  535. If *None* (default), this is inferred from the task(s) being pulled
  536. (see below for details).
  537. :param include_prior_dates: If False, only XComs from the current
  538. execution_date are returned. If *True*, XComs from previous dates
  539. are returned as well.
  540. When pulling one single task (``task_id`` is *None* or a str) without
  541. specifying ``map_indexes``, the return value is inferred from whether
  542. the specified task is mapped. If not, value from the one single task
  543. instance is returned. If the task to pull is mapped, an iterator (not a
  544. list) yielding XComs from mapped task instances is returned. In either
  545. case, ``default`` (*None* if not specified) is returned if no matching
  546. XComs are found.
  547. When pulling multiple tasks (i.e. either ``task_id`` or ``map_index`` is
  548. a non-str iterable), a list of matching XComs is returned. Elements in
  549. the list is ordered by item ordering in ``task_id`` and ``map_index``.
  550. """
  551. if dag_id is None:
  552. dag_id = ti.dag_id
  553. query = XCom.get_many(
  554. key=key,
  555. run_id=ti.run_id,
  556. dag_ids=dag_id,
  557. task_ids=task_ids,
  558. map_indexes=map_indexes,
  559. include_prior_dates=include_prior_dates,
  560. session=session,
  561. )
  562. # NOTE: Since we're only fetching the value field and not the whole
  563. # class, the @recreate annotation does not kick in. Therefore we need to
  564. # call XCom.deserialize_value() manually.
  565. # We are only pulling one single task.
  566. if (task_ids is None or isinstance(task_ids, str)) and not isinstance(map_indexes, Iterable):
  567. first = query.with_entities(
  568. XCom.run_id, XCom.task_id, XCom.dag_id, XCom.map_index, XCom.value
  569. ).first()
  570. if first is None: # No matching XCom at all.
  571. return default
  572. if map_indexes is not None or first.map_index < 0:
  573. return XCom.deserialize_value(first)
  574. return LazyXComSelectSequence.from_select(
  575. query.with_entities(XCom.value).order_by(None).statement,
  576. order_by=[XCom.map_index],
  577. session=session,
  578. )
  579. # At this point either task_ids or map_indexes is explicitly multi-value.
  580. # Order return values to match task_ids and map_indexes ordering.
  581. ordering = []
  582. if task_ids is None or isinstance(task_ids, str):
  583. ordering.append(XCom.task_id)
  584. elif task_id_whens := {tid: i for i, tid in enumerate(task_ids)}:
  585. ordering.append(case(task_id_whens, value=XCom.task_id))
  586. else:
  587. ordering.append(XCom.task_id)
  588. if map_indexes is None or isinstance(map_indexes, int):
  589. ordering.append(XCom.map_index)
  590. elif isinstance(map_indexes, range):
  591. order = XCom.map_index
  592. if map_indexes.step < 0:
  593. order = order.desc()
  594. ordering.append(order)
  595. elif map_index_whens := {map_index: i for i, map_index in enumerate(map_indexes)}:
  596. ordering.append(case(map_index_whens, value=XCom.map_index))
  597. else:
  598. ordering.append(XCom.map_index)
  599. return LazyXComSelectSequence.from_select(
  600. query.with_entities(XCom.value).order_by(None).statement,
  601. order_by=ordering,
  602. session=session,
  603. )
  604. def _is_mappable_value(value: Any) -> TypeGuard[Collection]:
  605. """
  606. Whether a value can be used for task mapping.
  607. We only allow collections with guaranteed ordering, but exclude character
  608. sequences since that's usually not what users would expect to be mappable.
  609. """
  610. if not isinstance(value, (collections.abc.Sequence, dict)):
  611. return False
  612. if isinstance(value, (bytearray, bytes, str)):
  613. return False
  614. return True
  615. def _creator_note(val):
  616. """Creator the ``note`` association proxy."""
  617. if isinstance(val, str):
  618. return TaskInstanceNote(content=val)
  619. elif isinstance(val, dict):
  620. return TaskInstanceNote(**val)
  621. else:
  622. return TaskInstanceNote(*val)
  623. def _execute_task(task_instance: TaskInstance | TaskInstancePydantic, context: Context, task_orig: Operator):
  624. """
  625. Execute Task (optionally with a Timeout) and push Xcom results.
  626. :param task_instance: the task instance
  627. :param context: Jinja2 context
  628. :param task_orig: origin task
  629. :meta private:
  630. """
  631. from airflow.models.mappedoperator import MappedOperator
  632. task_to_execute = task_instance.task
  633. if TYPE_CHECKING:
  634. assert task_to_execute
  635. if isinstance(task_to_execute, MappedOperator):
  636. raise AirflowException("MappedOperator cannot be executed.")
  637. # If the task has been deferred and is being executed due to a trigger,
  638. # then we need to pick the right method to come back to, otherwise
  639. # we go for the default execute
  640. execute_callable_kwargs: dict[str, Any] = {}
  641. execute_callable: Callable
  642. if task_instance.next_method:
  643. if task_instance.next_method == "execute":
  644. if not task_instance.next_kwargs:
  645. task_instance.next_kwargs = {}
  646. task_instance.next_kwargs[f"{task_to_execute.__class__.__name__}__sentinel"] = _sentinel
  647. execute_callable = task_to_execute.resume_execution
  648. execute_callable_kwargs["next_method"] = task_instance.next_method
  649. execute_callable_kwargs["next_kwargs"] = task_instance.next_kwargs
  650. else:
  651. execute_callable = task_to_execute.execute
  652. if execute_callable.__name__ == "execute":
  653. execute_callable_kwargs[f"{task_to_execute.__class__.__name__}__sentinel"] = _sentinel
  654. def _execute_callable(context: Context, **execute_callable_kwargs):
  655. try:
  656. # Print a marker for log grouping of details before task execution
  657. log.info("::endgroup::")
  658. return ExecutionCallableRunner(
  659. execute_callable,
  660. context_get_outlet_events(context),
  661. logger=log,
  662. ).run(context=context, **execute_callable_kwargs)
  663. except SystemExit as e:
  664. # Handle only successful cases here. Failure cases will be handled upper
  665. # in the exception chain.
  666. if e.code is not None and e.code != 0:
  667. raise
  668. return None
  669. # If a timeout is specified for the task, make it fail
  670. # if it goes beyond
  671. if task_to_execute.execution_timeout:
  672. # If we are coming in with a next_method (i.e. from a deferral),
  673. # calculate the timeout from our start_date.
  674. if task_instance.next_method and task_instance.start_date:
  675. timeout_seconds = (
  676. task_to_execute.execution_timeout - (timezone.utcnow() - task_instance.start_date)
  677. ).total_seconds()
  678. else:
  679. timeout_seconds = task_to_execute.execution_timeout.total_seconds()
  680. try:
  681. # It's possible we're already timed out, so fast-fail if true
  682. if timeout_seconds <= 0:
  683. raise AirflowTaskTimeout()
  684. # Run task in timeout wrapper
  685. with timeout(timeout_seconds):
  686. result = _execute_callable(context=context, **execute_callable_kwargs)
  687. except AirflowTaskTimeout:
  688. task_to_execute.on_kill()
  689. raise
  690. else:
  691. result = _execute_callable(context=context, **execute_callable_kwargs)
  692. cm = nullcontext() if InternalApiConfig.get_use_internal_api() else create_session()
  693. with cm as session_or_null:
  694. if task_to_execute.do_xcom_push:
  695. xcom_value = result
  696. else:
  697. xcom_value = None
  698. if xcom_value is not None: # If the task returns a result, push an XCom containing it.
  699. if task_to_execute.multiple_outputs:
  700. if not isinstance(xcom_value, Mapping):
  701. raise AirflowException(
  702. f"Returned output was type {type(xcom_value)} "
  703. "expected dictionary for multiple_outputs"
  704. )
  705. for key in xcom_value.keys():
  706. if not isinstance(key, str):
  707. raise AirflowException(
  708. "Returned dictionary keys must be strings when using "
  709. f"multiple_outputs, found {key} ({type(key)}) instead"
  710. )
  711. for key, value in xcom_value.items():
  712. task_instance.xcom_push(key=key, value=value, session=session_or_null)
  713. task_instance.xcom_push(key=XCOM_RETURN_KEY, value=xcom_value, session=session_or_null)
  714. if TYPE_CHECKING:
  715. assert task_orig.dag
  716. _record_task_map_for_downstreams(
  717. task_instance=task_instance,
  718. task=task_orig,
  719. dag=task_orig.dag,
  720. value=xcom_value,
  721. session=session_or_null,
  722. )
  723. return result
  724. def _set_ti_attrs(target, source, include_dag_run=False):
  725. # Fields ordered per model definition
  726. target.start_date = source.start_date
  727. target.end_date = source.end_date
  728. target.duration = source.duration
  729. target.state = source.state
  730. target.try_number = source.try_number
  731. target.max_tries = source.max_tries
  732. target.hostname = source.hostname
  733. target.unixname = source.unixname
  734. target.job_id = source.job_id
  735. target.pool = source.pool
  736. target.pool_slots = source.pool_slots or 1
  737. target.queue = source.queue
  738. target.priority_weight = source.priority_weight
  739. target.operator = source.operator
  740. target.custom_operator_name = source.custom_operator_name
  741. target.queued_dttm = source.queued_dttm
  742. target.queued_by_job_id = source.queued_by_job_id
  743. target.pid = source.pid
  744. target.executor = source.executor
  745. target.executor_config = source.executor_config
  746. target.external_executor_id = source.external_executor_id
  747. target.trigger_id = source.trigger_id
  748. target.next_method = source.next_method
  749. target.next_kwargs = source.next_kwargs
  750. if include_dag_run:
  751. target.execution_date = source.execution_date
  752. target.dag_run.id = source.dag_run.id
  753. target.dag_run.dag_id = source.dag_run.dag_id
  754. target.dag_run.queued_at = source.dag_run.queued_at
  755. target.dag_run.execution_date = source.dag_run.execution_date
  756. target.dag_run.start_date = source.dag_run.start_date
  757. target.dag_run.end_date = source.dag_run.end_date
  758. target.dag_run.state = source.dag_run.state
  759. target.dag_run.run_id = source.dag_run.run_id
  760. target.dag_run.creating_job_id = source.dag_run.creating_job_id
  761. target.dag_run.external_trigger = source.dag_run.external_trigger
  762. target.dag_run.run_type = source.dag_run.run_type
  763. target.dag_run.conf = source.dag_run.conf
  764. target.dag_run.data_interval_start = source.dag_run.data_interval_start
  765. target.dag_run.data_interval_end = source.dag_run.data_interval_end
  766. target.dag_run.last_scheduling_decision = source.dag_run.last_scheduling_decision
  767. target.dag_run.dag_hash = source.dag_run.dag_hash
  768. target.dag_run.updated_at = source.dag_run.updated_at
  769. target.dag_run.log_template_id = source.dag_run.log_template_id
  770. def _refresh_from_db(
  771. *,
  772. task_instance: TaskInstance | TaskInstancePydantic,
  773. session: Session | None = None,
  774. lock_for_update: bool = False,
  775. ) -> None:
  776. """
  777. Refresh the task instance from the database based on the primary key.
  778. :param task_instance: the task instance
  779. :param session: SQLAlchemy ORM Session
  780. :param lock_for_update: if True, indicates that the database should
  781. lock the TaskInstance (issuing a FOR UPDATE clause) until the
  782. session is committed.
  783. :meta private:
  784. """
  785. if not InternalApiConfig.get_use_internal_api():
  786. if session and task_instance in session:
  787. session.refresh(task_instance, TaskInstance.__mapper__.column_attrs.keys())
  788. ti = TaskInstance.get_task_instance(
  789. dag_id=task_instance.dag_id,
  790. task_id=task_instance.task_id,
  791. run_id=task_instance.run_id,
  792. map_index=task_instance.map_index,
  793. lock_for_update=lock_for_update,
  794. session=session,
  795. )
  796. if ti:
  797. from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
  798. include_dag_run = isinstance(ti, TaskInstancePydantic)
  799. # in case of internal API, what we get is TaskInstancePydantic value, and we are supposed
  800. # to also update dag_run information as it might not be available. We cannot always do it in
  801. # case ti is TaskInstance, because it might be detached/ not loaded yet and dag_run might
  802. # not be available.
  803. _set_ti_attrs(task_instance, ti, include_dag_run=include_dag_run)
  804. else:
  805. task_instance.state = None
  806. def _set_duration(*, task_instance: TaskInstance | TaskInstancePydantic) -> None:
  807. """
  808. Set task instance duration.
  809. :param task_instance: the task instance
  810. :meta private:
  811. """
  812. if task_instance.end_date and task_instance.start_date:
  813. task_instance.duration = (task_instance.end_date - task_instance.start_date).total_seconds()
  814. else:
  815. task_instance.duration = None
  816. log.debug("Task Duration set to %s", task_instance.duration)
  817. def _stats_tags(*, task_instance: TaskInstance | TaskInstancePydantic) -> dict[str, str]:
  818. """
  819. Return task instance tags.
  820. :param task_instance: the task instance
  821. :meta private:
  822. """
  823. return prune_dict({"dag_id": task_instance.dag_id, "task_id": task_instance.task_id})
  824. def _clear_next_method_args(*, task_instance: TaskInstance | TaskInstancePydantic) -> None:
  825. """
  826. Ensure we unset next_method and next_kwargs to ensure that any retries don't reuse them.
  827. :param task_instance: the task instance
  828. :meta private:
  829. """
  830. log.debug("Clearing next_method and next_kwargs.")
  831. task_instance.next_method = None
  832. task_instance.next_kwargs = None
  833. @internal_api_call
  834. def _get_template_context(
  835. *,
  836. task_instance: TaskInstance | TaskInstancePydantic,
  837. dag: DAG,
  838. session: Session | None = None,
  839. ignore_param_exceptions: bool = True,
  840. ) -> Context:
  841. """
  842. Return TI Context.
  843. :param task_instance: the task instance for the task
  844. :param dag: dag for the task
  845. :param session: SQLAlchemy ORM Session
  846. :param ignore_param_exceptions: flag to suppress value exceptions while initializing the ParamsDict
  847. :meta private:
  848. """
  849. # Do not use provide_session here -- it expunges everything on exit!
  850. if not session:
  851. session = settings.Session()
  852. from airflow import macros
  853. from airflow.models.abstractoperator import NotMapped
  854. integrate_macros_plugins()
  855. task = task_instance.task
  856. if TYPE_CHECKING:
  857. assert task_instance.task
  858. assert task
  859. assert task.dag
  860. if task.dag.__class__ is AttributeRemoved:
  861. task.dag = dag # required after deserialization
  862. dag_run = task_instance.get_dagrun(session)
  863. data_interval = dag.get_run_data_interval(dag_run)
  864. validated_params = process_params(dag, task, dag_run, suppress_exception=ignore_param_exceptions)
  865. logical_date: DateTime = timezone.coerce_datetime(task_instance.execution_date)
  866. ds = logical_date.strftime("%Y-%m-%d")
  867. ds_nodash = ds.replace("-", "")
  868. ts = logical_date.isoformat()
  869. ts_nodash = logical_date.strftime("%Y%m%dT%H%M%S")
  870. ts_nodash_with_tz = ts.replace("-", "").replace(":", "")
  871. @cache # Prevent multiple database access.
  872. def _get_previous_dagrun_success() -> DagRun | None:
  873. return task_instance.get_previous_dagrun(state=DagRunState.SUCCESS, session=session)
  874. def _get_previous_dagrun_data_interval_success() -> DataInterval | None:
  875. dagrun = _get_previous_dagrun_success()
  876. if dagrun is None:
  877. return None
  878. return dag.get_run_data_interval(dagrun)
  879. def get_prev_data_interval_start_success() -> pendulum.DateTime | None:
  880. data_interval = _get_previous_dagrun_data_interval_success()
  881. if data_interval is None:
  882. return None
  883. return data_interval.start
  884. def get_prev_data_interval_end_success() -> pendulum.DateTime | None:
  885. data_interval = _get_previous_dagrun_data_interval_success()
  886. if data_interval is None:
  887. return None
  888. return data_interval.end
  889. def get_prev_start_date_success() -> pendulum.DateTime | None:
  890. dagrun = _get_previous_dagrun_success()
  891. if dagrun is None:
  892. return None
  893. return timezone.coerce_datetime(dagrun.start_date)
  894. def get_prev_end_date_success() -> pendulum.DateTime | None:
  895. dagrun = _get_previous_dagrun_success()
  896. if dagrun is None:
  897. return None
  898. return timezone.coerce_datetime(dagrun.end_date)
  899. @cache
  900. def get_yesterday_ds() -> str:
  901. return (logical_date - timedelta(1)).strftime("%Y-%m-%d")
  902. def get_yesterday_ds_nodash() -> str:
  903. return get_yesterday_ds().replace("-", "")
  904. @cache
  905. def get_tomorrow_ds() -> str:
  906. return (logical_date + timedelta(1)).strftime("%Y-%m-%d")
  907. def get_tomorrow_ds_nodash() -> str:
  908. return get_tomorrow_ds().replace("-", "")
  909. @cache
  910. def get_next_execution_date() -> pendulum.DateTime | None:
  911. # For manually triggered dagruns that aren't run on a schedule,
  912. # the "next" execution date doesn't make sense, and should be set
  913. # to execution date for consistency with how execution_date is set
  914. # for manually triggered tasks, i.e. triggered_date == execution_date.
  915. if dag_run.external_trigger:
  916. return logical_date
  917. if dag is None:
  918. return None
  919. next_info = dag.next_dagrun_info(data_interval, restricted=False)
  920. if next_info is None:
  921. return None
  922. return timezone.coerce_datetime(next_info.logical_date)
  923. def get_next_ds() -> str | None:
  924. execution_date = get_next_execution_date()
  925. if execution_date is None:
  926. return None
  927. return execution_date.strftime("%Y-%m-%d")
  928. def get_next_ds_nodash() -> str | None:
  929. ds = get_next_ds()
  930. if ds is None:
  931. return ds
  932. return ds.replace("-", "")
  933. @cache
  934. def get_prev_execution_date():
  935. # For manually triggered dagruns that aren't run on a schedule,
  936. # the "previous" execution date doesn't make sense, and should be set
  937. # to execution date for consistency with how execution_date is set
  938. # for manually triggered tasks, i.e. triggered_date == execution_date.
  939. if dag_run.external_trigger:
  940. return logical_date
  941. with warnings.catch_warnings():
  942. warnings.simplefilter("ignore", RemovedInAirflow3Warning)
  943. return dag.previous_schedule(logical_date)
  944. @cache
  945. def get_prev_ds() -> str | None:
  946. execution_date = get_prev_execution_date()
  947. if execution_date is None:
  948. return None
  949. return execution_date.strftime("%Y-%m-%d")
  950. def get_prev_ds_nodash() -> str | None:
  951. prev_ds = get_prev_ds()
  952. if prev_ds is None:
  953. return None
  954. return prev_ds.replace("-", "")
  955. def get_triggering_events() -> dict[str, list[DatasetEvent | DatasetEventPydantic]]:
  956. if TYPE_CHECKING:
  957. assert session is not None
  958. # The dag_run may not be attached to the session anymore since the
  959. # code base is over-zealous with use of session.expunge_all().
  960. # Re-attach it if we get called.
  961. nonlocal dag_run
  962. if dag_run not in session:
  963. dag_run = session.merge(dag_run, load=False)
  964. dataset_events = dag_run.consumed_dataset_events
  965. triggering_events: dict[str, list[DatasetEvent | DatasetEventPydantic]] = defaultdict(list)
  966. for event in dataset_events:
  967. if event.dataset:
  968. triggering_events[event.dataset.uri].append(event)
  969. return triggering_events
  970. try:
  971. expanded_ti_count: int | None = task.get_mapped_ti_count(task_instance.run_id, session=session)
  972. except NotMapped:
  973. expanded_ti_count = None
  974. # NOTE: If you add to this dict, make sure to also update the following:
  975. # * Context in airflow/utils/context.pyi
  976. # * KNOWN_CONTEXT_KEYS in airflow/utils/context.py
  977. # * Table in docs/apache-airflow/templates-ref.rst
  978. context: dict[str, Any] = {
  979. "conf": conf,
  980. "dag": dag,
  981. "dag_run": dag_run,
  982. "data_interval_end": timezone.coerce_datetime(data_interval.end),
  983. "data_interval_start": timezone.coerce_datetime(data_interval.start),
  984. "outlet_events": OutletEventAccessors(),
  985. "ds": ds,
  986. "ds_nodash": ds_nodash,
  987. "execution_date": logical_date,
  988. "expanded_ti_count": expanded_ti_count,
  989. "inlets": task.inlets,
  990. "inlet_events": InletEventsAccessors(task.inlets, session=session),
  991. "logical_date": logical_date,
  992. "macros": macros,
  993. "map_index_template": task.map_index_template,
  994. "next_ds": get_next_ds(),
  995. "next_ds_nodash": get_next_ds_nodash(),
  996. "next_execution_date": get_next_execution_date(),
  997. "outlets": task.outlets,
  998. "params": validated_params,
  999. "prev_data_interval_start_success": get_prev_data_interval_start_success(),
  1000. "prev_data_interval_end_success": get_prev_data_interval_end_success(),
  1001. "prev_ds": get_prev_ds(),
  1002. "prev_ds_nodash": get_prev_ds_nodash(),
  1003. "prev_execution_date": get_prev_execution_date(),
  1004. "prev_execution_date_success": task_instance.get_previous_execution_date(
  1005. state=DagRunState.SUCCESS,
  1006. session=session,
  1007. ),
  1008. "prev_start_date_success": get_prev_start_date_success(),
  1009. "prev_end_date_success": get_prev_end_date_success(),
  1010. "run_id": task_instance.run_id,
  1011. "task": task,
  1012. "task_instance": task_instance,
  1013. "task_instance_key_str": f"{task.dag_id}__{task.task_id}__{ds_nodash}",
  1014. "test_mode": task_instance.test_mode,
  1015. "ti": task_instance,
  1016. "tomorrow_ds": get_tomorrow_ds(),
  1017. "tomorrow_ds_nodash": get_tomorrow_ds_nodash(),
  1018. "triggering_dataset_events": lazy_object_proxy.Proxy(get_triggering_events),
  1019. "ts": ts,
  1020. "ts_nodash": ts_nodash,
  1021. "ts_nodash_with_tz": ts_nodash_with_tz,
  1022. "var": {
  1023. "json": VariableAccessor(deserialize_json=True),
  1024. "value": VariableAccessor(deserialize_json=False),
  1025. },
  1026. "conn": ConnectionAccessor(),
  1027. "yesterday_ds": get_yesterday_ds(),
  1028. "yesterday_ds_nodash": get_yesterday_ds_nodash(),
  1029. }
  1030. # Mypy doesn't like turning existing dicts in to a TypeDict -- and we "lie" in the type stub to say it
  1031. # is one, but in practice it isn't. See https://github.com/python/mypy/issues/8890
  1032. return Context(context) # type: ignore
  1033. def _is_eligible_to_retry(*, task_instance: TaskInstance | TaskInstancePydantic):
  1034. """
  1035. Is task instance is eligible for retry.
  1036. :param task_instance: the task instance
  1037. :meta private:
  1038. """
  1039. if task_instance.state == TaskInstanceState.RESTARTING:
  1040. # If a task is cleared when running, it goes into RESTARTING state and is always
  1041. # eligible for retry
  1042. return True
  1043. if not getattr(task_instance, "task", None):
  1044. # Couldn't load the task, don't know number of retries, guess:
  1045. return task_instance.try_number <= task_instance.max_tries
  1046. if TYPE_CHECKING:
  1047. assert task_instance.task
  1048. return task_instance.task.retries and task_instance.try_number <= task_instance.max_tries
  1049. @provide_session
  1050. @internal_api_call
  1051. def _handle_failure(
  1052. *,
  1053. task_instance: TaskInstance | TaskInstancePydantic,
  1054. error: None | str | BaseException,
  1055. session: Session,
  1056. test_mode: bool | None = None,
  1057. context: Context | None = None,
  1058. force_fail: bool = False,
  1059. fail_stop: bool = False,
  1060. ) -> None:
  1061. """
  1062. Handle Failure for a task instance.
  1063. :param task_instance: the task instance
  1064. :param error: if specified, log the specific exception if thrown
  1065. :param session: SQLAlchemy ORM Session
  1066. :param test_mode: doesn't record success or failure in the DB if True
  1067. :param context: Jinja2 context
  1068. :param force_fail: if True, task does not retry
  1069. :meta private:
  1070. """
  1071. if test_mode is None:
  1072. test_mode = task_instance.test_mode
  1073. task_instance = _coalesce_to_orm_ti(ti=task_instance, session=session)
  1074. failure_context = TaskInstance.fetch_handle_failure_context(
  1075. ti=task_instance, # type: ignore[arg-type]
  1076. error=error,
  1077. test_mode=test_mode,
  1078. context=context,
  1079. force_fail=force_fail,
  1080. session=session,
  1081. fail_stop=fail_stop,
  1082. )
  1083. _log_state(task_instance=task_instance, lead_msg="Immediate failure requested. " if force_fail else "")
  1084. if (
  1085. failure_context["task"]
  1086. and failure_context["email_for_state"](failure_context["task"])
  1087. and failure_context["task"].email
  1088. ):
  1089. try:
  1090. task_instance.email_alert(error, failure_context["task"])
  1091. except Exception:
  1092. log.exception("Failed to send email to: %s", failure_context["task"].email)
  1093. if failure_context["callbacks"] and failure_context["context"]:
  1094. _run_finished_callback(
  1095. callbacks=failure_context["callbacks"],
  1096. context=failure_context["context"],
  1097. )
  1098. if not test_mode:
  1099. TaskInstance.save_to_db(task_instance, session)
  1100. with Trace.start_span_from_taskinstance(ti=task_instance) as span:
  1101. # ---- error info ----
  1102. span.set_attribute("error", "true")
  1103. span.set_attribute("error_msg", str(error))
  1104. span.set_attribute("context", context)
  1105. span.set_attribute("force_fail", force_fail)
  1106. # ---- common info ----
  1107. span.set_attribute("category", "DAG runs")
  1108. span.set_attribute("task_id", task_instance.task_id)
  1109. span.set_attribute("dag_id", task_instance.dag_id)
  1110. span.set_attribute("state", task_instance.state)
  1111. span.set_attribute("start_date", str(task_instance.start_date))
  1112. span.set_attribute("end_date", str(task_instance.end_date))
  1113. span.set_attribute("duration", task_instance.duration)
  1114. span.set_attribute("executor_config", str(task_instance.executor_config))
  1115. span.set_attribute("execution_date", str(task_instance.execution_date))
  1116. span.set_attribute("hostname", task_instance.hostname)
  1117. if isinstance(task_instance, TaskInstance):
  1118. span.set_attribute("log_url", task_instance.log_url)
  1119. span.set_attribute("operator", str(task_instance.operator))
  1120. def _refresh_from_task(
  1121. *, task_instance: TaskInstance | TaskInstancePydantic, task: Operator, pool_override: str | None = None
  1122. ) -> None:
  1123. """
  1124. Copy common attributes from the given task.
  1125. :param task_instance: the task instance
  1126. :param task: The task object to copy from
  1127. :param pool_override: Use the pool_override instead of task's pool
  1128. :meta private:
  1129. """
  1130. task_instance.task = task
  1131. task_instance.queue = task.queue
  1132. task_instance.pool = pool_override or task.pool
  1133. task_instance.pool_slots = task.pool_slots
  1134. with contextlib.suppress(Exception):
  1135. # This method is called from the different places, and sometimes the TI is not fully initialized
  1136. task_instance.priority_weight = task_instance.task.weight_rule.get_weight(
  1137. task_instance # type: ignore[arg-type]
  1138. )
  1139. task_instance.run_as_user = task.run_as_user
  1140. # Do not set max_tries to task.retries here because max_tries is a cumulative
  1141. # value that needs to be stored in the db.
  1142. task_instance.executor = task.executor
  1143. task_instance.executor_config = task.executor_config
  1144. task_instance.operator = task.task_type
  1145. task_instance.custom_operator_name = getattr(task, "custom_operator_name", None)
  1146. # Re-apply cluster policy here so that task default do not overload previous data
  1147. task_instance_mutation_hook(task_instance)
  1148. @internal_api_call
  1149. @provide_session
  1150. def _record_task_map_for_downstreams(
  1151. *,
  1152. task_instance: TaskInstance | TaskInstancePydantic,
  1153. task: Operator,
  1154. dag: DAG,
  1155. value: Any,
  1156. session: Session,
  1157. ) -> None:
  1158. """
  1159. Record the task map for downstream tasks.
  1160. :param task_instance: the task instance
  1161. :param task: The task object
  1162. :param dag: the dag associated with the task
  1163. :param value: The value
  1164. :param session: SQLAlchemy ORM Session
  1165. :meta private:
  1166. """
  1167. from airflow.models.mappedoperator import MappedOperator
  1168. if task.dag.__class__ is AttributeRemoved:
  1169. task.dag = dag # required after deserialization
  1170. if next(task.iter_mapped_dependants(), None) is None: # No mapped dependants, no need to validate.
  1171. return
  1172. # TODO: We don't push TaskMap for mapped task instances because it's not
  1173. # currently possible for a downstream to depend on one individual mapped
  1174. # task instance. This will change when we implement task mapping inside
  1175. # a mapped task group, and we'll need to further analyze the case.
  1176. if isinstance(task, MappedOperator):
  1177. return
  1178. if value is None:
  1179. raise XComForMappingNotPushed()
  1180. if not _is_mappable_value(value):
  1181. raise UnmappableXComTypePushed(value)
  1182. task_map = TaskMap.from_task_instance_xcom(task_instance, value)
  1183. max_map_length = conf.getint("core", "max_map_length", fallback=1024)
  1184. if task_map.length > max_map_length:
  1185. raise UnmappableXComLengthPushed(value, max_map_length)
  1186. session.merge(task_map)
  1187. def _get_previous_dagrun(
  1188. *,
  1189. task_instance: TaskInstance | TaskInstancePydantic,
  1190. state: DagRunState | None = None,
  1191. session: Session | None = None,
  1192. ) -> DagRun | None:
  1193. """
  1194. Return the DagRun that ran prior to this task instance's DagRun.
  1195. :param task_instance: the task instance
  1196. :param state: If passed, it only take into account instances of a specific state.
  1197. :param session: SQLAlchemy ORM Session.
  1198. :meta private:
  1199. """
  1200. if TYPE_CHECKING:
  1201. assert task_instance.task
  1202. dag = task_instance.task.dag
  1203. if dag is None:
  1204. return None
  1205. dr = task_instance.get_dagrun(session=session)
  1206. dr.dag = dag
  1207. from airflow.models.dagrun import DagRun # Avoid circular import
  1208. # We always ignore schedule in dagrun lookup when `state` is given
  1209. # or the DAG is never scheduled. For legacy reasons, when
  1210. # `catchup=True`, we use `get_previous_scheduled_dagrun` unless
  1211. # `ignore_schedule` is `True`.
  1212. ignore_schedule = state is not None or not dag.timetable.can_be_scheduled
  1213. if dag.catchup is True and not ignore_schedule:
  1214. last_dagrun = DagRun.get_previous_scheduled_dagrun(dr.id, session=session)
  1215. else:
  1216. last_dagrun = DagRun.get_previous_dagrun(dag_run=dr, session=session, state=state)
  1217. if last_dagrun:
  1218. return last_dagrun
  1219. return None
  1220. def _get_previous_execution_date(
  1221. *,
  1222. task_instance: TaskInstance | TaskInstancePydantic,
  1223. state: DagRunState | None,
  1224. session: Session,
  1225. ) -> pendulum.DateTime | None:
  1226. """
  1227. Get execution date from property previous_ti_success.
  1228. :param task_instance: the task instance
  1229. :param session: SQLAlchemy ORM Session
  1230. :param state: If passed, it only take into account instances of a specific state.
  1231. :meta private:
  1232. """
  1233. log.debug("previous_execution_date was called")
  1234. prev_ti = task_instance.get_previous_ti(state=state, session=session)
  1235. return pendulum.instance(prev_ti.execution_date) if prev_ti and prev_ti.execution_date else None
  1236. def _get_previous_start_date(
  1237. *,
  1238. task_instance: TaskInstance | TaskInstancePydantic,
  1239. state: DagRunState | None,
  1240. session: Session,
  1241. ) -> pendulum.DateTime | None:
  1242. """
  1243. Return the start date from property previous_ti_success.
  1244. :param task_instance: the task instance
  1245. :param state: If passed, it only take into account instances of a specific state.
  1246. :param session: SQLAlchemy ORM Session
  1247. """
  1248. log.debug("previous_start_date was called")
  1249. prev_ti = task_instance.get_previous_ti(state=state, session=session)
  1250. # prev_ti may not exist and prev_ti.start_date may be None.
  1251. return pendulum.instance(prev_ti.start_date) if prev_ti and prev_ti.start_date else None
  1252. def _email_alert(
  1253. *, task_instance: TaskInstance | TaskInstancePydantic, exception, task: BaseOperator
  1254. ) -> None:
  1255. """
  1256. Send alert email with exception information.
  1257. :param task_instance: the task instance
  1258. :param exception: the exception
  1259. :param task: task related to the exception
  1260. :meta private:
  1261. """
  1262. subject, html_content, html_content_err = task_instance.get_email_subject_content(exception, task=task)
  1263. if TYPE_CHECKING:
  1264. assert task.email
  1265. try:
  1266. send_email(task.email, subject, html_content)
  1267. except Exception:
  1268. send_email(task.email, subject, html_content_err)
  1269. def _get_email_subject_content(
  1270. *,
  1271. task_instance: TaskInstance | TaskInstancePydantic,
  1272. exception: BaseException,
  1273. task: BaseOperator | None = None,
  1274. ) -> tuple[str, str, str]:
  1275. """
  1276. Get the email subject content for exceptions.
  1277. :param task_instance: the task instance
  1278. :param exception: the exception sent in the email
  1279. :param task:
  1280. :meta private:
  1281. """
  1282. # For a ti from DB (without ti.task), return the default value
  1283. if task is None:
  1284. task = getattr(task_instance, "task")
  1285. use_default = task is None
  1286. exception_html = str(exception).replace("\n", "<br>")
  1287. default_subject = "Airflow alert: {{ti}}"
  1288. # For reporting purposes, we report based on 1-indexed,
  1289. # not 0-indexed lists (i.e. Try 1 instead of
  1290. # Try 0 for the first attempt).
  1291. default_html_content = (
  1292. "Try {{try_number}} out of {{max_tries + 1}}<br>"
  1293. "Exception:<br>{{exception_html}}<br>"
  1294. 'Log: <a href="{{ti.log_url}}">Link</a><br>'
  1295. "Host: {{ti.hostname}}<br>"
  1296. 'Mark success: <a href="{{ti.mark_success_url}}">Link</a><br>'
  1297. )
  1298. default_html_content_err = (
  1299. "Try {{try_number}} out of {{max_tries + 1}}<br>"
  1300. "Exception:<br>Failed attempt to attach error logs<br>"
  1301. 'Log: <a href="{{ti.log_url}}">Link</a><br>'
  1302. "Host: {{ti.hostname}}<br>"
  1303. 'Mark success: <a href="{{ti.mark_success_url}}">Link</a><br>'
  1304. )
  1305. additional_context: dict[str, Any] = {
  1306. "exception": exception,
  1307. "exception_html": exception_html,
  1308. "try_number": task_instance.try_number,
  1309. "max_tries": task_instance.max_tries,
  1310. }
  1311. if use_default:
  1312. default_context = {"ti": task_instance, **additional_context}
  1313. jinja_env = jinja2.Environment(
  1314. loader=jinja2.FileSystemLoader(os.path.dirname(__file__)), autoescape=True
  1315. )
  1316. subject = jinja_env.from_string(default_subject).render(**default_context)
  1317. html_content = jinja_env.from_string(default_html_content).render(**default_context)
  1318. html_content_err = jinja_env.from_string(default_html_content_err).render(**default_context)
  1319. else:
  1320. if TYPE_CHECKING:
  1321. assert task_instance.task
  1322. # Use the DAG's get_template_env() to set force_sandboxed. Don't add
  1323. # the flag to the function on task object -- that function can be
  1324. # overridden, and adding a flag breaks backward compatibility.
  1325. dag = task_instance.task.get_dag()
  1326. if dag:
  1327. jinja_env = dag.get_template_env(force_sandboxed=True)
  1328. else:
  1329. jinja_env = SandboxedEnvironment(cache_size=0)
  1330. jinja_context = task_instance.get_template_context()
  1331. context_merge(jinja_context, additional_context)
  1332. def render(key: str, content: str) -> str:
  1333. if conf.has_option("email", key):
  1334. path = conf.get_mandatory_value("email", key)
  1335. try:
  1336. with open(path) as f:
  1337. content = f.read()
  1338. except FileNotFoundError:
  1339. log.warning("Could not find email template file '%s'. Using defaults...", path)
  1340. except OSError:
  1341. log.exception("Error while using email template %s. Using defaults...", path)
  1342. return render_template_to_string(jinja_env.from_string(content), jinja_context)
  1343. subject = render("subject_template", default_subject)
  1344. html_content = render("html_content_template", default_html_content)
  1345. html_content_err = render("html_content_template", default_html_content_err)
  1346. return subject, html_content, html_content_err
  1347. def _run_finished_callback(
  1348. *,
  1349. callbacks: None | TaskStateChangeCallback | list[TaskStateChangeCallback],
  1350. context: Context,
  1351. ) -> None:
  1352. """
  1353. Run callback after task finishes.
  1354. :param callbacks: callbacks to run
  1355. :param context: callbacks context
  1356. :meta private:
  1357. """
  1358. if callbacks:
  1359. callbacks = callbacks if isinstance(callbacks, list) else [callbacks]
  1360. def get_callback_representation(callback: TaskStateChangeCallback) -> Any:
  1361. with contextlib.suppress(AttributeError):
  1362. return callback.__name__
  1363. with contextlib.suppress(AttributeError):
  1364. return callback.__class__.__name__
  1365. return callback
  1366. for idx, callback in enumerate(callbacks):
  1367. callback_repr = get_callback_representation(callback)
  1368. log.info("Executing callback at index %d: %s", idx, callback_repr)
  1369. try:
  1370. callback(context)
  1371. except Exception:
  1372. log.exception("Error in callback at index %d: %s", idx, callback_repr)
  1373. def _log_state(*, task_instance: TaskInstance | TaskInstancePydantic, lead_msg: str = "") -> None:
  1374. """
  1375. Log task state.
  1376. :param task_instance: the task instance
  1377. :param lead_msg: lead message
  1378. :meta private:
  1379. """
  1380. params = [
  1381. lead_msg,
  1382. str(task_instance.state).upper(),
  1383. task_instance.dag_id,
  1384. task_instance.task_id,
  1385. task_instance.run_id,
  1386. ]
  1387. message = "%sMarking task as %s. dag_id=%s, task_id=%s, run_id=%s, "
  1388. if task_instance.map_index >= 0:
  1389. params.append(task_instance.map_index)
  1390. message += "map_index=%d, "
  1391. message += "execution_date=%s, start_date=%s, end_date=%s"
  1392. log.info(
  1393. message,
  1394. *params,
  1395. _date_or_empty(task_instance=task_instance, attr="execution_date"),
  1396. _date_or_empty(task_instance=task_instance, attr="start_date"),
  1397. _date_or_empty(task_instance=task_instance, attr="end_date"),
  1398. stacklevel=2,
  1399. )
  1400. def _date_or_empty(*, task_instance: TaskInstance | TaskInstancePydantic, attr: str) -> str:
  1401. """
  1402. Fetch a date attribute or None of it does not exist.
  1403. :param task_instance: the task instance
  1404. :param attr: the attribute name
  1405. :meta private:
  1406. """
  1407. result: datetime | None = getattr(task_instance, attr, None)
  1408. return result.strftime("%Y%m%dT%H%M%S") if result else ""
  1409. def _get_previous_ti(
  1410. *,
  1411. task_instance: TaskInstance | TaskInstancePydantic,
  1412. session: Session,
  1413. state: DagRunState | None = None,
  1414. ) -> TaskInstance | TaskInstancePydantic | None:
  1415. """
  1416. Get task instance for the task that ran before this task instance.
  1417. :param task_instance: the task instance
  1418. :param state: If passed, it only take into account instances of a specific state.
  1419. :param session: SQLAlchemy ORM Session
  1420. :meta private:
  1421. """
  1422. dagrun = task_instance.get_previous_dagrun(state, session=session)
  1423. if dagrun is None:
  1424. return None
  1425. return dagrun.get_task_instance(task_instance.task_id, session=session)
  1426. @internal_api_call
  1427. @provide_session
  1428. def _update_rtif(ti, rendered_fields, session: Session = NEW_SESSION):
  1429. from airflow.models.renderedtifields import RenderedTaskInstanceFields
  1430. rtif = RenderedTaskInstanceFields(ti=ti, render_templates=False, rendered_fields=rendered_fields)
  1431. RenderedTaskInstanceFields.write(rtif, session=session)
  1432. session.flush()
  1433. RenderedTaskInstanceFields.delete_old_records(ti.task_id, ti.dag_id, session=session)
  1434. def _coalesce_to_orm_ti(*, ti: TaskInstancePydantic | TaskInstance, session: Session):
  1435. from airflow.models.dagrun import DagRun
  1436. from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
  1437. if isinstance(ti, TaskInstancePydantic):
  1438. orm_ti = DagRun.fetch_task_instance(
  1439. dag_id=ti.dag_id,
  1440. dag_run_id=ti.run_id,
  1441. task_id=ti.task_id,
  1442. map_index=ti.map_index,
  1443. session=session,
  1444. )
  1445. if TYPE_CHECKING:
  1446. assert orm_ti
  1447. ti, pydantic_ti = orm_ti, ti
  1448. _set_ti_attrs(ti, pydantic_ti)
  1449. ti.task = pydantic_ti.task
  1450. return ti
  1451. @internal_api_call
  1452. @provide_session
  1453. def _defer_task(
  1454. ti: TaskInstance | TaskInstancePydantic,
  1455. exception: TaskDeferred | None = None,
  1456. session: Session = NEW_SESSION,
  1457. ) -> TaskInstancePydantic | TaskInstance:
  1458. from airflow.models.trigger import Trigger
  1459. if exception is not None:
  1460. trigger_row = Trigger.from_object(exception.trigger)
  1461. next_method = exception.method_name
  1462. next_kwargs = exception.kwargs
  1463. timeout = exception.timeout
  1464. elif ti.task is not None and ti.task.start_trigger_args is not None:
  1465. context = ti.get_template_context()
  1466. start_trigger_args = ti.task.expand_start_trigger_args(context=context, session=session)
  1467. if start_trigger_args is None:
  1468. raise TaskDeferralError(
  1469. "A none 'None' start_trigger_args has been change to 'None' during expandion"
  1470. )
  1471. trigger_kwargs = start_trigger_args.trigger_kwargs or {}
  1472. next_kwargs = start_trigger_args.next_kwargs
  1473. next_method = start_trigger_args.next_method
  1474. timeout = start_trigger_args.timeout
  1475. trigger_row = Trigger(
  1476. classpath=ti.task.start_trigger_args.trigger_cls,
  1477. kwargs=trigger_kwargs,
  1478. )
  1479. else:
  1480. raise TaskDeferralError("exception and ti.task.start_trigger_args cannot both be None")
  1481. # First, make the trigger entry
  1482. session.add(trigger_row)
  1483. session.flush()
  1484. ti = _coalesce_to_orm_ti(ti=ti, session=session) # ensure orm obj in case it's pydantic
  1485. if TYPE_CHECKING:
  1486. assert ti.task
  1487. # Then, update ourselves so it matches the deferral request
  1488. # Keep an eye on the logic in `check_and_change_state_before_execution()`
  1489. # depending on self.next_method semantics
  1490. ti.state = TaskInstanceState.DEFERRED
  1491. ti.trigger_id = trigger_row.id
  1492. ti.next_method = next_method
  1493. ti.next_kwargs = next_kwargs or {}
  1494. # Calculate timeout too if it was passed
  1495. if timeout is not None:
  1496. ti.trigger_timeout = timezone.utcnow() + timeout
  1497. else:
  1498. ti.trigger_timeout = None
  1499. # If an execution_timeout is set, set the timeout to the minimum of
  1500. # it and the trigger timeout
  1501. execution_timeout = ti.task.execution_timeout
  1502. if execution_timeout:
  1503. if TYPE_CHECKING:
  1504. assert ti.start_date
  1505. if ti.trigger_timeout:
  1506. ti.trigger_timeout = min(ti.start_date + execution_timeout, ti.trigger_timeout)
  1507. else:
  1508. ti.trigger_timeout = ti.start_date + execution_timeout
  1509. if ti.test_mode:
  1510. _add_log(event=ti.state, task_instance=ti, session=session)
  1511. if exception is not None:
  1512. session.merge(ti)
  1513. session.commit()
  1514. return ti
  1515. @internal_api_call
  1516. @provide_session
  1517. def _handle_reschedule(
  1518. ti,
  1519. actual_start_date: datetime,
  1520. reschedule_exception: AirflowRescheduleException,
  1521. test_mode: bool = False,
  1522. session: Session = NEW_SESSION,
  1523. ):
  1524. # Don't record reschedule request in test mode
  1525. if test_mode:
  1526. return
  1527. ti = _coalesce_to_orm_ti(ti=ti, session=session)
  1528. from airflow.models.dagrun import DagRun # Avoid circular import
  1529. ti.refresh_from_db(session)
  1530. if TYPE_CHECKING:
  1531. assert ti.task
  1532. ti.end_date = timezone.utcnow()
  1533. ti.set_duration()
  1534. # Lock DAG run to be sure not to get into a deadlock situation when trying to insert
  1535. # TaskReschedule which apparently also creates lock on corresponding DagRun entity
  1536. with_row_locks(
  1537. session.query(DagRun).filter_by(
  1538. dag_id=ti.dag_id,
  1539. run_id=ti.run_id,
  1540. ),
  1541. session=session,
  1542. ).one()
  1543. # Log reschedule request
  1544. session.add(
  1545. TaskReschedule(
  1546. ti.task_id,
  1547. ti.dag_id,
  1548. ti.run_id,
  1549. ti.try_number,
  1550. actual_start_date,
  1551. ti.end_date,
  1552. reschedule_exception.reschedule_date,
  1553. ti.map_index,
  1554. )
  1555. )
  1556. # set state
  1557. ti.state = TaskInstanceState.UP_FOR_RESCHEDULE
  1558. ti.clear_next_method_args()
  1559. session.merge(ti)
  1560. session.commit()
  1561. return ti
  1562. class TaskInstance(Base, LoggingMixin):
  1563. """
  1564. Task instances store the state of a task instance.
  1565. This table is the authority and single source of truth around what tasks
  1566. have run and the state they are in.
  1567. The SqlAlchemy model doesn't have a SqlAlchemy foreign key to the task or
  1568. dag model deliberately to have more control over transactions.
  1569. Database transactions on this table should insure double triggers and
  1570. any confusion around what task instances are or aren't ready to run
  1571. even while multiple schedulers may be firing task instances.
  1572. A value of -1 in map_index represents any of: a TI without mapped tasks;
  1573. a TI with mapped tasks that has yet to be expanded (state=pending);
  1574. a TI with mapped tasks that expanded to an empty list (state=skipped).
  1575. """
  1576. __tablename__ = "task_instance"
  1577. task_id = Column(StringID(), primary_key=True, nullable=False)
  1578. dag_id = Column(StringID(), primary_key=True, nullable=False)
  1579. run_id = Column(StringID(), primary_key=True, nullable=False)
  1580. map_index = Column(Integer, primary_key=True, nullable=False, server_default=text("-1"))
  1581. start_date = Column(UtcDateTime)
  1582. end_date = Column(UtcDateTime)
  1583. duration = Column(Float)
  1584. state = Column(String(20))
  1585. try_number = Column(Integer, default=0)
  1586. max_tries = Column(Integer, server_default=text("-1"))
  1587. hostname = Column(String(1000))
  1588. unixname = Column(String(1000))
  1589. job_id = Column(Integer)
  1590. pool = Column(String(256), nullable=False)
  1591. pool_slots = Column(Integer, default=1, nullable=False)
  1592. queue = Column(String(256))
  1593. priority_weight = Column(Integer)
  1594. operator = Column(String(1000))
  1595. custom_operator_name = Column(String(1000))
  1596. queued_dttm = Column(UtcDateTime)
  1597. queued_by_job_id = Column(Integer)
  1598. pid = Column(Integer)
  1599. executor = Column(String(1000))
  1600. executor_config = Column(ExecutorConfigType(pickler=dill))
  1601. updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow)
  1602. rendered_map_index = Column(String(250))
  1603. external_executor_id = Column(StringID())
  1604. # The trigger to resume on if we are in state DEFERRED
  1605. trigger_id = Column(Integer)
  1606. # Optional timeout datetime for the trigger (past this, we'll fail)
  1607. trigger_timeout = Column(DateTime)
  1608. # The trigger_timeout should be TIMESTAMP(using UtcDateTime) but for ease of
  1609. # migration, we are keeping it as DateTime pending a change where expensive
  1610. # migration is inevitable.
  1611. # The method to call next, and any extra arguments to pass to it.
  1612. # Usually used when resuming from DEFERRED.
  1613. next_method = Column(String(1000))
  1614. next_kwargs = Column(MutableDict.as_mutable(ExtendedJSON))
  1615. _task_display_property_value = Column("task_display_name", String(2000), nullable=True)
  1616. # If adding new fields here then remember to add them to
  1617. # refresh_from_db() or they won't display in the UI correctly
  1618. __table_args__ = (
  1619. Index("ti_dag_state", dag_id, state),
  1620. Index("ti_dag_run", dag_id, run_id),
  1621. Index("ti_state", state),
  1622. Index("ti_state_lkp", dag_id, task_id, run_id, state),
  1623. Index("ti_pool", pool, state, priority_weight),
  1624. Index("ti_job_id", job_id),
  1625. Index("ti_trigger_id", trigger_id),
  1626. PrimaryKeyConstraint("dag_id", "task_id", "run_id", "map_index", name="task_instance_pkey"),
  1627. ForeignKeyConstraint(
  1628. [trigger_id],
  1629. ["trigger.id"],
  1630. name="task_instance_trigger_id_fkey",
  1631. ondelete="CASCADE",
  1632. ),
  1633. ForeignKeyConstraint(
  1634. [dag_id, run_id],
  1635. ["dag_run.dag_id", "dag_run.run_id"],
  1636. name="task_instance_dag_run_fkey",
  1637. ondelete="CASCADE",
  1638. ),
  1639. )
  1640. dag_model: DagModel = relationship(
  1641. "DagModel",
  1642. primaryjoin="TaskInstance.dag_id == DagModel.dag_id",
  1643. foreign_keys=dag_id,
  1644. uselist=False,
  1645. innerjoin=True,
  1646. viewonly=True,
  1647. )
  1648. trigger = relationship("Trigger", uselist=False, back_populates="task_instance")
  1649. triggerer_job = association_proxy("trigger", "triggerer_job")
  1650. dag_run = relationship("DagRun", back_populates="task_instances", lazy="joined", innerjoin=True)
  1651. rendered_task_instance_fields = relationship("RenderedTaskInstanceFields", lazy="noload", uselist=False)
  1652. execution_date = association_proxy("dag_run", "execution_date")
  1653. task_instance_note = relationship(
  1654. "TaskInstanceNote",
  1655. back_populates="task_instance",
  1656. uselist=False,
  1657. cascade="all, delete, delete-orphan",
  1658. )
  1659. note = association_proxy("task_instance_note", "content", creator=_creator_note)
  1660. task: Operator | None = None
  1661. test_mode: bool = False
  1662. is_trigger_log_context: bool = False
  1663. run_as_user: str | None = None
  1664. raw: bool | None = None
  1665. """Indicate to FileTaskHandler that logging context should be set up for trigger logging.
  1666. :meta private:
  1667. """
  1668. _logger_name = "airflow.task"
  1669. def __init__(
  1670. self,
  1671. task: Operator,
  1672. execution_date: datetime | None = None,
  1673. run_id: str | None = None,
  1674. state: str | None = None,
  1675. map_index: int = -1,
  1676. ):
  1677. super().__init__()
  1678. self.dag_id = task.dag_id
  1679. self.task_id = task.task_id
  1680. self.map_index = map_index
  1681. self.refresh_from_task(task)
  1682. if TYPE_CHECKING:
  1683. assert self.task
  1684. # init_on_load will config the log
  1685. self.init_on_load()
  1686. if run_id is None and execution_date is not None:
  1687. from airflow.models.dagrun import DagRun # Avoid circular import
  1688. warnings.warn(
  1689. "Passing an execution_date to `TaskInstance()` is deprecated in favour of passing a run_id",
  1690. RemovedInAirflow3Warning,
  1691. # Stack level is 4 because SQLA adds some wrappers around the constructor
  1692. stacklevel=4,
  1693. )
  1694. # make sure we have a localized execution_date stored in UTC
  1695. if execution_date and not timezone.is_localized(execution_date):
  1696. self.log.warning(
  1697. "execution date %s has no timezone information. Using default from dag or system",
  1698. execution_date,
  1699. )
  1700. if self.task.has_dag():
  1701. if TYPE_CHECKING:
  1702. assert self.task.dag
  1703. execution_date = timezone.make_aware(execution_date, self.task.dag.timezone)
  1704. else:
  1705. execution_date = timezone.make_aware(execution_date)
  1706. execution_date = timezone.convert_to_utc(execution_date)
  1707. with create_session() as session:
  1708. run_id = (
  1709. session.query(DagRun.run_id)
  1710. .filter_by(dag_id=self.dag_id, execution_date=execution_date)
  1711. .scalar()
  1712. )
  1713. if run_id is None:
  1714. raise DagRunNotFound(
  1715. f"DagRun for {self.dag_id!r} with date {execution_date} not found"
  1716. ) from None
  1717. self.run_id = run_id
  1718. self.try_number = 0
  1719. self.max_tries = self.task.retries
  1720. self.unixname = getuser()
  1721. if state:
  1722. self.state = state
  1723. self.hostname = ""
  1724. # Is this TaskInstance being currently running within `airflow tasks run --raw`.
  1725. # Not persisted to the database so only valid for the current process
  1726. self.raw = False
  1727. # can be changed when calling 'run'
  1728. self.test_mode = False
  1729. def __hash__(self):
  1730. return hash((self.task_id, self.dag_id, self.run_id, self.map_index))
  1731. @property
  1732. @deprecated(reason="Use try_number instead.", version="2.10.0", category=RemovedInAirflow3Warning)
  1733. def _try_number(self):
  1734. """
  1735. Do not use. For semblance of backcompat.
  1736. :meta private:
  1737. """
  1738. return self.try_number
  1739. @_try_number.setter
  1740. @deprecated(reason="Use try_number instead.", version="2.10.0", category=RemovedInAirflow3Warning)
  1741. def _try_number(self, val):
  1742. """
  1743. Do not use. For semblance of backcompat.
  1744. :meta private:
  1745. """
  1746. self.try_number = val
  1747. @property
  1748. def stats_tags(self) -> dict[str, str]:
  1749. """Returns task instance tags."""
  1750. return _stats_tags(task_instance=self)
  1751. @staticmethod
  1752. def insert_mapping(run_id: str, task: Operator, map_index: int) -> dict[str, Any]:
  1753. """
  1754. Insert mapping.
  1755. :meta private:
  1756. """
  1757. priority_weight = task.weight_rule.get_weight(
  1758. TaskInstance(task=task, run_id=run_id, map_index=map_index)
  1759. )
  1760. return {
  1761. "dag_id": task.dag_id,
  1762. "task_id": task.task_id,
  1763. "run_id": run_id,
  1764. "try_number": 0,
  1765. "hostname": "",
  1766. "unixname": getuser(),
  1767. "queue": task.queue,
  1768. "pool": task.pool,
  1769. "pool_slots": task.pool_slots,
  1770. "priority_weight": priority_weight,
  1771. "run_as_user": task.run_as_user,
  1772. "max_tries": task.retries,
  1773. "executor": task.executor,
  1774. "executor_config": task.executor_config,
  1775. "operator": task.task_type,
  1776. "custom_operator_name": getattr(task, "custom_operator_name", None),
  1777. "map_index": map_index,
  1778. "_task_display_property_value": task.task_display_name,
  1779. }
  1780. @reconstructor
  1781. def init_on_load(self) -> None:
  1782. """Initialize the attributes that aren't stored in the DB."""
  1783. self.test_mode = False # can be changed when calling 'run'
  1784. @property
  1785. @deprecated(reason="Use try_number instead.", version="2.10.0", category=RemovedInAirflow3Warning)
  1786. def prev_attempted_tries(self) -> int:
  1787. """
  1788. Calculate the total number of attempted tries, defaulting to 0.
  1789. This used to be necessary because try_number did not always tell the truth.
  1790. :meta private:
  1791. """
  1792. return self.try_number
  1793. @property
  1794. def next_try_number(self) -> int:
  1795. # todo (dstandish): deprecate this property; we don't need a property that is just + 1
  1796. return self.try_number + 1
  1797. @property
  1798. def operator_name(self) -> str | None:
  1799. """@property: use a more friendly display name for the operator, if set."""
  1800. return self.custom_operator_name or self.operator
  1801. @hybrid_property
  1802. def task_display_name(self) -> str:
  1803. return self._task_display_property_value or self.task_id
  1804. @staticmethod
  1805. def _command_as_list(
  1806. ti: TaskInstance | TaskInstancePydantic,
  1807. mark_success: bool = False,
  1808. ignore_all_deps: bool = False,
  1809. ignore_task_deps: bool = False,
  1810. ignore_depends_on_past: bool = False,
  1811. wait_for_past_depends_before_skipping: bool = False,
  1812. ignore_ti_state: bool = False,
  1813. local: bool = False,
  1814. pickle_id: int | None = None,
  1815. raw: bool = False,
  1816. job_id: str | None = None,
  1817. pool: str | None = None,
  1818. cfg_path: str | None = None,
  1819. ) -> list[str]:
  1820. dag: DAG | DagModel | DagModelPydantic | None
  1821. # Use the dag if we have it, else fallback to the ORM dag_model, which might not be loaded
  1822. if hasattr(ti, "task") and getattr(ti.task, "dag", None) is not None:
  1823. if TYPE_CHECKING:
  1824. assert ti.task
  1825. dag = ti.task.dag
  1826. else:
  1827. dag = ti.dag_model
  1828. if dag is None:
  1829. raise ValueError("DagModel is empty")
  1830. should_pass_filepath = not pickle_id and dag
  1831. path: PurePath | None = None
  1832. if should_pass_filepath:
  1833. if dag.is_subdag:
  1834. if TYPE_CHECKING:
  1835. assert dag.parent_dag is not None
  1836. path = dag.parent_dag.relative_fileloc
  1837. else:
  1838. path = dag.relative_fileloc
  1839. if path:
  1840. if not path.is_absolute():
  1841. path = "DAGS_FOLDER" / path
  1842. return TaskInstance.generate_command(
  1843. ti.dag_id,
  1844. ti.task_id,
  1845. run_id=ti.run_id,
  1846. mark_success=mark_success,
  1847. ignore_all_deps=ignore_all_deps,
  1848. ignore_task_deps=ignore_task_deps,
  1849. ignore_depends_on_past=ignore_depends_on_past,
  1850. wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
  1851. ignore_ti_state=ignore_ti_state,
  1852. local=local,
  1853. pickle_id=pickle_id,
  1854. file_path=path,
  1855. raw=raw,
  1856. job_id=job_id,
  1857. pool=pool,
  1858. cfg_path=cfg_path,
  1859. map_index=ti.map_index,
  1860. )
  1861. def command_as_list(
  1862. self,
  1863. mark_success: bool = False,
  1864. ignore_all_deps: bool = False,
  1865. ignore_task_deps: bool = False,
  1866. ignore_depends_on_past: bool = False,
  1867. wait_for_past_depends_before_skipping: bool = False,
  1868. ignore_ti_state: bool = False,
  1869. local: bool = False,
  1870. pickle_id: int | None = None,
  1871. raw: bool = False,
  1872. job_id: str | None = None,
  1873. pool: str | None = None,
  1874. cfg_path: str | None = None,
  1875. ) -> list[str]:
  1876. """
  1877. Return a command that can be executed anywhere where airflow is installed.
  1878. This command is part of the message sent to executors by the orchestrator.
  1879. """
  1880. return TaskInstance._command_as_list(
  1881. ti=self,
  1882. mark_success=mark_success,
  1883. ignore_all_deps=ignore_all_deps,
  1884. ignore_task_deps=ignore_task_deps,
  1885. ignore_depends_on_past=ignore_depends_on_past,
  1886. wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
  1887. ignore_ti_state=ignore_ti_state,
  1888. local=local,
  1889. pickle_id=pickle_id,
  1890. raw=raw,
  1891. job_id=job_id,
  1892. pool=pool,
  1893. cfg_path=cfg_path,
  1894. )
  1895. @staticmethod
  1896. def generate_command(
  1897. dag_id: str,
  1898. task_id: str,
  1899. run_id: str,
  1900. mark_success: bool = False,
  1901. ignore_all_deps: bool = False,
  1902. ignore_depends_on_past: bool = False,
  1903. wait_for_past_depends_before_skipping: bool = False,
  1904. ignore_task_deps: bool = False,
  1905. ignore_ti_state: bool = False,
  1906. local: bool = False,
  1907. pickle_id: int | None = None,
  1908. file_path: PurePath | str | None = None,
  1909. raw: bool = False,
  1910. job_id: str | None = None,
  1911. pool: str | None = None,
  1912. cfg_path: str | None = None,
  1913. map_index: int = -1,
  1914. ) -> list[str]:
  1915. """
  1916. Generate the shell command required to execute this task instance.
  1917. :param dag_id: DAG ID
  1918. :param task_id: Task ID
  1919. :param run_id: The run_id of this task's DagRun
  1920. :param mark_success: Whether to mark the task as successful
  1921. :param ignore_all_deps: Ignore all ignorable dependencies.
  1922. Overrides the other ignore_* parameters.
  1923. :param ignore_depends_on_past: Ignore depends_on_past parameter of DAGs
  1924. (e.g. for Backfills)
  1925. :param wait_for_past_depends_before_skipping: Wait for past depends before marking the ti as skipped
  1926. :param ignore_task_deps: Ignore task-specific dependencies such as depends_on_past
  1927. and trigger rule
  1928. :param ignore_ti_state: Ignore the task instance's previous failure/success
  1929. :param local: Whether to run the task locally
  1930. :param pickle_id: If the DAG was serialized to the DB, the ID
  1931. associated with the pickled DAG
  1932. :param file_path: path to the file containing the DAG definition
  1933. :param raw: raw mode (needs more details)
  1934. :param job_id: job ID (needs more details)
  1935. :param pool: the Airflow pool that the task should run in
  1936. :param cfg_path: the Path to the configuration file
  1937. :return: shell command that can be used to run the task instance
  1938. """
  1939. cmd = ["airflow", "tasks", "run", dag_id, task_id, run_id]
  1940. if mark_success:
  1941. cmd.extend(["--mark-success"])
  1942. if pickle_id:
  1943. cmd.extend(["--pickle", str(pickle_id)])
  1944. if job_id:
  1945. cmd.extend(["--job-id", str(job_id)])
  1946. if ignore_all_deps:
  1947. cmd.extend(["--ignore-all-dependencies"])
  1948. if ignore_task_deps:
  1949. cmd.extend(["--ignore-dependencies"])
  1950. if ignore_depends_on_past:
  1951. cmd.extend(["--depends-on-past", "ignore"])
  1952. elif wait_for_past_depends_before_skipping:
  1953. cmd.extend(["--depends-on-past", "wait"])
  1954. if ignore_ti_state:
  1955. cmd.extend(["--force"])
  1956. if local:
  1957. cmd.extend(["--local"])
  1958. if pool:
  1959. cmd.extend(["--pool", pool])
  1960. if raw:
  1961. cmd.extend(["--raw"])
  1962. if file_path:
  1963. cmd.extend(["--subdir", os.fspath(file_path)])
  1964. if cfg_path:
  1965. cmd.extend(["--cfg-path", cfg_path])
  1966. if map_index != -1:
  1967. cmd.extend(["--map-index", str(map_index)])
  1968. return cmd
  1969. @property
  1970. def log_url(self) -> str:
  1971. """Log URL for TaskInstance."""
  1972. run_id = quote(self.run_id)
  1973. base_date = quote(self.execution_date.strftime("%Y-%m-%dT%H:%M:%S%z"))
  1974. base_url = conf.get_mandatory_value("webserver", "BASE_URL")
  1975. map_index = f"&map_index={self.map_index}" if self.map_index >= 0 else ""
  1976. return (
  1977. f"{base_url}"
  1978. f"/dags"
  1979. f"/{self.dag_id}"
  1980. f"/grid"
  1981. f"?dag_run_id={run_id}"
  1982. f"&task_id={self.task_id}"
  1983. f"{map_index}"
  1984. f"&base_date={base_date}"
  1985. "&tab=logs"
  1986. )
  1987. @property
  1988. def mark_success_url(self) -> str:
  1989. """URL to mark TI success."""
  1990. base_url = conf.get_mandatory_value("webserver", "BASE_URL")
  1991. return (
  1992. f"{base_url}"
  1993. "/confirm"
  1994. f"?task_id={self.task_id}"
  1995. f"&dag_id={self.dag_id}"
  1996. f"&dag_run_id={quote(self.run_id)}"
  1997. "&upstream=false"
  1998. "&downstream=false"
  1999. "&state=success"
  2000. )
  2001. @provide_session
  2002. def current_state(self, session: Session = NEW_SESSION) -> str:
  2003. """
  2004. Get the very latest state from the database.
  2005. If a session is passed, we use and looking up the state becomes part of the session,
  2006. otherwise a new session is used.
  2007. sqlalchemy.inspect is used here to get the primary keys ensuring that if they change
  2008. it will not regress
  2009. :param session: SQLAlchemy ORM Session
  2010. """
  2011. filters = (col == getattr(self, col.name) for col in inspect(TaskInstance).primary_key)
  2012. return session.query(TaskInstance.state).filter(*filters).scalar()
  2013. @provide_session
  2014. def error(self, session: Session = NEW_SESSION) -> None:
  2015. """
  2016. Force the task instance's state to FAILED in the database.
  2017. :param session: SQLAlchemy ORM Session
  2018. """
  2019. self.log.error("Recording the task instance as FAILED")
  2020. self.state = TaskInstanceState.FAILED
  2021. session.merge(self)
  2022. session.commit()
  2023. @classmethod
  2024. @internal_api_call
  2025. @provide_session
  2026. def get_task_instance(
  2027. cls,
  2028. dag_id: str,
  2029. run_id: str,
  2030. task_id: str,
  2031. map_index: int,
  2032. lock_for_update: bool = False,
  2033. session: Session = NEW_SESSION,
  2034. ) -> TaskInstance | TaskInstancePydantic | None:
  2035. query = (
  2036. session.query(TaskInstance)
  2037. .options(lazyload(TaskInstance.dag_run)) # lazy load dag run to avoid locking it
  2038. .filter_by(
  2039. dag_id=dag_id,
  2040. run_id=run_id,
  2041. task_id=task_id,
  2042. map_index=map_index,
  2043. )
  2044. )
  2045. if lock_for_update:
  2046. for attempt in run_with_db_retries(logger=cls.logger()):
  2047. with attempt:
  2048. return query.with_for_update().one_or_none()
  2049. else:
  2050. return query.one_or_none()
  2051. return None
  2052. @provide_session
  2053. def refresh_from_db(self, session: Session = NEW_SESSION, lock_for_update: bool = False) -> None:
  2054. """
  2055. Refresh the task instance from the database based on the primary key.
  2056. :param session: SQLAlchemy ORM Session
  2057. :param lock_for_update: if True, indicates that the database should
  2058. lock the TaskInstance (issuing a FOR UPDATE clause) until the
  2059. session is committed.
  2060. """
  2061. _refresh_from_db(task_instance=self, session=session, lock_for_update=lock_for_update)
  2062. def refresh_from_task(self, task: Operator, pool_override: str | None = None) -> None:
  2063. """
  2064. Copy common attributes from the given task.
  2065. :param task: The task object to copy from
  2066. :param pool_override: Use the pool_override instead of task's pool
  2067. """
  2068. _refresh_from_task(task_instance=self, task=task, pool_override=pool_override)
  2069. @staticmethod
  2070. @internal_api_call
  2071. @provide_session
  2072. def _clear_xcom_data(ti: TaskInstance | TaskInstancePydantic, session: Session = NEW_SESSION) -> None:
  2073. """
  2074. Clear all XCom data from the database for the task instance.
  2075. If the task is unmapped, all XComs matching this task ID in the same DAG
  2076. run are removed. If the task is mapped, only the one with matching map
  2077. index is removed.
  2078. :param ti: The TI for which we need to clear xcoms.
  2079. :param session: SQLAlchemy ORM Session
  2080. """
  2081. ti.log.debug("Clearing XCom data")
  2082. if ti.map_index < 0:
  2083. map_index: int | None = None
  2084. else:
  2085. map_index = ti.map_index
  2086. XCom.clear(
  2087. dag_id=ti.dag_id,
  2088. task_id=ti.task_id,
  2089. run_id=ti.run_id,
  2090. map_index=map_index,
  2091. session=session,
  2092. )
  2093. @provide_session
  2094. def clear_xcom_data(self, session: Session = NEW_SESSION):
  2095. self._clear_xcom_data(ti=self, session=session)
  2096. @property
  2097. def key(self) -> TaskInstanceKey:
  2098. """Returns a tuple that identifies the task instance uniquely."""
  2099. return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, self.try_number, self.map_index)
  2100. @staticmethod
  2101. @internal_api_call
  2102. def _set_state(ti: TaskInstance | TaskInstancePydantic, state, session: Session) -> bool:
  2103. if not isinstance(ti, TaskInstance):
  2104. ti = session.scalars(
  2105. select(TaskInstance).where(
  2106. TaskInstance.task_id == ti.task_id,
  2107. TaskInstance.dag_id == ti.dag_id,
  2108. TaskInstance.run_id == ti.run_id,
  2109. TaskInstance.map_index == ti.map_index,
  2110. )
  2111. ).one()
  2112. if ti.state == state:
  2113. return False
  2114. current_time = timezone.utcnow()
  2115. ti.log.debug("Setting task state for %s to %s", ti, state)
  2116. ti.state = state
  2117. ti.start_date = ti.start_date or current_time
  2118. if ti.state in State.finished or ti.state == TaskInstanceState.UP_FOR_RETRY:
  2119. ti.end_date = ti.end_date or current_time
  2120. ti.duration = (ti.end_date - ti.start_date).total_seconds()
  2121. session.merge(ti)
  2122. return True
  2123. @provide_session
  2124. def set_state(self, state: str | None, session: Session = NEW_SESSION) -> bool:
  2125. """
  2126. Set TaskInstance state.
  2127. :param state: State to set for the TI
  2128. :param session: SQLAlchemy ORM Session
  2129. :return: Was the state changed
  2130. """
  2131. return self._set_state(ti=self, state=state, session=session)
  2132. @property
  2133. def is_premature(self) -> bool:
  2134. """Returns whether a task is in UP_FOR_RETRY state and its retry interval has elapsed."""
  2135. # is the task still in the retry waiting period?
  2136. return self.state == TaskInstanceState.UP_FOR_RETRY and not self.ready_for_retry()
  2137. @provide_session
  2138. def are_dependents_done(self, session: Session = NEW_SESSION) -> bool:
  2139. """
  2140. Check whether the immediate dependents of this task instance have succeeded or have been skipped.
  2141. This is meant to be used by wait_for_downstream.
  2142. This is useful when you do not want to start processing the next
  2143. schedule of a task until the dependents are done. For instance,
  2144. if the task DROPs and recreates a table.
  2145. :param session: SQLAlchemy ORM Session
  2146. """
  2147. task = self.task
  2148. if TYPE_CHECKING:
  2149. assert task
  2150. if not task.downstream_task_ids:
  2151. return True
  2152. ti = session.query(func.count(TaskInstance.task_id)).filter(
  2153. TaskInstance.dag_id == self.dag_id,
  2154. TaskInstance.task_id.in_(task.downstream_task_ids),
  2155. TaskInstance.run_id == self.run_id,
  2156. TaskInstance.state.in_((TaskInstanceState.SKIPPED, TaskInstanceState.SUCCESS)),
  2157. )
  2158. count = ti[0][0]
  2159. return count == len(task.downstream_task_ids)
  2160. @provide_session
  2161. def get_previous_dagrun(
  2162. self,
  2163. state: DagRunState | None = None,
  2164. session: Session | None = None,
  2165. ) -> DagRun | None:
  2166. """
  2167. Return the DagRun that ran before this task instance's DagRun.
  2168. :param state: If passed, it only take into account instances of a specific state.
  2169. :param session: SQLAlchemy ORM Session.
  2170. """
  2171. return _get_previous_dagrun(task_instance=self, state=state, session=session)
  2172. @provide_session
  2173. def get_previous_ti(
  2174. self,
  2175. state: DagRunState | None = None,
  2176. session: Session = NEW_SESSION,
  2177. ) -> TaskInstance | TaskInstancePydantic | None:
  2178. """
  2179. Return the task instance for the task that ran before this task instance.
  2180. :param session: SQLAlchemy ORM Session
  2181. :param state: If passed, it only take into account instances of a specific state.
  2182. """
  2183. return _get_previous_ti(task_instance=self, state=state, session=session)
  2184. @property
  2185. def previous_ti(self) -> TaskInstance | TaskInstancePydantic | None:
  2186. """
  2187. This attribute is deprecated.
  2188. Please use :class:`airflow.models.taskinstance.TaskInstance.get_previous_ti`.
  2189. """
  2190. warnings.warn(
  2191. """
  2192. This attribute is deprecated.
  2193. Please use `airflow.models.taskinstance.TaskInstance.get_previous_ti` method.
  2194. """,
  2195. RemovedInAirflow3Warning,
  2196. stacklevel=2,
  2197. )
  2198. return self.get_previous_ti()
  2199. @property
  2200. def previous_ti_success(self) -> TaskInstance | TaskInstancePydantic | None:
  2201. """
  2202. This attribute is deprecated.
  2203. Please use :class:`airflow.models.taskinstance.TaskInstance.get_previous_ti`.
  2204. """
  2205. warnings.warn(
  2206. """
  2207. This attribute is deprecated.
  2208. Please use `airflow.models.taskinstance.TaskInstance.get_previous_ti` method.
  2209. """,
  2210. RemovedInAirflow3Warning,
  2211. stacklevel=2,
  2212. )
  2213. return self.get_previous_ti(state=DagRunState.SUCCESS)
  2214. @provide_session
  2215. def get_previous_execution_date(
  2216. self,
  2217. state: DagRunState | None = None,
  2218. session: Session = NEW_SESSION,
  2219. ) -> pendulum.DateTime | None:
  2220. """
  2221. Return the execution date from property previous_ti_success.
  2222. :param state: If passed, it only take into account instances of a specific state.
  2223. :param session: SQLAlchemy ORM Session
  2224. """
  2225. return _get_previous_execution_date(task_instance=self, state=state, session=session)
  2226. @provide_session
  2227. def get_previous_start_date(
  2228. self, state: DagRunState | None = None, session: Session = NEW_SESSION
  2229. ) -> pendulum.DateTime | None:
  2230. """
  2231. Return the start date from property previous_ti_success.
  2232. :param state: If passed, it only take into account instances of a specific state.
  2233. :param session: SQLAlchemy ORM Session
  2234. """
  2235. return _get_previous_start_date(task_instance=self, state=state, session=session)
  2236. @property
  2237. def previous_start_date_success(self) -> pendulum.DateTime | None:
  2238. """
  2239. This attribute is deprecated.
  2240. Please use :class:`airflow.models.taskinstance.TaskInstance.get_previous_start_date`.
  2241. """
  2242. warnings.warn(
  2243. """
  2244. This attribute is deprecated.
  2245. Please use `airflow.models.taskinstance.TaskInstance.get_previous_start_date` method.
  2246. """,
  2247. RemovedInAirflow3Warning,
  2248. stacklevel=2,
  2249. )
  2250. return self.get_previous_start_date(state=DagRunState.SUCCESS)
  2251. @provide_session
  2252. def are_dependencies_met(
  2253. self, dep_context: DepContext | None = None, session: Session = NEW_SESSION, verbose: bool = False
  2254. ) -> bool:
  2255. """
  2256. Are all conditions met for this task instance to be run given the context for the dependencies.
  2257. (e.g. a task instance being force run from the UI will ignore some dependencies).
  2258. :param dep_context: The execution context that determines the dependencies that should be evaluated.
  2259. :param session: database session
  2260. :param verbose: whether log details on failed dependencies on info or debug log level
  2261. """
  2262. dep_context = dep_context or DepContext()
  2263. failed = False
  2264. verbose_aware_logger = self.log.info if verbose else self.log.debug
  2265. for dep_status in self.get_failed_dep_statuses(dep_context=dep_context, session=session):
  2266. failed = True
  2267. verbose_aware_logger(
  2268. "Dependencies not met for %s, dependency '%s' FAILED: %s",
  2269. self,
  2270. dep_status.dep_name,
  2271. dep_status.reason,
  2272. )
  2273. if failed:
  2274. return False
  2275. verbose_aware_logger("Dependencies all met for dep_context=%s ti=%s", dep_context.description, self)
  2276. return True
  2277. @provide_session
  2278. def get_failed_dep_statuses(self, dep_context: DepContext | None = None, session: Session = NEW_SESSION):
  2279. """Get failed Dependencies."""
  2280. if TYPE_CHECKING:
  2281. assert self.task
  2282. dep_context = dep_context or DepContext()
  2283. for dep in dep_context.deps | self.task.deps:
  2284. for dep_status in dep.get_dep_statuses(self, session, dep_context):
  2285. self.log.debug(
  2286. "%s dependency '%s' PASSED: %s, %s",
  2287. self,
  2288. dep_status.dep_name,
  2289. dep_status.passed,
  2290. dep_status.reason,
  2291. )
  2292. if not dep_status.passed:
  2293. yield dep_status
  2294. def __repr__(self) -> str:
  2295. prefix = f"<TaskInstance: {self.dag_id}.{self.task_id} {self.run_id} "
  2296. if self.map_index != -1:
  2297. prefix += f"map_index={self.map_index} "
  2298. return prefix + f"[{self.state}]>"
  2299. def next_retry_datetime(self):
  2300. """
  2301. Get datetime of the next retry if the task instance fails.
  2302. For exponential backoff, retry_delay is used as base and will be converted to seconds.
  2303. """
  2304. from airflow.models.abstractoperator import MAX_RETRY_DELAY
  2305. delay = self.task.retry_delay
  2306. if self.task.retry_exponential_backoff:
  2307. # If the min_backoff calculation is below 1, it will be converted to 0 via int. Thus,
  2308. # we must round up prior to converting to an int, otherwise a divide by zero error
  2309. # will occur in the modded_hash calculation.
  2310. # this probably gives unexpected results if a task instance has previously been cleared,
  2311. # because try_number can increase without bound
  2312. min_backoff = math.ceil(delay.total_seconds() * (2 ** (self.try_number - 1)))
  2313. # In the case when delay.total_seconds() is 0, min_backoff will not be rounded up to 1.
  2314. # To address this, we impose a lower bound of 1 on min_backoff. This effectively makes
  2315. # the ceiling function unnecessary, but the ceiling function was retained to avoid
  2316. # introducing a breaking change.
  2317. if min_backoff < 1:
  2318. min_backoff = 1
  2319. # deterministic per task instance
  2320. ti_hash = int(
  2321. hashlib.sha1(
  2322. f"{self.dag_id}#{self.task_id}#{self.execution_date}#{self.try_number}".encode()
  2323. ).hexdigest(),
  2324. 16,
  2325. )
  2326. # between 1 and 1.0 * delay * (2^retry_number)
  2327. modded_hash = min_backoff + ti_hash % min_backoff
  2328. # timedelta has a maximum representable value. The exponentiation
  2329. # here means this value can be exceeded after a certain number
  2330. # of tries (around 50 if the initial delay is 1s, even fewer if
  2331. # the delay is larger). Cap the value here before creating a
  2332. # timedelta object so the operation doesn't fail with "OverflowError".
  2333. delay_backoff_in_seconds = min(modded_hash, MAX_RETRY_DELAY)
  2334. delay = timedelta(seconds=delay_backoff_in_seconds)
  2335. if self.task.max_retry_delay:
  2336. delay = min(self.task.max_retry_delay, delay)
  2337. return self.end_date + delay
  2338. def ready_for_retry(self) -> bool:
  2339. """Check on whether the task instance is in the right state and timeframe to be retried."""
  2340. return self.state == TaskInstanceState.UP_FOR_RETRY and self.next_retry_datetime() < timezone.utcnow()
  2341. @staticmethod
  2342. @internal_api_call
  2343. def _get_dagrun(dag_id, run_id, session) -> DagRun:
  2344. from airflow.models.dagrun import DagRun # Avoid circular import
  2345. dr = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == run_id).one()
  2346. return dr
  2347. @provide_session
  2348. def get_dagrun(self, session: Session = NEW_SESSION) -> DagRun:
  2349. """
  2350. Return the DagRun for this TaskInstance.
  2351. :param session: SQLAlchemy ORM Session
  2352. :return: DagRun
  2353. """
  2354. info = inspect(self)
  2355. if info.attrs.dag_run.loaded_value is not NO_VALUE:
  2356. if getattr(self, "task", None) is not None:
  2357. if TYPE_CHECKING:
  2358. assert self.task
  2359. self.dag_run.dag = self.task.dag
  2360. return self.dag_run
  2361. dr = self._get_dagrun(self.dag_id, self.run_id, session)
  2362. if getattr(self, "task", None) is not None:
  2363. if TYPE_CHECKING:
  2364. assert self.task
  2365. dr.dag = self.task.dag
  2366. # Record it in the instance for next time. This means that `self.execution_date` will work correctly
  2367. set_committed_value(self, "dag_run", dr)
  2368. return dr
  2369. @classmethod
  2370. @provide_session
  2371. def ensure_dag(
  2372. cls, task_instance: TaskInstance | TaskInstancePydantic, session: Session = NEW_SESSION
  2373. ) -> DAG:
  2374. """Ensure that task has a dag object associated, might have been removed by serialization."""
  2375. if TYPE_CHECKING:
  2376. assert task_instance.task
  2377. if task_instance.task.dag is None or task_instance.task.dag.__class__ is AttributeRemoved:
  2378. task_instance.task.dag = DagBag(read_dags_from_db=True).get_dag(
  2379. dag_id=task_instance.dag_id, session=session
  2380. )
  2381. if TYPE_CHECKING:
  2382. assert task_instance.task.dag
  2383. return task_instance.task.dag
  2384. @classmethod
  2385. @internal_api_call
  2386. @provide_session
  2387. def _check_and_change_state_before_execution(
  2388. cls,
  2389. task_instance: TaskInstance | TaskInstancePydantic,
  2390. verbose: bool = True,
  2391. ignore_all_deps: bool = False,
  2392. ignore_depends_on_past: bool = False,
  2393. wait_for_past_depends_before_skipping: bool = False,
  2394. ignore_task_deps: bool = False,
  2395. ignore_ti_state: bool = False,
  2396. mark_success: bool = False,
  2397. test_mode: bool = False,
  2398. hostname: str = "",
  2399. job_id: str | None = None,
  2400. pool: str | None = None,
  2401. external_executor_id: str | None = None,
  2402. session: Session = NEW_SESSION,
  2403. ) -> bool:
  2404. """
  2405. Check dependencies and then sets state to RUNNING if they are met.
  2406. Returns True if and only if state is set to RUNNING, which implies that task should be
  2407. executed, in preparation for _run_raw_task.
  2408. :param verbose: whether to turn on more verbose logging
  2409. :param ignore_all_deps: Ignore all of the non-critical dependencies, just runs
  2410. :param ignore_depends_on_past: Ignore depends_on_past DAG attribute
  2411. :param wait_for_past_depends_before_skipping: Wait for past depends before mark the ti as skipped
  2412. :param ignore_task_deps: Don't check the dependencies of this TaskInstance's task
  2413. :param ignore_ti_state: Disregards previous task instance state
  2414. :param mark_success: Don't run the task, mark its state as success
  2415. :param test_mode: Doesn't record success or failure in the DB
  2416. :param hostname: The hostname of the worker running the task instance.
  2417. :param job_id: Job (BackfillJob / LocalTaskJob / SchedulerJob) ID
  2418. :param pool: specifies the pool to use to run the task instance
  2419. :param external_executor_id: The identifier of the celery executor
  2420. :param session: SQLAlchemy ORM Session
  2421. :return: whether the state was changed to running or not
  2422. """
  2423. if TYPE_CHECKING:
  2424. assert task_instance.task
  2425. if isinstance(task_instance, TaskInstance):
  2426. ti: TaskInstance = task_instance
  2427. else: # isinstance(task_instance, TaskInstancePydantic)
  2428. filters = (col == getattr(task_instance, col.name) for col in inspect(TaskInstance).primary_key)
  2429. ti = session.query(TaskInstance).filter(*filters).scalar()
  2430. dag = DagBag(read_dags_from_db=True).get_dag(task_instance.dag_id, session=session)
  2431. task_instance.task = dag.task_dict[ti.task_id]
  2432. ti.task = task_instance.task
  2433. task = task_instance.task
  2434. if TYPE_CHECKING:
  2435. assert task
  2436. ti.refresh_from_task(task, pool_override=pool)
  2437. ti.test_mode = test_mode
  2438. ti.refresh_from_db(session=session, lock_for_update=True)
  2439. ti.job_id = job_id
  2440. ti.hostname = hostname
  2441. ti.pid = None
  2442. if not ignore_all_deps and not ignore_ti_state and ti.state == TaskInstanceState.SUCCESS:
  2443. Stats.incr("previously_succeeded", tags=ti.stats_tags)
  2444. if not mark_success:
  2445. # Firstly find non-runnable and non-requeueable tis.
  2446. # Since mark_success is not set, we do nothing.
  2447. non_requeueable_dep_context = DepContext(
  2448. deps=RUNNING_DEPS - REQUEUEABLE_DEPS,
  2449. ignore_all_deps=ignore_all_deps,
  2450. ignore_ti_state=ignore_ti_state,
  2451. ignore_depends_on_past=ignore_depends_on_past,
  2452. wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
  2453. ignore_task_deps=ignore_task_deps,
  2454. description="non-requeueable deps",
  2455. )
  2456. if not ti.are_dependencies_met(
  2457. dep_context=non_requeueable_dep_context, session=session, verbose=True
  2458. ):
  2459. session.commit()
  2460. return False
  2461. # For reporting purposes, we report based on 1-indexed,
  2462. # not 0-indexed lists (i.e. Attempt 1 instead of
  2463. # Attempt 0 for the first attempt).
  2464. # Set the task start date. In case it was re-scheduled use the initial
  2465. # start date that is recorded in task_reschedule table
  2466. # If the task continues after being deferred (next_method is set), use the original start_date
  2467. ti.start_date = ti.start_date if ti.next_method else timezone.utcnow()
  2468. if ti.state == TaskInstanceState.UP_FOR_RESCHEDULE:
  2469. tr_start_date = session.scalar(
  2470. TR.stmt_for_task_instance(ti, descending=False).with_only_columns(TR.start_date).limit(1)
  2471. )
  2472. if tr_start_date:
  2473. ti.start_date = tr_start_date
  2474. # Secondly we find non-runnable but requeueable tis. We reset its state.
  2475. # This is because we might have hit concurrency limits,
  2476. # e.g. because of backfilling.
  2477. dep_context = DepContext(
  2478. deps=REQUEUEABLE_DEPS,
  2479. ignore_all_deps=ignore_all_deps,
  2480. ignore_depends_on_past=ignore_depends_on_past,
  2481. wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
  2482. ignore_task_deps=ignore_task_deps,
  2483. ignore_ti_state=ignore_ti_state,
  2484. description="requeueable deps",
  2485. )
  2486. if not ti.are_dependencies_met(dep_context=dep_context, session=session, verbose=True):
  2487. ti.state = None
  2488. cls.logger().warning(
  2489. "Rescheduling due to concurrency limits reached "
  2490. "at task runtime. Attempt %s of "
  2491. "%s. State set to NONE.",
  2492. ti.try_number,
  2493. ti.max_tries + 1,
  2494. )
  2495. ti.queued_dttm = timezone.utcnow()
  2496. session.merge(ti)
  2497. session.commit()
  2498. return False
  2499. if ti.next_kwargs is not None:
  2500. cls.logger().info("Resuming after deferral")
  2501. else:
  2502. cls.logger().info("Starting attempt %s of %s", ti.try_number, ti.max_tries + 1)
  2503. if not test_mode:
  2504. session.add(Log(TaskInstanceState.RUNNING.value, ti))
  2505. ti.state = TaskInstanceState.RUNNING
  2506. ti.emit_state_change_metric(TaskInstanceState.RUNNING)
  2507. if external_executor_id:
  2508. ti.external_executor_id = external_executor_id
  2509. ti.end_date = None
  2510. if not test_mode:
  2511. session.merge(ti).task = task
  2512. session.commit()
  2513. # Closing all pooled connections to prevent
  2514. # "max number of connections reached"
  2515. settings.engine.dispose() # type: ignore
  2516. if verbose:
  2517. if mark_success:
  2518. cls.logger().info("Marking success for %s on %s", ti.task, ti.execution_date)
  2519. else:
  2520. cls.logger().info("Executing %s on %s", ti.task, ti.execution_date)
  2521. return True
  2522. @provide_session
  2523. def check_and_change_state_before_execution(
  2524. self,
  2525. verbose: bool = True,
  2526. ignore_all_deps: bool = False,
  2527. ignore_depends_on_past: bool = False,
  2528. wait_for_past_depends_before_skipping: bool = False,
  2529. ignore_task_deps: bool = False,
  2530. ignore_ti_state: bool = False,
  2531. mark_success: bool = False,
  2532. test_mode: bool = False,
  2533. job_id: str | None = None,
  2534. pool: str | None = None,
  2535. external_executor_id: str | None = None,
  2536. session: Session = NEW_SESSION,
  2537. ) -> bool:
  2538. return TaskInstance._check_and_change_state_before_execution(
  2539. task_instance=self,
  2540. verbose=verbose,
  2541. ignore_all_deps=ignore_all_deps,
  2542. ignore_depends_on_past=ignore_depends_on_past,
  2543. wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
  2544. ignore_task_deps=ignore_task_deps,
  2545. ignore_ti_state=ignore_ti_state,
  2546. mark_success=mark_success,
  2547. test_mode=test_mode,
  2548. hostname=get_hostname(),
  2549. job_id=job_id,
  2550. pool=pool,
  2551. external_executor_id=external_executor_id,
  2552. session=session,
  2553. )
  2554. def emit_state_change_metric(self, new_state: TaskInstanceState) -> None:
  2555. """
  2556. Send a time metric representing how much time a given state transition took.
  2557. The previous state and metric name is deduced from the state the task was put in.
  2558. :param new_state: The state that has just been set for this task.
  2559. We do not use `self.state`, because sometimes the state is updated directly in the DB and not in
  2560. the local TaskInstance object.
  2561. Supported states: QUEUED and RUNNING
  2562. """
  2563. if self.end_date:
  2564. # if the task has an end date, it means that this is not its first round.
  2565. # we send the state transition time metric only on the first try, otherwise it gets more complex.
  2566. return
  2567. # switch on state and deduce which metric to send
  2568. if new_state == TaskInstanceState.RUNNING:
  2569. metric_name = "queued_duration"
  2570. if self.queued_dttm is None:
  2571. # this should not really happen except in tests or rare cases,
  2572. # but we don't want to create errors just for a metric, so we just skip it
  2573. self.log.warning(
  2574. "cannot record %s for task %s because previous state change time has not been saved",
  2575. metric_name,
  2576. self.task_id,
  2577. )
  2578. return
  2579. timing = timezone.utcnow() - self.queued_dttm
  2580. elif new_state == TaskInstanceState.QUEUED:
  2581. metric_name = "scheduled_duration"
  2582. if self.start_date is None:
  2583. # This check does not work correctly before fields like `scheduled_dttm` are implemented.
  2584. # TODO: Change the level to WARNING once it's viable.
  2585. # see #30612 #34493 and #34771 for more details
  2586. self.log.debug(
  2587. "cannot record %s for task %s because previous state change time has not been saved",
  2588. metric_name,
  2589. self.task_id,
  2590. )
  2591. return
  2592. timing = timezone.utcnow() - self.start_date
  2593. else:
  2594. raise NotImplementedError("no metric emission setup for state %s", new_state)
  2595. # send metric twice, once (legacy) with tags in the name and once with tags as tags
  2596. Stats.timing(f"dag.{self.dag_id}.{self.task_id}.{metric_name}", timing)
  2597. Stats.timing(f"task.{metric_name}", timing, tags={"task_id": self.task_id, "dag_id": self.dag_id})
  2598. def clear_next_method_args(self) -> None:
  2599. """Ensure we unset next_method and next_kwargs to ensure that any retries don't reuse them."""
  2600. _clear_next_method_args(task_instance=self)
  2601. @provide_session
  2602. @Sentry.enrich_errors
  2603. def _run_raw_task(
  2604. self,
  2605. mark_success: bool = False,
  2606. test_mode: bool = False,
  2607. job_id: str | None = None,
  2608. pool: str | None = None,
  2609. raise_on_defer: bool = False,
  2610. session: Session = NEW_SESSION,
  2611. ) -> TaskReturnCode | None:
  2612. """
  2613. Run a task, update the state upon completion, and run any appropriate callbacks.
  2614. Immediately runs the task (without checking or changing db state
  2615. before execution) and then sets the appropriate final state after
  2616. completion and runs any post-execute callbacks. Meant to be called
  2617. only after another function changes the state to running.
  2618. :param mark_success: Don't run the task, mark its state as success
  2619. :param test_mode: Doesn't record success or failure in the DB
  2620. :param pool: specifies the pool to use to run the task instance
  2621. :param session: SQLAlchemy ORM Session
  2622. """
  2623. if TYPE_CHECKING:
  2624. assert self.task
  2625. return _run_raw_task(
  2626. ti=self,
  2627. mark_success=mark_success,
  2628. test_mode=test_mode,
  2629. job_id=job_id,
  2630. pool=pool,
  2631. raise_on_defer=raise_on_defer,
  2632. session=session,
  2633. )
  2634. def _register_dataset_changes(self, *, events: OutletEventAccessors, session: Session) -> None:
  2635. if TYPE_CHECKING:
  2636. assert self.task
  2637. # One task only triggers one dataset event for each dataset with the same extra.
  2638. # This tuple[dataset uri, extra] to sets alias names mapping is used to find whether
  2639. # there're datasets with same uri but different extra that we need to emit more than one dataset events.
  2640. dataset_tuple_to_alias_names_mapping: dict[tuple[str, frozenset], set[str]] = defaultdict(set)
  2641. for obj in self.task.outlets or []:
  2642. self.log.debug("outlet obj %s", obj)
  2643. # Lineage can have other types of objects besides datasets
  2644. if isinstance(obj, Dataset):
  2645. dataset_manager.register_dataset_change(
  2646. task_instance=self,
  2647. dataset=obj,
  2648. extra=events[obj].extra,
  2649. session=session,
  2650. )
  2651. elif isinstance(obj, DatasetAlias):
  2652. for dataset_alias_event in events[obj].dataset_alias_events:
  2653. dataset_alias_name = dataset_alias_event["source_alias_name"]
  2654. dataset_uri = dataset_alias_event["dest_dataset_uri"]
  2655. extra = dataset_alias_event["extra"]
  2656. frozen_extra = frozenset(extra.items())
  2657. dataset_tuple_to_alias_names_mapping[(dataset_uri, frozen_extra)].add(dataset_alias_name)
  2658. dataset_objs_cache: dict[str, DatasetModel] = {}
  2659. for (uri, extra_items), alias_names in dataset_tuple_to_alias_names_mapping.items():
  2660. if uri not in dataset_objs_cache:
  2661. dataset_obj = session.scalar(select(DatasetModel).where(DatasetModel.uri == uri).limit(1))
  2662. dataset_objs_cache[uri] = dataset_obj
  2663. else:
  2664. dataset_obj = dataset_objs_cache[uri]
  2665. if not dataset_obj:
  2666. dataset_obj = DatasetModel(uri=uri)
  2667. dataset_manager.create_datasets(dataset_models=[dataset_obj], session=session)
  2668. self.log.warning("Created a new %r as it did not exist.", dataset_obj)
  2669. dataset_objs_cache[uri] = dataset_obj
  2670. for alias in alias_names:
  2671. alias_obj = session.scalar(
  2672. select(DatasetAliasModel).where(DatasetAliasModel.name == alias).limit(1)
  2673. )
  2674. dataset_obj.aliases.append(alias_obj)
  2675. extra = {k: v for k, v in extra_items}
  2676. self.log.info(
  2677. 'Creating event for %r through aliases "%s"',
  2678. dataset_obj,
  2679. ", ".join(alias_names),
  2680. )
  2681. dataset_manager.register_dataset_change(
  2682. task_instance=self,
  2683. dataset=dataset_obj,
  2684. extra=extra,
  2685. session=session,
  2686. source_alias_names=alias_names,
  2687. )
  2688. def _execute_task_with_callbacks(self, context: Context, test_mode: bool = False, *, session: Session):
  2689. """Prepare Task for Execution."""
  2690. if TYPE_CHECKING:
  2691. assert self.task
  2692. parent_pid = os.getpid()
  2693. def signal_handler(signum, frame):
  2694. pid = os.getpid()
  2695. # If a task forks during execution (from DAG code) for whatever
  2696. # reason, we want to make sure that we react to the signal only in
  2697. # the process that we've spawned ourselves (referred to here as the
  2698. # parent process).
  2699. if pid != parent_pid:
  2700. os._exit(1)
  2701. return
  2702. self.log.error("Received SIGTERM. Terminating subprocesses.")
  2703. self.log.error("Stacktrace: \n%s", "".join(traceback.format_stack()))
  2704. self.task.on_kill()
  2705. raise AirflowTaskTerminated("Task received SIGTERM signal")
  2706. signal.signal(signal.SIGTERM, signal_handler)
  2707. # Don't clear Xcom until the task is certain to execute, and check if we are resuming from deferral.
  2708. if not self.next_method:
  2709. self.clear_xcom_data()
  2710. with Stats.timer(f"dag.{self.task.dag_id}.{self.task.task_id}.duration"), Stats.timer(
  2711. "task.duration", tags=self.stats_tags
  2712. ):
  2713. # Set the validated/merged params on the task object.
  2714. self.task.params = context["params"]
  2715. with set_current_context(context):
  2716. dag = self.task.get_dag()
  2717. if dag is not None:
  2718. jinja_env = dag.get_template_env()
  2719. else:
  2720. jinja_env = None
  2721. task_orig = self.render_templates(context=context, jinja_env=jinja_env)
  2722. # The task is never MappedOperator at this point.
  2723. if TYPE_CHECKING:
  2724. assert isinstance(self.task, BaseOperator)
  2725. if not test_mode:
  2726. rendered_fields = get_serialized_template_fields(task=self.task)
  2727. _update_rtif(ti=self, rendered_fields=rendered_fields)
  2728. # Export context to make it available for operators to use.
  2729. airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True)
  2730. os.environ.update(airflow_context_vars)
  2731. # Log context only for the default execution method, the assumption
  2732. # being that otherwise we're resuming a deferred task (in which
  2733. # case there's no need to log these again).
  2734. if not self.next_method:
  2735. self.log.info(
  2736. "Exporting env vars: %s",
  2737. " ".join(f"{k}={v!r}" for k, v in airflow_context_vars.items()),
  2738. )
  2739. # Run pre_execute callback
  2740. self.task.pre_execute(context=context)
  2741. # Run on_execute callback
  2742. self._run_execute_callback(context, self.task)
  2743. # Run on_task_instance_running event
  2744. get_listener_manager().hook.on_task_instance_running(
  2745. previous_state=TaskInstanceState.QUEUED, task_instance=self, session=session
  2746. )
  2747. def _render_map_index(context: Context, *, jinja_env: jinja2.Environment | None) -> str | None:
  2748. """Render named map index if the DAG author defined map_index_template at the task level."""
  2749. if jinja_env is None or (template := context.get("map_index_template")) is None:
  2750. return None
  2751. rendered_map_index = jinja_env.from_string(template).render(context)
  2752. log.debug("Map index rendered as %s", rendered_map_index)
  2753. return rendered_map_index
  2754. # Execute the task.
  2755. with set_current_context(context):
  2756. try:
  2757. result = self._execute_task(context, task_orig)
  2758. except Exception:
  2759. # If the task failed, swallow rendering error so it doesn't mask the main error.
  2760. with contextlib.suppress(jinja2.TemplateSyntaxError, jinja2.UndefinedError):
  2761. self.rendered_map_index = _render_map_index(context, jinja_env=jinja_env)
  2762. raise
  2763. else: # If the task succeeded, render normally to let rendering error bubble up.
  2764. self.rendered_map_index = _render_map_index(context, jinja_env=jinja_env)
  2765. # Run post_execute callback
  2766. self.task.post_execute(context=context, result=result)
  2767. Stats.incr(f"operator_successes_{self.task.task_type}", tags=self.stats_tags)
  2768. # Same metric with tagging
  2769. Stats.incr("operator_successes", tags={**self.stats_tags, "task_type": self.task.task_type})
  2770. Stats.incr("ti_successes", tags=self.stats_tags)
  2771. def _execute_task(self, context: Context, task_orig: Operator):
  2772. """
  2773. Execute Task (optionally with a Timeout) and push Xcom results.
  2774. :param context: Jinja2 context
  2775. :param task_orig: origin task
  2776. """
  2777. return _execute_task(self, context, task_orig)
  2778. @provide_session
  2779. def defer_task(self, exception: TaskDeferred | None, session: Session = NEW_SESSION) -> None:
  2780. """
  2781. Mark the task as deferred and sets up the trigger that is needed to resume it when TaskDeferred is raised.
  2782. :meta: private
  2783. """
  2784. _defer_task(ti=self, exception=exception, session=session)
  2785. def _run_execute_callback(self, context: Context, task: BaseOperator) -> None:
  2786. """Functions that need to be run before a Task is executed."""
  2787. if not (callbacks := task.on_execute_callback):
  2788. return
  2789. for callback in callbacks if isinstance(callbacks, list) else [callbacks]:
  2790. try:
  2791. callback(context)
  2792. except Exception:
  2793. self.log.exception("Failed when executing execute callback")
  2794. @provide_session
  2795. def run(
  2796. self,
  2797. verbose: bool = True,
  2798. ignore_all_deps: bool = False,
  2799. ignore_depends_on_past: bool = False,
  2800. wait_for_past_depends_before_skipping: bool = False,
  2801. ignore_task_deps: bool = False,
  2802. ignore_ti_state: bool = False,
  2803. mark_success: bool = False,
  2804. test_mode: bool = False,
  2805. job_id: str | None = None,
  2806. pool: str | None = None,
  2807. session: Session = NEW_SESSION,
  2808. raise_on_defer: bool = False,
  2809. ) -> None:
  2810. """Run TaskInstance."""
  2811. res = self.check_and_change_state_before_execution(
  2812. verbose=verbose,
  2813. ignore_all_deps=ignore_all_deps,
  2814. ignore_depends_on_past=ignore_depends_on_past,
  2815. wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
  2816. ignore_task_deps=ignore_task_deps,
  2817. ignore_ti_state=ignore_ti_state,
  2818. mark_success=mark_success,
  2819. test_mode=test_mode,
  2820. job_id=job_id,
  2821. pool=pool,
  2822. session=session,
  2823. )
  2824. if not res:
  2825. return
  2826. self._run_raw_task(
  2827. mark_success=mark_success,
  2828. test_mode=test_mode,
  2829. job_id=job_id,
  2830. pool=pool,
  2831. session=session,
  2832. raise_on_defer=raise_on_defer,
  2833. )
  2834. def dry_run(self) -> None:
  2835. """Only Renders Templates for the TI."""
  2836. if TYPE_CHECKING:
  2837. assert self.task
  2838. self.task = self.task.prepare_for_execution()
  2839. self.render_templates()
  2840. if TYPE_CHECKING:
  2841. assert isinstance(self.task, BaseOperator)
  2842. self.task.dry_run()
  2843. @provide_session
  2844. def _handle_reschedule(
  2845. self,
  2846. actual_start_date: datetime,
  2847. reschedule_exception: AirflowRescheduleException,
  2848. test_mode: bool = False,
  2849. session: Session = NEW_SESSION,
  2850. ):
  2851. _handle_reschedule(
  2852. ti=self,
  2853. actual_start_date=actual_start_date,
  2854. reschedule_exception=reschedule_exception,
  2855. test_mode=test_mode,
  2856. session=session,
  2857. )
  2858. @staticmethod
  2859. def get_truncated_error_traceback(error: BaseException, truncate_to: Callable) -> TracebackType | None:
  2860. """
  2861. Truncate the traceback of an exception to the first frame called from within a given function.
  2862. :param error: exception to get traceback from
  2863. :param truncate_to: Function to truncate TB to. Must have a ``__code__`` attribute
  2864. :meta private:
  2865. """
  2866. tb = error.__traceback__
  2867. code = truncate_to.__func__.__code__ # type: ignore[attr-defined]
  2868. while tb is not None:
  2869. if tb.tb_frame.f_code is code:
  2870. return tb.tb_next
  2871. tb = tb.tb_next
  2872. return tb or error.__traceback__
  2873. @classmethod
  2874. def fetch_handle_failure_context(
  2875. cls,
  2876. ti: TaskInstance,
  2877. error: None | str | BaseException,
  2878. test_mode: bool | None = None,
  2879. context: Context | None = None,
  2880. force_fail: bool = False,
  2881. *,
  2882. session: Session,
  2883. fail_stop: bool = False,
  2884. ):
  2885. """
  2886. Handle Failure for the TaskInstance.
  2887. :param fail_stop: if true, stop remaining tasks in dag
  2888. """
  2889. if error:
  2890. if isinstance(error, BaseException):
  2891. tb = TaskInstance.get_truncated_error_traceback(error, truncate_to=ti._execute_task)
  2892. cls.logger().error("Task failed with exception", exc_info=(type(error), error, tb))
  2893. else:
  2894. cls.logger().error("%s", error)
  2895. if not test_mode:
  2896. ti.refresh_from_db(session)
  2897. ti.end_date = timezone.utcnow()
  2898. ti.set_duration()
  2899. Stats.incr(f"operator_failures_{ti.operator}", tags=ti.stats_tags)
  2900. # Same metric with tagging
  2901. Stats.incr("operator_failures", tags={**ti.stats_tags, "operator": ti.operator})
  2902. Stats.incr("ti_failures", tags=ti.stats_tags)
  2903. if not test_mode:
  2904. session.add(Log(TaskInstanceState.FAILED.value, ti))
  2905. # Log failure duration
  2906. session.add(TaskFail(ti=ti))
  2907. ti.clear_next_method_args()
  2908. # In extreme cases (zombie in case of dag with parse error) we might _not_ have a Task.
  2909. if context is None and getattr(ti, "task", None):
  2910. context = ti.get_template_context(session)
  2911. if context is not None:
  2912. context["exception"] = error
  2913. # Set state correctly and figure out how to log it and decide whether
  2914. # to email
  2915. # Note, callback invocation needs to be handled by caller of
  2916. # _run_raw_task to avoid race conditions which could lead to duplicate
  2917. # invocations or miss invocation.
  2918. # Since this function is called only when the TaskInstance state is running,
  2919. # try_number contains the current try_number (not the next). We
  2920. # only mark task instance as FAILED if the next task instance
  2921. # try_number exceeds the max_tries ... or if force_fail is truthy
  2922. task: BaseOperator | None = None
  2923. try:
  2924. if getattr(ti, "task", None) and context:
  2925. if TYPE_CHECKING:
  2926. assert ti.task
  2927. task = ti.task.unmap((context, session))
  2928. except Exception:
  2929. cls.logger().error("Unable to unmap task to determine if we need to send an alert email")
  2930. if force_fail or not ti.is_eligible_to_retry():
  2931. ti.state = TaskInstanceState.FAILED
  2932. email_for_state = operator.attrgetter("email_on_failure")
  2933. callbacks = task.on_failure_callback if task else None
  2934. if task and fail_stop:
  2935. _stop_remaining_tasks(task_instance=ti, session=session)
  2936. else:
  2937. if ti.state == TaskInstanceState.RUNNING:
  2938. # If the task instance is in the running state, it means it raised an exception and
  2939. # about to retry so we record the task instance history. For other states, the task
  2940. # instance was cleared and already recorded in the task instance history.
  2941. from airflow.models.taskinstancehistory import TaskInstanceHistory
  2942. TaskInstanceHistory.record_ti(ti, session=session)
  2943. ti.state = State.UP_FOR_RETRY
  2944. email_for_state = operator.attrgetter("email_on_retry")
  2945. callbacks = task.on_retry_callback if task else None
  2946. get_listener_manager().hook.on_task_instance_failed(
  2947. previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error, session=session
  2948. )
  2949. return {
  2950. "ti": ti,
  2951. "email_for_state": email_for_state,
  2952. "task": task,
  2953. "callbacks": callbacks,
  2954. "context": context,
  2955. }
  2956. @staticmethod
  2957. @internal_api_call
  2958. @provide_session
  2959. def save_to_db(
  2960. ti: TaskInstance | TaskInstancePydantic, session: Session = NEW_SESSION, refresh_dag: bool = True
  2961. ):
  2962. if refresh_dag and isinstance(ti, TaskInstance):
  2963. ti.get_dagrun().refresh_from_db()
  2964. ti = _coalesce_to_orm_ti(ti=ti, session=session)
  2965. ti.updated_at = timezone.utcnow()
  2966. session.merge(ti)
  2967. session.flush()
  2968. session.commit()
  2969. @provide_session
  2970. def handle_failure(
  2971. self,
  2972. error: None | str | BaseException,
  2973. test_mode: bool | None = None,
  2974. context: Context | None = None,
  2975. force_fail: bool = False,
  2976. session: Session = NEW_SESSION,
  2977. ) -> None:
  2978. """
  2979. Handle Failure for a task instance.
  2980. :param error: if specified, log the specific exception if thrown
  2981. :param session: SQLAlchemy ORM Session
  2982. :param test_mode: doesn't record success or failure in the DB if True
  2983. :param context: Jinja2 context
  2984. :param force_fail: if True, task does not retry
  2985. """
  2986. if TYPE_CHECKING:
  2987. assert self.task
  2988. assert self.task.dag
  2989. try:
  2990. fail_stop = self.task.dag.fail_stop
  2991. except Exception:
  2992. fail_stop = False
  2993. _handle_failure(
  2994. task_instance=self,
  2995. error=error,
  2996. session=session,
  2997. test_mode=test_mode,
  2998. context=context,
  2999. force_fail=force_fail,
  3000. fail_stop=fail_stop,
  3001. )
  3002. def is_eligible_to_retry(self):
  3003. """Is task instance is eligible for retry."""
  3004. return _is_eligible_to_retry(task_instance=self)
  3005. def get_template_context(
  3006. self,
  3007. session: Session | None = None,
  3008. ignore_param_exceptions: bool = True,
  3009. ) -> Context:
  3010. """
  3011. Return TI Context.
  3012. :param session: SQLAlchemy ORM Session
  3013. :param ignore_param_exceptions: flag to suppress value exceptions while initializing the ParamsDict
  3014. """
  3015. if TYPE_CHECKING:
  3016. assert self.task
  3017. assert self.task.dag
  3018. return _get_template_context(
  3019. task_instance=self,
  3020. dag=self.task.dag,
  3021. session=session,
  3022. ignore_param_exceptions=ignore_param_exceptions,
  3023. )
  3024. @provide_session
  3025. def get_rendered_template_fields(self, session: Session = NEW_SESSION) -> None:
  3026. """
  3027. Update task with rendered template fields for presentation in UI.
  3028. If task has already run, will fetch from DB; otherwise will render.
  3029. """
  3030. from airflow.models.renderedtifields import RenderedTaskInstanceFields
  3031. if TYPE_CHECKING:
  3032. assert self.task
  3033. rendered_task_instance_fields = RenderedTaskInstanceFields.get_templated_fields(self, session=session)
  3034. if rendered_task_instance_fields:
  3035. self.task = self.task.unmap(None)
  3036. for field_name, rendered_value in rendered_task_instance_fields.items():
  3037. setattr(self.task, field_name, rendered_value)
  3038. return
  3039. try:
  3040. # If we get here, either the task hasn't run or the RTIF record was purged.
  3041. from airflow.utils.log.secrets_masker import redact
  3042. self.render_templates()
  3043. for field_name in self.task.template_fields:
  3044. rendered_value = getattr(self.task, field_name)
  3045. setattr(self.task, field_name, redact(rendered_value, field_name))
  3046. except (TemplateAssertionError, UndefinedError) as e:
  3047. raise AirflowException(
  3048. "Webserver does not have access to User-defined Macros or Filters "
  3049. "when Dag Serialization is enabled. Hence for the task that have not yet "
  3050. "started running, please use 'airflow tasks render' for debugging the "
  3051. "rendering of template_fields."
  3052. ) from e
  3053. def overwrite_params_with_dag_run_conf(self, params: dict, dag_run: DagRun):
  3054. """Overwrite Task Params with DagRun.conf."""
  3055. if dag_run and dag_run.conf:
  3056. self.log.debug("Updating task params (%s) with DagRun.conf (%s)", params, dag_run.conf)
  3057. params.update(dag_run.conf)
  3058. def render_templates(
  3059. self, context: Context | None = None, jinja_env: jinja2.Environment | None = None
  3060. ) -> Operator:
  3061. """
  3062. Render templates in the operator fields.
  3063. If the task was originally mapped, this may replace ``self.task`` with
  3064. the unmapped, fully rendered BaseOperator. The original ``self.task``
  3065. before replacement is returned.
  3066. """
  3067. from airflow.models.mappedoperator import MappedOperator
  3068. if not context:
  3069. context = self.get_template_context()
  3070. original_task = self.task
  3071. ti = context["ti"]
  3072. if TYPE_CHECKING:
  3073. assert original_task
  3074. assert self.task
  3075. assert ti.task
  3076. if ti.task.dag.__class__ is AttributeRemoved:
  3077. ti.task.dag = self.task.dag
  3078. # If self.task is mapped, this call replaces self.task to point to the
  3079. # unmapped BaseOperator created by this function! This is because the
  3080. # MappedOperator is useless for template rendering, and we need to be
  3081. # able to access the unmapped task instead.
  3082. original_task.render_template_fields(context, jinja_env)
  3083. if isinstance(self.task, MappedOperator):
  3084. self.task = context["ti"].task
  3085. return original_task
  3086. def render_k8s_pod_yaml(self) -> dict | None:
  3087. """Render the k8s pod yaml."""
  3088. try:
  3089. from airflow.providers.cncf.kubernetes.template_rendering import (
  3090. render_k8s_pod_yaml as render_k8s_pod_yaml_from_provider,
  3091. )
  3092. except ImportError:
  3093. raise RuntimeError(
  3094. "You need to have the `cncf.kubernetes` provider installed to use this feature. "
  3095. "Also rather than calling it directly you should import "
  3096. "render_k8s_pod_yaml from airflow.providers.cncf.kubernetes.template_rendering "
  3097. "and call it with TaskInstance as the first argument."
  3098. )
  3099. warnings.warn(
  3100. "You should not call `task_instance.render_k8s_pod_yaml` directly. This method will be removed"
  3101. "in Airflow 3. Rather than calling it directly you should import "
  3102. "`render_k8s_pod_yaml` from `airflow.providers.cncf.kubernetes.template_rendering` "
  3103. "and call it with `TaskInstance` as the first argument.",
  3104. DeprecationWarning,
  3105. stacklevel=2,
  3106. )
  3107. return render_k8s_pod_yaml_from_provider(self)
  3108. @provide_session
  3109. def get_rendered_k8s_spec(self, session: Session = NEW_SESSION):
  3110. """Render the k8s pod yaml."""
  3111. try:
  3112. from airflow.providers.cncf.kubernetes.template_rendering import (
  3113. get_rendered_k8s_spec as get_rendered_k8s_spec_from_provider,
  3114. )
  3115. except ImportError:
  3116. raise RuntimeError(
  3117. "You need to have the `cncf.kubernetes` provider installed to use this feature. "
  3118. "Also rather than calling it directly you should import "
  3119. "`get_rendered_k8s_spec` from `airflow.providers.cncf.kubernetes.template_rendering` "
  3120. "and call it with `TaskInstance` as the first argument."
  3121. )
  3122. warnings.warn(
  3123. "You should not call `task_instance.render_k8s_pod_yaml` directly. This method will be removed"
  3124. "in Airflow 3. Rather than calling it directly you should import "
  3125. "`get_rendered_k8s_spec` from `airflow.providers.cncf.kubernetes.template_rendering` "
  3126. "and call it with `TaskInstance` as the first argument.",
  3127. DeprecationWarning,
  3128. stacklevel=2,
  3129. )
  3130. return get_rendered_k8s_spec_from_provider(self, session=session)
  3131. def get_email_subject_content(
  3132. self, exception: BaseException, task: BaseOperator | None = None
  3133. ) -> tuple[str, str, str]:
  3134. """
  3135. Get the email subject content for exceptions.
  3136. :param exception: the exception sent in the email
  3137. :param task:
  3138. """
  3139. return _get_email_subject_content(task_instance=self, exception=exception, task=task)
  3140. def email_alert(self, exception, task: BaseOperator) -> None:
  3141. """
  3142. Send alert email with exception information.
  3143. :param exception: the exception
  3144. :param task: task related to the exception
  3145. """
  3146. _email_alert(task_instance=self, exception=exception, task=task)
  3147. def set_duration(self) -> None:
  3148. """Set task instance duration."""
  3149. _set_duration(task_instance=self)
  3150. @provide_session
  3151. def xcom_push(
  3152. self,
  3153. key: str,
  3154. value: Any,
  3155. execution_date: datetime | None = None,
  3156. session: Session = NEW_SESSION,
  3157. ) -> None:
  3158. """
  3159. Make an XCom available for tasks to pull.
  3160. :param key: Key to store the value under.
  3161. :param value: Value to store. What types are possible depends on whether
  3162. ``enable_xcom_pickling`` is true or not. If so, this can be any
  3163. picklable object; only be JSON-serializable may be used otherwise.
  3164. :param execution_date: Deprecated parameter that has no effect.
  3165. """
  3166. if execution_date is not None:
  3167. self_execution_date = self.get_dagrun(session).execution_date
  3168. if execution_date < self_execution_date:
  3169. raise ValueError(
  3170. f"execution_date can not be in the past (current execution_date is "
  3171. f"{self_execution_date}; received {execution_date})"
  3172. )
  3173. elif execution_date is not None:
  3174. message = "Passing 'execution_date' to 'TaskInstance.xcom_push()' is deprecated."
  3175. warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3)
  3176. XCom.set(
  3177. key=key,
  3178. value=value,
  3179. task_id=self.task_id,
  3180. dag_id=self.dag_id,
  3181. run_id=self.run_id,
  3182. map_index=self.map_index,
  3183. session=session,
  3184. )
  3185. @provide_session
  3186. def xcom_pull(
  3187. self,
  3188. task_ids: str | Iterable[str] | None = None,
  3189. dag_id: str | None = None,
  3190. key: str = XCOM_RETURN_KEY,
  3191. include_prior_dates: bool = False,
  3192. session: Session = NEW_SESSION,
  3193. *,
  3194. map_indexes: int | Iterable[int] | None = None,
  3195. default: Any = None,
  3196. ) -> Any:
  3197. """
  3198. Pull XComs that optionally meet certain criteria.
  3199. :param key: A key for the XCom. If provided, only XComs with matching
  3200. keys will be returned. The default key is ``'return_value'``, also
  3201. available as constant ``XCOM_RETURN_KEY``. This key is automatically
  3202. given to XComs returned by tasks (as opposed to being pushed
  3203. manually). To remove the filter, pass *None*.
  3204. :param task_ids: Only XComs from tasks with matching ids will be
  3205. pulled. Pass *None* to remove the filter.
  3206. :param dag_id: If provided, only pulls XComs from this DAG. If *None*
  3207. (default), the DAG of the calling task is used.
  3208. :param map_indexes: If provided, only pull XComs with matching indexes.
  3209. If *None* (default), this is inferred from the task(s) being pulled
  3210. (see below for details).
  3211. :param include_prior_dates: If False, only XComs from the current
  3212. execution_date are returned. If *True*, XComs from previous dates
  3213. are returned as well.
  3214. When pulling one single task (``task_id`` is *None* or a str) without
  3215. specifying ``map_indexes``, the return value is inferred from whether
  3216. the specified task is mapped. If not, value from the one single task
  3217. instance is returned. If the task to pull is mapped, an iterator (not a
  3218. list) yielding XComs from mapped task instances is returned. In either
  3219. case, ``default`` (*None* if not specified) is returned if no matching
  3220. XComs are found.
  3221. When pulling multiple tasks (i.e. either ``task_id`` or ``map_index`` is
  3222. a non-str iterable), a list of matching XComs is returned. Elements in
  3223. the list is ordered by item ordering in ``task_id`` and ``map_index``.
  3224. """
  3225. return _xcom_pull(
  3226. ti=self,
  3227. task_ids=task_ids,
  3228. dag_id=dag_id,
  3229. key=key,
  3230. include_prior_dates=include_prior_dates,
  3231. session=session,
  3232. map_indexes=map_indexes,
  3233. default=default,
  3234. )
  3235. @provide_session
  3236. def get_num_running_task_instances(self, session: Session, same_dagrun: bool = False) -> int:
  3237. """Return Number of running TIs from the DB."""
  3238. # .count() is inefficient
  3239. num_running_task_instances_query = session.query(func.count()).filter(
  3240. TaskInstance.dag_id == self.dag_id,
  3241. TaskInstance.task_id == self.task_id,
  3242. TaskInstance.state == TaskInstanceState.RUNNING,
  3243. )
  3244. if same_dagrun:
  3245. num_running_task_instances_query = num_running_task_instances_query.filter(
  3246. TaskInstance.run_id == self.run_id
  3247. )
  3248. return num_running_task_instances_query.scalar()
  3249. def init_run_context(self, raw: bool = False) -> None:
  3250. """Set the log context."""
  3251. self.raw = raw
  3252. self._set_context(self)
  3253. @staticmethod
  3254. def filter_for_tis(tis: Iterable[TaskInstance | TaskInstanceKey]) -> BooleanClauseList | None:
  3255. """Return SQLAlchemy filter to query selected task instances."""
  3256. # DictKeys type, (what we often pass here from the scheduler) is not directly indexable :(
  3257. # Or it might be a generator, but we need to be able to iterate over it more than once
  3258. tis = list(tis)
  3259. if not tis:
  3260. return None
  3261. first = tis[0]
  3262. dag_id = first.dag_id
  3263. run_id = first.run_id
  3264. map_index = first.map_index
  3265. first_task_id = first.task_id
  3266. # pre-compute the set of dag_id, run_id, map_indices and task_ids
  3267. dag_ids, run_ids, map_indices, task_ids = set(), set(), set(), set()
  3268. for t in tis:
  3269. dag_ids.add(t.dag_id)
  3270. run_ids.add(t.run_id)
  3271. map_indices.add(t.map_index)
  3272. task_ids.add(t.task_id)
  3273. # Common path optimisations: when all TIs are for the same dag_id and run_id, or same dag_id
  3274. # and task_id -- this can be over 150x faster for huge numbers of TIs (20k+)
  3275. if dag_ids == {dag_id} and run_ids == {run_id} and map_indices == {map_index}:
  3276. return and_(
  3277. TaskInstance.dag_id == dag_id,
  3278. TaskInstance.run_id == run_id,
  3279. TaskInstance.map_index == map_index,
  3280. TaskInstance.task_id.in_(task_ids),
  3281. )
  3282. if dag_ids == {dag_id} and task_ids == {first_task_id} and map_indices == {map_index}:
  3283. return and_(
  3284. TaskInstance.dag_id == dag_id,
  3285. TaskInstance.run_id.in_(run_ids),
  3286. TaskInstance.map_index == map_index,
  3287. TaskInstance.task_id == first_task_id,
  3288. )
  3289. if dag_ids == {dag_id} and run_ids == {run_id} and task_ids == {first_task_id}:
  3290. return and_(
  3291. TaskInstance.dag_id == dag_id,
  3292. TaskInstance.run_id == run_id,
  3293. TaskInstance.map_index.in_(map_indices),
  3294. TaskInstance.task_id == first_task_id,
  3295. )
  3296. filter_condition = []
  3297. # create 2 nested groups, both primarily grouped by dag_id and run_id,
  3298. # and in the nested group 1 grouped by task_id the other by map_index.
  3299. task_id_groups: dict[tuple, dict[Any, list[Any]]] = defaultdict(lambda: defaultdict(list))
  3300. map_index_groups: dict[tuple, dict[Any, list[Any]]] = defaultdict(lambda: defaultdict(list))
  3301. for t in tis:
  3302. task_id_groups[(t.dag_id, t.run_id)][t.task_id].append(t.map_index)
  3303. map_index_groups[(t.dag_id, t.run_id)][t.map_index].append(t.task_id)
  3304. # this assumes that most dags have dag_id as the largest grouping, followed by run_id. even
  3305. # if its not, this is still a significant optimization over querying for every single tuple key
  3306. for cur_dag_id, cur_run_id in itertools.product(dag_ids, run_ids):
  3307. # we compare the group size between task_id and map_index and use the smaller group
  3308. dag_task_id_groups = task_id_groups[(cur_dag_id, cur_run_id)]
  3309. dag_map_index_groups = map_index_groups[(cur_dag_id, cur_run_id)]
  3310. if len(dag_task_id_groups) <= len(dag_map_index_groups):
  3311. for cur_task_id, cur_map_indices in dag_task_id_groups.items():
  3312. filter_condition.append(
  3313. and_(
  3314. TaskInstance.dag_id == cur_dag_id,
  3315. TaskInstance.run_id == cur_run_id,
  3316. TaskInstance.task_id == cur_task_id,
  3317. TaskInstance.map_index.in_(cur_map_indices),
  3318. )
  3319. )
  3320. else:
  3321. for cur_map_index, cur_task_ids in dag_map_index_groups.items():
  3322. filter_condition.append(
  3323. and_(
  3324. TaskInstance.dag_id == cur_dag_id,
  3325. TaskInstance.run_id == cur_run_id,
  3326. TaskInstance.task_id.in_(cur_task_ids),
  3327. TaskInstance.map_index == cur_map_index,
  3328. )
  3329. )
  3330. return or_(*filter_condition)
  3331. @classmethod
  3332. def ti_selector_condition(cls, vals: Collection[str | tuple[str, int]]) -> ColumnOperators:
  3333. """
  3334. Build an SQLAlchemy filter for a list of task_ids or tuples of (task_id,map_index).
  3335. :meta private:
  3336. """
  3337. # Compute a filter for TI.task_id and TI.map_index based on input values
  3338. # For each item, it will either be a task_id, or (task_id, map_index)
  3339. task_id_only = [v for v in vals if isinstance(v, str)]
  3340. with_map_index = [v for v in vals if not isinstance(v, str)]
  3341. filters: list[ColumnOperators] = []
  3342. if task_id_only:
  3343. filters.append(cls.task_id.in_(task_id_only))
  3344. if with_map_index:
  3345. filters.append(tuple_in_condition((cls.task_id, cls.map_index), with_map_index))
  3346. if not filters:
  3347. return false()
  3348. if len(filters) == 1:
  3349. return filters[0]
  3350. return or_(*filters)
  3351. @classmethod
  3352. @provide_session
  3353. def _schedule_downstream_tasks(
  3354. cls,
  3355. ti: TaskInstance | TaskInstancePydantic,
  3356. session: Session = NEW_SESSION,
  3357. max_tis_per_query: int | None = None,
  3358. ):
  3359. from sqlalchemy.exc import OperationalError
  3360. from airflow.models.dagrun import DagRun
  3361. try:
  3362. # Re-select the row with a lock
  3363. dag_run = with_row_locks(
  3364. session.query(DagRun).filter_by(
  3365. dag_id=ti.dag_id,
  3366. run_id=ti.run_id,
  3367. ),
  3368. session=session,
  3369. skip_locked=True,
  3370. ).one_or_none()
  3371. if not dag_run:
  3372. cls.logger().debug("Skip locked rows, rollback")
  3373. session.rollback()
  3374. return
  3375. task = ti.task
  3376. if TYPE_CHECKING:
  3377. assert task
  3378. assert task.dag
  3379. # Previously, this section used task.dag.partial_subset to retrieve a partial DAG.
  3380. # However, this approach is unsafe as it can result in incomplete or incorrect task execution,
  3381. # leading to potential bad cases. As a result, the operation has been removed.
  3382. # For more details, refer to the discussion in PR #[https://github.com/apache/airflow/pull/42582].
  3383. dag_run.dag = task.dag
  3384. info = dag_run.task_instance_scheduling_decisions(session)
  3385. skippable_task_ids = {
  3386. task_id for task_id in task.dag.task_ids if task_id not in task.downstream_task_ids
  3387. }
  3388. schedulable_tis = [
  3389. ti
  3390. for ti in info.schedulable_tis
  3391. if ti.task_id not in skippable_task_ids
  3392. and not (
  3393. ti.task.inherits_from_empty_operator
  3394. and not ti.task.on_execute_callback
  3395. and not ti.task.on_success_callback
  3396. and not ti.task.outlets
  3397. )
  3398. ]
  3399. for schedulable_ti in schedulable_tis:
  3400. if getattr(schedulable_ti, "task", None) is None:
  3401. schedulable_ti.task = task.dag.get_task(schedulable_ti.task_id)
  3402. num = dag_run.schedule_tis(schedulable_tis, session=session, max_tis_per_query=max_tis_per_query)
  3403. cls.logger().info("%d downstream tasks scheduled from follow-on schedule check", num)
  3404. session.flush()
  3405. except OperationalError as e:
  3406. # Any kind of DB error here is _non fatal_ as this block is just an optimisation.
  3407. cls.logger().warning(
  3408. "Skipping mini scheduling run due to exception: %s",
  3409. e.statement,
  3410. exc_info=True,
  3411. )
  3412. session.rollback()
  3413. @provide_session
  3414. def schedule_downstream_tasks(self, session: Session = NEW_SESSION, max_tis_per_query: int | None = None):
  3415. """
  3416. Schedule downstream tasks of this task instance.
  3417. :meta: private
  3418. """
  3419. try:
  3420. return TaskInstance._schedule_downstream_tasks(
  3421. ti=self, session=session, max_tis_per_query=max_tis_per_query
  3422. )
  3423. except Exception:
  3424. self.log.exception(
  3425. "Error scheduling downstream tasks. Skipping it as this is entirely optional optimisation. "
  3426. "There might be various reasons for it, please take a look at the stack trace to figure "
  3427. "out if the root cause can be diagnosed and fixed. See the issue "
  3428. "https://github.com/apache/airflow/issues/39717 for details and an example problem. If you "
  3429. "would like to get help in solving root cause, open discussion with all details with your "
  3430. "managed service support or in Airflow repository."
  3431. )
  3432. def get_relevant_upstream_map_indexes(
  3433. self,
  3434. upstream: Operator,
  3435. ti_count: int | None,
  3436. *,
  3437. session: Session,
  3438. ) -> int | range | None:
  3439. """
  3440. Infer the map indexes of an upstream "relevant" to this ti.
  3441. The bulk of the logic mainly exists to solve the problem described by
  3442. the following example, where 'val' must resolve to different values,
  3443. depending on where the reference is being used::
  3444. @task
  3445. def this_task(v): # This is self.task.
  3446. return v * 2
  3447. @task_group
  3448. def tg1(inp):
  3449. val = upstream(inp) # This is the upstream task.
  3450. this_task(val) # When inp is 1, val here should resolve to 2.
  3451. return val
  3452. # This val is the same object returned by tg1.
  3453. val = tg1.expand(inp=[1, 2, 3])
  3454. @task_group
  3455. def tg2(inp):
  3456. another_task(inp, val) # val here should resolve to [2, 4, 6].
  3457. tg2.expand(inp=["a", "b"])
  3458. The surrounding mapped task groups of ``upstream`` and ``self.task`` are
  3459. inspected to find a common "ancestor". If such an ancestor is found,
  3460. we need to return specific map indexes to pull a partial value from
  3461. upstream XCom.
  3462. :param upstream: The referenced upstream task.
  3463. :param ti_count: The total count of task instance this task was expanded
  3464. by the scheduler, i.e. ``expanded_ti_count`` in the template context.
  3465. :return: Specific map index or map indexes to pull, or ``None`` if we
  3466. want to "whole" return value (i.e. no mapped task groups involved).
  3467. """
  3468. if TYPE_CHECKING:
  3469. assert self.task
  3470. # This value should never be None since we already know the current task
  3471. # is in a mapped task group, and should have been expanded, despite that,
  3472. # we need to check that it is not None to satisfy Mypy.
  3473. # But this value can be 0 when we expand an empty list, for that it is
  3474. # necessary to check that ti_count is not 0 to avoid dividing by 0.
  3475. if not ti_count:
  3476. return None
  3477. # Find the innermost common mapped task group between the current task
  3478. # If the current task and the referenced task does not have a common
  3479. # mapped task group, the two are in different task mapping contexts
  3480. # (like another_task above), and we should use the "whole" value.
  3481. common_ancestor = _find_common_ancestor_mapped_group(self.task, upstream)
  3482. if common_ancestor is None:
  3483. return None
  3484. # At this point we know the two tasks share a mapped task group, and we
  3485. # should use a "partial" value. Let's break down the mapped ti count
  3486. # between the ancestor and further expansion happened inside it.
  3487. ancestor_ti_count = common_ancestor.get_mapped_ti_count(self.run_id, session=session)
  3488. ancestor_map_index = self.map_index * ancestor_ti_count // ti_count
  3489. # If the task is NOT further expanded inside the common ancestor, we
  3490. # only want to reference one single ti. We must walk the actual DAG,
  3491. # and "ti_count == ancestor_ti_count" does not work, since the further
  3492. # expansion may be of length 1.
  3493. if not _is_further_mapped_inside(upstream, common_ancestor):
  3494. return ancestor_map_index
  3495. # Otherwise we need a partial aggregation for values from selected task
  3496. # instances in the ancestor's expansion context.
  3497. further_count = ti_count // ancestor_ti_count
  3498. map_index_start = ancestor_map_index * further_count
  3499. return range(map_index_start, map_index_start + further_count)
  3500. def clear_db_references(self, session: Session):
  3501. """
  3502. Clear db tables that have a reference to this instance.
  3503. :param session: ORM Session
  3504. :meta private:
  3505. """
  3506. from airflow.models.renderedtifields import RenderedTaskInstanceFields
  3507. tables: list[type[TaskInstanceDependencies]] = [
  3508. TaskFail,
  3509. TaskInstanceNote,
  3510. TaskReschedule,
  3511. XCom,
  3512. RenderedTaskInstanceFields,
  3513. TaskMap,
  3514. ]
  3515. for table in tables:
  3516. session.execute(
  3517. delete(table).where(
  3518. table.dag_id == self.dag_id,
  3519. table.task_id == self.task_id,
  3520. table.run_id == self.run_id,
  3521. table.map_index == self.map_index,
  3522. )
  3523. )
  3524. def _find_common_ancestor_mapped_group(node1: Operator, node2: Operator) -> MappedTaskGroup | None:
  3525. """Given two operators, find their innermost common mapped task group."""
  3526. if node1.dag is None or node2.dag is None or node1.dag_id != node2.dag_id:
  3527. return None
  3528. parent_group_ids = {g.group_id for g in node1.iter_mapped_task_groups()}
  3529. common_groups = (g for g in node2.iter_mapped_task_groups() if g.group_id in parent_group_ids)
  3530. return next(common_groups, None)
  3531. def _is_further_mapped_inside(operator: Operator, container: TaskGroup) -> bool:
  3532. """Whether given operator is *further* mapped inside a task group."""
  3533. from airflow.models.mappedoperator import MappedOperator
  3534. if isinstance(operator, MappedOperator):
  3535. return True
  3536. task_group = operator.task_group
  3537. while task_group is not None and task_group.group_id != container.group_id:
  3538. if isinstance(task_group, MappedTaskGroup):
  3539. return True
  3540. task_group = task_group.parent_group
  3541. return False
  3542. # State of the task instance.
  3543. # Stores string version of the task state.
  3544. TaskInstanceStateType = Tuple[TaskInstanceKey, TaskInstanceState]
  3545. class SimpleTaskInstance:
  3546. """
  3547. Simplified Task Instance.
  3548. Used to send data between processes via Queues.
  3549. """
  3550. def __init__(
  3551. self,
  3552. dag_id: str,
  3553. task_id: str,
  3554. run_id: str,
  3555. start_date: datetime | None,
  3556. end_date: datetime | None,
  3557. try_number: int,
  3558. map_index: int,
  3559. state: str,
  3560. executor: str | None,
  3561. executor_config: Any,
  3562. pool: str,
  3563. queue: str,
  3564. key: TaskInstanceKey,
  3565. run_as_user: str | None = None,
  3566. priority_weight: int | None = None,
  3567. ):
  3568. self.dag_id = dag_id
  3569. self.task_id = task_id
  3570. self.run_id = run_id
  3571. self.map_index = map_index
  3572. self.start_date = start_date
  3573. self.end_date = end_date
  3574. self.try_number = try_number
  3575. self.state = state
  3576. self.executor = executor
  3577. self.executor_config = executor_config
  3578. self.run_as_user = run_as_user
  3579. self.pool = pool
  3580. self.priority_weight = priority_weight
  3581. self.queue = queue
  3582. self.key = key
  3583. def __repr__(self) -> str:
  3584. attrs = ", ".join(f"{k}={v!r}" for k, v in self.__dict__.items())
  3585. return f"SimpleTaskInstance({attrs})"
  3586. def __eq__(self, other) -> bool:
  3587. if isinstance(other, self.__class__):
  3588. return self.__dict__ == other.__dict__
  3589. return NotImplemented
  3590. def as_dict(self):
  3591. warnings.warn(
  3592. "This method is deprecated. Use BaseSerialization.serialize.",
  3593. RemovedInAirflow3Warning,
  3594. stacklevel=2,
  3595. )
  3596. new_dict = dict(self.__dict__)
  3597. for key in new_dict:
  3598. if key in ["start_date", "end_date"]:
  3599. val = new_dict[key]
  3600. if not val or isinstance(val, str):
  3601. continue
  3602. new_dict.update({key: val.isoformat()})
  3603. return new_dict
  3604. @classmethod
  3605. def from_ti(cls, ti: TaskInstance) -> SimpleTaskInstance:
  3606. return cls(
  3607. dag_id=ti.dag_id,
  3608. task_id=ti.task_id,
  3609. run_id=ti.run_id,
  3610. map_index=ti.map_index,
  3611. start_date=ti.start_date,
  3612. end_date=ti.end_date,
  3613. try_number=ti.try_number,
  3614. state=ti.state,
  3615. executor=ti.executor,
  3616. executor_config=ti.executor_config,
  3617. pool=ti.pool,
  3618. queue=ti.queue,
  3619. key=ti.key,
  3620. run_as_user=ti.run_as_user if hasattr(ti, "run_as_user") else None,
  3621. priority_weight=ti.priority_weight if hasattr(ti, "priority_weight") else None,
  3622. )
  3623. @classmethod
  3624. def from_dict(cls, obj_dict: dict) -> SimpleTaskInstance:
  3625. warnings.warn(
  3626. "This method is deprecated. Use BaseSerialization.deserialize.",
  3627. RemovedInAirflow3Warning,
  3628. stacklevel=2,
  3629. )
  3630. ti_key = TaskInstanceKey(*obj_dict.pop("key"))
  3631. start_date = None
  3632. end_date = None
  3633. start_date_str: str | None = obj_dict.pop("start_date")
  3634. end_date_str: str | None = obj_dict.pop("end_date")
  3635. if start_date_str:
  3636. start_date = timezone.parse(start_date_str)
  3637. if end_date_str:
  3638. end_date = timezone.parse(end_date_str)
  3639. return cls(**obj_dict, start_date=start_date, end_date=end_date, key=ti_key)
  3640. class TaskInstanceNote(TaskInstanceDependencies):
  3641. """For storage of arbitrary notes concerning the task instance."""
  3642. __tablename__ = "task_instance_note"
  3643. user_id = Column(Integer, ForeignKey("ab_user.id", name="task_instance_note_user_fkey"), nullable=True)
  3644. task_id = Column(StringID(), primary_key=True, nullable=False)
  3645. dag_id = Column(StringID(), primary_key=True, nullable=False)
  3646. run_id = Column(StringID(), primary_key=True, nullable=False)
  3647. map_index = Column(Integer, primary_key=True, nullable=False)
  3648. content = Column(String(1000).with_variant(Text(1000), "mysql"))
  3649. created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
  3650. updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False)
  3651. task_instance = relationship("TaskInstance", back_populates="task_instance_note")
  3652. __table_args__ = (
  3653. PrimaryKeyConstraint("task_id", "dag_id", "run_id", "map_index", name="task_instance_note_pkey"),
  3654. ForeignKeyConstraint(
  3655. (dag_id, task_id, run_id, map_index),
  3656. [
  3657. "task_instance.dag_id",
  3658. "task_instance.task_id",
  3659. "task_instance.run_id",
  3660. "task_instance.map_index",
  3661. ],
  3662. name="task_instance_note_ti_fkey",
  3663. ondelete="CASCADE",
  3664. ),
  3665. )
  3666. def __init__(self, content, user_id=None):
  3667. self.content = content
  3668. self.user_id = user_id
  3669. def __repr__(self):
  3670. prefix = f"<{self.__class__.__name__}: {self.dag_id}.{self.task_id} {self.run_id}"
  3671. if self.map_index != -1:
  3672. prefix += f" map_index={self.map_index}"
  3673. return prefix + ">"
  3674. STATICA_HACK = True
  3675. globals()["kcah_acitats"[::-1].upper()] = False
  3676. if STATICA_HACK: # pragma: no cover
  3677. from airflow.jobs.job import Job
  3678. TaskInstance.queued_by_job = relationship(Job)