12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668366936703671367236733674367536763677367836793680368136823683368436853686368736883689369036913692369336943695369636973698369937003701370237033704370537063707370837093710371137123713371437153716371737183719372037213722372337243725372637273728372937303731373237333734373537363737373837393740374137423743374437453746374737483749375037513752375337543755375637573758375937603761376237633764376537663767376837693770377137723773377437753776377737783779378037813782378337843785378637873788378937903791379237933794379537963797379837993800380138023803380438053806380738083809381038113812381338143815381638173818381938203821382238233824382538263827382838293830383138323833383438353836383738383839384038413842384338443845384638473848384938503851385238533854385538563857385838593860386138623863386438653866386738683869387038713872387338743875387638773878387938803881388238833884388538863887388838893890389138923893389438953896389738983899390039013902390339043905390639073908390939103911391239133914391539163917391839193920392139223923392439253926392739283929393039313932393339343935393639373938393939403941394239433944394539463947394839493950395139523953395439553956395739583959396039613962396339643965396639673968396939703971397239733974397539763977397839793980398139823983398439853986398739883989399039913992399339943995399639973998399940004001400240034004400540064007400840094010401140124013401440154016401740184019402040214022402340244025402640274028402940304031403240334034403540364037403840394040404140424043404440454046404740484049405040514052405340544055405640574058405940604061406240634064406540664067406840694070407140724073407440754076407740784079408040814082408340844085408640874088408940904091409240934094409540964097409840994100410141024103410441054106410741084109411041114112411341144115411641174118411941204121412241234124412541264127412841294130413141324133413441354136413741384139414041414142414341444145414641474148414941504151415241534154415541564157415841594160416141624163416441654166416741684169417041714172417341744175417641774178417941804181418241834184418541864187418841894190419141924193419441954196419741984199420042014202420342044205420642074208420942104211421242134214421542164217421842194220422142224223422442254226422742284229 |
- #
- # Licensed to the Apache Software Foundation (ASF) under one
- # or more contributor license agreements. See the NOTICE file
- # distributed with this work for additional information
- # regarding copyright ownership. The ASF licenses this file
- # to you under the Apache License, Version 2.0 (the
- # "License"); you may not use this file except in compliance
- # with the License. You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- # KIND, either express or implied. See the License for the
- # specific language governing permissions and limitations
- # under the License.
- from __future__ import annotations
- import collections.abc
- import contextlib
- import hashlib
- import itertools
- import logging
- import math
- import operator
- import os
- import signal
- import traceback
- import warnings
- from collections import defaultdict
- from contextlib import nullcontext
- from datetime import timedelta
- from enum import Enum
- from typing import TYPE_CHECKING, Any, Callable, Collection, Generator, Iterable, Mapping, Tuple
- from urllib.parse import quote
- import dill
- import jinja2
- import lazy_object_proxy
- import pendulum
- from deprecated import deprecated
- from jinja2 import TemplateAssertionError, UndefinedError
- from sqlalchemy import (
- Column,
- DateTime,
- Float,
- ForeignKey,
- ForeignKeyConstraint,
- Index,
- Integer,
- PrimaryKeyConstraint,
- String,
- Text,
- and_,
- delete,
- false,
- func,
- inspect,
- or_,
- text,
- update,
- )
- from sqlalchemy.ext.associationproxy import association_proxy
- from sqlalchemy.ext.hybrid import hybrid_property
- from sqlalchemy.ext.mutable import MutableDict
- from sqlalchemy.orm import lazyload, reconstructor, relationship
- from sqlalchemy.orm.attributes import NO_VALUE, set_committed_value
- from sqlalchemy.sql.expression import case, select
- from airflow import settings
- from airflow.api_internal.internal_api_call import InternalApiConfig, internal_api_call
- from airflow.compat.functools import cache
- from airflow.configuration import conf
- from airflow.datasets import Dataset, DatasetAlias
- from airflow.datasets.manager import dataset_manager
- from airflow.exceptions import (
- AirflowException,
- AirflowFailException,
- AirflowRescheduleException,
- AirflowSensorTimeout,
- AirflowSkipException,
- AirflowTaskTerminated,
- AirflowTaskTimeout,
- DagRunNotFound,
- RemovedInAirflow3Warning,
- TaskDeferralError,
- TaskDeferred,
- UnmappableXComLengthPushed,
- UnmappableXComTypePushed,
- XComForMappingNotPushed,
- )
- from airflow.listeners.listener import get_listener_manager
- from airflow.models.base import Base, StringID, TaskInstanceDependencies, _sentinel
- from airflow.models.dagbag import DagBag
- from airflow.models.dataset import DatasetAliasModel, DatasetModel
- from airflow.models.log import Log
- from airflow.models.param import process_params
- from airflow.models.renderedtifields import get_serialized_template_fields
- from airflow.models.taskfail import TaskFail
- from airflow.models.taskinstancekey import TaskInstanceKey
- from airflow.models.taskmap import TaskMap
- from airflow.models.taskreschedule import TaskReschedule
- from airflow.models.xcom import LazyXComSelectSequence, XCom
- from airflow.plugins_manager import integrate_macros_plugins
- from airflow.sentry import Sentry
- from airflow.settings import task_instance_mutation_hook
- from airflow.stats import Stats
- from airflow.templates import SandboxedEnvironment
- from airflow.ti_deps.dep_context import DepContext
- from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS
- from airflow.traces.tracer import Trace
- from airflow.utils import timezone
- from airflow.utils.context import (
- ConnectionAccessor,
- Context,
- InletEventsAccessors,
- OutletEventAccessors,
- VariableAccessor,
- context_get_outlet_events,
- context_merge,
- )
- from airflow.utils.email import send_email
- from airflow.utils.helpers import prune_dict, render_template_to_string
- from airflow.utils.log.logging_mixin import LoggingMixin
- from airflow.utils.net import get_hostname
- from airflow.utils.operator_helpers import ExecutionCallableRunner, context_to_airflow_vars
- from airflow.utils.platform import getuser
- from airflow.utils.retries import run_with_db_retries
- from airflow.utils.session import NEW_SESSION, create_session, provide_session
- from airflow.utils.sqlalchemy import (
- ExecutorConfigType,
- ExtendedJSON,
- UtcDateTime,
- tuple_in_condition,
- with_row_locks,
- )
- from airflow.utils.state import DagRunState, JobState, State, TaskInstanceState
- from airflow.utils.task_group import MappedTaskGroup
- from airflow.utils.task_instance_session import set_current_task_instance_session
- from airflow.utils.timeout import timeout
- from airflow.utils.types import AttributeRemoved
- from airflow.utils.xcom import XCOM_RETURN_KEY
- TR = TaskReschedule
- _CURRENT_CONTEXT: list[Context] = []
- log = logging.getLogger(__name__)
- if TYPE_CHECKING:
- from datetime import datetime
- from pathlib import PurePath
- from types import TracebackType
- from sqlalchemy.orm.session import Session
- from sqlalchemy.sql.elements import BooleanClauseList
- from sqlalchemy.sql.expression import ColumnOperators
- from airflow.models.abstractoperator import TaskStateChangeCallback
- from airflow.models.baseoperator import BaseOperator
- from airflow.models.dag import DAG, DagModel
- from airflow.models.dagrun import DagRun
- from airflow.models.dataset import DatasetEvent
- from airflow.models.operator import Operator
- from airflow.serialization.pydantic.dag import DagModelPydantic
- from airflow.serialization.pydantic.dataset import DatasetEventPydantic
- from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
- from airflow.timetables.base import DataInterval
- from airflow.typing_compat import Literal, TypeGuard
- from airflow.utils.task_group import TaskGroup
- PAST_DEPENDS_MET = "past_depends_met"
- class TaskReturnCode(Enum):
- """
- Enum to signal manner of exit for task run command.
- :meta private:
- """
- DEFERRED = 100
- """When task exits with deferral to trigger."""
- @internal_api_call
- @provide_session
- def _merge_ti(ti, session: Session = NEW_SESSION):
- session.merge(ti)
- session.commit()
- @internal_api_call
- @provide_session
- def _add_log(
- event,
- task_instance=None,
- owner=None,
- owner_display_name=None,
- extra=None,
- session: Session = NEW_SESSION,
- **kwargs,
- ):
- session.add(
- Log(
- event,
- task_instance,
- owner,
- owner_display_name,
- extra,
- **kwargs,
- )
- )
- def _run_raw_task(
- ti: TaskInstance | TaskInstancePydantic,
- mark_success: bool = False,
- test_mode: bool = False,
- job_id: str | None = None,
- pool: str | None = None,
- raise_on_defer: bool = False,
- session: Session | None = None,
- ) -> TaskReturnCode | None:
- """
- Run a task, update the state upon completion, and run any appropriate callbacks.
- Immediately runs the task (without checking or changing db state
- before execution) and then sets the appropriate final state after
- completion and runs any post-execute callbacks. Meant to be called
- only after another function changes the state to running.
- :param mark_success: Don't run the task, mark its state as success
- :param test_mode: Doesn't record success or failure in the DB
- :param pool: specifies the pool to use to run the task instance
- :param session: SQLAlchemy ORM Session
- """
- if TYPE_CHECKING:
- assert ti.task
- ti.test_mode = test_mode
- ti.refresh_from_task(ti.task, pool_override=pool)
- ti.refresh_from_db(session=session)
- ti.job_id = job_id
- ti.hostname = get_hostname()
- ti.pid = os.getpid()
- if not test_mode:
- TaskInstance.save_to_db(ti=ti, session=session, refresh_dag=False)
- actual_start_date = timezone.utcnow()
- Stats.incr(f"ti.start.{ti.task.dag_id}.{ti.task.task_id}", tags=ti.stats_tags)
- # Same metric with tagging
- Stats.incr("ti.start", tags=ti.stats_tags)
- # Initialize final state counters at zero
- for state in State.task_states:
- Stats.incr(
- f"ti.finish.{ti.task.dag_id}.{ti.task.task_id}.{state}",
- count=0,
- tags=ti.stats_tags,
- )
- # Same metric with tagging
- Stats.incr(
- "ti.finish",
- count=0,
- tags={**ti.stats_tags, "state": str(state)},
- )
- with set_current_task_instance_session(session=session):
- ti.task = ti.task.prepare_for_execution()
- context = ti.get_template_context(ignore_param_exceptions=False, session=session)
- try:
- if not mark_success:
- TaskInstance._execute_task_with_callbacks(
- self=ti, # type: ignore[arg-type]
- context=context,
- test_mode=test_mode,
- session=session,
- )
- if not test_mode:
- ti.refresh_from_db(lock_for_update=True, session=session)
- ti.state = TaskInstanceState.SUCCESS
- except TaskDeferred as defer:
- # The task has signalled it wants to defer execution based on
- # a trigger.
- if raise_on_defer:
- raise
- ti.defer_task(exception=defer, session=session)
- ti.log.info(
- "Pausing task as DEFERRED. dag_id=%s, task_id=%s, run_id=%s, execution_date=%s, start_date=%s",
- ti.dag_id,
- ti.task_id,
- ti.run_id,
- _date_or_empty(task_instance=ti, attr="execution_date"),
- _date_or_empty(task_instance=ti, attr="start_date"),
- )
- return TaskReturnCode.DEFERRED
- except AirflowSkipException as e:
- # Recording SKIP
- # log only if exception has any arguments to prevent log flooding
- if e.args:
- ti.log.info(e)
- if not test_mode:
- ti.refresh_from_db(lock_for_update=True, session=session)
- ti.state = TaskInstanceState.SKIPPED
- _run_finished_callback(callbacks=ti.task.on_skipped_callback, context=context)
- TaskInstance.save_to_db(ti=ti, session=session)
- except AirflowRescheduleException as reschedule_exception:
- ti._handle_reschedule(actual_start_date, reschedule_exception, test_mode, session=session)
- ti.log.info("Rescheduling task, marking task as UP_FOR_RESCHEDULE")
- return None
- except (AirflowFailException, AirflowSensorTimeout) as e:
- # If AirflowFailException is raised, task should not retry.
- # If a sensor in reschedule mode reaches timeout, task should not retry.
- ti.handle_failure(e, test_mode, context, force_fail=True, session=session) # already saves to db
- raise
- except (AirflowTaskTimeout, AirflowException, AirflowTaskTerminated) as e:
- if not test_mode:
- ti.refresh_from_db(lock_for_update=True, session=session)
- # for case when task is marked as success/failed externally
- # or dagrun timed out and task is marked as skipped
- # current behavior doesn't hit the callbacks
- if ti.state in State.finished:
- ti.clear_next_method_args()
- TaskInstance.save_to_db(ti=ti, session=session)
- return None
- else:
- ti.handle_failure(e, test_mode, context, session=session)
- raise
- except SystemExit as e:
- # We have already handled SystemExit with success codes (0 and None) in the `_execute_task`.
- # Therefore, here we must handle only error codes.
- msg = f"Task failed due to SystemExit({e.code})"
- ti.handle_failure(msg, test_mode, context, session=session)
- raise AirflowException(msg)
- except BaseException as e:
- ti.handle_failure(e, test_mode, context, session=session)
- raise
- finally:
- # Print a marker post execution for internals of post task processing
- log.info("::group::Post task execution logs")
- Stats.incr(
- f"ti.finish.{ti.dag_id}.{ti.task_id}.{ti.state}",
- tags=ti.stats_tags,
- )
- # Same metric with tagging
- Stats.incr("ti.finish", tags={**ti.stats_tags, "state": str(ti.state)})
- # Recording SKIPPED or SUCCESS
- ti.clear_next_method_args()
- ti.end_date = timezone.utcnow()
- _log_state(task_instance=ti)
- ti.set_duration()
- # run on_success_callback before db committing
- # otherwise, the LocalTaskJob sees the state is changed to `success`,
- # but the task_runner is still running, LocalTaskJob then treats the state is set externally!
- _run_finished_callback(callbacks=ti.task.on_success_callback, context=context)
- if not test_mode:
- _add_log(event=ti.state, task_instance=ti, session=session)
- if ti.state == TaskInstanceState.SUCCESS:
- ti._register_dataset_changes(events=context["outlet_events"], session=session)
- TaskInstance.save_to_db(ti=ti, session=session)
- if ti.state == TaskInstanceState.SUCCESS:
- get_listener_manager().hook.on_task_instance_success(
- previous_state=TaskInstanceState.RUNNING, task_instance=ti, session=session
- )
- return None
- @contextlib.contextmanager
- def set_current_context(context: Context) -> Generator[Context, None, None]:
- """
- Set the current execution context to the provided context object.
- This method should be called once per Task execution, before calling operator.execute.
- """
- _CURRENT_CONTEXT.append(context)
- try:
- yield context
- finally:
- expected_state = _CURRENT_CONTEXT.pop()
- if expected_state != context:
- log.warning(
- "Current context is not equal to the state at context stack. Expected=%s, got=%s",
- context,
- expected_state,
- )
- def _stop_remaining_tasks(*, task_instance: TaskInstance | TaskInstancePydantic, session: Session):
- """
- Stop non-teardown tasks in dag.
- :meta private:
- """
- if not task_instance.dag_run:
- raise ValueError("``task_instance`` must have ``dag_run`` set")
- tis = task_instance.dag_run.get_task_instances(session=session)
- if TYPE_CHECKING:
- assert task_instance.task
- assert isinstance(task_instance.task.dag, DAG)
- for ti in tis:
- if ti.task_id == task_instance.task_id or ti.state in (
- TaskInstanceState.SUCCESS,
- TaskInstanceState.FAILED,
- ):
- continue
- task = task_instance.task.dag.task_dict[ti.task_id]
- if not task.is_teardown:
- if ti.state == TaskInstanceState.RUNNING:
- log.info("Forcing task %s to fail due to dag's `fail_stop` setting", ti.task_id)
- ti.error(session)
- else:
- log.info("Setting task %s to SKIPPED due to dag's `fail_stop` setting.", ti.task_id)
- ti.set_state(state=TaskInstanceState.SKIPPED, session=session)
- else:
- log.info("Not skipping teardown task '%s'", ti.task_id)
- def clear_task_instances(
- tis: list[TaskInstance],
- session: Session,
- activate_dag_runs: None = None,
- dag: DAG | None = None,
- dag_run_state: DagRunState | Literal[False] = DagRunState.QUEUED,
- ) -> None:
- """
- Clear a set of task instances, but make sure the running ones get killed.
- Also sets Dagrun's `state` to QUEUED and `start_date` to the time of execution.
- But only for finished DRs (SUCCESS and FAILED).
- Doesn't clear DR's `state` and `start_date`for running
- DRs (QUEUED and RUNNING) because clearing the state for already
- running DR is redundant and clearing `start_date` affects DR's duration.
- :param tis: a list of task instances
- :param session: current session
- :param dag_run_state: state to set finished DagRuns to.
- If set to False, DagRuns state will not be changed.
- :param dag: DAG object
- :param activate_dag_runs: Deprecated parameter, do not pass
- """
- job_ids = []
- # Keys: dag_id -> run_id -> map_indexes -> try_numbers -> task_id
- task_id_by_key: dict[str, dict[str, dict[int, dict[int, set[str]]]]] = defaultdict(
- lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(set)))
- )
- dag_bag = DagBag(read_dags_from_db=True)
- from airflow.models.taskinstancehistory import TaskInstanceHistory
- for ti in tis:
- TaskInstanceHistory.record_ti(ti, session)
- if ti.state == TaskInstanceState.RUNNING:
- if ti.job_id:
- # If a task is cleared when running, set its state to RESTARTING so that
- # the task is terminated and becomes eligible for retry.
- ti.state = TaskInstanceState.RESTARTING
- job_ids.append(ti.job_id)
- else:
- ti_dag = dag if dag and dag.dag_id == ti.dag_id else dag_bag.get_dag(ti.dag_id, session=session)
- task_id = ti.task_id
- if ti_dag and ti_dag.has_task(task_id):
- task = ti_dag.get_task(task_id)
- ti.refresh_from_task(task)
- if TYPE_CHECKING:
- assert ti.task
- ti.max_tries = ti.try_number + task.retries
- else:
- # Ignore errors when updating max_tries if the DAG or
- # task are not found since database records could be
- # outdated. We make max_tries the maximum value of its
- # original max_tries or the last attempted try number.
- ti.max_tries = max(ti.max_tries, ti.try_number)
- ti.state = None
- ti.external_executor_id = None
- ti.clear_next_method_args()
- session.merge(ti)
- task_id_by_key[ti.dag_id][ti.run_id][ti.map_index][ti.try_number].add(ti.task_id)
- if task_id_by_key:
- # Clear all reschedules related to the ti to clear
- # This is an optimization for the common case where all tis are for a small number
- # of dag_id, run_id, try_number, and map_index. Use a nested dict of dag_id,
- # run_id, try_number, map_index, and task_id to construct the where clause in a
- # hierarchical manner. This speeds up the delete statement by more than 40x for
- # large number of tis (50k+).
- conditions = or_(
- and_(
- TR.dag_id == dag_id,
- or_(
- and_(
- TR.run_id == run_id,
- or_(
- and_(
- TR.map_index == map_index,
- or_(
- and_(TR.try_number == try_number, TR.task_id.in_(task_ids))
- for try_number, task_ids in task_tries.items()
- ),
- )
- for map_index, task_tries in map_indexes.items()
- ),
- )
- for run_id, map_indexes in run_ids.items()
- ),
- )
- for dag_id, run_ids in task_id_by_key.items()
- )
- delete_qry = TR.__table__.delete().where(conditions)
- session.execute(delete_qry)
- if job_ids:
- from airflow.jobs.job import Job
- session.execute(update(Job).where(Job.id.in_(job_ids)).values(state=JobState.RESTARTING))
- if activate_dag_runs is not None:
- warnings.warn(
- "`activate_dag_runs` parameter to clear_task_instances function is deprecated. "
- "Please use `dag_run_state`",
- RemovedInAirflow3Warning,
- stacklevel=2,
- )
- if not activate_dag_runs:
- dag_run_state = False
- if dag_run_state is not False and tis:
- from airflow.models.dagrun import DagRun # Avoid circular import
- run_ids_by_dag_id = defaultdict(set)
- for instance in tis:
- run_ids_by_dag_id[instance.dag_id].add(instance.run_id)
- drs = (
- session.query(DagRun)
- .filter(
- or_(
- and_(DagRun.dag_id == dag_id, DagRun.run_id.in_(run_ids))
- for dag_id, run_ids in run_ids_by_dag_id.items()
- )
- )
- .all()
- )
- dag_run_state = DagRunState(dag_run_state) # Validate the state value.
- for dr in drs:
- if dr.state in State.finished_dr_states:
- dr.state = dag_run_state
- dr.start_date = timezone.utcnow()
- if dag_run_state == DagRunState.QUEUED:
- dr.last_scheduling_decision = None
- dr.start_date = None
- dr.clear_number += 1
- session.flush()
- @internal_api_call
- @provide_session
- def _xcom_pull(
- *,
- ti,
- task_ids: str | Iterable[str] | None = None,
- dag_id: str | None = None,
- key: str = XCOM_RETURN_KEY,
- include_prior_dates: bool = False,
- session: Session = NEW_SESSION,
- map_indexes: int | Iterable[int] | None = None,
- default: Any = None,
- ) -> Any:
- """
- Pull XComs that optionally meet certain criteria.
- :param key: A key for the XCom. If provided, only XComs with matching
- keys will be returned. The default key is ``'return_value'``, also
- available as constant ``XCOM_RETURN_KEY``. This key is automatically
- given to XComs returned by tasks (as opposed to being pushed
- manually). To remove the filter, pass *None*.
- :param task_ids: Only XComs from tasks with matching ids will be
- pulled. Pass *None* to remove the filter.
- :param dag_id: If provided, only pulls XComs from this DAG. If *None*
- (default), the DAG of the calling task is used.
- :param map_indexes: If provided, only pull XComs with matching indexes.
- If *None* (default), this is inferred from the task(s) being pulled
- (see below for details).
- :param include_prior_dates: If False, only XComs from the current
- execution_date are returned. If *True*, XComs from previous dates
- are returned as well.
- When pulling one single task (``task_id`` is *None* or a str) without
- specifying ``map_indexes``, the return value is inferred from whether
- the specified task is mapped. If not, value from the one single task
- instance is returned. If the task to pull is mapped, an iterator (not a
- list) yielding XComs from mapped task instances is returned. In either
- case, ``default`` (*None* if not specified) is returned if no matching
- XComs are found.
- When pulling multiple tasks (i.e. either ``task_id`` or ``map_index`` is
- a non-str iterable), a list of matching XComs is returned. Elements in
- the list is ordered by item ordering in ``task_id`` and ``map_index``.
- """
- if dag_id is None:
- dag_id = ti.dag_id
- query = XCom.get_many(
- key=key,
- run_id=ti.run_id,
- dag_ids=dag_id,
- task_ids=task_ids,
- map_indexes=map_indexes,
- include_prior_dates=include_prior_dates,
- session=session,
- )
- # NOTE: Since we're only fetching the value field and not the whole
- # class, the @recreate annotation does not kick in. Therefore we need to
- # call XCom.deserialize_value() manually.
- # We are only pulling one single task.
- if (task_ids is None or isinstance(task_ids, str)) and not isinstance(map_indexes, Iterable):
- first = query.with_entities(
- XCom.run_id, XCom.task_id, XCom.dag_id, XCom.map_index, XCom.value
- ).first()
- if first is None: # No matching XCom at all.
- return default
- if map_indexes is not None or first.map_index < 0:
- return XCom.deserialize_value(first)
- return LazyXComSelectSequence.from_select(
- query.with_entities(XCom.value).order_by(None).statement,
- order_by=[XCom.map_index],
- session=session,
- )
- # At this point either task_ids or map_indexes is explicitly multi-value.
- # Order return values to match task_ids and map_indexes ordering.
- ordering = []
- if task_ids is None or isinstance(task_ids, str):
- ordering.append(XCom.task_id)
- elif task_id_whens := {tid: i for i, tid in enumerate(task_ids)}:
- ordering.append(case(task_id_whens, value=XCom.task_id))
- else:
- ordering.append(XCom.task_id)
- if map_indexes is None or isinstance(map_indexes, int):
- ordering.append(XCom.map_index)
- elif isinstance(map_indexes, range):
- order = XCom.map_index
- if map_indexes.step < 0:
- order = order.desc()
- ordering.append(order)
- elif map_index_whens := {map_index: i for i, map_index in enumerate(map_indexes)}:
- ordering.append(case(map_index_whens, value=XCom.map_index))
- else:
- ordering.append(XCom.map_index)
- return LazyXComSelectSequence.from_select(
- query.with_entities(XCom.value).order_by(None).statement,
- order_by=ordering,
- session=session,
- )
- def _is_mappable_value(value: Any) -> TypeGuard[Collection]:
- """
- Whether a value can be used for task mapping.
- We only allow collections with guaranteed ordering, but exclude character
- sequences since that's usually not what users would expect to be mappable.
- """
- if not isinstance(value, (collections.abc.Sequence, dict)):
- return False
- if isinstance(value, (bytearray, bytes, str)):
- return False
- return True
- def _creator_note(val):
- """Creator the ``note`` association proxy."""
- if isinstance(val, str):
- return TaskInstanceNote(content=val)
- elif isinstance(val, dict):
- return TaskInstanceNote(**val)
- else:
- return TaskInstanceNote(*val)
- def _execute_task(task_instance: TaskInstance | TaskInstancePydantic, context: Context, task_orig: Operator):
- """
- Execute Task (optionally with a Timeout) and push Xcom results.
- :param task_instance: the task instance
- :param context: Jinja2 context
- :param task_orig: origin task
- :meta private:
- """
- from airflow.models.mappedoperator import MappedOperator
- task_to_execute = task_instance.task
- if TYPE_CHECKING:
- assert task_to_execute
- if isinstance(task_to_execute, MappedOperator):
- raise AirflowException("MappedOperator cannot be executed.")
- # If the task has been deferred and is being executed due to a trigger,
- # then we need to pick the right method to come back to, otherwise
- # we go for the default execute
- execute_callable_kwargs: dict[str, Any] = {}
- execute_callable: Callable
- if task_instance.next_method:
- if task_instance.next_method == "execute":
- if not task_instance.next_kwargs:
- task_instance.next_kwargs = {}
- task_instance.next_kwargs[f"{task_to_execute.__class__.__name__}__sentinel"] = _sentinel
- execute_callable = task_to_execute.resume_execution
- execute_callable_kwargs["next_method"] = task_instance.next_method
- execute_callable_kwargs["next_kwargs"] = task_instance.next_kwargs
- else:
- execute_callable = task_to_execute.execute
- if execute_callable.__name__ == "execute":
- execute_callable_kwargs[f"{task_to_execute.__class__.__name__}__sentinel"] = _sentinel
- def _execute_callable(context: Context, **execute_callable_kwargs):
- try:
- # Print a marker for log grouping of details before task execution
- log.info("::endgroup::")
- return ExecutionCallableRunner(
- execute_callable,
- context_get_outlet_events(context),
- logger=log,
- ).run(context=context, **execute_callable_kwargs)
- except SystemExit as e:
- # Handle only successful cases here. Failure cases will be handled upper
- # in the exception chain.
- if e.code is not None and e.code != 0:
- raise
- return None
- # If a timeout is specified for the task, make it fail
- # if it goes beyond
- if task_to_execute.execution_timeout:
- # If we are coming in with a next_method (i.e. from a deferral),
- # calculate the timeout from our start_date.
- if task_instance.next_method and task_instance.start_date:
- timeout_seconds = (
- task_to_execute.execution_timeout - (timezone.utcnow() - task_instance.start_date)
- ).total_seconds()
- else:
- timeout_seconds = task_to_execute.execution_timeout.total_seconds()
- try:
- # It's possible we're already timed out, so fast-fail if true
- if timeout_seconds <= 0:
- raise AirflowTaskTimeout()
- # Run task in timeout wrapper
- with timeout(timeout_seconds):
- result = _execute_callable(context=context, **execute_callable_kwargs)
- except AirflowTaskTimeout:
- task_to_execute.on_kill()
- raise
- else:
- result = _execute_callable(context=context, **execute_callable_kwargs)
- cm = nullcontext() if InternalApiConfig.get_use_internal_api() else create_session()
- with cm as session_or_null:
- if task_to_execute.do_xcom_push:
- xcom_value = result
- else:
- xcom_value = None
- if xcom_value is not None: # If the task returns a result, push an XCom containing it.
- if task_to_execute.multiple_outputs:
- if not isinstance(xcom_value, Mapping):
- raise AirflowException(
- f"Returned output was type {type(xcom_value)} "
- "expected dictionary for multiple_outputs"
- )
- for key in xcom_value.keys():
- if not isinstance(key, str):
- raise AirflowException(
- "Returned dictionary keys must be strings when using "
- f"multiple_outputs, found {key} ({type(key)}) instead"
- )
- for key, value in xcom_value.items():
- task_instance.xcom_push(key=key, value=value, session=session_or_null)
- task_instance.xcom_push(key=XCOM_RETURN_KEY, value=xcom_value, session=session_or_null)
- if TYPE_CHECKING:
- assert task_orig.dag
- _record_task_map_for_downstreams(
- task_instance=task_instance,
- task=task_orig,
- dag=task_orig.dag,
- value=xcom_value,
- session=session_or_null,
- )
- return result
- def _set_ti_attrs(target, source, include_dag_run=False):
- # Fields ordered per model definition
- target.start_date = source.start_date
- target.end_date = source.end_date
- target.duration = source.duration
- target.state = source.state
- target.try_number = source.try_number
- target.max_tries = source.max_tries
- target.hostname = source.hostname
- target.unixname = source.unixname
- target.job_id = source.job_id
- target.pool = source.pool
- target.pool_slots = source.pool_slots or 1
- target.queue = source.queue
- target.priority_weight = source.priority_weight
- target.operator = source.operator
- target.custom_operator_name = source.custom_operator_name
- target.queued_dttm = source.queued_dttm
- target.queued_by_job_id = source.queued_by_job_id
- target.pid = source.pid
- target.executor = source.executor
- target.executor_config = source.executor_config
- target.external_executor_id = source.external_executor_id
- target.trigger_id = source.trigger_id
- target.next_method = source.next_method
- target.next_kwargs = source.next_kwargs
- if include_dag_run:
- target.execution_date = source.execution_date
- target.dag_run.id = source.dag_run.id
- target.dag_run.dag_id = source.dag_run.dag_id
- target.dag_run.queued_at = source.dag_run.queued_at
- target.dag_run.execution_date = source.dag_run.execution_date
- target.dag_run.start_date = source.dag_run.start_date
- target.dag_run.end_date = source.dag_run.end_date
- target.dag_run.state = source.dag_run.state
- target.dag_run.run_id = source.dag_run.run_id
- target.dag_run.creating_job_id = source.dag_run.creating_job_id
- target.dag_run.external_trigger = source.dag_run.external_trigger
- target.dag_run.run_type = source.dag_run.run_type
- target.dag_run.conf = source.dag_run.conf
- target.dag_run.data_interval_start = source.dag_run.data_interval_start
- target.dag_run.data_interval_end = source.dag_run.data_interval_end
- target.dag_run.last_scheduling_decision = source.dag_run.last_scheduling_decision
- target.dag_run.dag_hash = source.dag_run.dag_hash
- target.dag_run.updated_at = source.dag_run.updated_at
- target.dag_run.log_template_id = source.dag_run.log_template_id
- def _refresh_from_db(
- *,
- task_instance: TaskInstance | TaskInstancePydantic,
- session: Session | None = None,
- lock_for_update: bool = False,
- ) -> None:
- """
- Refresh the task instance from the database based on the primary key.
- :param task_instance: the task instance
- :param session: SQLAlchemy ORM Session
- :param lock_for_update: if True, indicates that the database should
- lock the TaskInstance (issuing a FOR UPDATE clause) until the
- session is committed.
- :meta private:
- """
- if not InternalApiConfig.get_use_internal_api():
- if session and task_instance in session:
- session.refresh(task_instance, TaskInstance.__mapper__.column_attrs.keys())
- ti = TaskInstance.get_task_instance(
- dag_id=task_instance.dag_id,
- task_id=task_instance.task_id,
- run_id=task_instance.run_id,
- map_index=task_instance.map_index,
- lock_for_update=lock_for_update,
- session=session,
- )
- if ti:
- from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
- include_dag_run = isinstance(ti, TaskInstancePydantic)
- # in case of internal API, what we get is TaskInstancePydantic value, and we are supposed
- # to also update dag_run information as it might not be available. We cannot always do it in
- # case ti is TaskInstance, because it might be detached/ not loaded yet and dag_run might
- # not be available.
- _set_ti_attrs(task_instance, ti, include_dag_run=include_dag_run)
- else:
- task_instance.state = None
- def _set_duration(*, task_instance: TaskInstance | TaskInstancePydantic) -> None:
- """
- Set task instance duration.
- :param task_instance: the task instance
- :meta private:
- """
- if task_instance.end_date and task_instance.start_date:
- task_instance.duration = (task_instance.end_date - task_instance.start_date).total_seconds()
- else:
- task_instance.duration = None
- log.debug("Task Duration set to %s", task_instance.duration)
- def _stats_tags(*, task_instance: TaskInstance | TaskInstancePydantic) -> dict[str, str]:
- """
- Return task instance tags.
- :param task_instance: the task instance
- :meta private:
- """
- return prune_dict({"dag_id": task_instance.dag_id, "task_id": task_instance.task_id})
- def _clear_next_method_args(*, task_instance: TaskInstance | TaskInstancePydantic) -> None:
- """
- Ensure we unset next_method and next_kwargs to ensure that any retries don't reuse them.
- :param task_instance: the task instance
- :meta private:
- """
- log.debug("Clearing next_method and next_kwargs.")
- task_instance.next_method = None
- task_instance.next_kwargs = None
- @internal_api_call
- def _get_template_context(
- *,
- task_instance: TaskInstance | TaskInstancePydantic,
- dag: DAG,
- session: Session | None = None,
- ignore_param_exceptions: bool = True,
- ) -> Context:
- """
- Return TI Context.
- :param task_instance: the task instance for the task
- :param dag: dag for the task
- :param session: SQLAlchemy ORM Session
- :param ignore_param_exceptions: flag to suppress value exceptions while initializing the ParamsDict
- :meta private:
- """
- # Do not use provide_session here -- it expunges everything on exit!
- if not session:
- session = settings.Session()
- from airflow import macros
- from airflow.models.abstractoperator import NotMapped
- integrate_macros_plugins()
- task = task_instance.task
- if TYPE_CHECKING:
- assert task_instance.task
- assert task
- assert task.dag
- if task.dag.__class__ is AttributeRemoved:
- task.dag = dag # required after deserialization
- dag_run = task_instance.get_dagrun(session)
- data_interval = dag.get_run_data_interval(dag_run)
- validated_params = process_params(dag, task, dag_run, suppress_exception=ignore_param_exceptions)
- logical_date: DateTime = timezone.coerce_datetime(task_instance.execution_date)
- ds = logical_date.strftime("%Y-%m-%d")
- ds_nodash = ds.replace("-", "")
- ts = logical_date.isoformat()
- ts_nodash = logical_date.strftime("%Y%m%dT%H%M%S")
- ts_nodash_with_tz = ts.replace("-", "").replace(":", "")
- @cache # Prevent multiple database access.
- def _get_previous_dagrun_success() -> DagRun | None:
- return task_instance.get_previous_dagrun(state=DagRunState.SUCCESS, session=session)
- def _get_previous_dagrun_data_interval_success() -> DataInterval | None:
- dagrun = _get_previous_dagrun_success()
- if dagrun is None:
- return None
- return dag.get_run_data_interval(dagrun)
- def get_prev_data_interval_start_success() -> pendulum.DateTime | None:
- data_interval = _get_previous_dagrun_data_interval_success()
- if data_interval is None:
- return None
- return data_interval.start
- def get_prev_data_interval_end_success() -> pendulum.DateTime | None:
- data_interval = _get_previous_dagrun_data_interval_success()
- if data_interval is None:
- return None
- return data_interval.end
- def get_prev_start_date_success() -> pendulum.DateTime | None:
- dagrun = _get_previous_dagrun_success()
- if dagrun is None:
- return None
- return timezone.coerce_datetime(dagrun.start_date)
- def get_prev_end_date_success() -> pendulum.DateTime | None:
- dagrun = _get_previous_dagrun_success()
- if dagrun is None:
- return None
- return timezone.coerce_datetime(dagrun.end_date)
- @cache
- def get_yesterday_ds() -> str:
- return (logical_date - timedelta(1)).strftime("%Y-%m-%d")
- def get_yesterday_ds_nodash() -> str:
- return get_yesterday_ds().replace("-", "")
- @cache
- def get_tomorrow_ds() -> str:
- return (logical_date + timedelta(1)).strftime("%Y-%m-%d")
- def get_tomorrow_ds_nodash() -> str:
- return get_tomorrow_ds().replace("-", "")
- @cache
- def get_next_execution_date() -> pendulum.DateTime | None:
- # For manually triggered dagruns that aren't run on a schedule,
- # the "next" execution date doesn't make sense, and should be set
- # to execution date for consistency with how execution_date is set
- # for manually triggered tasks, i.e. triggered_date == execution_date.
- if dag_run.external_trigger:
- return logical_date
- if dag is None:
- return None
- next_info = dag.next_dagrun_info(data_interval, restricted=False)
- if next_info is None:
- return None
- return timezone.coerce_datetime(next_info.logical_date)
- def get_next_ds() -> str | None:
- execution_date = get_next_execution_date()
- if execution_date is None:
- return None
- return execution_date.strftime("%Y-%m-%d")
- def get_next_ds_nodash() -> str | None:
- ds = get_next_ds()
- if ds is None:
- return ds
- return ds.replace("-", "")
- @cache
- def get_prev_execution_date():
- # For manually triggered dagruns that aren't run on a schedule,
- # the "previous" execution date doesn't make sense, and should be set
- # to execution date for consistency with how execution_date is set
- # for manually triggered tasks, i.e. triggered_date == execution_date.
- if dag_run.external_trigger:
- return logical_date
- with warnings.catch_warnings():
- warnings.simplefilter("ignore", RemovedInAirflow3Warning)
- return dag.previous_schedule(logical_date)
- @cache
- def get_prev_ds() -> str | None:
- execution_date = get_prev_execution_date()
- if execution_date is None:
- return None
- return execution_date.strftime("%Y-%m-%d")
- def get_prev_ds_nodash() -> str | None:
- prev_ds = get_prev_ds()
- if prev_ds is None:
- return None
- return prev_ds.replace("-", "")
- def get_triggering_events() -> dict[str, list[DatasetEvent | DatasetEventPydantic]]:
- if TYPE_CHECKING:
- assert session is not None
- # The dag_run may not be attached to the session anymore since the
- # code base is over-zealous with use of session.expunge_all().
- # Re-attach it if we get called.
- nonlocal dag_run
- if dag_run not in session:
- dag_run = session.merge(dag_run, load=False)
- dataset_events = dag_run.consumed_dataset_events
- triggering_events: dict[str, list[DatasetEvent | DatasetEventPydantic]] = defaultdict(list)
- for event in dataset_events:
- if event.dataset:
- triggering_events[event.dataset.uri].append(event)
- return triggering_events
- try:
- expanded_ti_count: int | None = task.get_mapped_ti_count(task_instance.run_id, session=session)
- except NotMapped:
- expanded_ti_count = None
- # NOTE: If you add to this dict, make sure to also update the following:
- # * Context in airflow/utils/context.pyi
- # * KNOWN_CONTEXT_KEYS in airflow/utils/context.py
- # * Table in docs/apache-airflow/templates-ref.rst
- context: dict[str, Any] = {
- "conf": conf,
- "dag": dag,
- "dag_run": dag_run,
- "data_interval_end": timezone.coerce_datetime(data_interval.end),
- "data_interval_start": timezone.coerce_datetime(data_interval.start),
- "outlet_events": OutletEventAccessors(),
- "ds": ds,
- "ds_nodash": ds_nodash,
- "execution_date": logical_date,
- "expanded_ti_count": expanded_ti_count,
- "inlets": task.inlets,
- "inlet_events": InletEventsAccessors(task.inlets, session=session),
- "logical_date": logical_date,
- "macros": macros,
- "map_index_template": task.map_index_template,
- "next_ds": get_next_ds(),
- "next_ds_nodash": get_next_ds_nodash(),
- "next_execution_date": get_next_execution_date(),
- "outlets": task.outlets,
- "params": validated_params,
- "prev_data_interval_start_success": get_prev_data_interval_start_success(),
- "prev_data_interval_end_success": get_prev_data_interval_end_success(),
- "prev_ds": get_prev_ds(),
- "prev_ds_nodash": get_prev_ds_nodash(),
- "prev_execution_date": get_prev_execution_date(),
- "prev_execution_date_success": task_instance.get_previous_execution_date(
- state=DagRunState.SUCCESS,
- session=session,
- ),
- "prev_start_date_success": get_prev_start_date_success(),
- "prev_end_date_success": get_prev_end_date_success(),
- "run_id": task_instance.run_id,
- "task": task,
- "task_instance": task_instance,
- "task_instance_key_str": f"{task.dag_id}__{task.task_id}__{ds_nodash}",
- "test_mode": task_instance.test_mode,
- "ti": task_instance,
- "tomorrow_ds": get_tomorrow_ds(),
- "tomorrow_ds_nodash": get_tomorrow_ds_nodash(),
- "triggering_dataset_events": lazy_object_proxy.Proxy(get_triggering_events),
- "ts": ts,
- "ts_nodash": ts_nodash,
- "ts_nodash_with_tz": ts_nodash_with_tz,
- "var": {
- "json": VariableAccessor(deserialize_json=True),
- "value": VariableAccessor(deserialize_json=False),
- },
- "conn": ConnectionAccessor(),
- "yesterday_ds": get_yesterday_ds(),
- "yesterday_ds_nodash": get_yesterday_ds_nodash(),
- }
- # Mypy doesn't like turning existing dicts in to a TypeDict -- and we "lie" in the type stub to say it
- # is one, but in practice it isn't. See https://github.com/python/mypy/issues/8890
- return Context(context) # type: ignore
- def _is_eligible_to_retry(*, task_instance: TaskInstance | TaskInstancePydantic):
- """
- Is task instance is eligible for retry.
- :param task_instance: the task instance
- :meta private:
- """
- if task_instance.state == TaskInstanceState.RESTARTING:
- # If a task is cleared when running, it goes into RESTARTING state and is always
- # eligible for retry
- return True
- if not getattr(task_instance, "task", None):
- # Couldn't load the task, don't know number of retries, guess:
- return task_instance.try_number <= task_instance.max_tries
- if TYPE_CHECKING:
- assert task_instance.task
- return task_instance.task.retries and task_instance.try_number <= task_instance.max_tries
- @provide_session
- @internal_api_call
- def _handle_failure(
- *,
- task_instance: TaskInstance | TaskInstancePydantic,
- error: None | str | BaseException,
- session: Session,
- test_mode: bool | None = None,
- context: Context | None = None,
- force_fail: bool = False,
- fail_stop: bool = False,
- ) -> None:
- """
- Handle Failure for a task instance.
- :param task_instance: the task instance
- :param error: if specified, log the specific exception if thrown
- :param session: SQLAlchemy ORM Session
- :param test_mode: doesn't record success or failure in the DB if True
- :param context: Jinja2 context
- :param force_fail: if True, task does not retry
- :meta private:
- """
- if test_mode is None:
- test_mode = task_instance.test_mode
- task_instance = _coalesce_to_orm_ti(ti=task_instance, session=session)
- failure_context = TaskInstance.fetch_handle_failure_context(
- ti=task_instance, # type: ignore[arg-type]
- error=error,
- test_mode=test_mode,
- context=context,
- force_fail=force_fail,
- session=session,
- fail_stop=fail_stop,
- )
- _log_state(task_instance=task_instance, lead_msg="Immediate failure requested. " if force_fail else "")
- if (
- failure_context["task"]
- and failure_context["email_for_state"](failure_context["task"])
- and failure_context["task"].email
- ):
- try:
- task_instance.email_alert(error, failure_context["task"])
- except Exception:
- log.exception("Failed to send email to: %s", failure_context["task"].email)
- if failure_context["callbacks"] and failure_context["context"]:
- _run_finished_callback(
- callbacks=failure_context["callbacks"],
- context=failure_context["context"],
- )
- if not test_mode:
- TaskInstance.save_to_db(task_instance, session)
- with Trace.start_span_from_taskinstance(ti=task_instance) as span:
- # ---- error info ----
- span.set_attribute("error", "true")
- span.set_attribute("error_msg", str(error))
- span.set_attribute("context", context)
- span.set_attribute("force_fail", force_fail)
- # ---- common info ----
- span.set_attribute("category", "DAG runs")
- span.set_attribute("task_id", task_instance.task_id)
- span.set_attribute("dag_id", task_instance.dag_id)
- span.set_attribute("state", task_instance.state)
- span.set_attribute("start_date", str(task_instance.start_date))
- span.set_attribute("end_date", str(task_instance.end_date))
- span.set_attribute("duration", task_instance.duration)
- span.set_attribute("executor_config", str(task_instance.executor_config))
- span.set_attribute("execution_date", str(task_instance.execution_date))
- span.set_attribute("hostname", task_instance.hostname)
- if isinstance(task_instance, TaskInstance):
- span.set_attribute("log_url", task_instance.log_url)
- span.set_attribute("operator", str(task_instance.operator))
- def _refresh_from_task(
- *, task_instance: TaskInstance | TaskInstancePydantic, task: Operator, pool_override: str | None = None
- ) -> None:
- """
- Copy common attributes from the given task.
- :param task_instance: the task instance
- :param task: The task object to copy from
- :param pool_override: Use the pool_override instead of task's pool
- :meta private:
- """
- task_instance.task = task
- task_instance.queue = task.queue
- task_instance.pool = pool_override or task.pool
- task_instance.pool_slots = task.pool_slots
- with contextlib.suppress(Exception):
- # This method is called from the different places, and sometimes the TI is not fully initialized
- task_instance.priority_weight = task_instance.task.weight_rule.get_weight(
- task_instance # type: ignore[arg-type]
- )
- task_instance.run_as_user = task.run_as_user
- # Do not set max_tries to task.retries here because max_tries is a cumulative
- # value that needs to be stored in the db.
- task_instance.executor = task.executor
- task_instance.executor_config = task.executor_config
- task_instance.operator = task.task_type
- task_instance.custom_operator_name = getattr(task, "custom_operator_name", None)
- # Re-apply cluster policy here so that task default do not overload previous data
- task_instance_mutation_hook(task_instance)
- @internal_api_call
- @provide_session
- def _record_task_map_for_downstreams(
- *,
- task_instance: TaskInstance | TaskInstancePydantic,
- task: Operator,
- dag: DAG,
- value: Any,
- session: Session,
- ) -> None:
- """
- Record the task map for downstream tasks.
- :param task_instance: the task instance
- :param task: The task object
- :param dag: the dag associated with the task
- :param value: The value
- :param session: SQLAlchemy ORM Session
- :meta private:
- """
- from airflow.models.mappedoperator import MappedOperator
- if task.dag.__class__ is AttributeRemoved:
- task.dag = dag # required after deserialization
- if next(task.iter_mapped_dependants(), None) is None: # No mapped dependants, no need to validate.
- return
- # TODO: We don't push TaskMap for mapped task instances because it's not
- # currently possible for a downstream to depend on one individual mapped
- # task instance. This will change when we implement task mapping inside
- # a mapped task group, and we'll need to further analyze the case.
- if isinstance(task, MappedOperator):
- return
- if value is None:
- raise XComForMappingNotPushed()
- if not _is_mappable_value(value):
- raise UnmappableXComTypePushed(value)
- task_map = TaskMap.from_task_instance_xcom(task_instance, value)
- max_map_length = conf.getint("core", "max_map_length", fallback=1024)
- if task_map.length > max_map_length:
- raise UnmappableXComLengthPushed(value, max_map_length)
- session.merge(task_map)
- def _get_previous_dagrun(
- *,
- task_instance: TaskInstance | TaskInstancePydantic,
- state: DagRunState | None = None,
- session: Session | None = None,
- ) -> DagRun | None:
- """
- Return the DagRun that ran prior to this task instance's DagRun.
- :param task_instance: the task instance
- :param state: If passed, it only take into account instances of a specific state.
- :param session: SQLAlchemy ORM Session.
- :meta private:
- """
- if TYPE_CHECKING:
- assert task_instance.task
- dag = task_instance.task.dag
- if dag is None:
- return None
- dr = task_instance.get_dagrun(session=session)
- dr.dag = dag
- from airflow.models.dagrun import DagRun # Avoid circular import
- # We always ignore schedule in dagrun lookup when `state` is given
- # or the DAG is never scheduled. For legacy reasons, when
- # `catchup=True`, we use `get_previous_scheduled_dagrun` unless
- # `ignore_schedule` is `True`.
- ignore_schedule = state is not None or not dag.timetable.can_be_scheduled
- if dag.catchup is True and not ignore_schedule:
- last_dagrun = DagRun.get_previous_scheduled_dagrun(dr.id, session=session)
- else:
- last_dagrun = DagRun.get_previous_dagrun(dag_run=dr, session=session, state=state)
- if last_dagrun:
- return last_dagrun
- return None
- def _get_previous_execution_date(
- *,
- task_instance: TaskInstance | TaskInstancePydantic,
- state: DagRunState | None,
- session: Session,
- ) -> pendulum.DateTime | None:
- """
- Get execution date from property previous_ti_success.
- :param task_instance: the task instance
- :param session: SQLAlchemy ORM Session
- :param state: If passed, it only take into account instances of a specific state.
- :meta private:
- """
- log.debug("previous_execution_date was called")
- prev_ti = task_instance.get_previous_ti(state=state, session=session)
- return pendulum.instance(prev_ti.execution_date) if prev_ti and prev_ti.execution_date else None
- def _get_previous_start_date(
- *,
- task_instance: TaskInstance | TaskInstancePydantic,
- state: DagRunState | None,
- session: Session,
- ) -> pendulum.DateTime | None:
- """
- Return the start date from property previous_ti_success.
- :param task_instance: the task instance
- :param state: If passed, it only take into account instances of a specific state.
- :param session: SQLAlchemy ORM Session
- """
- log.debug("previous_start_date was called")
- prev_ti = task_instance.get_previous_ti(state=state, session=session)
- # prev_ti may not exist and prev_ti.start_date may be None.
- return pendulum.instance(prev_ti.start_date) if prev_ti and prev_ti.start_date else None
- def _email_alert(
- *, task_instance: TaskInstance | TaskInstancePydantic, exception, task: BaseOperator
- ) -> None:
- """
- Send alert email with exception information.
- :param task_instance: the task instance
- :param exception: the exception
- :param task: task related to the exception
- :meta private:
- """
- subject, html_content, html_content_err = task_instance.get_email_subject_content(exception, task=task)
- if TYPE_CHECKING:
- assert task.email
- try:
- send_email(task.email, subject, html_content)
- except Exception:
- send_email(task.email, subject, html_content_err)
- def _get_email_subject_content(
- *,
- task_instance: TaskInstance | TaskInstancePydantic,
- exception: BaseException,
- task: BaseOperator | None = None,
- ) -> tuple[str, str, str]:
- """
- Get the email subject content for exceptions.
- :param task_instance: the task instance
- :param exception: the exception sent in the email
- :param task:
- :meta private:
- """
- # For a ti from DB (without ti.task), return the default value
- if task is None:
- task = getattr(task_instance, "task")
- use_default = task is None
- exception_html = str(exception).replace("\n", "<br>")
- default_subject = "Airflow alert: {{ti}}"
- # For reporting purposes, we report based on 1-indexed,
- # not 0-indexed lists (i.e. Try 1 instead of
- # Try 0 for the first attempt).
- default_html_content = (
- "Try {{try_number}} out of {{max_tries + 1}}<br>"
- "Exception:<br>{{exception_html}}<br>"
- 'Log: <a href="{{ti.log_url}}">Link</a><br>'
- "Host: {{ti.hostname}}<br>"
- 'Mark success: <a href="{{ti.mark_success_url}}">Link</a><br>'
- )
- default_html_content_err = (
- "Try {{try_number}} out of {{max_tries + 1}}<br>"
- "Exception:<br>Failed attempt to attach error logs<br>"
- 'Log: <a href="{{ti.log_url}}">Link</a><br>'
- "Host: {{ti.hostname}}<br>"
- 'Mark success: <a href="{{ti.mark_success_url}}">Link</a><br>'
- )
- additional_context: dict[str, Any] = {
- "exception": exception,
- "exception_html": exception_html,
- "try_number": task_instance.try_number,
- "max_tries": task_instance.max_tries,
- }
- if use_default:
- default_context = {"ti": task_instance, **additional_context}
- jinja_env = jinja2.Environment(
- loader=jinja2.FileSystemLoader(os.path.dirname(__file__)), autoescape=True
- )
- subject = jinja_env.from_string(default_subject).render(**default_context)
- html_content = jinja_env.from_string(default_html_content).render(**default_context)
- html_content_err = jinja_env.from_string(default_html_content_err).render(**default_context)
- else:
- if TYPE_CHECKING:
- assert task_instance.task
- # Use the DAG's get_template_env() to set force_sandboxed. Don't add
- # the flag to the function on task object -- that function can be
- # overridden, and adding a flag breaks backward compatibility.
- dag = task_instance.task.get_dag()
- if dag:
- jinja_env = dag.get_template_env(force_sandboxed=True)
- else:
- jinja_env = SandboxedEnvironment(cache_size=0)
- jinja_context = task_instance.get_template_context()
- context_merge(jinja_context, additional_context)
- def render(key: str, content: str) -> str:
- if conf.has_option("email", key):
- path = conf.get_mandatory_value("email", key)
- try:
- with open(path) as f:
- content = f.read()
- except FileNotFoundError:
- log.warning("Could not find email template file '%s'. Using defaults...", path)
- except OSError:
- log.exception("Error while using email template %s. Using defaults...", path)
- return render_template_to_string(jinja_env.from_string(content), jinja_context)
- subject = render("subject_template", default_subject)
- html_content = render("html_content_template", default_html_content)
- html_content_err = render("html_content_template", default_html_content_err)
- return subject, html_content, html_content_err
- def _run_finished_callback(
- *,
- callbacks: None | TaskStateChangeCallback | list[TaskStateChangeCallback],
- context: Context,
- ) -> None:
- """
- Run callback after task finishes.
- :param callbacks: callbacks to run
- :param context: callbacks context
- :meta private:
- """
- if callbacks:
- callbacks = callbacks if isinstance(callbacks, list) else [callbacks]
- def get_callback_representation(callback: TaskStateChangeCallback) -> Any:
- with contextlib.suppress(AttributeError):
- return callback.__name__
- with contextlib.suppress(AttributeError):
- return callback.__class__.__name__
- return callback
- for idx, callback in enumerate(callbacks):
- callback_repr = get_callback_representation(callback)
- log.info("Executing callback at index %d: %s", idx, callback_repr)
- try:
- callback(context)
- except Exception:
- log.exception("Error in callback at index %d: %s", idx, callback_repr)
- def _log_state(*, task_instance: TaskInstance | TaskInstancePydantic, lead_msg: str = "") -> None:
- """
- Log task state.
- :param task_instance: the task instance
- :param lead_msg: lead message
- :meta private:
- """
- params = [
- lead_msg,
- str(task_instance.state).upper(),
- task_instance.dag_id,
- task_instance.task_id,
- task_instance.run_id,
- ]
- message = "%sMarking task as %s. dag_id=%s, task_id=%s, run_id=%s, "
- if task_instance.map_index >= 0:
- params.append(task_instance.map_index)
- message += "map_index=%d, "
- message += "execution_date=%s, start_date=%s, end_date=%s"
- log.info(
- message,
- *params,
- _date_or_empty(task_instance=task_instance, attr="execution_date"),
- _date_or_empty(task_instance=task_instance, attr="start_date"),
- _date_or_empty(task_instance=task_instance, attr="end_date"),
- stacklevel=2,
- )
- def _date_or_empty(*, task_instance: TaskInstance | TaskInstancePydantic, attr: str) -> str:
- """
- Fetch a date attribute or None of it does not exist.
- :param task_instance: the task instance
- :param attr: the attribute name
- :meta private:
- """
- result: datetime | None = getattr(task_instance, attr, None)
- return result.strftime("%Y%m%dT%H%M%S") if result else ""
- def _get_previous_ti(
- *,
- task_instance: TaskInstance | TaskInstancePydantic,
- session: Session,
- state: DagRunState | None = None,
- ) -> TaskInstance | TaskInstancePydantic | None:
- """
- Get task instance for the task that ran before this task instance.
- :param task_instance: the task instance
- :param state: If passed, it only take into account instances of a specific state.
- :param session: SQLAlchemy ORM Session
- :meta private:
- """
- dagrun = task_instance.get_previous_dagrun(state, session=session)
- if dagrun is None:
- return None
- return dagrun.get_task_instance(task_instance.task_id, session=session)
- @internal_api_call
- @provide_session
- def _update_rtif(ti, rendered_fields, session: Session = NEW_SESSION):
- from airflow.models.renderedtifields import RenderedTaskInstanceFields
- rtif = RenderedTaskInstanceFields(ti=ti, render_templates=False, rendered_fields=rendered_fields)
- RenderedTaskInstanceFields.write(rtif, session=session)
- session.flush()
- RenderedTaskInstanceFields.delete_old_records(ti.task_id, ti.dag_id, session=session)
- def _coalesce_to_orm_ti(*, ti: TaskInstancePydantic | TaskInstance, session: Session):
- from airflow.models.dagrun import DagRun
- from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
- if isinstance(ti, TaskInstancePydantic):
- orm_ti = DagRun.fetch_task_instance(
- dag_id=ti.dag_id,
- dag_run_id=ti.run_id,
- task_id=ti.task_id,
- map_index=ti.map_index,
- session=session,
- )
- if TYPE_CHECKING:
- assert orm_ti
- ti, pydantic_ti = orm_ti, ti
- _set_ti_attrs(ti, pydantic_ti)
- ti.task = pydantic_ti.task
- return ti
- @internal_api_call
- @provide_session
- def _defer_task(
- ti: TaskInstance | TaskInstancePydantic,
- exception: TaskDeferred | None = None,
- session: Session = NEW_SESSION,
- ) -> TaskInstancePydantic | TaskInstance:
- from airflow.models.trigger import Trigger
- if exception is not None:
- trigger_row = Trigger.from_object(exception.trigger)
- next_method = exception.method_name
- next_kwargs = exception.kwargs
- timeout = exception.timeout
- elif ti.task is not None and ti.task.start_trigger_args is not None:
- context = ti.get_template_context()
- start_trigger_args = ti.task.expand_start_trigger_args(context=context, session=session)
- if start_trigger_args is None:
- raise TaskDeferralError(
- "A none 'None' start_trigger_args has been change to 'None' during expandion"
- )
- trigger_kwargs = start_trigger_args.trigger_kwargs or {}
- next_kwargs = start_trigger_args.next_kwargs
- next_method = start_trigger_args.next_method
- timeout = start_trigger_args.timeout
- trigger_row = Trigger(
- classpath=ti.task.start_trigger_args.trigger_cls,
- kwargs=trigger_kwargs,
- )
- else:
- raise TaskDeferralError("exception and ti.task.start_trigger_args cannot both be None")
- # First, make the trigger entry
- session.add(trigger_row)
- session.flush()
- ti = _coalesce_to_orm_ti(ti=ti, session=session) # ensure orm obj in case it's pydantic
- if TYPE_CHECKING:
- assert ti.task
- # Then, update ourselves so it matches the deferral request
- # Keep an eye on the logic in `check_and_change_state_before_execution()`
- # depending on self.next_method semantics
- ti.state = TaskInstanceState.DEFERRED
- ti.trigger_id = trigger_row.id
- ti.next_method = next_method
- ti.next_kwargs = next_kwargs or {}
- # Calculate timeout too if it was passed
- if timeout is not None:
- ti.trigger_timeout = timezone.utcnow() + timeout
- else:
- ti.trigger_timeout = None
- # If an execution_timeout is set, set the timeout to the minimum of
- # it and the trigger timeout
- execution_timeout = ti.task.execution_timeout
- if execution_timeout:
- if TYPE_CHECKING:
- assert ti.start_date
- if ti.trigger_timeout:
- ti.trigger_timeout = min(ti.start_date + execution_timeout, ti.trigger_timeout)
- else:
- ti.trigger_timeout = ti.start_date + execution_timeout
- if ti.test_mode:
- _add_log(event=ti.state, task_instance=ti, session=session)
- if exception is not None:
- session.merge(ti)
- session.commit()
- return ti
- @internal_api_call
- @provide_session
- def _handle_reschedule(
- ti,
- actual_start_date: datetime,
- reschedule_exception: AirflowRescheduleException,
- test_mode: bool = False,
- session: Session = NEW_SESSION,
- ):
- # Don't record reschedule request in test mode
- if test_mode:
- return
- ti = _coalesce_to_orm_ti(ti=ti, session=session)
- from airflow.models.dagrun import DagRun # Avoid circular import
- ti.refresh_from_db(session)
- if TYPE_CHECKING:
- assert ti.task
- ti.end_date = timezone.utcnow()
- ti.set_duration()
- # Lock DAG run to be sure not to get into a deadlock situation when trying to insert
- # TaskReschedule which apparently also creates lock on corresponding DagRun entity
- with_row_locks(
- session.query(DagRun).filter_by(
- dag_id=ti.dag_id,
- run_id=ti.run_id,
- ),
- session=session,
- ).one()
- # Log reschedule request
- session.add(
- TaskReschedule(
- ti.task_id,
- ti.dag_id,
- ti.run_id,
- ti.try_number,
- actual_start_date,
- ti.end_date,
- reschedule_exception.reschedule_date,
- ti.map_index,
- )
- )
- # set state
- ti.state = TaskInstanceState.UP_FOR_RESCHEDULE
- ti.clear_next_method_args()
- session.merge(ti)
- session.commit()
- return ti
- class TaskInstance(Base, LoggingMixin):
- """
- Task instances store the state of a task instance.
- This table is the authority and single source of truth around what tasks
- have run and the state they are in.
- The SqlAlchemy model doesn't have a SqlAlchemy foreign key to the task or
- dag model deliberately to have more control over transactions.
- Database transactions on this table should insure double triggers and
- any confusion around what task instances are or aren't ready to run
- even while multiple schedulers may be firing task instances.
- A value of -1 in map_index represents any of: a TI without mapped tasks;
- a TI with mapped tasks that has yet to be expanded (state=pending);
- a TI with mapped tasks that expanded to an empty list (state=skipped).
- """
- __tablename__ = "task_instance"
- task_id = Column(StringID(), primary_key=True, nullable=False)
- dag_id = Column(StringID(), primary_key=True, nullable=False)
- run_id = Column(StringID(), primary_key=True, nullable=False)
- map_index = Column(Integer, primary_key=True, nullable=False, server_default=text("-1"))
- start_date = Column(UtcDateTime)
- end_date = Column(UtcDateTime)
- duration = Column(Float)
- state = Column(String(20))
- try_number = Column(Integer, default=0)
- max_tries = Column(Integer, server_default=text("-1"))
- hostname = Column(String(1000))
- unixname = Column(String(1000))
- job_id = Column(Integer)
- pool = Column(String(256), nullable=False)
- pool_slots = Column(Integer, default=1, nullable=False)
- queue = Column(String(256))
- priority_weight = Column(Integer)
- operator = Column(String(1000))
- custom_operator_name = Column(String(1000))
- queued_dttm = Column(UtcDateTime)
- queued_by_job_id = Column(Integer)
- pid = Column(Integer)
- executor = Column(String(1000))
- executor_config = Column(ExecutorConfigType(pickler=dill))
- updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow)
- rendered_map_index = Column(String(250))
- external_executor_id = Column(StringID())
- # The trigger to resume on if we are in state DEFERRED
- trigger_id = Column(Integer)
- # Optional timeout datetime for the trigger (past this, we'll fail)
- trigger_timeout = Column(DateTime)
- # The trigger_timeout should be TIMESTAMP(using UtcDateTime) but for ease of
- # migration, we are keeping it as DateTime pending a change where expensive
- # migration is inevitable.
- # The method to call next, and any extra arguments to pass to it.
- # Usually used when resuming from DEFERRED.
- next_method = Column(String(1000))
- next_kwargs = Column(MutableDict.as_mutable(ExtendedJSON))
- _task_display_property_value = Column("task_display_name", String(2000), nullable=True)
- # If adding new fields here then remember to add them to
- # refresh_from_db() or they won't display in the UI correctly
- __table_args__ = (
- Index("ti_dag_state", dag_id, state),
- Index("ti_dag_run", dag_id, run_id),
- Index("ti_state", state),
- Index("ti_state_lkp", dag_id, task_id, run_id, state),
- Index("ti_pool", pool, state, priority_weight),
- Index("ti_job_id", job_id),
- Index("ti_trigger_id", trigger_id),
- PrimaryKeyConstraint("dag_id", "task_id", "run_id", "map_index", name="task_instance_pkey"),
- ForeignKeyConstraint(
- [trigger_id],
- ["trigger.id"],
- name="task_instance_trigger_id_fkey",
- ondelete="CASCADE",
- ),
- ForeignKeyConstraint(
- [dag_id, run_id],
- ["dag_run.dag_id", "dag_run.run_id"],
- name="task_instance_dag_run_fkey",
- ondelete="CASCADE",
- ),
- )
- dag_model: DagModel = relationship(
- "DagModel",
- primaryjoin="TaskInstance.dag_id == DagModel.dag_id",
- foreign_keys=dag_id,
- uselist=False,
- innerjoin=True,
- viewonly=True,
- )
- trigger = relationship("Trigger", uselist=False, back_populates="task_instance")
- triggerer_job = association_proxy("trigger", "triggerer_job")
- dag_run = relationship("DagRun", back_populates="task_instances", lazy="joined", innerjoin=True)
- rendered_task_instance_fields = relationship("RenderedTaskInstanceFields", lazy="noload", uselist=False)
- execution_date = association_proxy("dag_run", "execution_date")
- task_instance_note = relationship(
- "TaskInstanceNote",
- back_populates="task_instance",
- uselist=False,
- cascade="all, delete, delete-orphan",
- )
- note = association_proxy("task_instance_note", "content", creator=_creator_note)
- task: Operator | None = None
- test_mode: bool = False
- is_trigger_log_context: bool = False
- run_as_user: str | None = None
- raw: bool | None = None
- """Indicate to FileTaskHandler that logging context should be set up for trigger logging.
- :meta private:
- """
- _logger_name = "airflow.task"
- def __init__(
- self,
- task: Operator,
- execution_date: datetime | None = None,
- run_id: str | None = None,
- state: str | None = None,
- map_index: int = -1,
- ):
- super().__init__()
- self.dag_id = task.dag_id
- self.task_id = task.task_id
- self.map_index = map_index
- self.refresh_from_task(task)
- if TYPE_CHECKING:
- assert self.task
- # init_on_load will config the log
- self.init_on_load()
- if run_id is None and execution_date is not None:
- from airflow.models.dagrun import DagRun # Avoid circular import
- warnings.warn(
- "Passing an execution_date to `TaskInstance()` is deprecated in favour of passing a run_id",
- RemovedInAirflow3Warning,
- # Stack level is 4 because SQLA adds some wrappers around the constructor
- stacklevel=4,
- )
- # make sure we have a localized execution_date stored in UTC
- if execution_date and not timezone.is_localized(execution_date):
- self.log.warning(
- "execution date %s has no timezone information. Using default from dag or system",
- execution_date,
- )
- if self.task.has_dag():
- if TYPE_CHECKING:
- assert self.task.dag
- execution_date = timezone.make_aware(execution_date, self.task.dag.timezone)
- else:
- execution_date = timezone.make_aware(execution_date)
- execution_date = timezone.convert_to_utc(execution_date)
- with create_session() as session:
- run_id = (
- session.query(DagRun.run_id)
- .filter_by(dag_id=self.dag_id, execution_date=execution_date)
- .scalar()
- )
- if run_id is None:
- raise DagRunNotFound(
- f"DagRun for {self.dag_id!r} with date {execution_date} not found"
- ) from None
- self.run_id = run_id
- self.try_number = 0
- self.max_tries = self.task.retries
- self.unixname = getuser()
- if state:
- self.state = state
- self.hostname = ""
- # Is this TaskInstance being currently running within `airflow tasks run --raw`.
- # Not persisted to the database so only valid for the current process
- self.raw = False
- # can be changed when calling 'run'
- self.test_mode = False
- def __hash__(self):
- return hash((self.task_id, self.dag_id, self.run_id, self.map_index))
- @property
- @deprecated(reason="Use try_number instead.", version="2.10.0", category=RemovedInAirflow3Warning)
- def _try_number(self):
- """
- Do not use. For semblance of backcompat.
- :meta private:
- """
- return self.try_number
- @_try_number.setter
- @deprecated(reason="Use try_number instead.", version="2.10.0", category=RemovedInAirflow3Warning)
- def _try_number(self, val):
- """
- Do not use. For semblance of backcompat.
- :meta private:
- """
- self.try_number = val
- @property
- def stats_tags(self) -> dict[str, str]:
- """Returns task instance tags."""
- return _stats_tags(task_instance=self)
- @staticmethod
- def insert_mapping(run_id: str, task: Operator, map_index: int) -> dict[str, Any]:
- """
- Insert mapping.
- :meta private:
- """
- priority_weight = task.weight_rule.get_weight(
- TaskInstance(task=task, run_id=run_id, map_index=map_index)
- )
- return {
- "dag_id": task.dag_id,
- "task_id": task.task_id,
- "run_id": run_id,
- "try_number": 0,
- "hostname": "",
- "unixname": getuser(),
- "queue": task.queue,
- "pool": task.pool,
- "pool_slots": task.pool_slots,
- "priority_weight": priority_weight,
- "run_as_user": task.run_as_user,
- "max_tries": task.retries,
- "executor": task.executor,
- "executor_config": task.executor_config,
- "operator": task.task_type,
- "custom_operator_name": getattr(task, "custom_operator_name", None),
- "map_index": map_index,
- "_task_display_property_value": task.task_display_name,
- }
- @reconstructor
- def init_on_load(self) -> None:
- """Initialize the attributes that aren't stored in the DB."""
- self.test_mode = False # can be changed when calling 'run'
- @property
- @deprecated(reason="Use try_number instead.", version="2.10.0", category=RemovedInAirflow3Warning)
- def prev_attempted_tries(self) -> int:
- """
- Calculate the total number of attempted tries, defaulting to 0.
- This used to be necessary because try_number did not always tell the truth.
- :meta private:
- """
- return self.try_number
- @property
- def next_try_number(self) -> int:
- # todo (dstandish): deprecate this property; we don't need a property that is just + 1
- return self.try_number + 1
- @property
- def operator_name(self) -> str | None:
- """@property: use a more friendly display name for the operator, if set."""
- return self.custom_operator_name or self.operator
- @hybrid_property
- def task_display_name(self) -> str:
- return self._task_display_property_value or self.task_id
- @staticmethod
- def _command_as_list(
- ti: TaskInstance | TaskInstancePydantic,
- mark_success: bool = False,
- ignore_all_deps: bool = False,
- ignore_task_deps: bool = False,
- ignore_depends_on_past: bool = False,
- wait_for_past_depends_before_skipping: bool = False,
- ignore_ti_state: bool = False,
- local: bool = False,
- pickle_id: int | None = None,
- raw: bool = False,
- job_id: str | None = None,
- pool: str | None = None,
- cfg_path: str | None = None,
- ) -> list[str]:
- dag: DAG | DagModel | DagModelPydantic | None
- # Use the dag if we have it, else fallback to the ORM dag_model, which might not be loaded
- if hasattr(ti, "task") and getattr(ti.task, "dag", None) is not None:
- if TYPE_CHECKING:
- assert ti.task
- dag = ti.task.dag
- else:
- dag = ti.dag_model
- if dag is None:
- raise ValueError("DagModel is empty")
- should_pass_filepath = not pickle_id and dag
- path: PurePath | None = None
- if should_pass_filepath:
- if dag.is_subdag:
- if TYPE_CHECKING:
- assert dag.parent_dag is not None
- path = dag.parent_dag.relative_fileloc
- else:
- path = dag.relative_fileloc
- if path:
- if not path.is_absolute():
- path = "DAGS_FOLDER" / path
- return TaskInstance.generate_command(
- ti.dag_id,
- ti.task_id,
- run_id=ti.run_id,
- mark_success=mark_success,
- ignore_all_deps=ignore_all_deps,
- ignore_task_deps=ignore_task_deps,
- ignore_depends_on_past=ignore_depends_on_past,
- wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
- ignore_ti_state=ignore_ti_state,
- local=local,
- pickle_id=pickle_id,
- file_path=path,
- raw=raw,
- job_id=job_id,
- pool=pool,
- cfg_path=cfg_path,
- map_index=ti.map_index,
- )
- def command_as_list(
- self,
- mark_success: bool = False,
- ignore_all_deps: bool = False,
- ignore_task_deps: bool = False,
- ignore_depends_on_past: bool = False,
- wait_for_past_depends_before_skipping: bool = False,
- ignore_ti_state: bool = False,
- local: bool = False,
- pickle_id: int | None = None,
- raw: bool = False,
- job_id: str | None = None,
- pool: str | None = None,
- cfg_path: str | None = None,
- ) -> list[str]:
- """
- Return a command that can be executed anywhere where airflow is installed.
- This command is part of the message sent to executors by the orchestrator.
- """
- return TaskInstance._command_as_list(
- ti=self,
- mark_success=mark_success,
- ignore_all_deps=ignore_all_deps,
- ignore_task_deps=ignore_task_deps,
- ignore_depends_on_past=ignore_depends_on_past,
- wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
- ignore_ti_state=ignore_ti_state,
- local=local,
- pickle_id=pickle_id,
- raw=raw,
- job_id=job_id,
- pool=pool,
- cfg_path=cfg_path,
- )
- @staticmethod
- def generate_command(
- dag_id: str,
- task_id: str,
- run_id: str,
- mark_success: bool = False,
- ignore_all_deps: bool = False,
- ignore_depends_on_past: bool = False,
- wait_for_past_depends_before_skipping: bool = False,
- ignore_task_deps: bool = False,
- ignore_ti_state: bool = False,
- local: bool = False,
- pickle_id: int | None = None,
- file_path: PurePath | str | None = None,
- raw: bool = False,
- job_id: str | None = None,
- pool: str | None = None,
- cfg_path: str | None = None,
- map_index: int = -1,
- ) -> list[str]:
- """
- Generate the shell command required to execute this task instance.
- :param dag_id: DAG ID
- :param task_id: Task ID
- :param run_id: The run_id of this task's DagRun
- :param mark_success: Whether to mark the task as successful
- :param ignore_all_deps: Ignore all ignorable dependencies.
- Overrides the other ignore_* parameters.
- :param ignore_depends_on_past: Ignore depends_on_past parameter of DAGs
- (e.g. for Backfills)
- :param wait_for_past_depends_before_skipping: Wait for past depends before marking the ti as skipped
- :param ignore_task_deps: Ignore task-specific dependencies such as depends_on_past
- and trigger rule
- :param ignore_ti_state: Ignore the task instance's previous failure/success
- :param local: Whether to run the task locally
- :param pickle_id: If the DAG was serialized to the DB, the ID
- associated with the pickled DAG
- :param file_path: path to the file containing the DAG definition
- :param raw: raw mode (needs more details)
- :param job_id: job ID (needs more details)
- :param pool: the Airflow pool that the task should run in
- :param cfg_path: the Path to the configuration file
- :return: shell command that can be used to run the task instance
- """
- cmd = ["airflow", "tasks", "run", dag_id, task_id, run_id]
- if mark_success:
- cmd.extend(["--mark-success"])
- if pickle_id:
- cmd.extend(["--pickle", str(pickle_id)])
- if job_id:
- cmd.extend(["--job-id", str(job_id)])
- if ignore_all_deps:
- cmd.extend(["--ignore-all-dependencies"])
- if ignore_task_deps:
- cmd.extend(["--ignore-dependencies"])
- if ignore_depends_on_past:
- cmd.extend(["--depends-on-past", "ignore"])
- elif wait_for_past_depends_before_skipping:
- cmd.extend(["--depends-on-past", "wait"])
- if ignore_ti_state:
- cmd.extend(["--force"])
- if local:
- cmd.extend(["--local"])
- if pool:
- cmd.extend(["--pool", pool])
- if raw:
- cmd.extend(["--raw"])
- if file_path:
- cmd.extend(["--subdir", os.fspath(file_path)])
- if cfg_path:
- cmd.extend(["--cfg-path", cfg_path])
- if map_index != -1:
- cmd.extend(["--map-index", str(map_index)])
- return cmd
- @property
- def log_url(self) -> str:
- """Log URL for TaskInstance."""
- run_id = quote(self.run_id)
- base_date = quote(self.execution_date.strftime("%Y-%m-%dT%H:%M:%S%z"))
- base_url = conf.get_mandatory_value("webserver", "BASE_URL")
- map_index = f"&map_index={self.map_index}" if self.map_index >= 0 else ""
- return (
- f"{base_url}"
- f"/dags"
- f"/{self.dag_id}"
- f"/grid"
- f"?dag_run_id={run_id}"
- f"&task_id={self.task_id}"
- f"{map_index}"
- f"&base_date={base_date}"
- "&tab=logs"
- )
- @property
- def mark_success_url(self) -> str:
- """URL to mark TI success."""
- base_url = conf.get_mandatory_value("webserver", "BASE_URL")
- return (
- f"{base_url}"
- "/confirm"
- f"?task_id={self.task_id}"
- f"&dag_id={self.dag_id}"
- f"&dag_run_id={quote(self.run_id)}"
- "&upstream=false"
- "&downstream=false"
- "&state=success"
- )
- @provide_session
- def current_state(self, session: Session = NEW_SESSION) -> str:
- """
- Get the very latest state from the database.
- If a session is passed, we use and looking up the state becomes part of the session,
- otherwise a new session is used.
- sqlalchemy.inspect is used here to get the primary keys ensuring that if they change
- it will not regress
- :param session: SQLAlchemy ORM Session
- """
- filters = (col == getattr(self, col.name) for col in inspect(TaskInstance).primary_key)
- return session.query(TaskInstance.state).filter(*filters).scalar()
- @provide_session
- def error(self, session: Session = NEW_SESSION) -> None:
- """
- Force the task instance's state to FAILED in the database.
- :param session: SQLAlchemy ORM Session
- """
- self.log.error("Recording the task instance as FAILED")
- self.state = TaskInstanceState.FAILED
- session.merge(self)
- session.commit()
- @classmethod
- @internal_api_call
- @provide_session
- def get_task_instance(
- cls,
- dag_id: str,
- run_id: str,
- task_id: str,
- map_index: int,
- lock_for_update: bool = False,
- session: Session = NEW_SESSION,
- ) -> TaskInstance | TaskInstancePydantic | None:
- query = (
- session.query(TaskInstance)
- .options(lazyload(TaskInstance.dag_run)) # lazy load dag run to avoid locking it
- .filter_by(
- dag_id=dag_id,
- run_id=run_id,
- task_id=task_id,
- map_index=map_index,
- )
- )
- if lock_for_update:
- for attempt in run_with_db_retries(logger=cls.logger()):
- with attempt:
- return query.with_for_update().one_or_none()
- else:
- return query.one_or_none()
- return None
- @provide_session
- def refresh_from_db(self, session: Session = NEW_SESSION, lock_for_update: bool = False) -> None:
- """
- Refresh the task instance from the database based on the primary key.
- :param session: SQLAlchemy ORM Session
- :param lock_for_update: if True, indicates that the database should
- lock the TaskInstance (issuing a FOR UPDATE clause) until the
- session is committed.
- """
- _refresh_from_db(task_instance=self, session=session, lock_for_update=lock_for_update)
- def refresh_from_task(self, task: Operator, pool_override: str | None = None) -> None:
- """
- Copy common attributes from the given task.
- :param task: The task object to copy from
- :param pool_override: Use the pool_override instead of task's pool
- """
- _refresh_from_task(task_instance=self, task=task, pool_override=pool_override)
- @staticmethod
- @internal_api_call
- @provide_session
- def _clear_xcom_data(ti: TaskInstance | TaskInstancePydantic, session: Session = NEW_SESSION) -> None:
- """
- Clear all XCom data from the database for the task instance.
- If the task is unmapped, all XComs matching this task ID in the same DAG
- run are removed. If the task is mapped, only the one with matching map
- index is removed.
- :param ti: The TI for which we need to clear xcoms.
- :param session: SQLAlchemy ORM Session
- """
- ti.log.debug("Clearing XCom data")
- if ti.map_index < 0:
- map_index: int | None = None
- else:
- map_index = ti.map_index
- XCom.clear(
- dag_id=ti.dag_id,
- task_id=ti.task_id,
- run_id=ti.run_id,
- map_index=map_index,
- session=session,
- )
- @provide_session
- def clear_xcom_data(self, session: Session = NEW_SESSION):
- self._clear_xcom_data(ti=self, session=session)
- @property
- def key(self) -> TaskInstanceKey:
- """Returns a tuple that identifies the task instance uniquely."""
- return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, self.try_number, self.map_index)
- @staticmethod
- @internal_api_call
- def _set_state(ti: TaskInstance | TaskInstancePydantic, state, session: Session) -> bool:
- if not isinstance(ti, TaskInstance):
- ti = session.scalars(
- select(TaskInstance).where(
- TaskInstance.task_id == ti.task_id,
- TaskInstance.dag_id == ti.dag_id,
- TaskInstance.run_id == ti.run_id,
- TaskInstance.map_index == ti.map_index,
- )
- ).one()
- if ti.state == state:
- return False
- current_time = timezone.utcnow()
- ti.log.debug("Setting task state for %s to %s", ti, state)
- ti.state = state
- ti.start_date = ti.start_date or current_time
- if ti.state in State.finished or ti.state == TaskInstanceState.UP_FOR_RETRY:
- ti.end_date = ti.end_date or current_time
- ti.duration = (ti.end_date - ti.start_date).total_seconds()
- session.merge(ti)
- return True
- @provide_session
- def set_state(self, state: str | None, session: Session = NEW_SESSION) -> bool:
- """
- Set TaskInstance state.
- :param state: State to set for the TI
- :param session: SQLAlchemy ORM Session
- :return: Was the state changed
- """
- return self._set_state(ti=self, state=state, session=session)
- @property
- def is_premature(self) -> bool:
- """Returns whether a task is in UP_FOR_RETRY state and its retry interval has elapsed."""
- # is the task still in the retry waiting period?
- return self.state == TaskInstanceState.UP_FOR_RETRY and not self.ready_for_retry()
- @provide_session
- def are_dependents_done(self, session: Session = NEW_SESSION) -> bool:
- """
- Check whether the immediate dependents of this task instance have succeeded or have been skipped.
- This is meant to be used by wait_for_downstream.
- This is useful when you do not want to start processing the next
- schedule of a task until the dependents are done. For instance,
- if the task DROPs and recreates a table.
- :param session: SQLAlchemy ORM Session
- """
- task = self.task
- if TYPE_CHECKING:
- assert task
- if not task.downstream_task_ids:
- return True
- ti = session.query(func.count(TaskInstance.task_id)).filter(
- TaskInstance.dag_id == self.dag_id,
- TaskInstance.task_id.in_(task.downstream_task_ids),
- TaskInstance.run_id == self.run_id,
- TaskInstance.state.in_((TaskInstanceState.SKIPPED, TaskInstanceState.SUCCESS)),
- )
- count = ti[0][0]
- return count == len(task.downstream_task_ids)
- @provide_session
- def get_previous_dagrun(
- self,
- state: DagRunState | None = None,
- session: Session | None = None,
- ) -> DagRun | None:
- """
- Return the DagRun that ran before this task instance's DagRun.
- :param state: If passed, it only take into account instances of a specific state.
- :param session: SQLAlchemy ORM Session.
- """
- return _get_previous_dagrun(task_instance=self, state=state, session=session)
- @provide_session
- def get_previous_ti(
- self,
- state: DagRunState | None = None,
- session: Session = NEW_SESSION,
- ) -> TaskInstance | TaskInstancePydantic | None:
- """
- Return the task instance for the task that ran before this task instance.
- :param session: SQLAlchemy ORM Session
- :param state: If passed, it only take into account instances of a specific state.
- """
- return _get_previous_ti(task_instance=self, state=state, session=session)
- @property
- def previous_ti(self) -> TaskInstance | TaskInstancePydantic | None:
- """
- This attribute is deprecated.
- Please use :class:`airflow.models.taskinstance.TaskInstance.get_previous_ti`.
- """
- warnings.warn(
- """
- This attribute is deprecated.
- Please use `airflow.models.taskinstance.TaskInstance.get_previous_ti` method.
- """,
- RemovedInAirflow3Warning,
- stacklevel=2,
- )
- return self.get_previous_ti()
- @property
- def previous_ti_success(self) -> TaskInstance | TaskInstancePydantic | None:
- """
- This attribute is deprecated.
- Please use :class:`airflow.models.taskinstance.TaskInstance.get_previous_ti`.
- """
- warnings.warn(
- """
- This attribute is deprecated.
- Please use `airflow.models.taskinstance.TaskInstance.get_previous_ti` method.
- """,
- RemovedInAirflow3Warning,
- stacklevel=2,
- )
- return self.get_previous_ti(state=DagRunState.SUCCESS)
- @provide_session
- def get_previous_execution_date(
- self,
- state: DagRunState | None = None,
- session: Session = NEW_SESSION,
- ) -> pendulum.DateTime | None:
- """
- Return the execution date from property previous_ti_success.
- :param state: If passed, it only take into account instances of a specific state.
- :param session: SQLAlchemy ORM Session
- """
- return _get_previous_execution_date(task_instance=self, state=state, session=session)
- @provide_session
- def get_previous_start_date(
- self, state: DagRunState | None = None, session: Session = NEW_SESSION
- ) -> pendulum.DateTime | None:
- """
- Return the start date from property previous_ti_success.
- :param state: If passed, it only take into account instances of a specific state.
- :param session: SQLAlchemy ORM Session
- """
- return _get_previous_start_date(task_instance=self, state=state, session=session)
- @property
- def previous_start_date_success(self) -> pendulum.DateTime | None:
- """
- This attribute is deprecated.
- Please use :class:`airflow.models.taskinstance.TaskInstance.get_previous_start_date`.
- """
- warnings.warn(
- """
- This attribute is deprecated.
- Please use `airflow.models.taskinstance.TaskInstance.get_previous_start_date` method.
- """,
- RemovedInAirflow3Warning,
- stacklevel=2,
- )
- return self.get_previous_start_date(state=DagRunState.SUCCESS)
- @provide_session
- def are_dependencies_met(
- self, dep_context: DepContext | None = None, session: Session = NEW_SESSION, verbose: bool = False
- ) -> bool:
- """
- Are all conditions met for this task instance to be run given the context for the dependencies.
- (e.g. a task instance being force run from the UI will ignore some dependencies).
- :param dep_context: The execution context that determines the dependencies that should be evaluated.
- :param session: database session
- :param verbose: whether log details on failed dependencies on info or debug log level
- """
- dep_context = dep_context or DepContext()
- failed = False
- verbose_aware_logger = self.log.info if verbose else self.log.debug
- for dep_status in self.get_failed_dep_statuses(dep_context=dep_context, session=session):
- failed = True
- verbose_aware_logger(
- "Dependencies not met for %s, dependency '%s' FAILED: %s",
- self,
- dep_status.dep_name,
- dep_status.reason,
- )
- if failed:
- return False
- verbose_aware_logger("Dependencies all met for dep_context=%s ti=%s", dep_context.description, self)
- return True
- @provide_session
- def get_failed_dep_statuses(self, dep_context: DepContext | None = None, session: Session = NEW_SESSION):
- """Get failed Dependencies."""
- if TYPE_CHECKING:
- assert self.task
- dep_context = dep_context or DepContext()
- for dep in dep_context.deps | self.task.deps:
- for dep_status in dep.get_dep_statuses(self, session, dep_context):
- self.log.debug(
- "%s dependency '%s' PASSED: %s, %s",
- self,
- dep_status.dep_name,
- dep_status.passed,
- dep_status.reason,
- )
- if not dep_status.passed:
- yield dep_status
- def __repr__(self) -> str:
- prefix = f"<TaskInstance: {self.dag_id}.{self.task_id} {self.run_id} "
- if self.map_index != -1:
- prefix += f"map_index={self.map_index} "
- return prefix + f"[{self.state}]>"
- def next_retry_datetime(self):
- """
- Get datetime of the next retry if the task instance fails.
- For exponential backoff, retry_delay is used as base and will be converted to seconds.
- """
- from airflow.models.abstractoperator import MAX_RETRY_DELAY
- delay = self.task.retry_delay
- if self.task.retry_exponential_backoff:
- # If the min_backoff calculation is below 1, it will be converted to 0 via int. Thus,
- # we must round up prior to converting to an int, otherwise a divide by zero error
- # will occur in the modded_hash calculation.
- # this probably gives unexpected results if a task instance has previously been cleared,
- # because try_number can increase without bound
- min_backoff = math.ceil(delay.total_seconds() * (2 ** (self.try_number - 1)))
- # In the case when delay.total_seconds() is 0, min_backoff will not be rounded up to 1.
- # To address this, we impose a lower bound of 1 on min_backoff. This effectively makes
- # the ceiling function unnecessary, but the ceiling function was retained to avoid
- # introducing a breaking change.
- if min_backoff < 1:
- min_backoff = 1
- # deterministic per task instance
- ti_hash = int(
- hashlib.sha1(
- f"{self.dag_id}#{self.task_id}#{self.execution_date}#{self.try_number}".encode()
- ).hexdigest(),
- 16,
- )
- # between 1 and 1.0 * delay * (2^retry_number)
- modded_hash = min_backoff + ti_hash % min_backoff
- # timedelta has a maximum representable value. The exponentiation
- # here means this value can be exceeded after a certain number
- # of tries (around 50 if the initial delay is 1s, even fewer if
- # the delay is larger). Cap the value here before creating a
- # timedelta object so the operation doesn't fail with "OverflowError".
- delay_backoff_in_seconds = min(modded_hash, MAX_RETRY_DELAY)
- delay = timedelta(seconds=delay_backoff_in_seconds)
- if self.task.max_retry_delay:
- delay = min(self.task.max_retry_delay, delay)
- return self.end_date + delay
- def ready_for_retry(self) -> bool:
- """Check on whether the task instance is in the right state and timeframe to be retried."""
- return self.state == TaskInstanceState.UP_FOR_RETRY and self.next_retry_datetime() < timezone.utcnow()
- @staticmethod
- @internal_api_call
- def _get_dagrun(dag_id, run_id, session) -> DagRun:
- from airflow.models.dagrun import DagRun # Avoid circular import
- dr = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == run_id).one()
- return dr
- @provide_session
- def get_dagrun(self, session: Session = NEW_SESSION) -> DagRun:
- """
- Return the DagRun for this TaskInstance.
- :param session: SQLAlchemy ORM Session
- :return: DagRun
- """
- info = inspect(self)
- if info.attrs.dag_run.loaded_value is not NO_VALUE:
- if getattr(self, "task", None) is not None:
- if TYPE_CHECKING:
- assert self.task
- self.dag_run.dag = self.task.dag
- return self.dag_run
- dr = self._get_dagrun(self.dag_id, self.run_id, session)
- if getattr(self, "task", None) is not None:
- if TYPE_CHECKING:
- assert self.task
- dr.dag = self.task.dag
- # Record it in the instance for next time. This means that `self.execution_date` will work correctly
- set_committed_value(self, "dag_run", dr)
- return dr
- @classmethod
- @provide_session
- def ensure_dag(
- cls, task_instance: TaskInstance | TaskInstancePydantic, session: Session = NEW_SESSION
- ) -> DAG:
- """Ensure that task has a dag object associated, might have been removed by serialization."""
- if TYPE_CHECKING:
- assert task_instance.task
- if task_instance.task.dag is None or task_instance.task.dag.__class__ is AttributeRemoved:
- task_instance.task.dag = DagBag(read_dags_from_db=True).get_dag(
- dag_id=task_instance.dag_id, session=session
- )
- if TYPE_CHECKING:
- assert task_instance.task.dag
- return task_instance.task.dag
- @classmethod
- @internal_api_call
- @provide_session
- def _check_and_change_state_before_execution(
- cls,
- task_instance: TaskInstance | TaskInstancePydantic,
- verbose: bool = True,
- ignore_all_deps: bool = False,
- ignore_depends_on_past: bool = False,
- wait_for_past_depends_before_skipping: bool = False,
- ignore_task_deps: bool = False,
- ignore_ti_state: bool = False,
- mark_success: bool = False,
- test_mode: bool = False,
- hostname: str = "",
- job_id: str | None = None,
- pool: str | None = None,
- external_executor_id: str | None = None,
- session: Session = NEW_SESSION,
- ) -> bool:
- """
- Check dependencies and then sets state to RUNNING if they are met.
- Returns True if and only if state is set to RUNNING, which implies that task should be
- executed, in preparation for _run_raw_task.
- :param verbose: whether to turn on more verbose logging
- :param ignore_all_deps: Ignore all of the non-critical dependencies, just runs
- :param ignore_depends_on_past: Ignore depends_on_past DAG attribute
- :param wait_for_past_depends_before_skipping: Wait for past depends before mark the ti as skipped
- :param ignore_task_deps: Don't check the dependencies of this TaskInstance's task
- :param ignore_ti_state: Disregards previous task instance state
- :param mark_success: Don't run the task, mark its state as success
- :param test_mode: Doesn't record success or failure in the DB
- :param hostname: The hostname of the worker running the task instance.
- :param job_id: Job (BackfillJob / LocalTaskJob / SchedulerJob) ID
- :param pool: specifies the pool to use to run the task instance
- :param external_executor_id: The identifier of the celery executor
- :param session: SQLAlchemy ORM Session
- :return: whether the state was changed to running or not
- """
- if TYPE_CHECKING:
- assert task_instance.task
- if isinstance(task_instance, TaskInstance):
- ti: TaskInstance = task_instance
- else: # isinstance(task_instance, TaskInstancePydantic)
- filters = (col == getattr(task_instance, col.name) for col in inspect(TaskInstance).primary_key)
- ti = session.query(TaskInstance).filter(*filters).scalar()
- dag = DagBag(read_dags_from_db=True).get_dag(task_instance.dag_id, session=session)
- task_instance.task = dag.task_dict[ti.task_id]
- ti.task = task_instance.task
- task = task_instance.task
- if TYPE_CHECKING:
- assert task
- ti.refresh_from_task(task, pool_override=pool)
- ti.test_mode = test_mode
- ti.refresh_from_db(session=session, lock_for_update=True)
- ti.job_id = job_id
- ti.hostname = hostname
- ti.pid = None
- if not ignore_all_deps and not ignore_ti_state and ti.state == TaskInstanceState.SUCCESS:
- Stats.incr("previously_succeeded", tags=ti.stats_tags)
- if not mark_success:
- # Firstly find non-runnable and non-requeueable tis.
- # Since mark_success is not set, we do nothing.
- non_requeueable_dep_context = DepContext(
- deps=RUNNING_DEPS - REQUEUEABLE_DEPS,
- ignore_all_deps=ignore_all_deps,
- ignore_ti_state=ignore_ti_state,
- ignore_depends_on_past=ignore_depends_on_past,
- wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
- ignore_task_deps=ignore_task_deps,
- description="non-requeueable deps",
- )
- if not ti.are_dependencies_met(
- dep_context=non_requeueable_dep_context, session=session, verbose=True
- ):
- session.commit()
- return False
- # For reporting purposes, we report based on 1-indexed,
- # not 0-indexed lists (i.e. Attempt 1 instead of
- # Attempt 0 for the first attempt).
- # Set the task start date. In case it was re-scheduled use the initial
- # start date that is recorded in task_reschedule table
- # If the task continues after being deferred (next_method is set), use the original start_date
- ti.start_date = ti.start_date if ti.next_method else timezone.utcnow()
- if ti.state == TaskInstanceState.UP_FOR_RESCHEDULE:
- tr_start_date = session.scalar(
- TR.stmt_for_task_instance(ti, descending=False).with_only_columns(TR.start_date).limit(1)
- )
- if tr_start_date:
- ti.start_date = tr_start_date
- # Secondly we find non-runnable but requeueable tis. We reset its state.
- # This is because we might have hit concurrency limits,
- # e.g. because of backfilling.
- dep_context = DepContext(
- deps=REQUEUEABLE_DEPS,
- ignore_all_deps=ignore_all_deps,
- ignore_depends_on_past=ignore_depends_on_past,
- wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
- ignore_task_deps=ignore_task_deps,
- ignore_ti_state=ignore_ti_state,
- description="requeueable deps",
- )
- if not ti.are_dependencies_met(dep_context=dep_context, session=session, verbose=True):
- ti.state = None
- cls.logger().warning(
- "Rescheduling due to concurrency limits reached "
- "at task runtime. Attempt %s of "
- "%s. State set to NONE.",
- ti.try_number,
- ti.max_tries + 1,
- )
- ti.queued_dttm = timezone.utcnow()
- session.merge(ti)
- session.commit()
- return False
- if ti.next_kwargs is not None:
- cls.logger().info("Resuming after deferral")
- else:
- cls.logger().info("Starting attempt %s of %s", ti.try_number, ti.max_tries + 1)
- if not test_mode:
- session.add(Log(TaskInstanceState.RUNNING.value, ti))
- ti.state = TaskInstanceState.RUNNING
- ti.emit_state_change_metric(TaskInstanceState.RUNNING)
- if external_executor_id:
- ti.external_executor_id = external_executor_id
- ti.end_date = None
- if not test_mode:
- session.merge(ti).task = task
- session.commit()
- # Closing all pooled connections to prevent
- # "max number of connections reached"
- settings.engine.dispose() # type: ignore
- if verbose:
- if mark_success:
- cls.logger().info("Marking success for %s on %s", ti.task, ti.execution_date)
- else:
- cls.logger().info("Executing %s on %s", ti.task, ti.execution_date)
- return True
- @provide_session
- def check_and_change_state_before_execution(
- self,
- verbose: bool = True,
- ignore_all_deps: bool = False,
- ignore_depends_on_past: bool = False,
- wait_for_past_depends_before_skipping: bool = False,
- ignore_task_deps: bool = False,
- ignore_ti_state: bool = False,
- mark_success: bool = False,
- test_mode: bool = False,
- job_id: str | None = None,
- pool: str | None = None,
- external_executor_id: str | None = None,
- session: Session = NEW_SESSION,
- ) -> bool:
- return TaskInstance._check_and_change_state_before_execution(
- task_instance=self,
- verbose=verbose,
- ignore_all_deps=ignore_all_deps,
- ignore_depends_on_past=ignore_depends_on_past,
- wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
- ignore_task_deps=ignore_task_deps,
- ignore_ti_state=ignore_ti_state,
- mark_success=mark_success,
- test_mode=test_mode,
- hostname=get_hostname(),
- job_id=job_id,
- pool=pool,
- external_executor_id=external_executor_id,
- session=session,
- )
- def emit_state_change_metric(self, new_state: TaskInstanceState) -> None:
- """
- Send a time metric representing how much time a given state transition took.
- The previous state and metric name is deduced from the state the task was put in.
- :param new_state: The state that has just been set for this task.
- We do not use `self.state`, because sometimes the state is updated directly in the DB and not in
- the local TaskInstance object.
- Supported states: QUEUED and RUNNING
- """
- if self.end_date:
- # if the task has an end date, it means that this is not its first round.
- # we send the state transition time metric only on the first try, otherwise it gets more complex.
- return
- # switch on state and deduce which metric to send
- if new_state == TaskInstanceState.RUNNING:
- metric_name = "queued_duration"
- if self.queued_dttm is None:
- # this should not really happen except in tests or rare cases,
- # but we don't want to create errors just for a metric, so we just skip it
- self.log.warning(
- "cannot record %s for task %s because previous state change time has not been saved",
- metric_name,
- self.task_id,
- )
- return
- timing = timezone.utcnow() - self.queued_dttm
- elif new_state == TaskInstanceState.QUEUED:
- metric_name = "scheduled_duration"
- if self.start_date is None:
- # This check does not work correctly before fields like `scheduled_dttm` are implemented.
- # TODO: Change the level to WARNING once it's viable.
- # see #30612 #34493 and #34771 for more details
- self.log.debug(
- "cannot record %s for task %s because previous state change time has not been saved",
- metric_name,
- self.task_id,
- )
- return
- timing = timezone.utcnow() - self.start_date
- else:
- raise NotImplementedError("no metric emission setup for state %s", new_state)
- # send metric twice, once (legacy) with tags in the name and once with tags as tags
- Stats.timing(f"dag.{self.dag_id}.{self.task_id}.{metric_name}", timing)
- Stats.timing(f"task.{metric_name}", timing, tags={"task_id": self.task_id, "dag_id": self.dag_id})
- def clear_next_method_args(self) -> None:
- """Ensure we unset next_method and next_kwargs to ensure that any retries don't reuse them."""
- _clear_next_method_args(task_instance=self)
- @provide_session
- @Sentry.enrich_errors
- def _run_raw_task(
- self,
- mark_success: bool = False,
- test_mode: bool = False,
- job_id: str | None = None,
- pool: str | None = None,
- raise_on_defer: bool = False,
- session: Session = NEW_SESSION,
- ) -> TaskReturnCode | None:
- """
- Run a task, update the state upon completion, and run any appropriate callbacks.
- Immediately runs the task (without checking or changing db state
- before execution) and then sets the appropriate final state after
- completion and runs any post-execute callbacks. Meant to be called
- only after another function changes the state to running.
- :param mark_success: Don't run the task, mark its state as success
- :param test_mode: Doesn't record success or failure in the DB
- :param pool: specifies the pool to use to run the task instance
- :param session: SQLAlchemy ORM Session
- """
- if TYPE_CHECKING:
- assert self.task
- return _run_raw_task(
- ti=self,
- mark_success=mark_success,
- test_mode=test_mode,
- job_id=job_id,
- pool=pool,
- raise_on_defer=raise_on_defer,
- session=session,
- )
- def _register_dataset_changes(self, *, events: OutletEventAccessors, session: Session) -> None:
- if TYPE_CHECKING:
- assert self.task
- # One task only triggers one dataset event for each dataset with the same extra.
- # This tuple[dataset uri, extra] to sets alias names mapping is used to find whether
- # there're datasets with same uri but different extra that we need to emit more than one dataset events.
- dataset_tuple_to_alias_names_mapping: dict[tuple[str, frozenset], set[str]] = defaultdict(set)
- for obj in self.task.outlets or []:
- self.log.debug("outlet obj %s", obj)
- # Lineage can have other types of objects besides datasets
- if isinstance(obj, Dataset):
- dataset_manager.register_dataset_change(
- task_instance=self,
- dataset=obj,
- extra=events[obj].extra,
- session=session,
- )
- elif isinstance(obj, DatasetAlias):
- for dataset_alias_event in events[obj].dataset_alias_events:
- dataset_alias_name = dataset_alias_event["source_alias_name"]
- dataset_uri = dataset_alias_event["dest_dataset_uri"]
- extra = dataset_alias_event["extra"]
- frozen_extra = frozenset(extra.items())
- dataset_tuple_to_alias_names_mapping[(dataset_uri, frozen_extra)].add(dataset_alias_name)
- dataset_objs_cache: dict[str, DatasetModel] = {}
- for (uri, extra_items), alias_names in dataset_tuple_to_alias_names_mapping.items():
- if uri not in dataset_objs_cache:
- dataset_obj = session.scalar(select(DatasetModel).where(DatasetModel.uri == uri).limit(1))
- dataset_objs_cache[uri] = dataset_obj
- else:
- dataset_obj = dataset_objs_cache[uri]
- if not dataset_obj:
- dataset_obj = DatasetModel(uri=uri)
- dataset_manager.create_datasets(dataset_models=[dataset_obj], session=session)
- self.log.warning("Created a new %r as it did not exist.", dataset_obj)
- dataset_objs_cache[uri] = dataset_obj
- for alias in alias_names:
- alias_obj = session.scalar(
- select(DatasetAliasModel).where(DatasetAliasModel.name == alias).limit(1)
- )
- dataset_obj.aliases.append(alias_obj)
- extra = {k: v for k, v in extra_items}
- self.log.info(
- 'Creating event for %r through aliases "%s"',
- dataset_obj,
- ", ".join(alias_names),
- )
- dataset_manager.register_dataset_change(
- task_instance=self,
- dataset=dataset_obj,
- extra=extra,
- session=session,
- source_alias_names=alias_names,
- )
- def _execute_task_with_callbacks(self, context: Context, test_mode: bool = False, *, session: Session):
- """Prepare Task for Execution."""
- if TYPE_CHECKING:
- assert self.task
- parent_pid = os.getpid()
- def signal_handler(signum, frame):
- pid = os.getpid()
- # If a task forks during execution (from DAG code) for whatever
- # reason, we want to make sure that we react to the signal only in
- # the process that we've spawned ourselves (referred to here as the
- # parent process).
- if pid != parent_pid:
- os._exit(1)
- return
- self.log.error("Received SIGTERM. Terminating subprocesses.")
- self.log.error("Stacktrace: \n%s", "".join(traceback.format_stack()))
- self.task.on_kill()
- raise AirflowTaskTerminated("Task received SIGTERM signal")
- signal.signal(signal.SIGTERM, signal_handler)
- # Don't clear Xcom until the task is certain to execute, and check if we are resuming from deferral.
- if not self.next_method:
- self.clear_xcom_data()
- with Stats.timer(f"dag.{self.task.dag_id}.{self.task.task_id}.duration"), Stats.timer(
- "task.duration", tags=self.stats_tags
- ):
- # Set the validated/merged params on the task object.
- self.task.params = context["params"]
- with set_current_context(context):
- dag = self.task.get_dag()
- if dag is not None:
- jinja_env = dag.get_template_env()
- else:
- jinja_env = None
- task_orig = self.render_templates(context=context, jinja_env=jinja_env)
- # The task is never MappedOperator at this point.
- if TYPE_CHECKING:
- assert isinstance(self.task, BaseOperator)
- if not test_mode:
- rendered_fields = get_serialized_template_fields(task=self.task)
- _update_rtif(ti=self, rendered_fields=rendered_fields)
- # Export context to make it available for operators to use.
- airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True)
- os.environ.update(airflow_context_vars)
- # Log context only for the default execution method, the assumption
- # being that otherwise we're resuming a deferred task (in which
- # case there's no need to log these again).
- if not self.next_method:
- self.log.info(
- "Exporting env vars: %s",
- " ".join(f"{k}={v!r}" for k, v in airflow_context_vars.items()),
- )
- # Run pre_execute callback
- self.task.pre_execute(context=context)
- # Run on_execute callback
- self._run_execute_callback(context, self.task)
- # Run on_task_instance_running event
- get_listener_manager().hook.on_task_instance_running(
- previous_state=TaskInstanceState.QUEUED, task_instance=self, session=session
- )
- def _render_map_index(context: Context, *, jinja_env: jinja2.Environment | None) -> str | None:
- """Render named map index if the DAG author defined map_index_template at the task level."""
- if jinja_env is None or (template := context.get("map_index_template")) is None:
- return None
- rendered_map_index = jinja_env.from_string(template).render(context)
- log.debug("Map index rendered as %s", rendered_map_index)
- return rendered_map_index
- # Execute the task.
- with set_current_context(context):
- try:
- result = self._execute_task(context, task_orig)
- except Exception:
- # If the task failed, swallow rendering error so it doesn't mask the main error.
- with contextlib.suppress(jinja2.TemplateSyntaxError, jinja2.UndefinedError):
- self.rendered_map_index = _render_map_index(context, jinja_env=jinja_env)
- raise
- else: # If the task succeeded, render normally to let rendering error bubble up.
- self.rendered_map_index = _render_map_index(context, jinja_env=jinja_env)
- # Run post_execute callback
- self.task.post_execute(context=context, result=result)
- Stats.incr(f"operator_successes_{self.task.task_type}", tags=self.stats_tags)
- # Same metric with tagging
- Stats.incr("operator_successes", tags={**self.stats_tags, "task_type": self.task.task_type})
- Stats.incr("ti_successes", tags=self.stats_tags)
- def _execute_task(self, context: Context, task_orig: Operator):
- """
- Execute Task (optionally with a Timeout) and push Xcom results.
- :param context: Jinja2 context
- :param task_orig: origin task
- """
- return _execute_task(self, context, task_orig)
- @provide_session
- def defer_task(self, exception: TaskDeferred | None, session: Session = NEW_SESSION) -> None:
- """
- Mark the task as deferred and sets up the trigger that is needed to resume it when TaskDeferred is raised.
- :meta: private
- """
- _defer_task(ti=self, exception=exception, session=session)
- def _run_execute_callback(self, context: Context, task: BaseOperator) -> None:
- """Functions that need to be run before a Task is executed."""
- if not (callbacks := task.on_execute_callback):
- return
- for callback in callbacks if isinstance(callbacks, list) else [callbacks]:
- try:
- callback(context)
- except Exception:
- self.log.exception("Failed when executing execute callback")
- @provide_session
- def run(
- self,
- verbose: bool = True,
- ignore_all_deps: bool = False,
- ignore_depends_on_past: bool = False,
- wait_for_past_depends_before_skipping: bool = False,
- ignore_task_deps: bool = False,
- ignore_ti_state: bool = False,
- mark_success: bool = False,
- test_mode: bool = False,
- job_id: str | None = None,
- pool: str | None = None,
- session: Session = NEW_SESSION,
- raise_on_defer: bool = False,
- ) -> None:
- """Run TaskInstance."""
- res = self.check_and_change_state_before_execution(
- verbose=verbose,
- ignore_all_deps=ignore_all_deps,
- ignore_depends_on_past=ignore_depends_on_past,
- wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
- ignore_task_deps=ignore_task_deps,
- ignore_ti_state=ignore_ti_state,
- mark_success=mark_success,
- test_mode=test_mode,
- job_id=job_id,
- pool=pool,
- session=session,
- )
- if not res:
- return
- self._run_raw_task(
- mark_success=mark_success,
- test_mode=test_mode,
- job_id=job_id,
- pool=pool,
- session=session,
- raise_on_defer=raise_on_defer,
- )
- def dry_run(self) -> None:
- """Only Renders Templates for the TI."""
- if TYPE_CHECKING:
- assert self.task
- self.task = self.task.prepare_for_execution()
- self.render_templates()
- if TYPE_CHECKING:
- assert isinstance(self.task, BaseOperator)
- self.task.dry_run()
- @provide_session
- def _handle_reschedule(
- self,
- actual_start_date: datetime,
- reschedule_exception: AirflowRescheduleException,
- test_mode: bool = False,
- session: Session = NEW_SESSION,
- ):
- _handle_reschedule(
- ti=self,
- actual_start_date=actual_start_date,
- reschedule_exception=reschedule_exception,
- test_mode=test_mode,
- session=session,
- )
- @staticmethod
- def get_truncated_error_traceback(error: BaseException, truncate_to: Callable) -> TracebackType | None:
- """
- Truncate the traceback of an exception to the first frame called from within a given function.
- :param error: exception to get traceback from
- :param truncate_to: Function to truncate TB to. Must have a ``__code__`` attribute
- :meta private:
- """
- tb = error.__traceback__
- code = truncate_to.__func__.__code__ # type: ignore[attr-defined]
- while tb is not None:
- if tb.tb_frame.f_code is code:
- return tb.tb_next
- tb = tb.tb_next
- return tb or error.__traceback__
- @classmethod
- def fetch_handle_failure_context(
- cls,
- ti: TaskInstance,
- error: None | str | BaseException,
- test_mode: bool | None = None,
- context: Context | None = None,
- force_fail: bool = False,
- *,
- session: Session,
- fail_stop: bool = False,
- ):
- """
- Handle Failure for the TaskInstance.
- :param fail_stop: if true, stop remaining tasks in dag
- """
- if error:
- if isinstance(error, BaseException):
- tb = TaskInstance.get_truncated_error_traceback(error, truncate_to=ti._execute_task)
- cls.logger().error("Task failed with exception", exc_info=(type(error), error, tb))
- else:
- cls.logger().error("%s", error)
- if not test_mode:
- ti.refresh_from_db(session)
- ti.end_date = timezone.utcnow()
- ti.set_duration()
- Stats.incr(f"operator_failures_{ti.operator}", tags=ti.stats_tags)
- # Same metric with tagging
- Stats.incr("operator_failures", tags={**ti.stats_tags, "operator": ti.operator})
- Stats.incr("ti_failures", tags=ti.stats_tags)
- if not test_mode:
- session.add(Log(TaskInstanceState.FAILED.value, ti))
- # Log failure duration
- session.add(TaskFail(ti=ti))
- ti.clear_next_method_args()
- # In extreme cases (zombie in case of dag with parse error) we might _not_ have a Task.
- if context is None and getattr(ti, "task", None):
- context = ti.get_template_context(session)
- if context is not None:
- context["exception"] = error
- # Set state correctly and figure out how to log it and decide whether
- # to email
- # Note, callback invocation needs to be handled by caller of
- # _run_raw_task to avoid race conditions which could lead to duplicate
- # invocations or miss invocation.
- # Since this function is called only when the TaskInstance state is running,
- # try_number contains the current try_number (not the next). We
- # only mark task instance as FAILED if the next task instance
- # try_number exceeds the max_tries ... or if force_fail is truthy
- task: BaseOperator | None = None
- try:
- if getattr(ti, "task", None) and context:
- if TYPE_CHECKING:
- assert ti.task
- task = ti.task.unmap((context, session))
- except Exception:
- cls.logger().error("Unable to unmap task to determine if we need to send an alert email")
- if force_fail or not ti.is_eligible_to_retry():
- ti.state = TaskInstanceState.FAILED
- email_for_state = operator.attrgetter("email_on_failure")
- callbacks = task.on_failure_callback if task else None
- if task and fail_stop:
- _stop_remaining_tasks(task_instance=ti, session=session)
- else:
- if ti.state == TaskInstanceState.RUNNING:
- # If the task instance is in the running state, it means it raised an exception and
- # about to retry so we record the task instance history. For other states, the task
- # instance was cleared and already recorded in the task instance history.
- from airflow.models.taskinstancehistory import TaskInstanceHistory
- TaskInstanceHistory.record_ti(ti, session=session)
- ti.state = State.UP_FOR_RETRY
- email_for_state = operator.attrgetter("email_on_retry")
- callbacks = task.on_retry_callback if task else None
- get_listener_manager().hook.on_task_instance_failed(
- previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error, session=session
- )
- return {
- "ti": ti,
- "email_for_state": email_for_state,
- "task": task,
- "callbacks": callbacks,
- "context": context,
- }
- @staticmethod
- @internal_api_call
- @provide_session
- def save_to_db(
- ti: TaskInstance | TaskInstancePydantic, session: Session = NEW_SESSION, refresh_dag: bool = True
- ):
- if refresh_dag and isinstance(ti, TaskInstance):
- ti.get_dagrun().refresh_from_db()
- ti = _coalesce_to_orm_ti(ti=ti, session=session)
- ti.updated_at = timezone.utcnow()
- session.merge(ti)
- session.flush()
- session.commit()
- @provide_session
- def handle_failure(
- self,
- error: None | str | BaseException,
- test_mode: bool | None = None,
- context: Context | None = None,
- force_fail: bool = False,
- session: Session = NEW_SESSION,
- ) -> None:
- """
- Handle Failure for a task instance.
- :param error: if specified, log the specific exception if thrown
- :param session: SQLAlchemy ORM Session
- :param test_mode: doesn't record success or failure in the DB if True
- :param context: Jinja2 context
- :param force_fail: if True, task does not retry
- """
- if TYPE_CHECKING:
- assert self.task
- assert self.task.dag
- try:
- fail_stop = self.task.dag.fail_stop
- except Exception:
- fail_stop = False
- _handle_failure(
- task_instance=self,
- error=error,
- session=session,
- test_mode=test_mode,
- context=context,
- force_fail=force_fail,
- fail_stop=fail_stop,
- )
- def is_eligible_to_retry(self):
- """Is task instance is eligible for retry."""
- return _is_eligible_to_retry(task_instance=self)
- def get_template_context(
- self,
- session: Session | None = None,
- ignore_param_exceptions: bool = True,
- ) -> Context:
- """
- Return TI Context.
- :param session: SQLAlchemy ORM Session
- :param ignore_param_exceptions: flag to suppress value exceptions while initializing the ParamsDict
- """
- if TYPE_CHECKING:
- assert self.task
- assert self.task.dag
- return _get_template_context(
- task_instance=self,
- dag=self.task.dag,
- session=session,
- ignore_param_exceptions=ignore_param_exceptions,
- )
- @provide_session
- def get_rendered_template_fields(self, session: Session = NEW_SESSION) -> None:
- """
- Update task with rendered template fields for presentation in UI.
- If task has already run, will fetch from DB; otherwise will render.
- """
- from airflow.models.renderedtifields import RenderedTaskInstanceFields
- if TYPE_CHECKING:
- assert self.task
- rendered_task_instance_fields = RenderedTaskInstanceFields.get_templated_fields(self, session=session)
- if rendered_task_instance_fields:
- self.task = self.task.unmap(None)
- for field_name, rendered_value in rendered_task_instance_fields.items():
- setattr(self.task, field_name, rendered_value)
- return
- try:
- # If we get here, either the task hasn't run or the RTIF record was purged.
- from airflow.utils.log.secrets_masker import redact
- self.render_templates()
- for field_name in self.task.template_fields:
- rendered_value = getattr(self.task, field_name)
- setattr(self.task, field_name, redact(rendered_value, field_name))
- except (TemplateAssertionError, UndefinedError) as e:
- raise AirflowException(
- "Webserver does not have access to User-defined Macros or Filters "
- "when Dag Serialization is enabled. Hence for the task that have not yet "
- "started running, please use 'airflow tasks render' for debugging the "
- "rendering of template_fields."
- ) from e
- def overwrite_params_with_dag_run_conf(self, params: dict, dag_run: DagRun):
- """Overwrite Task Params with DagRun.conf."""
- if dag_run and dag_run.conf:
- self.log.debug("Updating task params (%s) with DagRun.conf (%s)", params, dag_run.conf)
- params.update(dag_run.conf)
- def render_templates(
- self, context: Context | None = None, jinja_env: jinja2.Environment | None = None
- ) -> Operator:
- """
- Render templates in the operator fields.
- If the task was originally mapped, this may replace ``self.task`` with
- the unmapped, fully rendered BaseOperator. The original ``self.task``
- before replacement is returned.
- """
- from airflow.models.mappedoperator import MappedOperator
- if not context:
- context = self.get_template_context()
- original_task = self.task
- ti = context["ti"]
- if TYPE_CHECKING:
- assert original_task
- assert self.task
- assert ti.task
- if ti.task.dag.__class__ is AttributeRemoved:
- ti.task.dag = self.task.dag
- # If self.task is mapped, this call replaces self.task to point to the
- # unmapped BaseOperator created by this function! This is because the
- # MappedOperator is useless for template rendering, and we need to be
- # able to access the unmapped task instead.
- original_task.render_template_fields(context, jinja_env)
- if isinstance(self.task, MappedOperator):
- self.task = context["ti"].task
- return original_task
- def render_k8s_pod_yaml(self) -> dict | None:
- """Render the k8s pod yaml."""
- try:
- from airflow.providers.cncf.kubernetes.template_rendering import (
- render_k8s_pod_yaml as render_k8s_pod_yaml_from_provider,
- )
- except ImportError:
- raise RuntimeError(
- "You need to have the `cncf.kubernetes` provider installed to use this feature. "
- "Also rather than calling it directly you should import "
- "render_k8s_pod_yaml from airflow.providers.cncf.kubernetes.template_rendering "
- "and call it with TaskInstance as the first argument."
- )
- warnings.warn(
- "You should not call `task_instance.render_k8s_pod_yaml` directly. This method will be removed"
- "in Airflow 3. Rather than calling it directly you should import "
- "`render_k8s_pod_yaml` from `airflow.providers.cncf.kubernetes.template_rendering` "
- "and call it with `TaskInstance` as the first argument.",
- DeprecationWarning,
- stacklevel=2,
- )
- return render_k8s_pod_yaml_from_provider(self)
- @provide_session
- def get_rendered_k8s_spec(self, session: Session = NEW_SESSION):
- """Render the k8s pod yaml."""
- try:
- from airflow.providers.cncf.kubernetes.template_rendering import (
- get_rendered_k8s_spec as get_rendered_k8s_spec_from_provider,
- )
- except ImportError:
- raise RuntimeError(
- "You need to have the `cncf.kubernetes` provider installed to use this feature. "
- "Also rather than calling it directly you should import "
- "`get_rendered_k8s_spec` from `airflow.providers.cncf.kubernetes.template_rendering` "
- "and call it with `TaskInstance` as the first argument."
- )
- warnings.warn(
- "You should not call `task_instance.render_k8s_pod_yaml` directly. This method will be removed"
- "in Airflow 3. Rather than calling it directly you should import "
- "`get_rendered_k8s_spec` from `airflow.providers.cncf.kubernetes.template_rendering` "
- "and call it with `TaskInstance` as the first argument.",
- DeprecationWarning,
- stacklevel=2,
- )
- return get_rendered_k8s_spec_from_provider(self, session=session)
- def get_email_subject_content(
- self, exception: BaseException, task: BaseOperator | None = None
- ) -> tuple[str, str, str]:
- """
- Get the email subject content for exceptions.
- :param exception: the exception sent in the email
- :param task:
- """
- return _get_email_subject_content(task_instance=self, exception=exception, task=task)
- def email_alert(self, exception, task: BaseOperator) -> None:
- """
- Send alert email with exception information.
- :param exception: the exception
- :param task: task related to the exception
- """
- _email_alert(task_instance=self, exception=exception, task=task)
- def set_duration(self) -> None:
- """Set task instance duration."""
- _set_duration(task_instance=self)
- @provide_session
- def xcom_push(
- self,
- key: str,
- value: Any,
- execution_date: datetime | None = None,
- session: Session = NEW_SESSION,
- ) -> None:
- """
- Make an XCom available for tasks to pull.
- :param key: Key to store the value under.
- :param value: Value to store. What types are possible depends on whether
- ``enable_xcom_pickling`` is true or not. If so, this can be any
- picklable object; only be JSON-serializable may be used otherwise.
- :param execution_date: Deprecated parameter that has no effect.
- """
- if execution_date is not None:
- self_execution_date = self.get_dagrun(session).execution_date
- if execution_date < self_execution_date:
- raise ValueError(
- f"execution_date can not be in the past (current execution_date is "
- f"{self_execution_date}; received {execution_date})"
- )
- elif execution_date is not None:
- message = "Passing 'execution_date' to 'TaskInstance.xcom_push()' is deprecated."
- warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3)
- XCom.set(
- key=key,
- value=value,
- task_id=self.task_id,
- dag_id=self.dag_id,
- run_id=self.run_id,
- map_index=self.map_index,
- session=session,
- )
- @provide_session
- def xcom_pull(
- self,
- task_ids: str | Iterable[str] | None = None,
- dag_id: str | None = None,
- key: str = XCOM_RETURN_KEY,
- include_prior_dates: bool = False,
- session: Session = NEW_SESSION,
- *,
- map_indexes: int | Iterable[int] | None = None,
- default: Any = None,
- ) -> Any:
- """
- Pull XComs that optionally meet certain criteria.
- :param key: A key for the XCom. If provided, only XComs with matching
- keys will be returned. The default key is ``'return_value'``, also
- available as constant ``XCOM_RETURN_KEY``. This key is automatically
- given to XComs returned by tasks (as opposed to being pushed
- manually). To remove the filter, pass *None*.
- :param task_ids: Only XComs from tasks with matching ids will be
- pulled. Pass *None* to remove the filter.
- :param dag_id: If provided, only pulls XComs from this DAG. If *None*
- (default), the DAG of the calling task is used.
- :param map_indexes: If provided, only pull XComs with matching indexes.
- If *None* (default), this is inferred from the task(s) being pulled
- (see below for details).
- :param include_prior_dates: If False, only XComs from the current
- execution_date are returned. If *True*, XComs from previous dates
- are returned as well.
- When pulling one single task (``task_id`` is *None* or a str) without
- specifying ``map_indexes``, the return value is inferred from whether
- the specified task is mapped. If not, value from the one single task
- instance is returned. If the task to pull is mapped, an iterator (not a
- list) yielding XComs from mapped task instances is returned. In either
- case, ``default`` (*None* if not specified) is returned if no matching
- XComs are found.
- When pulling multiple tasks (i.e. either ``task_id`` or ``map_index`` is
- a non-str iterable), a list of matching XComs is returned. Elements in
- the list is ordered by item ordering in ``task_id`` and ``map_index``.
- """
- return _xcom_pull(
- ti=self,
- task_ids=task_ids,
- dag_id=dag_id,
- key=key,
- include_prior_dates=include_prior_dates,
- session=session,
- map_indexes=map_indexes,
- default=default,
- )
- @provide_session
- def get_num_running_task_instances(self, session: Session, same_dagrun: bool = False) -> int:
- """Return Number of running TIs from the DB."""
- # .count() is inefficient
- num_running_task_instances_query = session.query(func.count()).filter(
- TaskInstance.dag_id == self.dag_id,
- TaskInstance.task_id == self.task_id,
- TaskInstance.state == TaskInstanceState.RUNNING,
- )
- if same_dagrun:
- num_running_task_instances_query = num_running_task_instances_query.filter(
- TaskInstance.run_id == self.run_id
- )
- return num_running_task_instances_query.scalar()
- def init_run_context(self, raw: bool = False) -> None:
- """Set the log context."""
- self.raw = raw
- self._set_context(self)
- @staticmethod
- def filter_for_tis(tis: Iterable[TaskInstance | TaskInstanceKey]) -> BooleanClauseList | None:
- """Return SQLAlchemy filter to query selected task instances."""
- # DictKeys type, (what we often pass here from the scheduler) is not directly indexable :(
- # Or it might be a generator, but we need to be able to iterate over it more than once
- tis = list(tis)
- if not tis:
- return None
- first = tis[0]
- dag_id = first.dag_id
- run_id = first.run_id
- map_index = first.map_index
- first_task_id = first.task_id
- # pre-compute the set of dag_id, run_id, map_indices and task_ids
- dag_ids, run_ids, map_indices, task_ids = set(), set(), set(), set()
- for t in tis:
- dag_ids.add(t.dag_id)
- run_ids.add(t.run_id)
- map_indices.add(t.map_index)
- task_ids.add(t.task_id)
- # Common path optimisations: when all TIs are for the same dag_id and run_id, or same dag_id
- # and task_id -- this can be over 150x faster for huge numbers of TIs (20k+)
- if dag_ids == {dag_id} and run_ids == {run_id} and map_indices == {map_index}:
- return and_(
- TaskInstance.dag_id == dag_id,
- TaskInstance.run_id == run_id,
- TaskInstance.map_index == map_index,
- TaskInstance.task_id.in_(task_ids),
- )
- if dag_ids == {dag_id} and task_ids == {first_task_id} and map_indices == {map_index}:
- return and_(
- TaskInstance.dag_id == dag_id,
- TaskInstance.run_id.in_(run_ids),
- TaskInstance.map_index == map_index,
- TaskInstance.task_id == first_task_id,
- )
- if dag_ids == {dag_id} and run_ids == {run_id} and task_ids == {first_task_id}:
- return and_(
- TaskInstance.dag_id == dag_id,
- TaskInstance.run_id == run_id,
- TaskInstance.map_index.in_(map_indices),
- TaskInstance.task_id == first_task_id,
- )
- filter_condition = []
- # create 2 nested groups, both primarily grouped by dag_id and run_id,
- # and in the nested group 1 grouped by task_id the other by map_index.
- task_id_groups: dict[tuple, dict[Any, list[Any]]] = defaultdict(lambda: defaultdict(list))
- map_index_groups: dict[tuple, dict[Any, list[Any]]] = defaultdict(lambda: defaultdict(list))
- for t in tis:
- task_id_groups[(t.dag_id, t.run_id)][t.task_id].append(t.map_index)
- map_index_groups[(t.dag_id, t.run_id)][t.map_index].append(t.task_id)
- # this assumes that most dags have dag_id as the largest grouping, followed by run_id. even
- # if its not, this is still a significant optimization over querying for every single tuple key
- for cur_dag_id, cur_run_id in itertools.product(dag_ids, run_ids):
- # we compare the group size between task_id and map_index and use the smaller group
- dag_task_id_groups = task_id_groups[(cur_dag_id, cur_run_id)]
- dag_map_index_groups = map_index_groups[(cur_dag_id, cur_run_id)]
- if len(dag_task_id_groups) <= len(dag_map_index_groups):
- for cur_task_id, cur_map_indices in dag_task_id_groups.items():
- filter_condition.append(
- and_(
- TaskInstance.dag_id == cur_dag_id,
- TaskInstance.run_id == cur_run_id,
- TaskInstance.task_id == cur_task_id,
- TaskInstance.map_index.in_(cur_map_indices),
- )
- )
- else:
- for cur_map_index, cur_task_ids in dag_map_index_groups.items():
- filter_condition.append(
- and_(
- TaskInstance.dag_id == cur_dag_id,
- TaskInstance.run_id == cur_run_id,
- TaskInstance.task_id.in_(cur_task_ids),
- TaskInstance.map_index == cur_map_index,
- )
- )
- return or_(*filter_condition)
- @classmethod
- def ti_selector_condition(cls, vals: Collection[str | tuple[str, int]]) -> ColumnOperators:
- """
- Build an SQLAlchemy filter for a list of task_ids or tuples of (task_id,map_index).
- :meta private:
- """
- # Compute a filter for TI.task_id and TI.map_index based on input values
- # For each item, it will either be a task_id, or (task_id, map_index)
- task_id_only = [v for v in vals if isinstance(v, str)]
- with_map_index = [v for v in vals if not isinstance(v, str)]
- filters: list[ColumnOperators] = []
- if task_id_only:
- filters.append(cls.task_id.in_(task_id_only))
- if with_map_index:
- filters.append(tuple_in_condition((cls.task_id, cls.map_index), with_map_index))
- if not filters:
- return false()
- if len(filters) == 1:
- return filters[0]
- return or_(*filters)
- @classmethod
- @provide_session
- def _schedule_downstream_tasks(
- cls,
- ti: TaskInstance | TaskInstancePydantic,
- session: Session = NEW_SESSION,
- max_tis_per_query: int | None = None,
- ):
- from sqlalchemy.exc import OperationalError
- from airflow.models.dagrun import DagRun
- try:
- # Re-select the row with a lock
- dag_run = with_row_locks(
- session.query(DagRun).filter_by(
- dag_id=ti.dag_id,
- run_id=ti.run_id,
- ),
- session=session,
- skip_locked=True,
- ).one_or_none()
- if not dag_run:
- cls.logger().debug("Skip locked rows, rollback")
- session.rollback()
- return
- task = ti.task
- if TYPE_CHECKING:
- assert task
- assert task.dag
- # Previously, this section used task.dag.partial_subset to retrieve a partial DAG.
- # However, this approach is unsafe as it can result in incomplete or incorrect task execution,
- # leading to potential bad cases. As a result, the operation has been removed.
- # For more details, refer to the discussion in PR #[https://github.com/apache/airflow/pull/42582].
- dag_run.dag = task.dag
- info = dag_run.task_instance_scheduling_decisions(session)
- skippable_task_ids = {
- task_id for task_id in task.dag.task_ids if task_id not in task.downstream_task_ids
- }
- schedulable_tis = [
- ti
- for ti in info.schedulable_tis
- if ti.task_id not in skippable_task_ids
- and not (
- ti.task.inherits_from_empty_operator
- and not ti.task.on_execute_callback
- and not ti.task.on_success_callback
- and not ti.task.outlets
- )
- ]
- for schedulable_ti in schedulable_tis:
- if getattr(schedulable_ti, "task", None) is None:
- schedulable_ti.task = task.dag.get_task(schedulable_ti.task_id)
- num = dag_run.schedule_tis(schedulable_tis, session=session, max_tis_per_query=max_tis_per_query)
- cls.logger().info("%d downstream tasks scheduled from follow-on schedule check", num)
- session.flush()
- except OperationalError as e:
- # Any kind of DB error here is _non fatal_ as this block is just an optimisation.
- cls.logger().warning(
- "Skipping mini scheduling run due to exception: %s",
- e.statement,
- exc_info=True,
- )
- session.rollback()
- @provide_session
- def schedule_downstream_tasks(self, session: Session = NEW_SESSION, max_tis_per_query: int | None = None):
- """
- Schedule downstream tasks of this task instance.
- :meta: private
- """
- try:
- return TaskInstance._schedule_downstream_tasks(
- ti=self, session=session, max_tis_per_query=max_tis_per_query
- )
- except Exception:
- self.log.exception(
- "Error scheduling downstream tasks. Skipping it as this is entirely optional optimisation. "
- "There might be various reasons for it, please take a look at the stack trace to figure "
- "out if the root cause can be diagnosed and fixed. See the issue "
- "https://github.com/apache/airflow/issues/39717 for details and an example problem. If you "
- "would like to get help in solving root cause, open discussion with all details with your "
- "managed service support or in Airflow repository."
- )
- def get_relevant_upstream_map_indexes(
- self,
- upstream: Operator,
- ti_count: int | None,
- *,
- session: Session,
- ) -> int | range | None:
- """
- Infer the map indexes of an upstream "relevant" to this ti.
- The bulk of the logic mainly exists to solve the problem described by
- the following example, where 'val' must resolve to different values,
- depending on where the reference is being used::
- @task
- def this_task(v): # This is self.task.
- return v * 2
- @task_group
- def tg1(inp):
- val = upstream(inp) # This is the upstream task.
- this_task(val) # When inp is 1, val here should resolve to 2.
- return val
- # This val is the same object returned by tg1.
- val = tg1.expand(inp=[1, 2, 3])
- @task_group
- def tg2(inp):
- another_task(inp, val) # val here should resolve to [2, 4, 6].
- tg2.expand(inp=["a", "b"])
- The surrounding mapped task groups of ``upstream`` and ``self.task`` are
- inspected to find a common "ancestor". If such an ancestor is found,
- we need to return specific map indexes to pull a partial value from
- upstream XCom.
- :param upstream: The referenced upstream task.
- :param ti_count: The total count of task instance this task was expanded
- by the scheduler, i.e. ``expanded_ti_count`` in the template context.
- :return: Specific map index or map indexes to pull, or ``None`` if we
- want to "whole" return value (i.e. no mapped task groups involved).
- """
- if TYPE_CHECKING:
- assert self.task
- # This value should never be None since we already know the current task
- # is in a mapped task group, and should have been expanded, despite that,
- # we need to check that it is not None to satisfy Mypy.
- # But this value can be 0 when we expand an empty list, for that it is
- # necessary to check that ti_count is not 0 to avoid dividing by 0.
- if not ti_count:
- return None
- # Find the innermost common mapped task group between the current task
- # If the current task and the referenced task does not have a common
- # mapped task group, the two are in different task mapping contexts
- # (like another_task above), and we should use the "whole" value.
- common_ancestor = _find_common_ancestor_mapped_group(self.task, upstream)
- if common_ancestor is None:
- return None
- # At this point we know the two tasks share a mapped task group, and we
- # should use a "partial" value. Let's break down the mapped ti count
- # between the ancestor and further expansion happened inside it.
- ancestor_ti_count = common_ancestor.get_mapped_ti_count(self.run_id, session=session)
- ancestor_map_index = self.map_index * ancestor_ti_count // ti_count
- # If the task is NOT further expanded inside the common ancestor, we
- # only want to reference one single ti. We must walk the actual DAG,
- # and "ti_count == ancestor_ti_count" does not work, since the further
- # expansion may be of length 1.
- if not _is_further_mapped_inside(upstream, common_ancestor):
- return ancestor_map_index
- # Otherwise we need a partial aggregation for values from selected task
- # instances in the ancestor's expansion context.
- further_count = ti_count // ancestor_ti_count
- map_index_start = ancestor_map_index * further_count
- return range(map_index_start, map_index_start + further_count)
- def clear_db_references(self, session: Session):
- """
- Clear db tables that have a reference to this instance.
- :param session: ORM Session
- :meta private:
- """
- from airflow.models.renderedtifields import RenderedTaskInstanceFields
- tables: list[type[TaskInstanceDependencies]] = [
- TaskFail,
- TaskInstanceNote,
- TaskReschedule,
- XCom,
- RenderedTaskInstanceFields,
- TaskMap,
- ]
- for table in tables:
- session.execute(
- delete(table).where(
- table.dag_id == self.dag_id,
- table.task_id == self.task_id,
- table.run_id == self.run_id,
- table.map_index == self.map_index,
- )
- )
- def _find_common_ancestor_mapped_group(node1: Operator, node2: Operator) -> MappedTaskGroup | None:
- """Given two operators, find their innermost common mapped task group."""
- if node1.dag is None or node2.dag is None or node1.dag_id != node2.dag_id:
- return None
- parent_group_ids = {g.group_id for g in node1.iter_mapped_task_groups()}
- common_groups = (g for g in node2.iter_mapped_task_groups() if g.group_id in parent_group_ids)
- return next(common_groups, None)
- def _is_further_mapped_inside(operator: Operator, container: TaskGroup) -> bool:
- """Whether given operator is *further* mapped inside a task group."""
- from airflow.models.mappedoperator import MappedOperator
- if isinstance(operator, MappedOperator):
- return True
- task_group = operator.task_group
- while task_group is not None and task_group.group_id != container.group_id:
- if isinstance(task_group, MappedTaskGroup):
- return True
- task_group = task_group.parent_group
- return False
- # State of the task instance.
- # Stores string version of the task state.
- TaskInstanceStateType = Tuple[TaskInstanceKey, TaskInstanceState]
- class SimpleTaskInstance:
- """
- Simplified Task Instance.
- Used to send data between processes via Queues.
- """
- def __init__(
- self,
- dag_id: str,
- task_id: str,
- run_id: str,
- start_date: datetime | None,
- end_date: datetime | None,
- try_number: int,
- map_index: int,
- state: str,
- executor: str | None,
- executor_config: Any,
- pool: str,
- queue: str,
- key: TaskInstanceKey,
- run_as_user: str | None = None,
- priority_weight: int | None = None,
- ):
- self.dag_id = dag_id
- self.task_id = task_id
- self.run_id = run_id
- self.map_index = map_index
- self.start_date = start_date
- self.end_date = end_date
- self.try_number = try_number
- self.state = state
- self.executor = executor
- self.executor_config = executor_config
- self.run_as_user = run_as_user
- self.pool = pool
- self.priority_weight = priority_weight
- self.queue = queue
- self.key = key
- def __repr__(self) -> str:
- attrs = ", ".join(f"{k}={v!r}" for k, v in self.__dict__.items())
- return f"SimpleTaskInstance({attrs})"
- def __eq__(self, other) -> bool:
- if isinstance(other, self.__class__):
- return self.__dict__ == other.__dict__
- return NotImplemented
- def as_dict(self):
- warnings.warn(
- "This method is deprecated. Use BaseSerialization.serialize.",
- RemovedInAirflow3Warning,
- stacklevel=2,
- )
- new_dict = dict(self.__dict__)
- for key in new_dict:
- if key in ["start_date", "end_date"]:
- val = new_dict[key]
- if not val or isinstance(val, str):
- continue
- new_dict.update({key: val.isoformat()})
- return new_dict
- @classmethod
- def from_ti(cls, ti: TaskInstance) -> SimpleTaskInstance:
- return cls(
- dag_id=ti.dag_id,
- task_id=ti.task_id,
- run_id=ti.run_id,
- map_index=ti.map_index,
- start_date=ti.start_date,
- end_date=ti.end_date,
- try_number=ti.try_number,
- state=ti.state,
- executor=ti.executor,
- executor_config=ti.executor_config,
- pool=ti.pool,
- queue=ti.queue,
- key=ti.key,
- run_as_user=ti.run_as_user if hasattr(ti, "run_as_user") else None,
- priority_weight=ti.priority_weight if hasattr(ti, "priority_weight") else None,
- )
- @classmethod
- def from_dict(cls, obj_dict: dict) -> SimpleTaskInstance:
- warnings.warn(
- "This method is deprecated. Use BaseSerialization.deserialize.",
- RemovedInAirflow3Warning,
- stacklevel=2,
- )
- ti_key = TaskInstanceKey(*obj_dict.pop("key"))
- start_date = None
- end_date = None
- start_date_str: str | None = obj_dict.pop("start_date")
- end_date_str: str | None = obj_dict.pop("end_date")
- if start_date_str:
- start_date = timezone.parse(start_date_str)
- if end_date_str:
- end_date = timezone.parse(end_date_str)
- return cls(**obj_dict, start_date=start_date, end_date=end_date, key=ti_key)
- class TaskInstanceNote(TaskInstanceDependencies):
- """For storage of arbitrary notes concerning the task instance."""
- __tablename__ = "task_instance_note"
- user_id = Column(Integer, ForeignKey("ab_user.id", name="task_instance_note_user_fkey"), nullable=True)
- task_id = Column(StringID(), primary_key=True, nullable=False)
- dag_id = Column(StringID(), primary_key=True, nullable=False)
- run_id = Column(StringID(), primary_key=True, nullable=False)
- map_index = Column(Integer, primary_key=True, nullable=False)
- content = Column(String(1000).with_variant(Text(1000), "mysql"))
- created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
- updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False)
- task_instance = relationship("TaskInstance", back_populates="task_instance_note")
- __table_args__ = (
- PrimaryKeyConstraint("task_id", "dag_id", "run_id", "map_index", name="task_instance_note_pkey"),
- ForeignKeyConstraint(
- (dag_id, task_id, run_id, map_index),
- [
- "task_instance.dag_id",
- "task_instance.task_id",
- "task_instance.run_id",
- "task_instance.map_index",
- ],
- name="task_instance_note_ti_fkey",
- ondelete="CASCADE",
- ),
- )
- def __init__(self, content, user_id=None):
- self.content = content
- self.user_id = user_id
- def __repr__(self):
- prefix = f"<{self.__class__.__name__}: {self.dag_id}.{self.task_id} {self.run_id}"
- if self.map_index != -1:
- prefix += f" map_index={self.map_index}"
- return prefix + ">"
- STATICA_HACK = True
- globals()["kcah_acitats"[::-1].upper()] = False
- if STATICA_HACK: # pragma: no cover
- from airflow.jobs.job import Job
- TaskInstance.queued_by_job = relationship(Job)
|