12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895 |
- # Licensed to the Apache Software Foundation (ASF) under one
- # or more contributor license agreements. See the NOTICE file
- # distributed with this work for additional information
- # regarding copyright ownership. The ASF licenses this file
- # to you under the Apache License, Version 2.0 (the
- # "License"); you may not use this file except in compliance
- # with the License. You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- # KIND, either express or implied. See the License for the
- # specific language governing permissions and limitations
- # under the License.
- """Serialized DAG and BaseOperator."""
- from __future__ import annotations
- import collections.abc
- import datetime
- import enum
- import inspect
- import logging
- import warnings
- import weakref
- from inspect import signature
- from textwrap import dedent
- from typing import TYPE_CHECKING, Any, Collection, Iterable, Mapping, NamedTuple, Union, cast
- import attrs
- import lazy_object_proxy
- from dateutil import relativedelta
- from pendulum.tz.timezone import FixedTimezone, Timezone
- from airflow import macros
- from airflow.callbacks.callback_requests import DagCallbackRequest, SlaCallbackRequest, TaskCallbackRequest
- from airflow.compat.functools import cache
- from airflow.configuration import conf
- from airflow.datasets import (
- BaseDataset,
- Dataset,
- DatasetAlias,
- DatasetAll,
- DatasetAny,
- _DatasetAliasCondition,
- )
- from airflow.exceptions import AirflowException, RemovedInAirflow3Warning, SerializationError, TaskDeferred
- from airflow.jobs.job import Job
- from airflow.models import Trigger
- from airflow.models.baseoperator import BaseOperator
- from airflow.models.connection import Connection
- from airflow.models.dag import DAG, DagModel, create_timetable
- from airflow.models.dagrun import DagRun
- from airflow.models.expandinput import EXPAND_INPUT_EMPTY, create_expand_input, get_map_type_key
- from airflow.models.mappedoperator import MappedOperator
- from airflow.models.param import Param, ParamsDict
- from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance
- from airflow.models.taskinstancekey import TaskInstanceKey
- from airflow.models.tasklog import LogTemplate
- from airflow.models.xcom_arg import XComArg, deserialize_xcom_arg, serialize_xcom_arg
- from airflow.providers_manager import ProvidersManager
- from airflow.serialization.dag_dependency import DagDependency
- from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
- from airflow.serialization.helpers import serialize_template_field
- from airflow.serialization.json_schema import load_dag_schema
- from airflow.serialization.pydantic.dag import DagModelPydantic
- from airflow.serialization.pydantic.dag_run import DagRunPydantic
- from airflow.serialization.pydantic.dataset import DatasetPydantic
- from airflow.serialization.pydantic.job import JobPydantic
- from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
- from airflow.serialization.pydantic.tasklog import LogTemplatePydantic
- from airflow.serialization.pydantic.trigger import TriggerPydantic
- from airflow.settings import _ENABLE_AIP_44, DAGS_FOLDER, json
- from airflow.task.priority_strategy import (
- PriorityWeightStrategy,
- airflow_priority_weight_strategies,
- airflow_priority_weight_strategies_classes,
- )
- from airflow.triggers.base import BaseTrigger, StartTriggerArgs
- from airflow.utils.code_utils import get_python_source
- from airflow.utils.context import (
- ConnectionAccessor,
- Context,
- OutletEventAccessor,
- OutletEventAccessors,
- VariableAccessor,
- )
- from airflow.utils.db import LazySelectSequence
- from airflow.utils.docs import get_docs_url
- from airflow.utils.module_loading import import_string, qualname
- from airflow.utils.operator_resources import Resources
- from airflow.utils.task_group import MappedTaskGroup, TaskGroup
- from airflow.utils.timezone import from_timestamp, parse_timezone
- from airflow.utils.types import NOTSET, ArgNotSet, AttributeRemoved
- if TYPE_CHECKING:
- from inspect import Parameter
- from airflow.models.baseoperatorlink import BaseOperatorLink
- from airflow.models.expandinput import ExpandInput
- from airflow.models.operator import Operator
- from airflow.models.taskmixin import DAGNode
- from airflow.serialization.json_schema import Validator
- from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
- from airflow.timetables.base import Timetable
- from airflow.utils.pydantic import BaseModel
- HAS_KUBERNETES: bool
- try:
- from kubernetes.client import models as k8s # noqa: TCH004
- from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator # noqa: TCH004
- except ImportError:
- pass
- log = logging.getLogger(__name__)
- _OPERATOR_EXTRA_LINKS: set[str] = {
- "airflow.operators.trigger_dagrun.TriggerDagRunLink",
- "airflow.sensors.external_task.ExternalDagLink",
- # Deprecated names, so that existing serialized dags load straight away.
- "airflow.sensors.external_task.ExternalTaskSensorLink",
- "airflow.operators.dagrun_operator.TriggerDagRunLink",
- "airflow.sensors.external_task_sensor.ExternalTaskSensorLink",
- }
- @cache
- def get_operator_extra_links() -> set[str]:
- """
- Get the operator extra links.
- This includes both the built-in ones, and those come from the providers.
- """
- _OPERATOR_EXTRA_LINKS.update(ProvidersManager().extra_links_class_names)
- return _OPERATOR_EXTRA_LINKS
- @cache
- def _get_default_mapped_partial() -> dict[str, Any]:
- """
- Get default partial kwargs in a mapped operator.
- This is used to simplify a serialized mapped operator by excluding default
- values supplied in the implementation from the serialized dict. Since those
- are defaults, they are automatically supplied on de-serialization, so we
- don't need to store them.
- """
- # Use the private _expand() method to avoid the empty kwargs check.
- default = BaseOperator.partial(task_id="_")._expand(EXPAND_INPUT_EMPTY, strict=False).partial_kwargs
- return BaseSerialization.serialize(default)[Encoding.VAR]
- def encode_relativedelta(var: relativedelta.relativedelta) -> dict[str, Any]:
- """Encode a relativedelta object."""
- encoded = {k: v for k, v in var.__dict__.items() if not k.startswith("_") and v}
- if var.weekday and var.weekday.n:
- # Every n'th Friday for example
- encoded["weekday"] = [var.weekday.weekday, var.weekday.n]
- elif var.weekday:
- encoded["weekday"] = [var.weekday.weekday]
- return encoded
- def decode_relativedelta(var: dict[str, Any]) -> relativedelta.relativedelta:
- """Dencode a relativedelta object."""
- if "weekday" in var:
- var["weekday"] = relativedelta.weekday(*var["weekday"]) # type: ignore
- return relativedelta.relativedelta(**var)
- def encode_timezone(var: Timezone | FixedTimezone) -> str | int:
- """
- Encode a Pendulum Timezone for serialization.
- Airflow only supports timezone objects that implements Pendulum's Timezone
- interface. We try to keep as much information as possible to make conversion
- round-tripping possible (see ``decode_timezone``). We need to special-case
- UTC; Pendulum implements it as a FixedTimezone (i.e. it gets encoded as
- 0 without the special case), but passing 0 into ``pendulum.timezone`` does
- not give us UTC (but ``+00:00``).
- """
- if isinstance(var, FixedTimezone):
- if var.offset == 0:
- return "UTC"
- return var.offset
- if isinstance(var, Timezone):
- return var.name
- raise ValueError(
- f"DAG timezone should be a pendulum.tz.Timezone, not {var!r}. "
- f"See {get_docs_url('timezone.html#time-zone-aware-dags')}"
- )
- def decode_timezone(var: str | int) -> Timezone | FixedTimezone:
- """Decode a previously serialized Pendulum Timezone."""
- return parse_timezone(var)
- def _get_registered_timetable(importable_string: str) -> type[Timetable] | None:
- from airflow import plugins_manager
- if importable_string.startswith("airflow.timetables."):
- return import_string(importable_string)
- plugins_manager.initialize_timetables_plugins()
- if plugins_manager.timetable_classes:
- return plugins_manager.timetable_classes.get(importable_string)
- else:
- return None
- def _get_registered_priority_weight_strategy(importable_string: str) -> type[PriorityWeightStrategy] | None:
- from airflow import plugins_manager
- if importable_string in airflow_priority_weight_strategies:
- return airflow_priority_weight_strategies[importable_string]
- plugins_manager.initialize_priority_weight_strategy_plugins()
- if plugins_manager.priority_weight_strategy_classes:
- return plugins_manager.priority_weight_strategy_classes.get(importable_string)
- else:
- return None
- class _TimetableNotRegistered(ValueError):
- def __init__(self, type_string: str) -> None:
- self.type_string = type_string
- def __str__(self) -> str:
- return (
- f"Timetable class {self.type_string!r} is not registered or "
- "you have a top level database access that disrupted the session. "
- "Please check the airflow best practices documentation."
- )
- class _PriorityWeightStrategyNotRegistered(AirflowException):
- def __init__(self, type_string: str) -> None:
- self.type_string = type_string
- def __str__(self) -> str:
- return (
- f"Priority weight strategy class {self.type_string!r} is not registered or "
- "you have a top level database access that disrupted the session. "
- "Please check the airflow best practices documentation."
- )
- def encode_dataset_condition(var: BaseDataset) -> dict[str, Any]:
- """
- Encode a dataset condition.
- :meta private:
- """
- if isinstance(var, Dataset):
- return {"__type": DAT.DATASET, "uri": var.uri, "extra": var.extra}
- if isinstance(var, DatasetAlias):
- return {"__type": DAT.DATASET_ALIAS, "name": var.name}
- if isinstance(var, DatasetAll):
- return {"__type": DAT.DATASET_ALL, "objects": [encode_dataset_condition(x) for x in var.objects]}
- if isinstance(var, DatasetAny):
- return {"__type": DAT.DATASET_ANY, "objects": [encode_dataset_condition(x) for x in var.objects]}
- raise ValueError(f"serialization not implemented for {type(var).__name__!r}")
- def decode_dataset_condition(var: dict[str, Any]) -> BaseDataset:
- """
- Decode a previously serialized dataset condition.
- :meta private:
- """
- dat = var["__type"]
- if dat == DAT.DATASET:
- return Dataset(var["uri"], extra=var["extra"])
- if dat == DAT.DATASET_ALL:
- return DatasetAll(*(decode_dataset_condition(x) for x in var["objects"]))
- if dat == DAT.DATASET_ANY:
- return DatasetAny(*(decode_dataset_condition(x) for x in var["objects"]))
- if dat == DAT.DATASET_ALIAS:
- return DatasetAlias(name=var["name"])
- raise ValueError(f"deserialization not implemented for DAT {dat!r}")
- def encode_outlet_event_accessor(var: OutletEventAccessor) -> dict[str, Any]:
- raw_key = var.raw_key
- return {
- "extra": var.extra,
- "dataset_alias_events": var.dataset_alias_events,
- "raw_key": BaseSerialization.serialize(raw_key),
- }
- def decode_outlet_event_accessor(var: dict[str, Any]) -> OutletEventAccessor:
- # This is added for compatibility. The attribute used to be dataset_alias_event and
- # is now dataset_alias_events.
- if dataset_alias_event := var.get("dataset_alias_event", None):
- dataset_alias_events = [dataset_alias_event]
- else:
- dataset_alias_events = var.get("dataset_alias_events", [])
- outlet_event_accessor = OutletEventAccessor(
- extra=var["extra"],
- raw_key=BaseSerialization.deserialize(var["raw_key"]),
- dataset_alias_events=dataset_alias_events,
- )
- return outlet_event_accessor
- def encode_timetable(var: Timetable) -> dict[str, Any]:
- """
- Encode a timetable instance.
- This delegates most of the serialization work to the type, so the behavior
- can be completely controlled by a custom subclass.
- :meta private:
- """
- timetable_class = type(var)
- importable_string = qualname(timetable_class)
- if _get_registered_timetable(importable_string) is None:
- raise _TimetableNotRegistered(importable_string)
- return {Encoding.TYPE: importable_string, Encoding.VAR: var.serialize()}
- def decode_timetable(var: dict[str, Any]) -> Timetable:
- """
- Decode a previously serialized timetable.
- Most of the deserialization logic is delegated to the actual type, which
- we import from string.
- :meta private:
- """
- importable_string = var[Encoding.TYPE]
- timetable_class = _get_registered_timetable(importable_string)
- if timetable_class is None:
- raise _TimetableNotRegistered(importable_string)
- return timetable_class.deserialize(var[Encoding.VAR])
- def encode_priority_weight_strategy(var: PriorityWeightStrategy) -> str:
- """
- Encode a priority weight strategy instance.
- In this version, we only store the importable string, so the class should not wait
- for any parameters to be passed to it. If you need to store the parameters, you
- should store them in the class itself.
- """
- priority_weight_strategy_class = type(var)
- if priority_weight_strategy_class in airflow_priority_weight_strategies_classes:
- return airflow_priority_weight_strategies_classes[priority_weight_strategy_class]
- importable_string = qualname(priority_weight_strategy_class)
- if _get_registered_priority_weight_strategy(importable_string) is None:
- raise _PriorityWeightStrategyNotRegistered(importable_string)
- return importable_string
- def decode_priority_weight_strategy(var: str) -> PriorityWeightStrategy:
- """
- Decode a previously serialized priority weight strategy.
- In this version, we only store the importable string, so we just need to get the class
- from the dictionary of registered classes and instantiate it with no parameters.
- """
- priority_weight_strategy_class = _get_registered_priority_weight_strategy(var)
- if priority_weight_strategy_class is None:
- raise _PriorityWeightStrategyNotRegistered(var)
- return priority_weight_strategy_class()
- def encode_start_trigger_args(var: StartTriggerArgs) -> dict[str, Any]:
- """
- Encode a StartTriggerArgs.
- :meta private:
- """
- serialize_kwargs = (
- lambda key: BaseSerialization.serialize(getattr(var, key)) if getattr(var, key) is not None else None
- )
- return {
- "__type": "START_TRIGGER_ARGS",
- "trigger_cls": var.trigger_cls,
- "trigger_kwargs": serialize_kwargs("trigger_kwargs"),
- "next_method": var.next_method,
- "next_kwargs": serialize_kwargs("next_kwargs"),
- "timeout": var.timeout.total_seconds() if var.timeout else None,
- }
- def decode_start_trigger_args(var: dict[str, Any]) -> StartTriggerArgs:
- """
- Decode a StartTriggerArgs.
- :meta private:
- """
- deserialize_kwargs = lambda key: BaseSerialization.deserialize(var[key]) if var[key] is not None else None
- return StartTriggerArgs(
- trigger_cls=var["trigger_cls"],
- trigger_kwargs=deserialize_kwargs("trigger_kwargs"),
- next_method=var["next_method"],
- next_kwargs=deserialize_kwargs("next_kwargs"),
- timeout=datetime.timedelta(seconds=var["timeout"]) if var["timeout"] else None,
- )
- class _XComRef(NamedTuple):
- """
- Store info needed to create XComArg.
- We can't turn it in to a XComArg until we've loaded _all_ the tasks, so when
- deserializing an operator, we need to create something in its place, and
- post-process it in ``deserialize_dag``.
- """
- data: dict
- def deref(self, dag: DAG) -> XComArg:
- return deserialize_xcom_arg(self.data, dag)
- # These two should be kept in sync. Note that these are intentionally not using
- # the type declarations in expandinput.py so we always remember to update
- # serialization logic when adding new ExpandInput variants. If you add things to
- # the unions, be sure to update _ExpandInputRef to match.
- _ExpandInputOriginalValue = Union[
- # For .expand(**kwargs).
- Mapping[str, Any],
- # For expand_kwargs(arg).
- XComArg,
- Collection[Union[XComArg, Mapping[str, Any]]],
- ]
- _ExpandInputSerializedValue = Union[
- # For .expand(**kwargs).
- Mapping[str, Any],
- # For expand_kwargs(arg).
- _XComRef,
- Collection[Union[_XComRef, Mapping[str, Any]]],
- ]
- class _ExpandInputRef(NamedTuple):
- """
- Store info needed to create a mapped operator's expand input.
- This references a ``ExpandInput`` type, but replaces ``XComArg`` objects
- with ``_XComRef`` (see documentation on the latter type for reasoning).
- """
- key: str
- value: _ExpandInputSerializedValue
- @classmethod
- def validate_expand_input_value(cls, value: _ExpandInputOriginalValue) -> None:
- """
- Validate we've covered all ``ExpandInput.value`` types.
- This function does not actually do anything, but is called during
- serialization so Mypy will *statically* check we have handled all
- possible ExpandInput cases.
- """
- def deref(self, dag: DAG) -> ExpandInput:
- """
- De-reference into a concrete ExpandInput object.
- If you add more cases here, be sure to update _ExpandInputOriginalValue
- and _ExpandInputSerializedValue to match the logic.
- """
- if isinstance(self.value, _XComRef):
- value: Any = self.value.deref(dag)
- elif isinstance(self.value, collections.abc.Mapping):
- value = {k: v.deref(dag) if isinstance(v, _XComRef) else v for k, v in self.value.items()}
- else:
- value = [v.deref(dag) if isinstance(v, _XComRef) else v for v in self.value]
- return create_expand_input(self.key, value)
- _orm_to_model = {
- Job: JobPydantic,
- TaskInstance: TaskInstancePydantic,
- DagRun: DagRunPydantic,
- DagModel: DagModelPydantic,
- LogTemplate: LogTemplatePydantic,
- Dataset: DatasetPydantic,
- Trigger: TriggerPydantic,
- }
- _type_to_class: dict[DAT | str, list] = {
- DAT.BASE_JOB: [JobPydantic, Job],
- DAT.TASK_INSTANCE: [TaskInstancePydantic, TaskInstance],
- DAT.DAG_RUN: [DagRunPydantic, DagRun],
- DAT.DAG_MODEL: [DagModelPydantic, DagModel],
- DAT.LOG_TEMPLATE: [LogTemplatePydantic, LogTemplate],
- DAT.DATA_SET: [DatasetPydantic, Dataset],
- DAT.TRIGGER: [TriggerPydantic, Trigger],
- }
- _class_to_type = {cls_: type_ for type_, classes in _type_to_class.items() for cls_ in classes}
- def add_pydantic_class_type_mapping(attribute_type: str, orm_class, pydantic_class):
- _orm_to_model[orm_class] = pydantic_class
- _type_to_class[attribute_type] = [pydantic_class, orm_class]
- _class_to_type[pydantic_class] = attribute_type
- _class_to_type[orm_class] = attribute_type
- class BaseSerialization:
- """BaseSerialization provides utils for serialization."""
- # JSON primitive types.
- _primitive_types = (int, bool, float, str)
- # Time types.
- # datetime.date and datetime.time are converted to strings.
- _datetime_types = (datetime.datetime,)
- # Object types that are always excluded in serialization.
- _excluded_types = (logging.Logger, Connection, type, property)
- _json_schema: Validator | None = None
- # Should the extra operator link be loaded via plugins when
- # de-serializing the DAG? This flag is set to False in Scheduler so that Extra Operator links
- # are not loaded to not run User code in Scheduler.
- _load_operator_extra_links = True
- _CONSTRUCTOR_PARAMS: dict[str, Parameter] = {}
- SERIALIZER_VERSION = 1
- @classmethod
- def to_json(cls, var: DAG | BaseOperator | dict | list | set | tuple) -> str:
- """Stringify DAGs and operators contained by var and returns a JSON string of var."""
- return json.dumps(cls.to_dict(var), ensure_ascii=True)
- @classmethod
- def to_dict(cls, var: DAG | BaseOperator | dict | list | set | tuple) -> dict:
- """Stringify DAGs and operators contained by var and returns a dict of var."""
- # Don't call on this class directly - only SerializedDAG or
- # SerializedBaseOperator should be used as the "entrypoint"
- raise NotImplementedError()
- @classmethod
- def from_json(cls, serialized_obj: str) -> BaseSerialization | dict | list | set | tuple:
- """Deserialize json_str and reconstructs all DAGs and operators it contains."""
- return cls.from_dict(json.loads(serialized_obj))
- @classmethod
- def from_dict(cls, serialized_obj: dict[Encoding, Any]) -> BaseSerialization | dict | list | set | tuple:
- """Deserialize a dict of type decorators and reconstructs all DAGs and operators it contains."""
- return cls.deserialize(serialized_obj)
- @classmethod
- def validate_schema(cls, serialized_obj: str | dict) -> None:
- """Validate serialized_obj satisfies JSON schema."""
- if cls._json_schema is None:
- raise AirflowException(f"JSON schema of {cls.__name__:s} is not set.")
- if isinstance(serialized_obj, dict):
- cls._json_schema.validate(serialized_obj)
- elif isinstance(serialized_obj, str):
- cls._json_schema.validate(json.loads(serialized_obj))
- else:
- raise TypeError("Invalid type: Only dict and str are supported.")
- @staticmethod
- def _encode(x: Any, type_: Any) -> dict[Encoding, Any]:
- """Encode data by a JSON dict."""
- return {Encoding.VAR: x, Encoding.TYPE: type_}
- @classmethod
- def _is_primitive(cls, var: Any) -> bool:
- """Primitive types."""
- return var is None or isinstance(var, cls._primitive_types)
- @classmethod
- def _is_excluded(cls, var: Any, attrname: str, instance: Any) -> bool:
- """Check if type is excluded from serialization."""
- if var is None:
- if not cls._is_constructor_param(attrname, instance):
- # Any instance attribute, that is not a constructor argument, we exclude None as the default
- return True
- return cls._value_is_hardcoded_default(attrname, var, instance)
- return isinstance(var, cls._excluded_types) or cls._value_is_hardcoded_default(
- attrname, var, instance
- )
- @classmethod
- def serialize_to_json(
- cls, object_to_serialize: BaseOperator | MappedOperator | DAG, decorated_fields: set
- ) -> dict[str, Any]:
- """Serialize an object to JSON."""
- serialized_object: dict[str, Any] = {}
- keys_to_serialize = object_to_serialize.get_serialized_fields()
- for key in keys_to_serialize:
- # None is ignored in serialized form and is added back in deserialization.
- value = getattr(object_to_serialize, key, None)
- if cls._is_excluded(value, key, object_to_serialize):
- continue
- if key == "_operator_name":
- # when operator_name matches task_type, we can remove
- # it to reduce the JSON payload
- task_type = getattr(object_to_serialize, "_task_type", None)
- if value != task_type:
- serialized_object[key] = cls.serialize(value)
- elif key in decorated_fields:
- serialized_object[key] = cls.serialize(value)
- elif key == "timetable" and value is not None:
- serialized_object[key] = encode_timetable(value)
- elif key == "weight_rule" and value is not None:
- serialized_object[key] = encode_priority_weight_strategy(value)
- else:
- value = cls.serialize(value)
- if isinstance(value, dict) and Encoding.TYPE in value:
- value = value[Encoding.VAR]
- serialized_object[key] = value
- return serialized_object
- @classmethod
- def serialize(
- cls, var: Any, *, strict: bool = False, use_pydantic_models: bool = False
- ) -> Any: # Unfortunately there is no support for recursive types in mypy
- """
- Serialize an object; helper function of depth first search for serialization.
- The serialization protocol is:
- (1) keeping JSON supported types: primitives, dict, list;
- (2) encoding other types as ``{TYPE: 'foo', VAR: 'bar'}``, the deserialization
- step decode VAR according to TYPE;
- (3) Operator has a special field CLASS to record the original class
- name for displaying in UI.
- :meta private:
- """
- if use_pydantic_models and not _ENABLE_AIP_44:
- raise RuntimeError(
- "Setting use_pydantic_models = True requires AIP-44 (in progress) feature flag to be true. "
- "This parameter will be removed eventually when new serialization is used by AIP-44"
- )
- if cls._is_primitive(var):
- # enum.IntEnum is an int instance, it causes json dumps error so we use its value.
- if isinstance(var, enum.Enum):
- return var.value
- return var
- elif isinstance(var, dict):
- return cls._encode(
- {
- str(k): cls.serialize(v, strict=strict, use_pydantic_models=use_pydantic_models)
- for k, v in var.items()
- },
- type_=DAT.DICT,
- )
- elif isinstance(var, list):
- return [cls.serialize(v, strict=strict, use_pydantic_models=use_pydantic_models) for v in var]
- elif var.__class__.__name__ == "V1Pod" and _has_kubernetes() and isinstance(var, k8s.V1Pod):
- json_pod = PodGenerator.serialize_pod(var)
- return cls._encode(json_pod, type_=DAT.POD)
- elif isinstance(var, OutletEventAccessors):
- return cls._encode(
- cls.serialize(var._dict, strict=strict, use_pydantic_models=use_pydantic_models), # type: ignore[attr-defined]
- type_=DAT.DATASET_EVENT_ACCESSORS,
- )
- elif isinstance(var, OutletEventAccessor):
- return cls._encode(
- encode_outlet_event_accessor(var),
- type_=DAT.DATASET_EVENT_ACCESSOR,
- )
- elif isinstance(var, DAG):
- return cls._encode(SerializedDAG.serialize_dag(var), type_=DAT.DAG)
- elif isinstance(var, Resources):
- return var.to_dict()
- elif isinstance(var, MappedOperator):
- return cls._encode(SerializedBaseOperator.serialize_mapped_operator(var), type_=DAT.OP)
- elif isinstance(var, BaseOperator):
- var._needs_expansion = var.get_needs_expansion()
- return cls._encode(SerializedBaseOperator.serialize_operator(var), type_=DAT.OP)
- elif isinstance(var, cls._datetime_types):
- return cls._encode(var.timestamp(), type_=DAT.DATETIME)
- elif isinstance(var, datetime.timedelta):
- return cls._encode(var.total_seconds(), type_=DAT.TIMEDELTA)
- elif isinstance(var, (Timezone, FixedTimezone)):
- return cls._encode(encode_timezone(var), type_=DAT.TIMEZONE)
- elif isinstance(var, relativedelta.relativedelta):
- return cls._encode(encode_relativedelta(var), type_=DAT.RELATIVEDELTA)
- elif isinstance(var, TaskInstanceKey):
- return cls._encode(
- var._asdict(),
- type_=DAT.TASK_INSTANCE_KEY,
- )
- elif isinstance(var, (AirflowException, TaskDeferred)) and hasattr(var, "serialize"):
- exc_cls_name, args, kwargs = var.serialize()
- return cls._encode(
- cls.serialize(
- {"exc_cls_name": exc_cls_name, "args": args, "kwargs": kwargs},
- use_pydantic_models=use_pydantic_models,
- strict=strict,
- ),
- type_=DAT.AIRFLOW_EXC_SER,
- )
- elif isinstance(var, (KeyError, AttributeError)):
- return cls._encode(
- cls.serialize(
- {"exc_cls_name": var.__class__.__name__, "args": [var.args], "kwargs": {}},
- use_pydantic_models=use_pydantic_models,
- strict=strict,
- ),
- type_=DAT.BASE_EXC_SER,
- )
- elif isinstance(var, BaseTrigger):
- return cls._encode(
- cls.serialize(var.serialize(), use_pydantic_models=use_pydantic_models, strict=strict),
- type_=DAT.BASE_TRIGGER,
- )
- elif callable(var):
- return str(get_python_source(var))
- elif isinstance(var, set):
- # FIXME: casts set to list in customized serialization in future.
- try:
- return cls._encode(
- sorted(
- cls.serialize(v, strict=strict, use_pydantic_models=use_pydantic_models) for v in var
- ),
- type_=DAT.SET,
- )
- except TypeError:
- return cls._encode(
- [cls.serialize(v, strict=strict, use_pydantic_models=use_pydantic_models) for v in var],
- type_=DAT.SET,
- )
- elif isinstance(var, tuple):
- # FIXME: casts tuple to list in customized serialization in future.
- return cls._encode(
- [cls.serialize(v, strict=strict, use_pydantic_models=use_pydantic_models) for v in var],
- type_=DAT.TUPLE,
- )
- elif isinstance(var, TaskGroup):
- return TaskGroupSerialization.serialize_task_group(var)
- elif isinstance(var, Param):
- return cls._encode(cls._serialize_param(var), type_=DAT.PARAM)
- elif isinstance(var, XComArg):
- return cls._encode(serialize_xcom_arg(var), type_=DAT.XCOM_REF)
- elif isinstance(var, LazySelectSequence):
- return cls.serialize(list(var))
- elif isinstance(var, BaseDataset):
- serialized_dataset = encode_dataset_condition(var)
- return cls._encode(serialized_dataset, type_=serialized_dataset.pop("__type"))
- elif isinstance(var, SimpleTaskInstance):
- return cls._encode(
- cls.serialize(var.__dict__, strict=strict, use_pydantic_models=use_pydantic_models),
- type_=DAT.SIMPLE_TASK_INSTANCE,
- )
- elif isinstance(var, Connection):
- return cls._encode(var.to_dict(validate=True), type_=DAT.CONNECTION)
- elif isinstance(var, TaskCallbackRequest):
- return cls._encode(var.to_json(), type_=DAT.TASK_CALLBACK_REQUEST)
- elif isinstance(var, DagCallbackRequest):
- return cls._encode(var.to_json(), type_=DAT.DAG_CALLBACK_REQUEST)
- elif isinstance(var, SlaCallbackRequest):
- return cls._encode(var.to_json(), type_=DAT.SLA_CALLBACK_REQUEST)
- elif var.__class__ == Context:
- d = {}
- for k, v in var._context.items():
- obj = cls.serialize(v, strict=strict, use_pydantic_models=use_pydantic_models)
- d[str(k)] = obj
- return cls._encode(d, type_=DAT.TASK_CONTEXT)
- elif use_pydantic_models and _ENABLE_AIP_44:
- def _pydantic_model_dump(model_cls: type[BaseModel], var: Any) -> dict[str, Any]:
- return model_cls.model_validate(var).model_dump(mode="json") # type: ignore[attr-defined]
- if var.__class__ in _class_to_type:
- pyd_mod = _orm_to_model.get(var.__class__, var)
- mod = _pydantic_model_dump(pyd_mod, var)
- type_ = _class_to_type[var.__class__]
- return cls._encode(mod, type_=type_)
- else:
- return cls.default_serialization(strict, var)
- elif isinstance(var, ArgNotSet):
- return cls._encode(None, type_=DAT.ARG_NOT_SET)
- else:
- return cls.default_serialization(strict, var)
- @classmethod
- def default_serialization(cls, strict, var) -> str:
- log.debug("Cast type %s to str in serialization.", type(var))
- if strict:
- raise SerializationError("Encountered unexpected type")
- return str(var)
- @classmethod
- def deserialize(cls, encoded_var: Any, use_pydantic_models=False) -> Any:
- """
- Deserialize an object; helper function of depth first search for deserialization.
- :meta private:
- """
- # JSON primitives (except for dict) are not encoded.
- if use_pydantic_models and not _ENABLE_AIP_44:
- raise RuntimeError(
- "Setting use_pydantic_models = True requires AIP-44 (in progress) feature flag to be true. "
- "This parameter will be removed eventually when new serialization is used by AIP-44"
- )
- if cls._is_primitive(encoded_var):
- return encoded_var
- elif isinstance(encoded_var, list):
- return [cls.deserialize(v, use_pydantic_models) for v in encoded_var]
- if not isinstance(encoded_var, dict):
- raise ValueError(f"The encoded_var should be dict and is {type(encoded_var)}")
- var = encoded_var[Encoding.VAR]
- type_ = encoded_var[Encoding.TYPE]
- if type_ == DAT.TASK_CONTEXT:
- d = {}
- for k, v in var.items():
- if k == "task": # todo: add `_encode` of Operator so we don't need this
- continue
- d[k] = cls.deserialize(v, use_pydantic_models=True)
- d["task"] = d["task_instance"].task # todo: add `_encode` of Operator so we don't need this
- d["macros"] = macros
- d["var"] = {
- "json": VariableAccessor(deserialize_json=True),
- "value": VariableAccessor(deserialize_json=False),
- }
- d["conn"] = ConnectionAccessor()
- return Context(**d)
- elif type_ == DAT.DICT:
- return {k: cls.deserialize(v, use_pydantic_models) for k, v in var.items()}
- elif type_ == DAT.DATASET_EVENT_ACCESSORS:
- d = OutletEventAccessors() # type: ignore[assignment]
- d._dict = cls.deserialize(var) # type: ignore[attr-defined]
- return d
- elif type_ == DAT.DATASET_EVENT_ACCESSOR:
- return decode_outlet_event_accessor(var)
- elif type_ == DAT.DAG:
- return SerializedDAG.deserialize_dag(var)
- elif type_ == DAT.OP:
- return SerializedBaseOperator.deserialize_operator(var)
- elif type_ == DAT.DATETIME:
- return from_timestamp(var)
- elif type_ == DAT.POD:
- if not _has_kubernetes():
- raise RuntimeError("Cannot deserialize POD objects without kubernetes libraries installed!")
- pod = PodGenerator.deserialize_model_dict(var)
- return pod
- elif type_ == DAT.TIMEDELTA:
- return datetime.timedelta(seconds=var)
- elif type_ == DAT.TIMEZONE:
- return decode_timezone(var)
- elif type_ == DAT.RELATIVEDELTA:
- return decode_relativedelta(var)
- elif type_ == DAT.AIRFLOW_EXC_SER or type_ == DAT.BASE_EXC_SER:
- deser = cls.deserialize(var, use_pydantic_models=use_pydantic_models)
- exc_cls_name = deser["exc_cls_name"]
- args = deser["args"]
- kwargs = deser["kwargs"]
- del deser
- if type_ == DAT.AIRFLOW_EXC_SER:
- exc_cls = import_string(exc_cls_name)
- else:
- exc_cls = import_string(f"builtins.{exc_cls_name}")
- return exc_cls(*args, **kwargs)
- elif type_ == DAT.BASE_TRIGGER:
- tr_cls_name, kwargs = cls.deserialize(var, use_pydantic_models=use_pydantic_models)
- tr_cls = import_string(tr_cls_name)
- return tr_cls(**kwargs)
- elif type_ == DAT.SET:
- return {cls.deserialize(v, use_pydantic_models) for v in var}
- elif type_ == DAT.TUPLE:
- return tuple(cls.deserialize(v, use_pydantic_models) for v in var)
- elif type_ == DAT.PARAM:
- return cls._deserialize_param(var)
- elif type_ == DAT.XCOM_REF:
- return _XComRef(var) # Delay deserializing XComArg objects until we have the entire DAG.
- elif type_ == DAT.DATASET:
- return Dataset(**var)
- elif type_ == DAT.DATASET_ALIAS:
- return DatasetAlias(**var)
- elif type_ == DAT.DATASET_ANY:
- return DatasetAny(*(decode_dataset_condition(x) for x in var["objects"]))
- elif type_ == DAT.DATASET_ALL:
- return DatasetAll(*(decode_dataset_condition(x) for x in var["objects"]))
- elif type_ == DAT.SIMPLE_TASK_INSTANCE:
- return SimpleTaskInstance(**cls.deserialize(var))
- elif type_ == DAT.CONNECTION:
- return Connection(**var)
- elif type_ == DAT.TASK_CALLBACK_REQUEST:
- return TaskCallbackRequest.from_json(var)
- elif type_ == DAT.DAG_CALLBACK_REQUEST:
- return DagCallbackRequest.from_json(var)
- elif type_ == DAT.SLA_CALLBACK_REQUEST:
- return SlaCallbackRequest.from_json(var)
- elif type_ == DAT.TASK_INSTANCE_KEY:
- return TaskInstanceKey(**var)
- elif use_pydantic_models and _ENABLE_AIP_44:
- return _type_to_class[type_][0].model_validate(var)
- elif type_ == DAT.ARG_NOT_SET:
- return NOTSET
- else:
- raise TypeError(f"Invalid type {type_!s} in deserialization.")
- _deserialize_datetime = from_timestamp
- _deserialize_timezone = parse_timezone
- @classmethod
- def _deserialize_timedelta(cls, seconds: int) -> datetime.timedelta:
- return datetime.timedelta(seconds=seconds)
- @classmethod
- def _is_constructor_param(cls, attrname: str, instance: Any) -> bool:
- return attrname in cls._CONSTRUCTOR_PARAMS
- @classmethod
- def _value_is_hardcoded_default(cls, attrname: str, value: Any, instance: Any) -> bool:
- """
- Return true if ``value`` is the hard-coded default for the given attribute.
- This takes in to account cases where the ``max_active_tasks`` parameter is
- stored in the ``_max_active_tasks`` attribute.
- And by using `is` here only and not `==` this copes with the case a
- user explicitly specifies an attribute with the same "value" as the
- default. (This is because ``"default" is "default"`` will be False as
- they are different strings with the same characters.)
- Also returns True if the value is an empty list or empty dict. This is done
- to account for the case where the default value of the field is None but has the
- ``field = field or {}`` set.
- """
- if attrname in cls._CONSTRUCTOR_PARAMS and (
- cls._CONSTRUCTOR_PARAMS[attrname] is value or (value in [{}, []])
- ):
- return True
- return False
- @classmethod
- def _serialize_param(cls, param: Param):
- return {
- "__class": f"{param.__module__}.{param.__class__.__name__}",
- "default": cls.serialize(param.value),
- "description": cls.serialize(param.description),
- "schema": cls.serialize(param.schema),
- }
- @classmethod
- def _deserialize_param(cls, param_dict: dict):
- """
- Workaround to serialize Param on older versions.
- In 2.2.0, Param attrs were assumed to be json-serializable and were not run through
- this class's ``serialize`` method. So before running through ``deserialize``,
- we first verify that it's necessary to do.
- """
- class_name = param_dict["__class"]
- class_: type[Param] = import_string(class_name)
- attrs = ("default", "description", "schema")
- kwargs = {}
- def is_serialized(val):
- if isinstance(val, dict):
- return Encoding.TYPE in val
- if isinstance(val, list):
- return all(isinstance(item, dict) and Encoding.TYPE in item for item in val)
- return False
- for attr in attrs:
- if attr in param_dict:
- val = param_dict[attr]
- if is_serialized(val):
- val = cls.deserialize(val)
- kwargs[attr] = val
- return class_(**kwargs)
- @classmethod
- def _serialize_params_dict(cls, params: ParamsDict | dict) -> list[tuple[str, dict]]:
- """Serialize Params dict for a DAG or task as a list of tuples to ensure ordering."""
- serialized_params = []
- for k, v in params.items():
- if isinstance(params, ParamsDict):
- # Use native param object, not resolved value if possible
- v = params.get_param(k)
- try:
- class_identity = f"{v.__module__}.{v.__class__.__name__}"
- except AttributeError:
- class_identity = ""
- if class_identity == "airflow.models.param.Param":
- serialized_params.append((k, cls._serialize_param(v)))
- else:
- # Auto-box other values into Params object like it is done by DAG parsing as well
- serialized_params.append((k, cls._serialize_param(Param(v))))
- return serialized_params
- @classmethod
- def _deserialize_params_dict(cls, encoded_params: list[tuple[str, dict]]) -> ParamsDict:
- """Deserialize a DAG's Params dict."""
- if isinstance(encoded_params, collections.abc.Mapping):
- # in 2.9.2 or earlier params were serialized as JSON objects
- encoded_param_pairs: Iterable[tuple[str, dict]] = encoded_params.items()
- else:
- encoded_param_pairs = encoded_params
- op_params = {}
- for k, v in encoded_param_pairs:
- if isinstance(v, dict) and "__class" in v:
- op_params[k] = cls._deserialize_param(v)
- else:
- # Old style params, convert it
- op_params[k] = Param(v)
- return ParamsDict(op_params)
- class DependencyDetector:
- """
- Detects dependencies between DAGs.
- :meta private:
- """
- @staticmethod
- def detect_task_dependencies(task: Operator) -> list[DagDependency]:
- """Detect dependencies caused by tasks."""
- from airflow.operators.trigger_dagrun import TriggerDagRunOperator
- from airflow.sensors.external_task import ExternalTaskSensor
- deps = []
- if isinstance(task, TriggerDagRunOperator):
- deps.append(
- DagDependency(
- source=task.dag_id,
- target=getattr(task, "trigger_dag_id"),
- dependency_type="trigger",
- dependency_id=task.task_id,
- )
- )
- elif isinstance(task, ExternalTaskSensor):
- deps.append(
- DagDependency(
- source=getattr(task, "external_dag_id"),
- target=task.dag_id,
- dependency_type="sensor",
- dependency_id=task.task_id,
- )
- )
- for obj in task.outlets or []:
- if isinstance(obj, Dataset):
- deps.append(
- DagDependency(
- source=task.dag_id,
- target="dataset",
- dependency_type="dataset",
- dependency_id=obj.uri,
- )
- )
- elif isinstance(obj, DatasetAlias):
- cond = _DatasetAliasCondition(obj.name)
- deps.extend(cond.iter_dag_dependencies(source=task.dag_id, target=""))
- return deps
- @staticmethod
- def detect_dag_dependencies(dag: DAG | None) -> Iterable[DagDependency]:
- """Detect dependencies set directly on the DAG object."""
- if not dag:
- return
- yield from dag.timetable.dataset_condition.iter_dag_dependencies(source="", target=dag.dag_id)
- class SerializedBaseOperator(BaseOperator, BaseSerialization):
- """
- A JSON serializable representation of operator.
- All operators are casted to SerializedBaseOperator after deserialization.
- Class specific attributes used by UI are move to object attributes.
- Creating a SerializedBaseOperator is a three-step process:
- 1. Instantiate a :class:`SerializedBaseOperator` object.
- 2. Populate attributes with :func:`SerializedBaseOperator.populated_operator`.
- 3. When the task's containing DAG is available, fix references to the DAG
- with :func:`SerializedBaseOperator.set_task_dag_references`.
- """
- _decorated_fields = {"executor_config"}
- _CONSTRUCTOR_PARAMS = {
- k: v.default
- for k, v in signature(BaseOperator.__init__).parameters.items()
- if v.default is not v.empty
- }
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- # task_type is used by UI to display the correct class type, because UI only
- # receives BaseOperator from deserialized DAGs.
- self._task_type = "BaseOperator"
- # Move class attributes into object attributes.
- self.ui_color = BaseOperator.ui_color
- self.ui_fgcolor = BaseOperator.ui_fgcolor
- self.template_ext = BaseOperator.template_ext
- self.template_fields = BaseOperator.template_fields
- self.operator_extra_links = BaseOperator.operator_extra_links
- @property
- def task_type(self) -> str:
- # Overwrites task_type of BaseOperator to use _task_type instead of
- # __class__.__name__.
- return self._task_type
- @task_type.setter
- def task_type(self, task_type: str):
- self._task_type = task_type
- @property
- def operator_name(self) -> str:
- # Overwrites operator_name of BaseOperator to use _operator_name instead of
- # __class__.operator_name.
- return self._operator_name
- @operator_name.setter
- def operator_name(self, operator_name: str):
- self._operator_name = operator_name
- @classmethod
- def serialize_mapped_operator(cls, op: MappedOperator) -> dict[str, Any]:
- serialized_op = cls._serialize_node(op, include_deps=op.deps != MappedOperator.deps_for(BaseOperator))
- # Handle expand_input and op_kwargs_expand_input.
- expansion_kwargs = op._get_specified_expand_input()
- if TYPE_CHECKING: # Let Mypy check the input type for us!
- _ExpandInputRef.validate_expand_input_value(expansion_kwargs.value)
- serialized_op[op._expand_input_attr] = {
- "type": get_map_type_key(expansion_kwargs),
- "value": cls.serialize(expansion_kwargs.value),
- }
- # Simplify partial_kwargs by comparing it to the most barebone object.
- # Remove all entries that are simply default values.
- serialized_partial = serialized_op["partial_kwargs"]
- for k, default in _get_default_mapped_partial().items():
- try:
- v = serialized_partial[k]
- except KeyError:
- continue
- if v == default:
- del serialized_partial[k]
- serialized_op["_is_mapped"] = True
- return serialized_op
- @classmethod
- def serialize_operator(cls, op: BaseOperator | MappedOperator) -> dict[str, Any]:
- return cls._serialize_node(op, include_deps=op.deps is not BaseOperator.deps)
- @classmethod
- def _serialize_node(cls, op: BaseOperator | MappedOperator, include_deps: bool) -> dict[str, Any]:
- """Serialize operator into a JSON object."""
- serialize_op = cls.serialize_to_json(op, cls._decorated_fields)
- serialize_op["_task_type"] = getattr(op, "_task_type", type(op).__name__)
- serialize_op["_task_module"] = getattr(op, "_task_module", type(op).__module__)
- if op.operator_name != serialize_op["_task_type"]:
- serialize_op["_operator_name"] = op.operator_name
- # Used to determine if an Operator is inherited from EmptyOperator
- serialize_op["_is_empty"] = op.inherits_from_empty_operator
- serialize_op["start_trigger_args"] = (
- encode_start_trigger_args(op.start_trigger_args) if op.start_trigger_args else None
- )
- serialize_op["start_from_trigger"] = op.start_from_trigger
- if op.operator_extra_links:
- serialize_op["_operator_extra_links"] = cls._serialize_operator_extra_links(
- op.operator_extra_links.__get__(op)
- if isinstance(op.operator_extra_links, property)
- else op.operator_extra_links
- )
- if include_deps:
- serialize_op["deps"] = cls._serialize_deps(op.deps)
- # Store all template_fields as they are if there are JSON Serializable
- # If not, store them as strings
- # And raise an exception if the field is not templateable
- forbidden_fields = set(inspect.signature(BaseOperator.__init__).parameters.keys())
- # Though allow some of the BaseOperator fields to be templated anyway
- forbidden_fields.difference_update({"email"})
- if op.template_fields:
- for template_field in op.template_fields:
- if template_field in forbidden_fields:
- raise AirflowException(
- dedent(
- f"""Cannot template BaseOperator field:
- {template_field!r} {op.__class__.__name__=} {op.template_fields=}"""
- )
- )
- value = getattr(op, template_field, None)
- if not cls._is_excluded(value, template_field, op):
- serialize_op[template_field] = serialize_template_field(value, template_field)
- if op.params:
- serialize_op["params"] = cls._serialize_params_dict(op.params)
- return serialize_op
- @classmethod
- def _serialize_deps(cls, op_deps: Iterable[BaseTIDep]) -> list[str]:
- from airflow import plugins_manager
- plugins_manager.initialize_ti_deps_plugins()
- if plugins_manager.registered_ti_dep_classes is None:
- raise AirflowException("Can not load plugins")
- deps = []
- for dep in op_deps:
- klass = type(dep)
- module_name = klass.__module__
- qualname = f"{module_name}.{klass.__name__}"
- if (
- not qualname.startswith("airflow.ti_deps.deps.")
- and qualname not in plugins_manager.registered_ti_dep_classes
- ):
- raise SerializationError(
- f"Custom dep class {qualname} not serialized, please register it through plugins."
- )
- deps.append(qualname)
- # deps needs to be sorted here, because op_deps is a set, which is unstable when traversing,
- # and the same call may get different results.
- # When calling json.dumps(self.data, sort_keys=True) to generate dag_hash, misjudgment will occur
- return sorted(deps)
- @classmethod
- def populate_operator(cls, op: Operator, encoded_op: dict[str, Any]) -> None:
- """
- Populate operator attributes with serialized values.
- This covers simple attributes that don't reference other things in the
- DAG. Setting references (such as ``op.dag`` and task dependencies) is
- done in ``set_task_dag_references`` instead, which is called after the
- DAG is hydrated.
- """
- if "label" not in encoded_op:
- # Handle deserialization of old data before the introduction of TaskGroup
- encoded_op["label"] = encoded_op["task_id"]
- # Extra Operator Links defined in Plugins
- op_extra_links_from_plugin = {}
- if "_operator_name" not in encoded_op:
- encoded_op["_operator_name"] = encoded_op["_task_type"]
- # We don't want to load Extra Operator links in Scheduler
- if cls._load_operator_extra_links:
- from airflow import plugins_manager
- plugins_manager.initialize_extra_operators_links_plugins()
- if plugins_manager.operator_extra_links is None:
- raise AirflowException("Can not load plugins")
- for ope in plugins_manager.operator_extra_links:
- for operator in ope.operators:
- if (
- operator.__name__ == encoded_op["_task_type"]
- and operator.__module__ == encoded_op["_task_module"]
- ):
- op_extra_links_from_plugin.update({ope.name: ope})
- # If OperatorLinks are defined in Plugins but not in the Operator that is being Serialized
- # set the Operator links attribute
- # The case for "If OperatorLinks are defined in the operator that is being Serialized"
- # is handled in the deserialization loop where it matches k == "_operator_extra_links"
- if op_extra_links_from_plugin and "_operator_extra_links" not in encoded_op:
- setattr(op, "operator_extra_links", list(op_extra_links_from_plugin.values()))
- for k, v in encoded_op.items():
- # Todo: TODO: Remove in Airflow 3.0 when dummy operator is removed
- if k == "_is_dummy":
- k = "_is_empty"
- if k in ("_outlets", "_inlets"):
- # `_outlets` -> `outlets`
- k = k[1:]
- if k == "_downstream_task_ids":
- # Upgrade from old format/name
- k = "downstream_task_ids"
- if k == "label":
- # Label shouldn't be set anymore -- it's computed from task_id now
- continue
- elif k == "downstream_task_ids":
- v = set(v)
- elif k == "subdag":
- v = SerializedDAG.deserialize_dag(v)
- elif k in {"retry_delay", "execution_timeout", "sla", "max_retry_delay"}:
- v = cls._deserialize_timedelta(v)
- elif k in encoded_op["template_fields"]:
- pass
- elif k == "resources":
- v = Resources.from_dict(v)
- elif k.endswith("_date"):
- v = cls._deserialize_datetime(v)
- elif k == "_operator_extra_links":
- if cls._load_operator_extra_links:
- op_predefined_extra_links = cls._deserialize_operator_extra_links(v)
- # If OperatorLinks with the same name exists, Links via Plugin have higher precedence
- op_predefined_extra_links.update(op_extra_links_from_plugin)
- else:
- op_predefined_extra_links = {}
- v = list(op_predefined_extra_links.values())
- k = "operator_extra_links"
- elif k == "deps":
- v = cls._deserialize_deps(v)
- elif k == "params":
- v = cls._deserialize_params_dict(v)
- if op.params: # Merge existing params if needed.
- v, new = op.params, v
- v.update(new)
- elif k == "partial_kwargs":
- v = {arg: cls.deserialize(value) for arg, value in v.items()}
- elif k in {"expand_input", "op_kwargs_expand_input"}:
- v = _ExpandInputRef(v["type"], cls.deserialize(v["value"]))
- elif k == "operator_class":
- v = {k_: cls.deserialize(v_, use_pydantic_models=True) for k_, v_ in v.items()}
- elif (
- k in cls._decorated_fields
- or k not in op.get_serialized_fields()
- or k in ("outlets", "inlets")
- ):
- v = cls.deserialize(v)
- elif k == "on_failure_fail_dagrun":
- k = "_on_failure_fail_dagrun"
- elif k == "weight_rule":
- v = decode_priority_weight_strategy(v)
- # else use v as it is
- setattr(op, k, v)
- for k in op.get_serialized_fields() - encoded_op.keys() - cls._CONSTRUCTOR_PARAMS.keys():
- # TODO: refactor deserialization of BaseOperator and MappedOperator (split it out), then check
- # could go away.
- if not hasattr(op, k):
- setattr(op, k, None)
- # Set all the template_field to None that were not present in Serialized JSON
- for field in op.template_fields:
- if not hasattr(op, field):
- setattr(op, field, None)
- # Used to determine if an Operator is inherited from EmptyOperator
- setattr(op, "_is_empty", bool(encoded_op.get("_is_empty", False)))
- start_trigger_args = None
- encoded_start_trigger_args = encoded_op.get("start_trigger_args", None)
- if encoded_start_trigger_args:
- encoded_start_trigger_args = cast(dict, encoded_start_trigger_args)
- start_trigger_args = decode_start_trigger_args(encoded_start_trigger_args)
- setattr(op, "start_trigger_args", start_trigger_args)
- setattr(op, "start_from_trigger", bool(encoded_op.get("start_from_trigger", False)))
- @staticmethod
- def set_task_dag_references(task: Operator, dag: DAG) -> None:
- """
- Handle DAG references on an operator.
- The operator should have been mostly populated earlier by calling
- ``populate_operator``. This function further fixes object references
- that were not possible before the task's containing DAG is hydrated.
- """
- task.dag = dag
- for date_attr in ("start_date", "end_date"):
- if getattr(task, date_attr, None) is None:
- setattr(task, date_attr, getattr(dag, date_attr, None))
- if task.subdag is not None:
- task.subdag.parent_dag = dag
- # Dereference expand_input and op_kwargs_expand_input.
- for k in ("expand_input", "op_kwargs_expand_input"):
- if isinstance(kwargs_ref := getattr(task, k, None), _ExpandInputRef):
- setattr(task, k, kwargs_ref.deref(dag))
- for task_id in task.downstream_task_ids:
- # Bypass set_upstream etc here - it does more than we want
- dag.task_dict[task_id].upstream_task_ids.add(task.task_id)
- @classmethod
- def deserialize_operator(cls, encoded_op: dict[str, Any]) -> Operator:
- """Deserializes an operator from a JSON object."""
- op: Operator
- if encoded_op.get("_is_mapped", False):
- # Most of these will be loaded later, these are just some stand-ins.
- op_data = {k: v for k, v in encoded_op.items() if k in BaseOperator.get_serialized_fields()}
- try:
- operator_name = encoded_op["_operator_name"]
- except KeyError:
- operator_name = encoded_op["_task_type"]
- op = MappedOperator(
- operator_class=op_data,
- expand_input=EXPAND_INPUT_EMPTY,
- partial_kwargs={},
- task_id=encoded_op["task_id"],
- params={},
- deps=MappedOperator.deps_for(BaseOperator),
- operator_extra_links=BaseOperator.operator_extra_links,
- template_ext=BaseOperator.template_ext,
- template_fields=BaseOperator.template_fields,
- template_fields_renderers=BaseOperator.template_fields_renderers,
- ui_color=BaseOperator.ui_color,
- ui_fgcolor=BaseOperator.ui_fgcolor,
- is_empty=False,
- task_module=encoded_op["_task_module"],
- task_type=encoded_op["_task_type"],
- operator_name=operator_name,
- dag=None,
- task_group=None,
- start_date=None,
- end_date=None,
- disallow_kwargs_override=encoded_op["_disallow_kwargs_override"],
- expand_input_attr=encoded_op["_expand_input_attr"],
- start_trigger_args=encoded_op.get("start_trigger_args", None),
- start_from_trigger=encoded_op.get("start_from_trigger", False),
- )
- else:
- op = SerializedBaseOperator(task_id=encoded_op["task_id"])
- op.dag = AttributeRemoved("dag") # type: ignore[assignment]
- cls.populate_operator(op, encoded_op)
- return op
- @classmethod
- def detect_dependencies(cls, op: Operator) -> set[DagDependency]:
- """Detect between DAG dependencies for the operator."""
- def get_custom_dep() -> list[DagDependency]:
- """
- If custom dependency detector is configured, use it.
- TODO: Remove this logic in 3.0.
- """
- custom_dependency_detector_cls = conf.getimport("scheduler", "dependency_detector", fallback=None)
- if not (
- custom_dependency_detector_cls is None or custom_dependency_detector_cls is DependencyDetector
- ):
- warnings.warn(
- "Use of a custom dependency detector is deprecated. "
- "Support will be removed in a future release.",
- RemovedInAirflow3Warning,
- stacklevel=1,
- )
- dep = custom_dependency_detector_cls().detect_task_dependencies(op)
- if type(dep) is DagDependency:
- return [dep]
- return []
- dependency_detector = DependencyDetector()
- deps = set(dependency_detector.detect_task_dependencies(op))
- deps.update(get_custom_dep()) # todo: remove in 3.0
- return deps
- @classmethod
- def _is_excluded(cls, var: Any, attrname: str, op: DAGNode):
- if (
- var is not None
- and op.has_dag()
- and op.dag.__class__ is not AttributeRemoved
- and attrname.endswith("_date")
- ):
- # If this date is the same as the matching field in the dag, then
- # don't store it again at the task level.
- dag_date = getattr(op.dag, attrname, None)
- if var is dag_date or var == dag_date:
- return True
- return super()._is_excluded(var, attrname, op)
- @classmethod
- def _deserialize_deps(cls, deps: list[str]) -> set[BaseTIDep]:
- from airflow import plugins_manager
- plugins_manager.initialize_ti_deps_plugins()
- if plugins_manager.registered_ti_dep_classes is None:
- raise AirflowException("Can not load plugins")
- instances = set()
- for qn in set(deps):
- if (
- not qn.startswith("airflow.ti_deps.deps.")
- and qn not in plugins_manager.registered_ti_dep_classes
- ):
- raise SerializationError(
- f"Custom dep class {qn} not deserialized, please register it through plugins."
- )
- try:
- instances.add(import_string(qn)())
- except ImportError:
- log.warning("Error importing dep %r", qn, exc_info=True)
- return instances
- @classmethod
- def _deserialize_operator_extra_links(cls, encoded_op_links: list) -> dict[str, BaseOperatorLink]:
- """
- Deserialize Operator Links if the Classes are registered in Airflow Plugins.
- Error is raised if the OperatorLink is not found in Plugins too.
- :param encoded_op_links: Serialized Operator Link
- :return: De-Serialized Operator Link
- """
- from airflow import plugins_manager
- plugins_manager.initialize_extra_operators_links_plugins()
- if plugins_manager.registered_operator_link_classes is None:
- raise AirflowException("Can't load plugins")
- op_predefined_extra_links = {}
- for _operator_links_source in encoded_op_links:
- # Get the key, value pair as Tuple where key is OperatorLink ClassName
- # and value is the dictionary containing the arguments passed to the OperatorLink
- #
- # Example of a single iteration:
- #
- # _operator_links_source =
- # {
- # 'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink': {
- # 'index': 0
- # }
- # },
- #
- # list(_operator_links_source.items()) =
- # [
- # (
- # 'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink',
- # {'index': 0}
- # )
- # ]
- #
- # list(_operator_links_source.items())[0] =
- # (
- # 'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink',
- # {
- # 'index': 0
- # }
- # )
- _operator_link_class_path, data = next(iter(_operator_links_source.items()))
- if _operator_link_class_path in get_operator_extra_links():
- single_op_link_class = import_string(_operator_link_class_path)
- elif _operator_link_class_path in plugins_manager.registered_operator_link_classes:
- single_op_link_class = plugins_manager.registered_operator_link_classes[
- _operator_link_class_path
- ]
- else:
- log.error("Operator Link class %r not registered", _operator_link_class_path)
- return {}
- op_link_parameters = {param: cls.deserialize(value) for param, value in data.items()}
- op_predefined_extra_link: BaseOperatorLink = single_op_link_class(**op_link_parameters)
- op_predefined_extra_links.update({op_predefined_extra_link.name: op_predefined_extra_link})
- return op_predefined_extra_links
- @classmethod
- def _serialize_operator_extra_links(cls, operator_extra_links: Iterable[BaseOperatorLink]):
- """
- Serialize Operator Links.
- Store the import path of the OperatorLink and the arguments passed to it.
- For example:
- ``[{'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleLink': {}}]``
- :param operator_extra_links: Operator Link
- :return: Serialized Operator Link
- """
- serialize_operator_extra_links = []
- for operator_extra_link in operator_extra_links:
- op_link_arguments = {
- param: cls.serialize(value) for param, value in attrs.asdict(operator_extra_link).items()
- }
- module_path = (
- f"{operator_extra_link.__class__.__module__}.{operator_extra_link.__class__.__name__}"
- )
- serialize_operator_extra_links.append({module_path: op_link_arguments})
- return serialize_operator_extra_links
- @classmethod
- def serialize(cls, var: Any, *, strict: bool = False, use_pydantic_models: bool = False) -> Any:
- # the wonders of multiple inheritance BaseOperator defines an instance method
- return BaseSerialization.serialize(var=var, strict=strict, use_pydantic_models=use_pydantic_models)
- @classmethod
- def deserialize(cls, encoded_var: Any, use_pydantic_models: bool = False) -> Any:
- return BaseSerialization.deserialize(encoded_var=encoded_var, use_pydantic_models=use_pydantic_models)
- class SerializedDAG(DAG, BaseSerialization):
- """
- A JSON serializable representation of DAG.
- A stringified DAG can only be used in the scope of scheduler and webserver, because fields
- that are not serializable, such as functions and customer defined classes, are casted to
- strings.
- Compared with SimpleDAG: SerializedDAG contains all information for webserver.
- Compared with DagPickle: DagPickle contains all information for worker, but some DAGs are
- not pickle-able. SerializedDAG works for all DAGs.
- """
- _decorated_fields = {"schedule_interval", "default_args", "_access_control"}
- @staticmethod
- def __get_constructor_defaults():
- param_to_attr = {
- "max_active_tasks": "_max_active_tasks",
- "dag_display_name": "_dag_display_property_value",
- "description": "_description",
- "default_view": "_default_view",
- "access_control": "_access_control",
- }
- return {
- param_to_attr.get(k, k): v.default
- for k, v in signature(DAG.__init__).parameters.items()
- if v.default is not v.empty
- }
- _CONSTRUCTOR_PARAMS = __get_constructor_defaults.__func__() # type: ignore
- del __get_constructor_defaults
- _json_schema = lazy_object_proxy.Proxy(load_dag_schema)
- @classmethod
- def serialize_dag(cls, dag: DAG) -> dict:
- """Serialize a DAG into a JSON object."""
- try:
- serialized_dag = cls.serialize_to_json(dag, cls._decorated_fields)
- serialized_dag["_processor_dags_folder"] = DAGS_FOLDER
- # If schedule_interval is backed by timetable, serialize only
- # timetable; vice versa for a timetable backed by schedule_interval.
- if dag.timetable.summary == dag.schedule_interval:
- del serialized_dag["schedule_interval"]
- else:
- del serialized_dag["timetable"]
- serialized_dag["tasks"] = [cls.serialize(task) for _, task in dag.task_dict.items()]
- dag_deps = [
- dep
- for task in dag.task_dict.values()
- for dep in SerializedBaseOperator.detect_dependencies(task)
- ]
- dag_deps.extend(DependencyDetector.detect_dag_dependencies(dag))
- serialized_dag["dag_dependencies"] = [x.__dict__ for x in sorted(dag_deps)]
- serialized_dag["_task_group"] = TaskGroupSerialization.serialize_task_group(dag.task_group)
- # Edge info in the JSON exactly matches our internal structure
- serialized_dag["edge_info"] = dag.edge_info
- serialized_dag["params"] = cls._serialize_params_dict(dag.params)
- # has_on_*_callback are only stored if the value is True, as the default is False
- if dag.has_on_success_callback:
- serialized_dag["has_on_success_callback"] = True
- if dag.has_on_failure_callback:
- serialized_dag["has_on_failure_callback"] = True
- return serialized_dag
- except SerializationError:
- raise
- except Exception as e:
- raise SerializationError(f"Failed to serialize DAG {dag.dag_id!r}: {e}")
- @classmethod
- def deserialize_dag(cls, encoded_dag: dict[str, Any]) -> SerializedDAG:
- """Deserializes a DAG from a JSON object."""
- dag = SerializedDAG(dag_id=encoded_dag["_dag_id"], schedule=None)
- for k, v in encoded_dag.items():
- if k == "_downstream_task_ids":
- v = set(v)
- elif k == "tasks":
- SerializedBaseOperator._load_operator_extra_links = cls._load_operator_extra_links
- tasks = {}
- for obj in v:
- if obj.get(Encoding.TYPE) == DAT.OP:
- deser = SerializedBaseOperator.deserialize_operator(obj[Encoding.VAR])
- tasks[deser.task_id] = deser
- else: # todo: remove in Airflow 3.0 (backcompat for pre-2.10)
- tasks[obj["task_id"]] = SerializedBaseOperator.deserialize_operator(obj)
- k = "task_dict"
- v = tasks
- elif k == "timezone":
- v = cls._deserialize_timezone(v)
- elif k == "dagrun_timeout":
- v = cls._deserialize_timedelta(v)
- elif k.endswith("_date"):
- v = cls._deserialize_datetime(v)
- elif k == "edge_info":
- # Value structure matches exactly
- pass
- elif k == "timetable":
- v = decode_timetable(v)
- elif k == "weight_rule":
- v = decode_priority_weight_strategy(v)
- elif k in cls._decorated_fields:
- v = cls.deserialize(v)
- elif k == "params":
- v = cls._deserialize_params_dict(v)
- # else use v as it is
- setattr(dag, k, v)
- # A DAG is always serialized with only one of schedule_interval and
- # timetable. This back-populates the other to ensure the two attributes
- # line up correctly on the DAG instance.
- if "timetable" in encoded_dag:
- dag.schedule_interval = dag.timetable.summary
- else:
- dag.timetable = create_timetable(dag.schedule_interval, dag.timezone)
- # Set _task_group
- if "_task_group" in encoded_dag:
- dag._task_group = TaskGroupSerialization.deserialize_task_group(
- encoded_dag["_task_group"],
- None,
- dag.task_dict,
- dag,
- )
- else:
- # This must be old data that had no task_group. Create a root TaskGroup and add
- # all tasks to it.
- dag._task_group = TaskGroup.create_root(dag)
- for task in dag.tasks:
- dag.task_group.add(task)
- # Set has_on_*_callbacks to True if they exist in Serialized blob as False is the default
- if "has_on_success_callback" in encoded_dag:
- dag.has_on_success_callback = True
- if "has_on_failure_callback" in encoded_dag:
- dag.has_on_failure_callback = True
- keys_to_set_none = dag.get_serialized_fields() - encoded_dag.keys() - cls._CONSTRUCTOR_PARAMS.keys()
- for k in keys_to_set_none:
- setattr(dag, k, None)
- for task in dag.task_dict.values():
- SerializedBaseOperator.set_task_dag_references(task, dag)
- return dag
- @classmethod
- def _is_excluded(cls, var: Any, attrname: str, op: DAGNode):
- # {} is explicitly different from None in the case of DAG-level access control
- # and as a result we need to preserve empty dicts through serialization for this field
- if attrname == "_access_control" and var is not None:
- return False
- return super()._is_excluded(var, attrname, op)
- @classmethod
- def to_dict(cls, var: Any) -> dict:
- """Stringifies DAGs and operators contained by var and returns a dict of var."""
- json_dict = {"__version": cls.SERIALIZER_VERSION, "dag": cls.serialize_dag(var)}
- # Validate Serialized DAG with Json Schema. Raises Error if it mismatches
- cls.validate_schema(json_dict)
- return json_dict
- @classmethod
- def from_dict(cls, serialized_obj: dict) -> SerializedDAG:
- """Deserializes a python dict in to the DAG and operators it contains."""
- ver = serialized_obj.get("__version", "<not present>")
- if ver != cls.SERIALIZER_VERSION:
- raise ValueError(f"Unsure how to deserialize version {ver!r}")
- return cls.deserialize_dag(serialized_obj["dag"])
- class TaskGroupSerialization(BaseSerialization):
- """JSON serializable representation of a task group."""
- @classmethod
- def serialize_task_group(cls, task_group: TaskGroup) -> dict[str, Any] | None:
- """Serialize TaskGroup into a JSON object."""
- if not task_group:
- return None
- # task_group.xxx_ids needs to be sorted here, because task_group.xxx_ids is a set,
- # when converting set to list, the order is uncertain.
- # When calling json.dumps(self.data, sort_keys=True) to generate dag_hash, misjudgment will occur
- encoded = {
- "_group_id": task_group._group_id,
- "prefix_group_id": task_group.prefix_group_id,
- "tooltip": task_group.tooltip,
- "ui_color": task_group.ui_color,
- "ui_fgcolor": task_group.ui_fgcolor,
- "children": {
- label: child.serialize_for_task_group() for label, child in task_group.children.items()
- },
- "upstream_group_ids": cls.serialize(sorted(task_group.upstream_group_ids)),
- "downstream_group_ids": cls.serialize(sorted(task_group.downstream_group_ids)),
- "upstream_task_ids": cls.serialize(sorted(task_group.upstream_task_ids)),
- "downstream_task_ids": cls.serialize(sorted(task_group.downstream_task_ids)),
- }
- if isinstance(task_group, MappedTaskGroup):
- expand_input = task_group._expand_input
- encoded["expand_input"] = {
- "type": get_map_type_key(expand_input),
- "value": cls.serialize(expand_input.value),
- }
- encoded["is_mapped"] = True
- return encoded
- @classmethod
- def deserialize_task_group(
- cls,
- encoded_group: dict[str, Any],
- parent_group: TaskGroup | None,
- task_dict: dict[str, Operator],
- dag: SerializedDAG,
- ) -> TaskGroup:
- """Deserializes a TaskGroup from a JSON object."""
- group_id = cls.deserialize(encoded_group["_group_id"])
- kwargs = {
- key: cls.deserialize(encoded_group[key])
- for key in ["prefix_group_id", "tooltip", "ui_color", "ui_fgcolor"]
- }
- if not encoded_group.get("is_mapped"):
- group = TaskGroup(group_id=group_id, parent_group=parent_group, dag=dag, **kwargs)
- else:
- xi = encoded_group["expand_input"]
- group = MappedTaskGroup(
- group_id=group_id,
- parent_group=parent_group,
- dag=dag,
- expand_input=_ExpandInputRef(xi["type"], cls.deserialize(xi["value"])).deref(dag),
- **kwargs,
- )
- def set_ref(task: Operator) -> Operator:
- task.task_group = weakref.proxy(group)
- return task
- group.children = {
- label: (
- set_ref(task_dict[val])
- if _type == DAT.OP
- else cls.deserialize_task_group(val, group, task_dict, dag=dag)
- )
- for label, (_type, val) in encoded_group["children"].items()
- }
- group.upstream_group_ids.update(cls.deserialize(encoded_group["upstream_group_ids"]))
- group.downstream_group_ids.update(cls.deserialize(encoded_group["downstream_group_ids"]))
- group.upstream_task_ids.update(cls.deserialize(encoded_group["upstream_task_ids"]))
- group.downstream_task_ids.update(cls.deserialize(encoded_group["downstream_task_ids"]))
- return group
- def _has_kubernetes() -> bool:
- global HAS_KUBERNETES
- if "HAS_KUBERNETES" in globals():
- return HAS_KUBERNETES
- # Loading kube modules is expensive, so delay it until the last moment
- try:
- from kubernetes.client import models as k8s
- try:
- from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator
- except ImportError:
- from airflow.kubernetes.pre_7_4_0_compatibility.pod_generator import ( # type: ignore[assignment]
- PodGenerator,
- )
- globals()["k8s"] = k8s
- globals()["PodGenerator"] = PodGenerator
- # isort: on
- HAS_KUBERNETES = True
- except ImportError:
- HAS_KUBERNETES = False
- return HAS_KUBERNETES
|