serialized_objects.py 78 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895
  1. # Licensed to the Apache Software Foundation (ASF) under one
  2. # or more contributor license agreements. See the NOTICE file
  3. # distributed with this work for additional information
  4. # regarding copyright ownership. The ASF licenses this file
  5. # to you under the Apache License, Version 2.0 (the
  6. # "License"); you may not use this file except in compliance
  7. # with the License. You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing,
  12. # software distributed under the License is distributed on an
  13. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  14. # KIND, either express or implied. See the License for the
  15. # specific language governing permissions and limitations
  16. # under the License.
  17. """Serialized DAG and BaseOperator."""
  18. from __future__ import annotations
  19. import collections.abc
  20. import datetime
  21. import enum
  22. import inspect
  23. import logging
  24. import warnings
  25. import weakref
  26. from inspect import signature
  27. from textwrap import dedent
  28. from typing import TYPE_CHECKING, Any, Collection, Iterable, Mapping, NamedTuple, Union, cast
  29. import attrs
  30. import lazy_object_proxy
  31. from dateutil import relativedelta
  32. from pendulum.tz.timezone import FixedTimezone, Timezone
  33. from airflow import macros
  34. from airflow.callbacks.callback_requests import DagCallbackRequest, SlaCallbackRequest, TaskCallbackRequest
  35. from airflow.compat.functools import cache
  36. from airflow.configuration import conf
  37. from airflow.datasets import (
  38. BaseDataset,
  39. Dataset,
  40. DatasetAlias,
  41. DatasetAll,
  42. DatasetAny,
  43. _DatasetAliasCondition,
  44. )
  45. from airflow.exceptions import AirflowException, RemovedInAirflow3Warning, SerializationError, TaskDeferred
  46. from airflow.jobs.job import Job
  47. from airflow.models import Trigger
  48. from airflow.models.baseoperator import BaseOperator
  49. from airflow.models.connection import Connection
  50. from airflow.models.dag import DAG, DagModel, create_timetable
  51. from airflow.models.dagrun import DagRun
  52. from airflow.models.expandinput import EXPAND_INPUT_EMPTY, create_expand_input, get_map_type_key
  53. from airflow.models.mappedoperator import MappedOperator
  54. from airflow.models.param import Param, ParamsDict
  55. from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance
  56. from airflow.models.taskinstancekey import TaskInstanceKey
  57. from airflow.models.tasklog import LogTemplate
  58. from airflow.models.xcom_arg import XComArg, deserialize_xcom_arg, serialize_xcom_arg
  59. from airflow.providers_manager import ProvidersManager
  60. from airflow.serialization.dag_dependency import DagDependency
  61. from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
  62. from airflow.serialization.helpers import serialize_template_field
  63. from airflow.serialization.json_schema import load_dag_schema
  64. from airflow.serialization.pydantic.dag import DagModelPydantic
  65. from airflow.serialization.pydantic.dag_run import DagRunPydantic
  66. from airflow.serialization.pydantic.dataset import DatasetPydantic
  67. from airflow.serialization.pydantic.job import JobPydantic
  68. from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
  69. from airflow.serialization.pydantic.tasklog import LogTemplatePydantic
  70. from airflow.serialization.pydantic.trigger import TriggerPydantic
  71. from airflow.settings import _ENABLE_AIP_44, DAGS_FOLDER, json
  72. from airflow.task.priority_strategy import (
  73. PriorityWeightStrategy,
  74. airflow_priority_weight_strategies,
  75. airflow_priority_weight_strategies_classes,
  76. )
  77. from airflow.triggers.base import BaseTrigger, StartTriggerArgs
  78. from airflow.utils.code_utils import get_python_source
  79. from airflow.utils.context import (
  80. ConnectionAccessor,
  81. Context,
  82. OutletEventAccessor,
  83. OutletEventAccessors,
  84. VariableAccessor,
  85. )
  86. from airflow.utils.db import LazySelectSequence
  87. from airflow.utils.docs import get_docs_url
  88. from airflow.utils.module_loading import import_string, qualname
  89. from airflow.utils.operator_resources import Resources
  90. from airflow.utils.task_group import MappedTaskGroup, TaskGroup
  91. from airflow.utils.timezone import from_timestamp, parse_timezone
  92. from airflow.utils.types import NOTSET, ArgNotSet, AttributeRemoved
  93. if TYPE_CHECKING:
  94. from inspect import Parameter
  95. from airflow.models.baseoperatorlink import BaseOperatorLink
  96. from airflow.models.expandinput import ExpandInput
  97. from airflow.models.operator import Operator
  98. from airflow.models.taskmixin import DAGNode
  99. from airflow.serialization.json_schema import Validator
  100. from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
  101. from airflow.timetables.base import Timetable
  102. from airflow.utils.pydantic import BaseModel
  103. HAS_KUBERNETES: bool
  104. try:
  105. from kubernetes.client import models as k8s # noqa: TCH004
  106. from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator # noqa: TCH004
  107. except ImportError:
  108. pass
  109. log = logging.getLogger(__name__)
  110. _OPERATOR_EXTRA_LINKS: set[str] = {
  111. "airflow.operators.trigger_dagrun.TriggerDagRunLink",
  112. "airflow.sensors.external_task.ExternalDagLink",
  113. # Deprecated names, so that existing serialized dags load straight away.
  114. "airflow.sensors.external_task.ExternalTaskSensorLink",
  115. "airflow.operators.dagrun_operator.TriggerDagRunLink",
  116. "airflow.sensors.external_task_sensor.ExternalTaskSensorLink",
  117. }
  118. @cache
  119. def get_operator_extra_links() -> set[str]:
  120. """
  121. Get the operator extra links.
  122. This includes both the built-in ones, and those come from the providers.
  123. """
  124. _OPERATOR_EXTRA_LINKS.update(ProvidersManager().extra_links_class_names)
  125. return _OPERATOR_EXTRA_LINKS
  126. @cache
  127. def _get_default_mapped_partial() -> dict[str, Any]:
  128. """
  129. Get default partial kwargs in a mapped operator.
  130. This is used to simplify a serialized mapped operator by excluding default
  131. values supplied in the implementation from the serialized dict. Since those
  132. are defaults, they are automatically supplied on de-serialization, so we
  133. don't need to store them.
  134. """
  135. # Use the private _expand() method to avoid the empty kwargs check.
  136. default = BaseOperator.partial(task_id="_")._expand(EXPAND_INPUT_EMPTY, strict=False).partial_kwargs
  137. return BaseSerialization.serialize(default)[Encoding.VAR]
  138. def encode_relativedelta(var: relativedelta.relativedelta) -> dict[str, Any]:
  139. """Encode a relativedelta object."""
  140. encoded = {k: v for k, v in var.__dict__.items() if not k.startswith("_") and v}
  141. if var.weekday and var.weekday.n:
  142. # Every n'th Friday for example
  143. encoded["weekday"] = [var.weekday.weekday, var.weekday.n]
  144. elif var.weekday:
  145. encoded["weekday"] = [var.weekday.weekday]
  146. return encoded
  147. def decode_relativedelta(var: dict[str, Any]) -> relativedelta.relativedelta:
  148. """Dencode a relativedelta object."""
  149. if "weekday" in var:
  150. var["weekday"] = relativedelta.weekday(*var["weekday"]) # type: ignore
  151. return relativedelta.relativedelta(**var)
  152. def encode_timezone(var: Timezone | FixedTimezone) -> str | int:
  153. """
  154. Encode a Pendulum Timezone for serialization.
  155. Airflow only supports timezone objects that implements Pendulum's Timezone
  156. interface. We try to keep as much information as possible to make conversion
  157. round-tripping possible (see ``decode_timezone``). We need to special-case
  158. UTC; Pendulum implements it as a FixedTimezone (i.e. it gets encoded as
  159. 0 without the special case), but passing 0 into ``pendulum.timezone`` does
  160. not give us UTC (but ``+00:00``).
  161. """
  162. if isinstance(var, FixedTimezone):
  163. if var.offset == 0:
  164. return "UTC"
  165. return var.offset
  166. if isinstance(var, Timezone):
  167. return var.name
  168. raise ValueError(
  169. f"DAG timezone should be a pendulum.tz.Timezone, not {var!r}. "
  170. f"See {get_docs_url('timezone.html#time-zone-aware-dags')}"
  171. )
  172. def decode_timezone(var: str | int) -> Timezone | FixedTimezone:
  173. """Decode a previously serialized Pendulum Timezone."""
  174. return parse_timezone(var)
  175. def _get_registered_timetable(importable_string: str) -> type[Timetable] | None:
  176. from airflow import plugins_manager
  177. if importable_string.startswith("airflow.timetables."):
  178. return import_string(importable_string)
  179. plugins_manager.initialize_timetables_plugins()
  180. if plugins_manager.timetable_classes:
  181. return plugins_manager.timetable_classes.get(importable_string)
  182. else:
  183. return None
  184. def _get_registered_priority_weight_strategy(importable_string: str) -> type[PriorityWeightStrategy] | None:
  185. from airflow import plugins_manager
  186. if importable_string in airflow_priority_weight_strategies:
  187. return airflow_priority_weight_strategies[importable_string]
  188. plugins_manager.initialize_priority_weight_strategy_plugins()
  189. if plugins_manager.priority_weight_strategy_classes:
  190. return plugins_manager.priority_weight_strategy_classes.get(importable_string)
  191. else:
  192. return None
  193. class _TimetableNotRegistered(ValueError):
  194. def __init__(self, type_string: str) -> None:
  195. self.type_string = type_string
  196. def __str__(self) -> str:
  197. return (
  198. f"Timetable class {self.type_string!r} is not registered or "
  199. "you have a top level database access that disrupted the session. "
  200. "Please check the airflow best practices documentation."
  201. )
  202. class _PriorityWeightStrategyNotRegistered(AirflowException):
  203. def __init__(self, type_string: str) -> None:
  204. self.type_string = type_string
  205. def __str__(self) -> str:
  206. return (
  207. f"Priority weight strategy class {self.type_string!r} is not registered or "
  208. "you have a top level database access that disrupted the session. "
  209. "Please check the airflow best practices documentation."
  210. )
  211. def encode_dataset_condition(var: BaseDataset) -> dict[str, Any]:
  212. """
  213. Encode a dataset condition.
  214. :meta private:
  215. """
  216. if isinstance(var, Dataset):
  217. return {"__type": DAT.DATASET, "uri": var.uri, "extra": var.extra}
  218. if isinstance(var, DatasetAlias):
  219. return {"__type": DAT.DATASET_ALIAS, "name": var.name}
  220. if isinstance(var, DatasetAll):
  221. return {"__type": DAT.DATASET_ALL, "objects": [encode_dataset_condition(x) for x in var.objects]}
  222. if isinstance(var, DatasetAny):
  223. return {"__type": DAT.DATASET_ANY, "objects": [encode_dataset_condition(x) for x in var.objects]}
  224. raise ValueError(f"serialization not implemented for {type(var).__name__!r}")
  225. def decode_dataset_condition(var: dict[str, Any]) -> BaseDataset:
  226. """
  227. Decode a previously serialized dataset condition.
  228. :meta private:
  229. """
  230. dat = var["__type"]
  231. if dat == DAT.DATASET:
  232. return Dataset(var["uri"], extra=var["extra"])
  233. if dat == DAT.DATASET_ALL:
  234. return DatasetAll(*(decode_dataset_condition(x) for x in var["objects"]))
  235. if dat == DAT.DATASET_ANY:
  236. return DatasetAny(*(decode_dataset_condition(x) for x in var["objects"]))
  237. if dat == DAT.DATASET_ALIAS:
  238. return DatasetAlias(name=var["name"])
  239. raise ValueError(f"deserialization not implemented for DAT {dat!r}")
  240. def encode_outlet_event_accessor(var: OutletEventAccessor) -> dict[str, Any]:
  241. raw_key = var.raw_key
  242. return {
  243. "extra": var.extra,
  244. "dataset_alias_events": var.dataset_alias_events,
  245. "raw_key": BaseSerialization.serialize(raw_key),
  246. }
  247. def decode_outlet_event_accessor(var: dict[str, Any]) -> OutletEventAccessor:
  248. # This is added for compatibility. The attribute used to be dataset_alias_event and
  249. # is now dataset_alias_events.
  250. if dataset_alias_event := var.get("dataset_alias_event", None):
  251. dataset_alias_events = [dataset_alias_event]
  252. else:
  253. dataset_alias_events = var.get("dataset_alias_events", [])
  254. outlet_event_accessor = OutletEventAccessor(
  255. extra=var["extra"],
  256. raw_key=BaseSerialization.deserialize(var["raw_key"]),
  257. dataset_alias_events=dataset_alias_events,
  258. )
  259. return outlet_event_accessor
  260. def encode_timetable(var: Timetable) -> dict[str, Any]:
  261. """
  262. Encode a timetable instance.
  263. This delegates most of the serialization work to the type, so the behavior
  264. can be completely controlled by a custom subclass.
  265. :meta private:
  266. """
  267. timetable_class = type(var)
  268. importable_string = qualname(timetable_class)
  269. if _get_registered_timetable(importable_string) is None:
  270. raise _TimetableNotRegistered(importable_string)
  271. return {Encoding.TYPE: importable_string, Encoding.VAR: var.serialize()}
  272. def decode_timetable(var: dict[str, Any]) -> Timetable:
  273. """
  274. Decode a previously serialized timetable.
  275. Most of the deserialization logic is delegated to the actual type, which
  276. we import from string.
  277. :meta private:
  278. """
  279. importable_string = var[Encoding.TYPE]
  280. timetable_class = _get_registered_timetable(importable_string)
  281. if timetable_class is None:
  282. raise _TimetableNotRegistered(importable_string)
  283. return timetable_class.deserialize(var[Encoding.VAR])
  284. def encode_priority_weight_strategy(var: PriorityWeightStrategy) -> str:
  285. """
  286. Encode a priority weight strategy instance.
  287. In this version, we only store the importable string, so the class should not wait
  288. for any parameters to be passed to it. If you need to store the parameters, you
  289. should store them in the class itself.
  290. """
  291. priority_weight_strategy_class = type(var)
  292. if priority_weight_strategy_class in airflow_priority_weight_strategies_classes:
  293. return airflow_priority_weight_strategies_classes[priority_weight_strategy_class]
  294. importable_string = qualname(priority_weight_strategy_class)
  295. if _get_registered_priority_weight_strategy(importable_string) is None:
  296. raise _PriorityWeightStrategyNotRegistered(importable_string)
  297. return importable_string
  298. def decode_priority_weight_strategy(var: str) -> PriorityWeightStrategy:
  299. """
  300. Decode a previously serialized priority weight strategy.
  301. In this version, we only store the importable string, so we just need to get the class
  302. from the dictionary of registered classes and instantiate it with no parameters.
  303. """
  304. priority_weight_strategy_class = _get_registered_priority_weight_strategy(var)
  305. if priority_weight_strategy_class is None:
  306. raise _PriorityWeightStrategyNotRegistered(var)
  307. return priority_weight_strategy_class()
  308. def encode_start_trigger_args(var: StartTriggerArgs) -> dict[str, Any]:
  309. """
  310. Encode a StartTriggerArgs.
  311. :meta private:
  312. """
  313. serialize_kwargs = (
  314. lambda key: BaseSerialization.serialize(getattr(var, key)) if getattr(var, key) is not None else None
  315. )
  316. return {
  317. "__type": "START_TRIGGER_ARGS",
  318. "trigger_cls": var.trigger_cls,
  319. "trigger_kwargs": serialize_kwargs("trigger_kwargs"),
  320. "next_method": var.next_method,
  321. "next_kwargs": serialize_kwargs("next_kwargs"),
  322. "timeout": var.timeout.total_seconds() if var.timeout else None,
  323. }
  324. def decode_start_trigger_args(var: dict[str, Any]) -> StartTriggerArgs:
  325. """
  326. Decode a StartTriggerArgs.
  327. :meta private:
  328. """
  329. deserialize_kwargs = lambda key: BaseSerialization.deserialize(var[key]) if var[key] is not None else None
  330. return StartTriggerArgs(
  331. trigger_cls=var["trigger_cls"],
  332. trigger_kwargs=deserialize_kwargs("trigger_kwargs"),
  333. next_method=var["next_method"],
  334. next_kwargs=deserialize_kwargs("next_kwargs"),
  335. timeout=datetime.timedelta(seconds=var["timeout"]) if var["timeout"] else None,
  336. )
  337. class _XComRef(NamedTuple):
  338. """
  339. Store info needed to create XComArg.
  340. We can't turn it in to a XComArg until we've loaded _all_ the tasks, so when
  341. deserializing an operator, we need to create something in its place, and
  342. post-process it in ``deserialize_dag``.
  343. """
  344. data: dict
  345. def deref(self, dag: DAG) -> XComArg:
  346. return deserialize_xcom_arg(self.data, dag)
  347. # These two should be kept in sync. Note that these are intentionally not using
  348. # the type declarations in expandinput.py so we always remember to update
  349. # serialization logic when adding new ExpandInput variants. If you add things to
  350. # the unions, be sure to update _ExpandInputRef to match.
  351. _ExpandInputOriginalValue = Union[
  352. # For .expand(**kwargs).
  353. Mapping[str, Any],
  354. # For expand_kwargs(arg).
  355. XComArg,
  356. Collection[Union[XComArg, Mapping[str, Any]]],
  357. ]
  358. _ExpandInputSerializedValue = Union[
  359. # For .expand(**kwargs).
  360. Mapping[str, Any],
  361. # For expand_kwargs(arg).
  362. _XComRef,
  363. Collection[Union[_XComRef, Mapping[str, Any]]],
  364. ]
  365. class _ExpandInputRef(NamedTuple):
  366. """
  367. Store info needed to create a mapped operator's expand input.
  368. This references a ``ExpandInput`` type, but replaces ``XComArg`` objects
  369. with ``_XComRef`` (see documentation on the latter type for reasoning).
  370. """
  371. key: str
  372. value: _ExpandInputSerializedValue
  373. @classmethod
  374. def validate_expand_input_value(cls, value: _ExpandInputOriginalValue) -> None:
  375. """
  376. Validate we've covered all ``ExpandInput.value`` types.
  377. This function does not actually do anything, but is called during
  378. serialization so Mypy will *statically* check we have handled all
  379. possible ExpandInput cases.
  380. """
  381. def deref(self, dag: DAG) -> ExpandInput:
  382. """
  383. De-reference into a concrete ExpandInput object.
  384. If you add more cases here, be sure to update _ExpandInputOriginalValue
  385. and _ExpandInputSerializedValue to match the logic.
  386. """
  387. if isinstance(self.value, _XComRef):
  388. value: Any = self.value.deref(dag)
  389. elif isinstance(self.value, collections.abc.Mapping):
  390. value = {k: v.deref(dag) if isinstance(v, _XComRef) else v for k, v in self.value.items()}
  391. else:
  392. value = [v.deref(dag) if isinstance(v, _XComRef) else v for v in self.value]
  393. return create_expand_input(self.key, value)
  394. _orm_to_model = {
  395. Job: JobPydantic,
  396. TaskInstance: TaskInstancePydantic,
  397. DagRun: DagRunPydantic,
  398. DagModel: DagModelPydantic,
  399. LogTemplate: LogTemplatePydantic,
  400. Dataset: DatasetPydantic,
  401. Trigger: TriggerPydantic,
  402. }
  403. _type_to_class: dict[DAT | str, list] = {
  404. DAT.BASE_JOB: [JobPydantic, Job],
  405. DAT.TASK_INSTANCE: [TaskInstancePydantic, TaskInstance],
  406. DAT.DAG_RUN: [DagRunPydantic, DagRun],
  407. DAT.DAG_MODEL: [DagModelPydantic, DagModel],
  408. DAT.LOG_TEMPLATE: [LogTemplatePydantic, LogTemplate],
  409. DAT.DATA_SET: [DatasetPydantic, Dataset],
  410. DAT.TRIGGER: [TriggerPydantic, Trigger],
  411. }
  412. _class_to_type = {cls_: type_ for type_, classes in _type_to_class.items() for cls_ in classes}
  413. def add_pydantic_class_type_mapping(attribute_type: str, orm_class, pydantic_class):
  414. _orm_to_model[orm_class] = pydantic_class
  415. _type_to_class[attribute_type] = [pydantic_class, orm_class]
  416. _class_to_type[pydantic_class] = attribute_type
  417. _class_to_type[orm_class] = attribute_type
  418. class BaseSerialization:
  419. """BaseSerialization provides utils for serialization."""
  420. # JSON primitive types.
  421. _primitive_types = (int, bool, float, str)
  422. # Time types.
  423. # datetime.date and datetime.time are converted to strings.
  424. _datetime_types = (datetime.datetime,)
  425. # Object types that are always excluded in serialization.
  426. _excluded_types = (logging.Logger, Connection, type, property)
  427. _json_schema: Validator | None = None
  428. # Should the extra operator link be loaded via plugins when
  429. # de-serializing the DAG? This flag is set to False in Scheduler so that Extra Operator links
  430. # are not loaded to not run User code in Scheduler.
  431. _load_operator_extra_links = True
  432. _CONSTRUCTOR_PARAMS: dict[str, Parameter] = {}
  433. SERIALIZER_VERSION = 1
  434. @classmethod
  435. def to_json(cls, var: DAG | BaseOperator | dict | list | set | tuple) -> str:
  436. """Stringify DAGs and operators contained by var and returns a JSON string of var."""
  437. return json.dumps(cls.to_dict(var), ensure_ascii=True)
  438. @classmethod
  439. def to_dict(cls, var: DAG | BaseOperator | dict | list | set | tuple) -> dict:
  440. """Stringify DAGs and operators contained by var and returns a dict of var."""
  441. # Don't call on this class directly - only SerializedDAG or
  442. # SerializedBaseOperator should be used as the "entrypoint"
  443. raise NotImplementedError()
  444. @classmethod
  445. def from_json(cls, serialized_obj: str) -> BaseSerialization | dict | list | set | tuple:
  446. """Deserialize json_str and reconstructs all DAGs and operators it contains."""
  447. return cls.from_dict(json.loads(serialized_obj))
  448. @classmethod
  449. def from_dict(cls, serialized_obj: dict[Encoding, Any]) -> BaseSerialization | dict | list | set | tuple:
  450. """Deserialize a dict of type decorators and reconstructs all DAGs and operators it contains."""
  451. return cls.deserialize(serialized_obj)
  452. @classmethod
  453. def validate_schema(cls, serialized_obj: str | dict) -> None:
  454. """Validate serialized_obj satisfies JSON schema."""
  455. if cls._json_schema is None:
  456. raise AirflowException(f"JSON schema of {cls.__name__:s} is not set.")
  457. if isinstance(serialized_obj, dict):
  458. cls._json_schema.validate(serialized_obj)
  459. elif isinstance(serialized_obj, str):
  460. cls._json_schema.validate(json.loads(serialized_obj))
  461. else:
  462. raise TypeError("Invalid type: Only dict and str are supported.")
  463. @staticmethod
  464. def _encode(x: Any, type_: Any) -> dict[Encoding, Any]:
  465. """Encode data by a JSON dict."""
  466. return {Encoding.VAR: x, Encoding.TYPE: type_}
  467. @classmethod
  468. def _is_primitive(cls, var: Any) -> bool:
  469. """Primitive types."""
  470. return var is None or isinstance(var, cls._primitive_types)
  471. @classmethod
  472. def _is_excluded(cls, var: Any, attrname: str, instance: Any) -> bool:
  473. """Check if type is excluded from serialization."""
  474. if var is None:
  475. if not cls._is_constructor_param(attrname, instance):
  476. # Any instance attribute, that is not a constructor argument, we exclude None as the default
  477. return True
  478. return cls._value_is_hardcoded_default(attrname, var, instance)
  479. return isinstance(var, cls._excluded_types) or cls._value_is_hardcoded_default(
  480. attrname, var, instance
  481. )
  482. @classmethod
  483. def serialize_to_json(
  484. cls, object_to_serialize: BaseOperator | MappedOperator | DAG, decorated_fields: set
  485. ) -> dict[str, Any]:
  486. """Serialize an object to JSON."""
  487. serialized_object: dict[str, Any] = {}
  488. keys_to_serialize = object_to_serialize.get_serialized_fields()
  489. for key in keys_to_serialize:
  490. # None is ignored in serialized form and is added back in deserialization.
  491. value = getattr(object_to_serialize, key, None)
  492. if cls._is_excluded(value, key, object_to_serialize):
  493. continue
  494. if key == "_operator_name":
  495. # when operator_name matches task_type, we can remove
  496. # it to reduce the JSON payload
  497. task_type = getattr(object_to_serialize, "_task_type", None)
  498. if value != task_type:
  499. serialized_object[key] = cls.serialize(value)
  500. elif key in decorated_fields:
  501. serialized_object[key] = cls.serialize(value)
  502. elif key == "timetable" and value is not None:
  503. serialized_object[key] = encode_timetable(value)
  504. elif key == "weight_rule" and value is not None:
  505. serialized_object[key] = encode_priority_weight_strategy(value)
  506. else:
  507. value = cls.serialize(value)
  508. if isinstance(value, dict) and Encoding.TYPE in value:
  509. value = value[Encoding.VAR]
  510. serialized_object[key] = value
  511. return serialized_object
  512. @classmethod
  513. def serialize(
  514. cls, var: Any, *, strict: bool = False, use_pydantic_models: bool = False
  515. ) -> Any: # Unfortunately there is no support for recursive types in mypy
  516. """
  517. Serialize an object; helper function of depth first search for serialization.
  518. The serialization protocol is:
  519. (1) keeping JSON supported types: primitives, dict, list;
  520. (2) encoding other types as ``{TYPE: 'foo', VAR: 'bar'}``, the deserialization
  521. step decode VAR according to TYPE;
  522. (3) Operator has a special field CLASS to record the original class
  523. name for displaying in UI.
  524. :meta private:
  525. """
  526. if use_pydantic_models and not _ENABLE_AIP_44:
  527. raise RuntimeError(
  528. "Setting use_pydantic_models = True requires AIP-44 (in progress) feature flag to be true. "
  529. "This parameter will be removed eventually when new serialization is used by AIP-44"
  530. )
  531. if cls._is_primitive(var):
  532. # enum.IntEnum is an int instance, it causes json dumps error so we use its value.
  533. if isinstance(var, enum.Enum):
  534. return var.value
  535. return var
  536. elif isinstance(var, dict):
  537. return cls._encode(
  538. {
  539. str(k): cls.serialize(v, strict=strict, use_pydantic_models=use_pydantic_models)
  540. for k, v in var.items()
  541. },
  542. type_=DAT.DICT,
  543. )
  544. elif isinstance(var, list):
  545. return [cls.serialize(v, strict=strict, use_pydantic_models=use_pydantic_models) for v in var]
  546. elif var.__class__.__name__ == "V1Pod" and _has_kubernetes() and isinstance(var, k8s.V1Pod):
  547. json_pod = PodGenerator.serialize_pod(var)
  548. return cls._encode(json_pod, type_=DAT.POD)
  549. elif isinstance(var, OutletEventAccessors):
  550. return cls._encode(
  551. cls.serialize(var._dict, strict=strict, use_pydantic_models=use_pydantic_models), # type: ignore[attr-defined]
  552. type_=DAT.DATASET_EVENT_ACCESSORS,
  553. )
  554. elif isinstance(var, OutletEventAccessor):
  555. return cls._encode(
  556. encode_outlet_event_accessor(var),
  557. type_=DAT.DATASET_EVENT_ACCESSOR,
  558. )
  559. elif isinstance(var, DAG):
  560. return cls._encode(SerializedDAG.serialize_dag(var), type_=DAT.DAG)
  561. elif isinstance(var, Resources):
  562. return var.to_dict()
  563. elif isinstance(var, MappedOperator):
  564. return cls._encode(SerializedBaseOperator.serialize_mapped_operator(var), type_=DAT.OP)
  565. elif isinstance(var, BaseOperator):
  566. var._needs_expansion = var.get_needs_expansion()
  567. return cls._encode(SerializedBaseOperator.serialize_operator(var), type_=DAT.OP)
  568. elif isinstance(var, cls._datetime_types):
  569. return cls._encode(var.timestamp(), type_=DAT.DATETIME)
  570. elif isinstance(var, datetime.timedelta):
  571. return cls._encode(var.total_seconds(), type_=DAT.TIMEDELTA)
  572. elif isinstance(var, (Timezone, FixedTimezone)):
  573. return cls._encode(encode_timezone(var), type_=DAT.TIMEZONE)
  574. elif isinstance(var, relativedelta.relativedelta):
  575. return cls._encode(encode_relativedelta(var), type_=DAT.RELATIVEDELTA)
  576. elif isinstance(var, TaskInstanceKey):
  577. return cls._encode(
  578. var._asdict(),
  579. type_=DAT.TASK_INSTANCE_KEY,
  580. )
  581. elif isinstance(var, (AirflowException, TaskDeferred)) and hasattr(var, "serialize"):
  582. exc_cls_name, args, kwargs = var.serialize()
  583. return cls._encode(
  584. cls.serialize(
  585. {"exc_cls_name": exc_cls_name, "args": args, "kwargs": kwargs},
  586. use_pydantic_models=use_pydantic_models,
  587. strict=strict,
  588. ),
  589. type_=DAT.AIRFLOW_EXC_SER,
  590. )
  591. elif isinstance(var, (KeyError, AttributeError)):
  592. return cls._encode(
  593. cls.serialize(
  594. {"exc_cls_name": var.__class__.__name__, "args": [var.args], "kwargs": {}},
  595. use_pydantic_models=use_pydantic_models,
  596. strict=strict,
  597. ),
  598. type_=DAT.BASE_EXC_SER,
  599. )
  600. elif isinstance(var, BaseTrigger):
  601. return cls._encode(
  602. cls.serialize(var.serialize(), use_pydantic_models=use_pydantic_models, strict=strict),
  603. type_=DAT.BASE_TRIGGER,
  604. )
  605. elif callable(var):
  606. return str(get_python_source(var))
  607. elif isinstance(var, set):
  608. # FIXME: casts set to list in customized serialization in future.
  609. try:
  610. return cls._encode(
  611. sorted(
  612. cls.serialize(v, strict=strict, use_pydantic_models=use_pydantic_models) for v in var
  613. ),
  614. type_=DAT.SET,
  615. )
  616. except TypeError:
  617. return cls._encode(
  618. [cls.serialize(v, strict=strict, use_pydantic_models=use_pydantic_models) for v in var],
  619. type_=DAT.SET,
  620. )
  621. elif isinstance(var, tuple):
  622. # FIXME: casts tuple to list in customized serialization in future.
  623. return cls._encode(
  624. [cls.serialize(v, strict=strict, use_pydantic_models=use_pydantic_models) for v in var],
  625. type_=DAT.TUPLE,
  626. )
  627. elif isinstance(var, TaskGroup):
  628. return TaskGroupSerialization.serialize_task_group(var)
  629. elif isinstance(var, Param):
  630. return cls._encode(cls._serialize_param(var), type_=DAT.PARAM)
  631. elif isinstance(var, XComArg):
  632. return cls._encode(serialize_xcom_arg(var), type_=DAT.XCOM_REF)
  633. elif isinstance(var, LazySelectSequence):
  634. return cls.serialize(list(var))
  635. elif isinstance(var, BaseDataset):
  636. serialized_dataset = encode_dataset_condition(var)
  637. return cls._encode(serialized_dataset, type_=serialized_dataset.pop("__type"))
  638. elif isinstance(var, SimpleTaskInstance):
  639. return cls._encode(
  640. cls.serialize(var.__dict__, strict=strict, use_pydantic_models=use_pydantic_models),
  641. type_=DAT.SIMPLE_TASK_INSTANCE,
  642. )
  643. elif isinstance(var, Connection):
  644. return cls._encode(var.to_dict(validate=True), type_=DAT.CONNECTION)
  645. elif isinstance(var, TaskCallbackRequest):
  646. return cls._encode(var.to_json(), type_=DAT.TASK_CALLBACK_REQUEST)
  647. elif isinstance(var, DagCallbackRequest):
  648. return cls._encode(var.to_json(), type_=DAT.DAG_CALLBACK_REQUEST)
  649. elif isinstance(var, SlaCallbackRequest):
  650. return cls._encode(var.to_json(), type_=DAT.SLA_CALLBACK_REQUEST)
  651. elif var.__class__ == Context:
  652. d = {}
  653. for k, v in var._context.items():
  654. obj = cls.serialize(v, strict=strict, use_pydantic_models=use_pydantic_models)
  655. d[str(k)] = obj
  656. return cls._encode(d, type_=DAT.TASK_CONTEXT)
  657. elif use_pydantic_models and _ENABLE_AIP_44:
  658. def _pydantic_model_dump(model_cls: type[BaseModel], var: Any) -> dict[str, Any]:
  659. return model_cls.model_validate(var).model_dump(mode="json") # type: ignore[attr-defined]
  660. if var.__class__ in _class_to_type:
  661. pyd_mod = _orm_to_model.get(var.__class__, var)
  662. mod = _pydantic_model_dump(pyd_mod, var)
  663. type_ = _class_to_type[var.__class__]
  664. return cls._encode(mod, type_=type_)
  665. else:
  666. return cls.default_serialization(strict, var)
  667. elif isinstance(var, ArgNotSet):
  668. return cls._encode(None, type_=DAT.ARG_NOT_SET)
  669. else:
  670. return cls.default_serialization(strict, var)
  671. @classmethod
  672. def default_serialization(cls, strict, var) -> str:
  673. log.debug("Cast type %s to str in serialization.", type(var))
  674. if strict:
  675. raise SerializationError("Encountered unexpected type")
  676. return str(var)
  677. @classmethod
  678. def deserialize(cls, encoded_var: Any, use_pydantic_models=False) -> Any:
  679. """
  680. Deserialize an object; helper function of depth first search for deserialization.
  681. :meta private:
  682. """
  683. # JSON primitives (except for dict) are not encoded.
  684. if use_pydantic_models and not _ENABLE_AIP_44:
  685. raise RuntimeError(
  686. "Setting use_pydantic_models = True requires AIP-44 (in progress) feature flag to be true. "
  687. "This parameter will be removed eventually when new serialization is used by AIP-44"
  688. )
  689. if cls._is_primitive(encoded_var):
  690. return encoded_var
  691. elif isinstance(encoded_var, list):
  692. return [cls.deserialize(v, use_pydantic_models) for v in encoded_var]
  693. if not isinstance(encoded_var, dict):
  694. raise ValueError(f"The encoded_var should be dict and is {type(encoded_var)}")
  695. var = encoded_var[Encoding.VAR]
  696. type_ = encoded_var[Encoding.TYPE]
  697. if type_ == DAT.TASK_CONTEXT:
  698. d = {}
  699. for k, v in var.items():
  700. if k == "task": # todo: add `_encode` of Operator so we don't need this
  701. continue
  702. d[k] = cls.deserialize(v, use_pydantic_models=True)
  703. d["task"] = d["task_instance"].task # todo: add `_encode` of Operator so we don't need this
  704. d["macros"] = macros
  705. d["var"] = {
  706. "json": VariableAccessor(deserialize_json=True),
  707. "value": VariableAccessor(deserialize_json=False),
  708. }
  709. d["conn"] = ConnectionAccessor()
  710. return Context(**d)
  711. elif type_ == DAT.DICT:
  712. return {k: cls.deserialize(v, use_pydantic_models) for k, v in var.items()}
  713. elif type_ == DAT.DATASET_EVENT_ACCESSORS:
  714. d = OutletEventAccessors() # type: ignore[assignment]
  715. d._dict = cls.deserialize(var) # type: ignore[attr-defined]
  716. return d
  717. elif type_ == DAT.DATASET_EVENT_ACCESSOR:
  718. return decode_outlet_event_accessor(var)
  719. elif type_ == DAT.DAG:
  720. return SerializedDAG.deserialize_dag(var)
  721. elif type_ == DAT.OP:
  722. return SerializedBaseOperator.deserialize_operator(var)
  723. elif type_ == DAT.DATETIME:
  724. return from_timestamp(var)
  725. elif type_ == DAT.POD:
  726. if not _has_kubernetes():
  727. raise RuntimeError("Cannot deserialize POD objects without kubernetes libraries installed!")
  728. pod = PodGenerator.deserialize_model_dict(var)
  729. return pod
  730. elif type_ == DAT.TIMEDELTA:
  731. return datetime.timedelta(seconds=var)
  732. elif type_ == DAT.TIMEZONE:
  733. return decode_timezone(var)
  734. elif type_ == DAT.RELATIVEDELTA:
  735. return decode_relativedelta(var)
  736. elif type_ == DAT.AIRFLOW_EXC_SER or type_ == DAT.BASE_EXC_SER:
  737. deser = cls.deserialize(var, use_pydantic_models=use_pydantic_models)
  738. exc_cls_name = deser["exc_cls_name"]
  739. args = deser["args"]
  740. kwargs = deser["kwargs"]
  741. del deser
  742. if type_ == DAT.AIRFLOW_EXC_SER:
  743. exc_cls = import_string(exc_cls_name)
  744. else:
  745. exc_cls = import_string(f"builtins.{exc_cls_name}")
  746. return exc_cls(*args, **kwargs)
  747. elif type_ == DAT.BASE_TRIGGER:
  748. tr_cls_name, kwargs = cls.deserialize(var, use_pydantic_models=use_pydantic_models)
  749. tr_cls = import_string(tr_cls_name)
  750. return tr_cls(**kwargs)
  751. elif type_ == DAT.SET:
  752. return {cls.deserialize(v, use_pydantic_models) for v in var}
  753. elif type_ == DAT.TUPLE:
  754. return tuple(cls.deserialize(v, use_pydantic_models) for v in var)
  755. elif type_ == DAT.PARAM:
  756. return cls._deserialize_param(var)
  757. elif type_ == DAT.XCOM_REF:
  758. return _XComRef(var) # Delay deserializing XComArg objects until we have the entire DAG.
  759. elif type_ == DAT.DATASET:
  760. return Dataset(**var)
  761. elif type_ == DAT.DATASET_ALIAS:
  762. return DatasetAlias(**var)
  763. elif type_ == DAT.DATASET_ANY:
  764. return DatasetAny(*(decode_dataset_condition(x) for x in var["objects"]))
  765. elif type_ == DAT.DATASET_ALL:
  766. return DatasetAll(*(decode_dataset_condition(x) for x in var["objects"]))
  767. elif type_ == DAT.SIMPLE_TASK_INSTANCE:
  768. return SimpleTaskInstance(**cls.deserialize(var))
  769. elif type_ == DAT.CONNECTION:
  770. return Connection(**var)
  771. elif type_ == DAT.TASK_CALLBACK_REQUEST:
  772. return TaskCallbackRequest.from_json(var)
  773. elif type_ == DAT.DAG_CALLBACK_REQUEST:
  774. return DagCallbackRequest.from_json(var)
  775. elif type_ == DAT.SLA_CALLBACK_REQUEST:
  776. return SlaCallbackRequest.from_json(var)
  777. elif type_ == DAT.TASK_INSTANCE_KEY:
  778. return TaskInstanceKey(**var)
  779. elif use_pydantic_models and _ENABLE_AIP_44:
  780. return _type_to_class[type_][0].model_validate(var)
  781. elif type_ == DAT.ARG_NOT_SET:
  782. return NOTSET
  783. else:
  784. raise TypeError(f"Invalid type {type_!s} in deserialization.")
  785. _deserialize_datetime = from_timestamp
  786. _deserialize_timezone = parse_timezone
  787. @classmethod
  788. def _deserialize_timedelta(cls, seconds: int) -> datetime.timedelta:
  789. return datetime.timedelta(seconds=seconds)
  790. @classmethod
  791. def _is_constructor_param(cls, attrname: str, instance: Any) -> bool:
  792. return attrname in cls._CONSTRUCTOR_PARAMS
  793. @classmethod
  794. def _value_is_hardcoded_default(cls, attrname: str, value: Any, instance: Any) -> bool:
  795. """
  796. Return true if ``value`` is the hard-coded default for the given attribute.
  797. This takes in to account cases where the ``max_active_tasks`` parameter is
  798. stored in the ``_max_active_tasks`` attribute.
  799. And by using `is` here only and not `==` this copes with the case a
  800. user explicitly specifies an attribute with the same "value" as the
  801. default. (This is because ``"default" is "default"`` will be False as
  802. they are different strings with the same characters.)
  803. Also returns True if the value is an empty list or empty dict. This is done
  804. to account for the case where the default value of the field is None but has the
  805. ``field = field or {}`` set.
  806. """
  807. if attrname in cls._CONSTRUCTOR_PARAMS and (
  808. cls._CONSTRUCTOR_PARAMS[attrname] is value or (value in [{}, []])
  809. ):
  810. return True
  811. return False
  812. @classmethod
  813. def _serialize_param(cls, param: Param):
  814. return {
  815. "__class": f"{param.__module__}.{param.__class__.__name__}",
  816. "default": cls.serialize(param.value),
  817. "description": cls.serialize(param.description),
  818. "schema": cls.serialize(param.schema),
  819. }
  820. @classmethod
  821. def _deserialize_param(cls, param_dict: dict):
  822. """
  823. Workaround to serialize Param on older versions.
  824. In 2.2.0, Param attrs were assumed to be json-serializable and were not run through
  825. this class's ``serialize`` method. So before running through ``deserialize``,
  826. we first verify that it's necessary to do.
  827. """
  828. class_name = param_dict["__class"]
  829. class_: type[Param] = import_string(class_name)
  830. attrs = ("default", "description", "schema")
  831. kwargs = {}
  832. def is_serialized(val):
  833. if isinstance(val, dict):
  834. return Encoding.TYPE in val
  835. if isinstance(val, list):
  836. return all(isinstance(item, dict) and Encoding.TYPE in item for item in val)
  837. return False
  838. for attr in attrs:
  839. if attr in param_dict:
  840. val = param_dict[attr]
  841. if is_serialized(val):
  842. val = cls.deserialize(val)
  843. kwargs[attr] = val
  844. return class_(**kwargs)
  845. @classmethod
  846. def _serialize_params_dict(cls, params: ParamsDict | dict) -> list[tuple[str, dict]]:
  847. """Serialize Params dict for a DAG or task as a list of tuples to ensure ordering."""
  848. serialized_params = []
  849. for k, v in params.items():
  850. if isinstance(params, ParamsDict):
  851. # Use native param object, not resolved value if possible
  852. v = params.get_param(k)
  853. try:
  854. class_identity = f"{v.__module__}.{v.__class__.__name__}"
  855. except AttributeError:
  856. class_identity = ""
  857. if class_identity == "airflow.models.param.Param":
  858. serialized_params.append((k, cls._serialize_param(v)))
  859. else:
  860. # Auto-box other values into Params object like it is done by DAG parsing as well
  861. serialized_params.append((k, cls._serialize_param(Param(v))))
  862. return serialized_params
  863. @classmethod
  864. def _deserialize_params_dict(cls, encoded_params: list[tuple[str, dict]]) -> ParamsDict:
  865. """Deserialize a DAG's Params dict."""
  866. if isinstance(encoded_params, collections.abc.Mapping):
  867. # in 2.9.2 or earlier params were serialized as JSON objects
  868. encoded_param_pairs: Iterable[tuple[str, dict]] = encoded_params.items()
  869. else:
  870. encoded_param_pairs = encoded_params
  871. op_params = {}
  872. for k, v in encoded_param_pairs:
  873. if isinstance(v, dict) and "__class" in v:
  874. op_params[k] = cls._deserialize_param(v)
  875. else:
  876. # Old style params, convert it
  877. op_params[k] = Param(v)
  878. return ParamsDict(op_params)
  879. class DependencyDetector:
  880. """
  881. Detects dependencies between DAGs.
  882. :meta private:
  883. """
  884. @staticmethod
  885. def detect_task_dependencies(task: Operator) -> list[DagDependency]:
  886. """Detect dependencies caused by tasks."""
  887. from airflow.operators.trigger_dagrun import TriggerDagRunOperator
  888. from airflow.sensors.external_task import ExternalTaskSensor
  889. deps = []
  890. if isinstance(task, TriggerDagRunOperator):
  891. deps.append(
  892. DagDependency(
  893. source=task.dag_id,
  894. target=getattr(task, "trigger_dag_id"),
  895. dependency_type="trigger",
  896. dependency_id=task.task_id,
  897. )
  898. )
  899. elif isinstance(task, ExternalTaskSensor):
  900. deps.append(
  901. DagDependency(
  902. source=getattr(task, "external_dag_id"),
  903. target=task.dag_id,
  904. dependency_type="sensor",
  905. dependency_id=task.task_id,
  906. )
  907. )
  908. for obj in task.outlets or []:
  909. if isinstance(obj, Dataset):
  910. deps.append(
  911. DagDependency(
  912. source=task.dag_id,
  913. target="dataset",
  914. dependency_type="dataset",
  915. dependency_id=obj.uri,
  916. )
  917. )
  918. elif isinstance(obj, DatasetAlias):
  919. cond = _DatasetAliasCondition(obj.name)
  920. deps.extend(cond.iter_dag_dependencies(source=task.dag_id, target=""))
  921. return deps
  922. @staticmethod
  923. def detect_dag_dependencies(dag: DAG | None) -> Iterable[DagDependency]:
  924. """Detect dependencies set directly on the DAG object."""
  925. if not dag:
  926. return
  927. yield from dag.timetable.dataset_condition.iter_dag_dependencies(source="", target=dag.dag_id)
  928. class SerializedBaseOperator(BaseOperator, BaseSerialization):
  929. """
  930. A JSON serializable representation of operator.
  931. All operators are casted to SerializedBaseOperator after deserialization.
  932. Class specific attributes used by UI are move to object attributes.
  933. Creating a SerializedBaseOperator is a three-step process:
  934. 1. Instantiate a :class:`SerializedBaseOperator` object.
  935. 2. Populate attributes with :func:`SerializedBaseOperator.populated_operator`.
  936. 3. When the task's containing DAG is available, fix references to the DAG
  937. with :func:`SerializedBaseOperator.set_task_dag_references`.
  938. """
  939. _decorated_fields = {"executor_config"}
  940. _CONSTRUCTOR_PARAMS = {
  941. k: v.default
  942. for k, v in signature(BaseOperator.__init__).parameters.items()
  943. if v.default is not v.empty
  944. }
  945. def __init__(self, *args, **kwargs):
  946. super().__init__(*args, **kwargs)
  947. # task_type is used by UI to display the correct class type, because UI only
  948. # receives BaseOperator from deserialized DAGs.
  949. self._task_type = "BaseOperator"
  950. # Move class attributes into object attributes.
  951. self.ui_color = BaseOperator.ui_color
  952. self.ui_fgcolor = BaseOperator.ui_fgcolor
  953. self.template_ext = BaseOperator.template_ext
  954. self.template_fields = BaseOperator.template_fields
  955. self.operator_extra_links = BaseOperator.operator_extra_links
  956. @property
  957. def task_type(self) -> str:
  958. # Overwrites task_type of BaseOperator to use _task_type instead of
  959. # __class__.__name__.
  960. return self._task_type
  961. @task_type.setter
  962. def task_type(self, task_type: str):
  963. self._task_type = task_type
  964. @property
  965. def operator_name(self) -> str:
  966. # Overwrites operator_name of BaseOperator to use _operator_name instead of
  967. # __class__.operator_name.
  968. return self._operator_name
  969. @operator_name.setter
  970. def operator_name(self, operator_name: str):
  971. self._operator_name = operator_name
  972. @classmethod
  973. def serialize_mapped_operator(cls, op: MappedOperator) -> dict[str, Any]:
  974. serialized_op = cls._serialize_node(op, include_deps=op.deps != MappedOperator.deps_for(BaseOperator))
  975. # Handle expand_input and op_kwargs_expand_input.
  976. expansion_kwargs = op._get_specified_expand_input()
  977. if TYPE_CHECKING: # Let Mypy check the input type for us!
  978. _ExpandInputRef.validate_expand_input_value(expansion_kwargs.value)
  979. serialized_op[op._expand_input_attr] = {
  980. "type": get_map_type_key(expansion_kwargs),
  981. "value": cls.serialize(expansion_kwargs.value),
  982. }
  983. # Simplify partial_kwargs by comparing it to the most barebone object.
  984. # Remove all entries that are simply default values.
  985. serialized_partial = serialized_op["partial_kwargs"]
  986. for k, default in _get_default_mapped_partial().items():
  987. try:
  988. v = serialized_partial[k]
  989. except KeyError:
  990. continue
  991. if v == default:
  992. del serialized_partial[k]
  993. serialized_op["_is_mapped"] = True
  994. return serialized_op
  995. @classmethod
  996. def serialize_operator(cls, op: BaseOperator | MappedOperator) -> dict[str, Any]:
  997. return cls._serialize_node(op, include_deps=op.deps is not BaseOperator.deps)
  998. @classmethod
  999. def _serialize_node(cls, op: BaseOperator | MappedOperator, include_deps: bool) -> dict[str, Any]:
  1000. """Serialize operator into a JSON object."""
  1001. serialize_op = cls.serialize_to_json(op, cls._decorated_fields)
  1002. serialize_op["_task_type"] = getattr(op, "_task_type", type(op).__name__)
  1003. serialize_op["_task_module"] = getattr(op, "_task_module", type(op).__module__)
  1004. if op.operator_name != serialize_op["_task_type"]:
  1005. serialize_op["_operator_name"] = op.operator_name
  1006. # Used to determine if an Operator is inherited from EmptyOperator
  1007. serialize_op["_is_empty"] = op.inherits_from_empty_operator
  1008. serialize_op["start_trigger_args"] = (
  1009. encode_start_trigger_args(op.start_trigger_args) if op.start_trigger_args else None
  1010. )
  1011. serialize_op["start_from_trigger"] = op.start_from_trigger
  1012. if op.operator_extra_links:
  1013. serialize_op["_operator_extra_links"] = cls._serialize_operator_extra_links(
  1014. op.operator_extra_links.__get__(op)
  1015. if isinstance(op.operator_extra_links, property)
  1016. else op.operator_extra_links
  1017. )
  1018. if include_deps:
  1019. serialize_op["deps"] = cls._serialize_deps(op.deps)
  1020. # Store all template_fields as they are if there are JSON Serializable
  1021. # If not, store them as strings
  1022. # And raise an exception if the field is not templateable
  1023. forbidden_fields = set(inspect.signature(BaseOperator.__init__).parameters.keys())
  1024. # Though allow some of the BaseOperator fields to be templated anyway
  1025. forbidden_fields.difference_update({"email"})
  1026. if op.template_fields:
  1027. for template_field in op.template_fields:
  1028. if template_field in forbidden_fields:
  1029. raise AirflowException(
  1030. dedent(
  1031. f"""Cannot template BaseOperator field:
  1032. {template_field!r} {op.__class__.__name__=} {op.template_fields=}"""
  1033. )
  1034. )
  1035. value = getattr(op, template_field, None)
  1036. if not cls._is_excluded(value, template_field, op):
  1037. serialize_op[template_field] = serialize_template_field(value, template_field)
  1038. if op.params:
  1039. serialize_op["params"] = cls._serialize_params_dict(op.params)
  1040. return serialize_op
  1041. @classmethod
  1042. def _serialize_deps(cls, op_deps: Iterable[BaseTIDep]) -> list[str]:
  1043. from airflow import plugins_manager
  1044. plugins_manager.initialize_ti_deps_plugins()
  1045. if plugins_manager.registered_ti_dep_classes is None:
  1046. raise AirflowException("Can not load plugins")
  1047. deps = []
  1048. for dep in op_deps:
  1049. klass = type(dep)
  1050. module_name = klass.__module__
  1051. qualname = f"{module_name}.{klass.__name__}"
  1052. if (
  1053. not qualname.startswith("airflow.ti_deps.deps.")
  1054. and qualname not in plugins_manager.registered_ti_dep_classes
  1055. ):
  1056. raise SerializationError(
  1057. f"Custom dep class {qualname} not serialized, please register it through plugins."
  1058. )
  1059. deps.append(qualname)
  1060. # deps needs to be sorted here, because op_deps is a set, which is unstable when traversing,
  1061. # and the same call may get different results.
  1062. # When calling json.dumps(self.data, sort_keys=True) to generate dag_hash, misjudgment will occur
  1063. return sorted(deps)
  1064. @classmethod
  1065. def populate_operator(cls, op: Operator, encoded_op: dict[str, Any]) -> None:
  1066. """
  1067. Populate operator attributes with serialized values.
  1068. This covers simple attributes that don't reference other things in the
  1069. DAG. Setting references (such as ``op.dag`` and task dependencies) is
  1070. done in ``set_task_dag_references`` instead, which is called after the
  1071. DAG is hydrated.
  1072. """
  1073. if "label" not in encoded_op:
  1074. # Handle deserialization of old data before the introduction of TaskGroup
  1075. encoded_op["label"] = encoded_op["task_id"]
  1076. # Extra Operator Links defined in Plugins
  1077. op_extra_links_from_plugin = {}
  1078. if "_operator_name" not in encoded_op:
  1079. encoded_op["_operator_name"] = encoded_op["_task_type"]
  1080. # We don't want to load Extra Operator links in Scheduler
  1081. if cls._load_operator_extra_links:
  1082. from airflow import plugins_manager
  1083. plugins_manager.initialize_extra_operators_links_plugins()
  1084. if plugins_manager.operator_extra_links is None:
  1085. raise AirflowException("Can not load plugins")
  1086. for ope in plugins_manager.operator_extra_links:
  1087. for operator in ope.operators:
  1088. if (
  1089. operator.__name__ == encoded_op["_task_type"]
  1090. and operator.__module__ == encoded_op["_task_module"]
  1091. ):
  1092. op_extra_links_from_plugin.update({ope.name: ope})
  1093. # If OperatorLinks are defined in Plugins but not in the Operator that is being Serialized
  1094. # set the Operator links attribute
  1095. # The case for "If OperatorLinks are defined in the operator that is being Serialized"
  1096. # is handled in the deserialization loop where it matches k == "_operator_extra_links"
  1097. if op_extra_links_from_plugin and "_operator_extra_links" not in encoded_op:
  1098. setattr(op, "operator_extra_links", list(op_extra_links_from_plugin.values()))
  1099. for k, v in encoded_op.items():
  1100. # Todo: TODO: Remove in Airflow 3.0 when dummy operator is removed
  1101. if k == "_is_dummy":
  1102. k = "_is_empty"
  1103. if k in ("_outlets", "_inlets"):
  1104. # `_outlets` -> `outlets`
  1105. k = k[1:]
  1106. if k == "_downstream_task_ids":
  1107. # Upgrade from old format/name
  1108. k = "downstream_task_ids"
  1109. if k == "label":
  1110. # Label shouldn't be set anymore -- it's computed from task_id now
  1111. continue
  1112. elif k == "downstream_task_ids":
  1113. v = set(v)
  1114. elif k == "subdag":
  1115. v = SerializedDAG.deserialize_dag(v)
  1116. elif k in {"retry_delay", "execution_timeout", "sla", "max_retry_delay"}:
  1117. v = cls._deserialize_timedelta(v)
  1118. elif k in encoded_op["template_fields"]:
  1119. pass
  1120. elif k == "resources":
  1121. v = Resources.from_dict(v)
  1122. elif k.endswith("_date"):
  1123. v = cls._deserialize_datetime(v)
  1124. elif k == "_operator_extra_links":
  1125. if cls._load_operator_extra_links:
  1126. op_predefined_extra_links = cls._deserialize_operator_extra_links(v)
  1127. # If OperatorLinks with the same name exists, Links via Plugin have higher precedence
  1128. op_predefined_extra_links.update(op_extra_links_from_plugin)
  1129. else:
  1130. op_predefined_extra_links = {}
  1131. v = list(op_predefined_extra_links.values())
  1132. k = "operator_extra_links"
  1133. elif k == "deps":
  1134. v = cls._deserialize_deps(v)
  1135. elif k == "params":
  1136. v = cls._deserialize_params_dict(v)
  1137. if op.params: # Merge existing params if needed.
  1138. v, new = op.params, v
  1139. v.update(new)
  1140. elif k == "partial_kwargs":
  1141. v = {arg: cls.deserialize(value) for arg, value in v.items()}
  1142. elif k in {"expand_input", "op_kwargs_expand_input"}:
  1143. v = _ExpandInputRef(v["type"], cls.deserialize(v["value"]))
  1144. elif k == "operator_class":
  1145. v = {k_: cls.deserialize(v_, use_pydantic_models=True) for k_, v_ in v.items()}
  1146. elif (
  1147. k in cls._decorated_fields
  1148. or k not in op.get_serialized_fields()
  1149. or k in ("outlets", "inlets")
  1150. ):
  1151. v = cls.deserialize(v)
  1152. elif k == "on_failure_fail_dagrun":
  1153. k = "_on_failure_fail_dagrun"
  1154. elif k == "weight_rule":
  1155. v = decode_priority_weight_strategy(v)
  1156. # else use v as it is
  1157. setattr(op, k, v)
  1158. for k in op.get_serialized_fields() - encoded_op.keys() - cls._CONSTRUCTOR_PARAMS.keys():
  1159. # TODO: refactor deserialization of BaseOperator and MappedOperator (split it out), then check
  1160. # could go away.
  1161. if not hasattr(op, k):
  1162. setattr(op, k, None)
  1163. # Set all the template_field to None that were not present in Serialized JSON
  1164. for field in op.template_fields:
  1165. if not hasattr(op, field):
  1166. setattr(op, field, None)
  1167. # Used to determine if an Operator is inherited from EmptyOperator
  1168. setattr(op, "_is_empty", bool(encoded_op.get("_is_empty", False)))
  1169. start_trigger_args = None
  1170. encoded_start_trigger_args = encoded_op.get("start_trigger_args", None)
  1171. if encoded_start_trigger_args:
  1172. encoded_start_trigger_args = cast(dict, encoded_start_trigger_args)
  1173. start_trigger_args = decode_start_trigger_args(encoded_start_trigger_args)
  1174. setattr(op, "start_trigger_args", start_trigger_args)
  1175. setattr(op, "start_from_trigger", bool(encoded_op.get("start_from_trigger", False)))
  1176. @staticmethod
  1177. def set_task_dag_references(task: Operator, dag: DAG) -> None:
  1178. """
  1179. Handle DAG references on an operator.
  1180. The operator should have been mostly populated earlier by calling
  1181. ``populate_operator``. This function further fixes object references
  1182. that were not possible before the task's containing DAG is hydrated.
  1183. """
  1184. task.dag = dag
  1185. for date_attr in ("start_date", "end_date"):
  1186. if getattr(task, date_attr, None) is None:
  1187. setattr(task, date_attr, getattr(dag, date_attr, None))
  1188. if task.subdag is not None:
  1189. task.subdag.parent_dag = dag
  1190. # Dereference expand_input and op_kwargs_expand_input.
  1191. for k in ("expand_input", "op_kwargs_expand_input"):
  1192. if isinstance(kwargs_ref := getattr(task, k, None), _ExpandInputRef):
  1193. setattr(task, k, kwargs_ref.deref(dag))
  1194. for task_id in task.downstream_task_ids:
  1195. # Bypass set_upstream etc here - it does more than we want
  1196. dag.task_dict[task_id].upstream_task_ids.add(task.task_id)
  1197. @classmethod
  1198. def deserialize_operator(cls, encoded_op: dict[str, Any]) -> Operator:
  1199. """Deserializes an operator from a JSON object."""
  1200. op: Operator
  1201. if encoded_op.get("_is_mapped", False):
  1202. # Most of these will be loaded later, these are just some stand-ins.
  1203. op_data = {k: v for k, v in encoded_op.items() if k in BaseOperator.get_serialized_fields()}
  1204. try:
  1205. operator_name = encoded_op["_operator_name"]
  1206. except KeyError:
  1207. operator_name = encoded_op["_task_type"]
  1208. op = MappedOperator(
  1209. operator_class=op_data,
  1210. expand_input=EXPAND_INPUT_EMPTY,
  1211. partial_kwargs={},
  1212. task_id=encoded_op["task_id"],
  1213. params={},
  1214. deps=MappedOperator.deps_for(BaseOperator),
  1215. operator_extra_links=BaseOperator.operator_extra_links,
  1216. template_ext=BaseOperator.template_ext,
  1217. template_fields=BaseOperator.template_fields,
  1218. template_fields_renderers=BaseOperator.template_fields_renderers,
  1219. ui_color=BaseOperator.ui_color,
  1220. ui_fgcolor=BaseOperator.ui_fgcolor,
  1221. is_empty=False,
  1222. task_module=encoded_op["_task_module"],
  1223. task_type=encoded_op["_task_type"],
  1224. operator_name=operator_name,
  1225. dag=None,
  1226. task_group=None,
  1227. start_date=None,
  1228. end_date=None,
  1229. disallow_kwargs_override=encoded_op["_disallow_kwargs_override"],
  1230. expand_input_attr=encoded_op["_expand_input_attr"],
  1231. start_trigger_args=encoded_op.get("start_trigger_args", None),
  1232. start_from_trigger=encoded_op.get("start_from_trigger", False),
  1233. )
  1234. else:
  1235. op = SerializedBaseOperator(task_id=encoded_op["task_id"])
  1236. op.dag = AttributeRemoved("dag") # type: ignore[assignment]
  1237. cls.populate_operator(op, encoded_op)
  1238. return op
  1239. @classmethod
  1240. def detect_dependencies(cls, op: Operator) -> set[DagDependency]:
  1241. """Detect between DAG dependencies for the operator."""
  1242. def get_custom_dep() -> list[DagDependency]:
  1243. """
  1244. If custom dependency detector is configured, use it.
  1245. TODO: Remove this logic in 3.0.
  1246. """
  1247. custom_dependency_detector_cls = conf.getimport("scheduler", "dependency_detector", fallback=None)
  1248. if not (
  1249. custom_dependency_detector_cls is None or custom_dependency_detector_cls is DependencyDetector
  1250. ):
  1251. warnings.warn(
  1252. "Use of a custom dependency detector is deprecated. "
  1253. "Support will be removed in a future release.",
  1254. RemovedInAirflow3Warning,
  1255. stacklevel=1,
  1256. )
  1257. dep = custom_dependency_detector_cls().detect_task_dependencies(op)
  1258. if type(dep) is DagDependency:
  1259. return [dep]
  1260. return []
  1261. dependency_detector = DependencyDetector()
  1262. deps = set(dependency_detector.detect_task_dependencies(op))
  1263. deps.update(get_custom_dep()) # todo: remove in 3.0
  1264. return deps
  1265. @classmethod
  1266. def _is_excluded(cls, var: Any, attrname: str, op: DAGNode):
  1267. if (
  1268. var is not None
  1269. and op.has_dag()
  1270. and op.dag.__class__ is not AttributeRemoved
  1271. and attrname.endswith("_date")
  1272. ):
  1273. # If this date is the same as the matching field in the dag, then
  1274. # don't store it again at the task level.
  1275. dag_date = getattr(op.dag, attrname, None)
  1276. if var is dag_date or var == dag_date:
  1277. return True
  1278. return super()._is_excluded(var, attrname, op)
  1279. @classmethod
  1280. def _deserialize_deps(cls, deps: list[str]) -> set[BaseTIDep]:
  1281. from airflow import plugins_manager
  1282. plugins_manager.initialize_ti_deps_plugins()
  1283. if plugins_manager.registered_ti_dep_classes is None:
  1284. raise AirflowException("Can not load plugins")
  1285. instances = set()
  1286. for qn in set(deps):
  1287. if (
  1288. not qn.startswith("airflow.ti_deps.deps.")
  1289. and qn not in plugins_manager.registered_ti_dep_classes
  1290. ):
  1291. raise SerializationError(
  1292. f"Custom dep class {qn} not deserialized, please register it through plugins."
  1293. )
  1294. try:
  1295. instances.add(import_string(qn)())
  1296. except ImportError:
  1297. log.warning("Error importing dep %r", qn, exc_info=True)
  1298. return instances
  1299. @classmethod
  1300. def _deserialize_operator_extra_links(cls, encoded_op_links: list) -> dict[str, BaseOperatorLink]:
  1301. """
  1302. Deserialize Operator Links if the Classes are registered in Airflow Plugins.
  1303. Error is raised if the OperatorLink is not found in Plugins too.
  1304. :param encoded_op_links: Serialized Operator Link
  1305. :return: De-Serialized Operator Link
  1306. """
  1307. from airflow import plugins_manager
  1308. plugins_manager.initialize_extra_operators_links_plugins()
  1309. if plugins_manager.registered_operator_link_classes is None:
  1310. raise AirflowException("Can't load plugins")
  1311. op_predefined_extra_links = {}
  1312. for _operator_links_source in encoded_op_links:
  1313. # Get the key, value pair as Tuple where key is OperatorLink ClassName
  1314. # and value is the dictionary containing the arguments passed to the OperatorLink
  1315. #
  1316. # Example of a single iteration:
  1317. #
  1318. # _operator_links_source =
  1319. # {
  1320. # 'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink': {
  1321. # 'index': 0
  1322. # }
  1323. # },
  1324. #
  1325. # list(_operator_links_source.items()) =
  1326. # [
  1327. # (
  1328. # 'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink',
  1329. # {'index': 0}
  1330. # )
  1331. # ]
  1332. #
  1333. # list(_operator_links_source.items())[0] =
  1334. # (
  1335. # 'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink',
  1336. # {
  1337. # 'index': 0
  1338. # }
  1339. # )
  1340. _operator_link_class_path, data = next(iter(_operator_links_source.items()))
  1341. if _operator_link_class_path in get_operator_extra_links():
  1342. single_op_link_class = import_string(_operator_link_class_path)
  1343. elif _operator_link_class_path in plugins_manager.registered_operator_link_classes:
  1344. single_op_link_class = plugins_manager.registered_operator_link_classes[
  1345. _operator_link_class_path
  1346. ]
  1347. else:
  1348. log.error("Operator Link class %r not registered", _operator_link_class_path)
  1349. return {}
  1350. op_link_parameters = {param: cls.deserialize(value) for param, value in data.items()}
  1351. op_predefined_extra_link: BaseOperatorLink = single_op_link_class(**op_link_parameters)
  1352. op_predefined_extra_links.update({op_predefined_extra_link.name: op_predefined_extra_link})
  1353. return op_predefined_extra_links
  1354. @classmethod
  1355. def _serialize_operator_extra_links(cls, operator_extra_links: Iterable[BaseOperatorLink]):
  1356. """
  1357. Serialize Operator Links.
  1358. Store the import path of the OperatorLink and the arguments passed to it.
  1359. For example:
  1360. ``[{'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleLink': {}}]``
  1361. :param operator_extra_links: Operator Link
  1362. :return: Serialized Operator Link
  1363. """
  1364. serialize_operator_extra_links = []
  1365. for operator_extra_link in operator_extra_links:
  1366. op_link_arguments = {
  1367. param: cls.serialize(value) for param, value in attrs.asdict(operator_extra_link).items()
  1368. }
  1369. module_path = (
  1370. f"{operator_extra_link.__class__.__module__}.{operator_extra_link.__class__.__name__}"
  1371. )
  1372. serialize_operator_extra_links.append({module_path: op_link_arguments})
  1373. return serialize_operator_extra_links
  1374. @classmethod
  1375. def serialize(cls, var: Any, *, strict: bool = False, use_pydantic_models: bool = False) -> Any:
  1376. # the wonders of multiple inheritance BaseOperator defines an instance method
  1377. return BaseSerialization.serialize(var=var, strict=strict, use_pydantic_models=use_pydantic_models)
  1378. @classmethod
  1379. def deserialize(cls, encoded_var: Any, use_pydantic_models: bool = False) -> Any:
  1380. return BaseSerialization.deserialize(encoded_var=encoded_var, use_pydantic_models=use_pydantic_models)
  1381. class SerializedDAG(DAG, BaseSerialization):
  1382. """
  1383. A JSON serializable representation of DAG.
  1384. A stringified DAG can only be used in the scope of scheduler and webserver, because fields
  1385. that are not serializable, such as functions and customer defined classes, are casted to
  1386. strings.
  1387. Compared with SimpleDAG: SerializedDAG contains all information for webserver.
  1388. Compared with DagPickle: DagPickle contains all information for worker, but some DAGs are
  1389. not pickle-able. SerializedDAG works for all DAGs.
  1390. """
  1391. _decorated_fields = {"schedule_interval", "default_args", "_access_control"}
  1392. @staticmethod
  1393. def __get_constructor_defaults():
  1394. param_to_attr = {
  1395. "max_active_tasks": "_max_active_tasks",
  1396. "dag_display_name": "_dag_display_property_value",
  1397. "description": "_description",
  1398. "default_view": "_default_view",
  1399. "access_control": "_access_control",
  1400. }
  1401. return {
  1402. param_to_attr.get(k, k): v.default
  1403. for k, v in signature(DAG.__init__).parameters.items()
  1404. if v.default is not v.empty
  1405. }
  1406. _CONSTRUCTOR_PARAMS = __get_constructor_defaults.__func__() # type: ignore
  1407. del __get_constructor_defaults
  1408. _json_schema = lazy_object_proxy.Proxy(load_dag_schema)
  1409. @classmethod
  1410. def serialize_dag(cls, dag: DAG) -> dict:
  1411. """Serialize a DAG into a JSON object."""
  1412. try:
  1413. serialized_dag = cls.serialize_to_json(dag, cls._decorated_fields)
  1414. serialized_dag["_processor_dags_folder"] = DAGS_FOLDER
  1415. # If schedule_interval is backed by timetable, serialize only
  1416. # timetable; vice versa for a timetable backed by schedule_interval.
  1417. if dag.timetable.summary == dag.schedule_interval:
  1418. del serialized_dag["schedule_interval"]
  1419. else:
  1420. del serialized_dag["timetable"]
  1421. serialized_dag["tasks"] = [cls.serialize(task) for _, task in dag.task_dict.items()]
  1422. dag_deps = [
  1423. dep
  1424. for task in dag.task_dict.values()
  1425. for dep in SerializedBaseOperator.detect_dependencies(task)
  1426. ]
  1427. dag_deps.extend(DependencyDetector.detect_dag_dependencies(dag))
  1428. serialized_dag["dag_dependencies"] = [x.__dict__ for x in sorted(dag_deps)]
  1429. serialized_dag["_task_group"] = TaskGroupSerialization.serialize_task_group(dag.task_group)
  1430. # Edge info in the JSON exactly matches our internal structure
  1431. serialized_dag["edge_info"] = dag.edge_info
  1432. serialized_dag["params"] = cls._serialize_params_dict(dag.params)
  1433. # has_on_*_callback are only stored if the value is True, as the default is False
  1434. if dag.has_on_success_callback:
  1435. serialized_dag["has_on_success_callback"] = True
  1436. if dag.has_on_failure_callback:
  1437. serialized_dag["has_on_failure_callback"] = True
  1438. return serialized_dag
  1439. except SerializationError:
  1440. raise
  1441. except Exception as e:
  1442. raise SerializationError(f"Failed to serialize DAG {dag.dag_id!r}: {e}")
  1443. @classmethod
  1444. def deserialize_dag(cls, encoded_dag: dict[str, Any]) -> SerializedDAG:
  1445. """Deserializes a DAG from a JSON object."""
  1446. dag = SerializedDAG(dag_id=encoded_dag["_dag_id"], schedule=None)
  1447. for k, v in encoded_dag.items():
  1448. if k == "_downstream_task_ids":
  1449. v = set(v)
  1450. elif k == "tasks":
  1451. SerializedBaseOperator._load_operator_extra_links = cls._load_operator_extra_links
  1452. tasks = {}
  1453. for obj in v:
  1454. if obj.get(Encoding.TYPE) == DAT.OP:
  1455. deser = SerializedBaseOperator.deserialize_operator(obj[Encoding.VAR])
  1456. tasks[deser.task_id] = deser
  1457. else: # todo: remove in Airflow 3.0 (backcompat for pre-2.10)
  1458. tasks[obj["task_id"]] = SerializedBaseOperator.deserialize_operator(obj)
  1459. k = "task_dict"
  1460. v = tasks
  1461. elif k == "timezone":
  1462. v = cls._deserialize_timezone(v)
  1463. elif k == "dagrun_timeout":
  1464. v = cls._deserialize_timedelta(v)
  1465. elif k.endswith("_date"):
  1466. v = cls._deserialize_datetime(v)
  1467. elif k == "edge_info":
  1468. # Value structure matches exactly
  1469. pass
  1470. elif k == "timetable":
  1471. v = decode_timetable(v)
  1472. elif k == "weight_rule":
  1473. v = decode_priority_weight_strategy(v)
  1474. elif k in cls._decorated_fields:
  1475. v = cls.deserialize(v)
  1476. elif k == "params":
  1477. v = cls._deserialize_params_dict(v)
  1478. # else use v as it is
  1479. setattr(dag, k, v)
  1480. # A DAG is always serialized with only one of schedule_interval and
  1481. # timetable. This back-populates the other to ensure the two attributes
  1482. # line up correctly on the DAG instance.
  1483. if "timetable" in encoded_dag:
  1484. dag.schedule_interval = dag.timetable.summary
  1485. else:
  1486. dag.timetable = create_timetable(dag.schedule_interval, dag.timezone)
  1487. # Set _task_group
  1488. if "_task_group" in encoded_dag:
  1489. dag._task_group = TaskGroupSerialization.deserialize_task_group(
  1490. encoded_dag["_task_group"],
  1491. None,
  1492. dag.task_dict,
  1493. dag,
  1494. )
  1495. else:
  1496. # This must be old data that had no task_group. Create a root TaskGroup and add
  1497. # all tasks to it.
  1498. dag._task_group = TaskGroup.create_root(dag)
  1499. for task in dag.tasks:
  1500. dag.task_group.add(task)
  1501. # Set has_on_*_callbacks to True if they exist in Serialized blob as False is the default
  1502. if "has_on_success_callback" in encoded_dag:
  1503. dag.has_on_success_callback = True
  1504. if "has_on_failure_callback" in encoded_dag:
  1505. dag.has_on_failure_callback = True
  1506. keys_to_set_none = dag.get_serialized_fields() - encoded_dag.keys() - cls._CONSTRUCTOR_PARAMS.keys()
  1507. for k in keys_to_set_none:
  1508. setattr(dag, k, None)
  1509. for task in dag.task_dict.values():
  1510. SerializedBaseOperator.set_task_dag_references(task, dag)
  1511. return dag
  1512. @classmethod
  1513. def _is_excluded(cls, var: Any, attrname: str, op: DAGNode):
  1514. # {} is explicitly different from None in the case of DAG-level access control
  1515. # and as a result we need to preserve empty dicts through serialization for this field
  1516. if attrname == "_access_control" and var is not None:
  1517. return False
  1518. return super()._is_excluded(var, attrname, op)
  1519. @classmethod
  1520. def to_dict(cls, var: Any) -> dict:
  1521. """Stringifies DAGs and operators contained by var and returns a dict of var."""
  1522. json_dict = {"__version": cls.SERIALIZER_VERSION, "dag": cls.serialize_dag(var)}
  1523. # Validate Serialized DAG with Json Schema. Raises Error if it mismatches
  1524. cls.validate_schema(json_dict)
  1525. return json_dict
  1526. @classmethod
  1527. def from_dict(cls, serialized_obj: dict) -> SerializedDAG:
  1528. """Deserializes a python dict in to the DAG and operators it contains."""
  1529. ver = serialized_obj.get("__version", "<not present>")
  1530. if ver != cls.SERIALIZER_VERSION:
  1531. raise ValueError(f"Unsure how to deserialize version {ver!r}")
  1532. return cls.deserialize_dag(serialized_obj["dag"])
  1533. class TaskGroupSerialization(BaseSerialization):
  1534. """JSON serializable representation of a task group."""
  1535. @classmethod
  1536. def serialize_task_group(cls, task_group: TaskGroup) -> dict[str, Any] | None:
  1537. """Serialize TaskGroup into a JSON object."""
  1538. if not task_group:
  1539. return None
  1540. # task_group.xxx_ids needs to be sorted here, because task_group.xxx_ids is a set,
  1541. # when converting set to list, the order is uncertain.
  1542. # When calling json.dumps(self.data, sort_keys=True) to generate dag_hash, misjudgment will occur
  1543. encoded = {
  1544. "_group_id": task_group._group_id,
  1545. "prefix_group_id": task_group.prefix_group_id,
  1546. "tooltip": task_group.tooltip,
  1547. "ui_color": task_group.ui_color,
  1548. "ui_fgcolor": task_group.ui_fgcolor,
  1549. "children": {
  1550. label: child.serialize_for_task_group() for label, child in task_group.children.items()
  1551. },
  1552. "upstream_group_ids": cls.serialize(sorted(task_group.upstream_group_ids)),
  1553. "downstream_group_ids": cls.serialize(sorted(task_group.downstream_group_ids)),
  1554. "upstream_task_ids": cls.serialize(sorted(task_group.upstream_task_ids)),
  1555. "downstream_task_ids": cls.serialize(sorted(task_group.downstream_task_ids)),
  1556. }
  1557. if isinstance(task_group, MappedTaskGroup):
  1558. expand_input = task_group._expand_input
  1559. encoded["expand_input"] = {
  1560. "type": get_map_type_key(expand_input),
  1561. "value": cls.serialize(expand_input.value),
  1562. }
  1563. encoded["is_mapped"] = True
  1564. return encoded
  1565. @classmethod
  1566. def deserialize_task_group(
  1567. cls,
  1568. encoded_group: dict[str, Any],
  1569. parent_group: TaskGroup | None,
  1570. task_dict: dict[str, Operator],
  1571. dag: SerializedDAG,
  1572. ) -> TaskGroup:
  1573. """Deserializes a TaskGroup from a JSON object."""
  1574. group_id = cls.deserialize(encoded_group["_group_id"])
  1575. kwargs = {
  1576. key: cls.deserialize(encoded_group[key])
  1577. for key in ["prefix_group_id", "tooltip", "ui_color", "ui_fgcolor"]
  1578. }
  1579. if not encoded_group.get("is_mapped"):
  1580. group = TaskGroup(group_id=group_id, parent_group=parent_group, dag=dag, **kwargs)
  1581. else:
  1582. xi = encoded_group["expand_input"]
  1583. group = MappedTaskGroup(
  1584. group_id=group_id,
  1585. parent_group=parent_group,
  1586. dag=dag,
  1587. expand_input=_ExpandInputRef(xi["type"], cls.deserialize(xi["value"])).deref(dag),
  1588. **kwargs,
  1589. )
  1590. def set_ref(task: Operator) -> Operator:
  1591. task.task_group = weakref.proxy(group)
  1592. return task
  1593. group.children = {
  1594. label: (
  1595. set_ref(task_dict[val])
  1596. if _type == DAT.OP
  1597. else cls.deserialize_task_group(val, group, task_dict, dag=dag)
  1598. )
  1599. for label, (_type, val) in encoded_group["children"].items()
  1600. }
  1601. group.upstream_group_ids.update(cls.deserialize(encoded_group["upstream_group_ids"]))
  1602. group.downstream_group_ids.update(cls.deserialize(encoded_group["downstream_group_ids"]))
  1603. group.upstream_task_ids.update(cls.deserialize(encoded_group["upstream_task_ids"]))
  1604. group.downstream_task_ids.update(cls.deserialize(encoded_group["downstream_task_ids"]))
  1605. return group
  1606. def _has_kubernetes() -> bool:
  1607. global HAS_KUBERNETES
  1608. if "HAS_KUBERNETES" in globals():
  1609. return HAS_KUBERNETES
  1610. # Loading kube modules is expensive, so delay it until the last moment
  1611. try:
  1612. from kubernetes.client import models as k8s
  1613. try:
  1614. from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator
  1615. except ImportError:
  1616. from airflow.kubernetes.pre_7_4_0_compatibility.pod_generator import ( # type: ignore[assignment]
  1617. PodGenerator,
  1618. )
  1619. globals()["k8s"] = k8s
  1620. globals()["PodGenerator"] = PodGenerator
  1621. # isort: on
  1622. HAS_KUBERNETES = True
  1623. except ImportError:
  1624. HAS_KUBERNETES = False
  1625. return HAS_KUBERNETES