mappedoperator.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928
  1. #
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing,
  13. # software distributed under the License is distributed on an
  14. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. # KIND, either express or implied. See the License for the
  16. # specific language governing permissions and limitations
  17. # under the License.
  18. from __future__ import annotations
  19. import collections.abc
  20. import contextlib
  21. import copy
  22. import warnings
  23. from typing import TYPE_CHECKING, Any, ClassVar, Collection, Iterable, Iterator, Mapping, Sequence, Union
  24. import attr
  25. import methodtools
  26. from airflow.exceptions import AirflowException, UnmappableOperator
  27. from airflow.models.abstractoperator import (
  28. DEFAULT_EXECUTOR,
  29. DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
  30. DEFAULT_OWNER,
  31. DEFAULT_POOL_SLOTS,
  32. DEFAULT_PRIORITY_WEIGHT,
  33. DEFAULT_QUEUE,
  34. DEFAULT_RETRIES,
  35. DEFAULT_RETRY_DELAY,
  36. DEFAULT_TRIGGER_RULE,
  37. DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING,
  38. DEFAULT_WEIGHT_RULE,
  39. AbstractOperator,
  40. NotMapped,
  41. )
  42. from airflow.models.expandinput import (
  43. DictOfListsExpandInput,
  44. ListOfDictsExpandInput,
  45. is_mappable,
  46. )
  47. from airflow.models.pool import Pool
  48. from airflow.serialization.enums import DagAttributeTypes
  49. from airflow.task.priority_strategy import PriorityWeightStrategy, validate_and_load_priority_weight_strategy
  50. from airflow.ti_deps.deps.mapped_task_expanded import MappedTaskIsExpanded
  51. from airflow.triggers.base import StartTriggerArgs
  52. from airflow.typing_compat import Literal
  53. from airflow.utils.context import context_update_for_unmapped
  54. from airflow.utils.helpers import is_container, prevent_duplicates
  55. from airflow.utils.task_instance_session import get_current_task_instance_session
  56. from airflow.utils.types import NOTSET
  57. from airflow.utils.xcom import XCOM_RETURN_KEY
  58. if TYPE_CHECKING:
  59. import datetime
  60. from typing import List
  61. import jinja2 # Slow import.
  62. import pendulum
  63. from sqlalchemy.orm.session import Session
  64. from airflow.models.abstractoperator import (
  65. TaskStateChangeCallback,
  66. )
  67. from airflow.models.baseoperator import BaseOperator
  68. from airflow.models.baseoperatorlink import BaseOperatorLink
  69. from airflow.models.dag import DAG
  70. from airflow.models.expandinput import (
  71. ExpandInput,
  72. OperatorExpandArgument,
  73. OperatorExpandKwargsArgument,
  74. )
  75. from airflow.models.operator import Operator
  76. from airflow.models.param import ParamsDict
  77. from airflow.models.xcom_arg import XComArg
  78. from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
  79. from airflow.utils.context import Context
  80. from airflow.utils.operator_resources import Resources
  81. from airflow.utils.task_group import TaskGroup
  82. from airflow.utils.trigger_rule import TriggerRule
  83. TaskStateChangeCallbackAttrType = Union[None, TaskStateChangeCallback, List[TaskStateChangeCallback]]
  84. ValidationSource = Union[Literal["expand"], Literal["partial"]]
  85. def validate_mapping_kwargs(op: type[BaseOperator], func: ValidationSource, value: dict[str, Any]) -> None:
  86. # use a dict so order of args is same as code order
  87. unknown_args = value.copy()
  88. for klass in op.mro():
  89. init = klass.__init__ # type: ignore[misc]
  90. try:
  91. param_names = init._BaseOperatorMeta__param_names
  92. except AttributeError:
  93. continue
  94. for name in param_names:
  95. value = unknown_args.pop(name, NOTSET)
  96. if func != "expand":
  97. continue
  98. if value is NOTSET:
  99. continue
  100. if is_mappable(value):
  101. continue
  102. type_name = type(value).__name__
  103. error = f"{op.__name__}.expand() got an unexpected type {type_name!r} for keyword argument {name}"
  104. raise ValueError(error)
  105. if not unknown_args:
  106. return # If we have no args left to check: stop looking at the MRO chain.
  107. if len(unknown_args) == 1:
  108. error = f"an unexpected keyword argument {unknown_args.popitem()[0]!r}"
  109. else:
  110. names = ", ".join(repr(n) for n in unknown_args)
  111. error = f"unexpected keyword arguments {names}"
  112. raise TypeError(f"{op.__name__}.{func}() got {error}")
  113. def ensure_xcomarg_return_value(arg: Any) -> None:
  114. from airflow.models.xcom_arg import XComArg
  115. if isinstance(arg, XComArg):
  116. for operator, key in arg.iter_references():
  117. if key != XCOM_RETURN_KEY:
  118. raise ValueError(f"cannot map over XCom with custom key {key!r} from {operator}")
  119. elif not is_container(arg):
  120. return
  121. elif isinstance(arg, collections.abc.Mapping):
  122. for v in arg.values():
  123. ensure_xcomarg_return_value(v)
  124. elif isinstance(arg, collections.abc.Iterable):
  125. for v in arg:
  126. ensure_xcomarg_return_value(v)
  127. @attr.define(kw_only=True, repr=False)
  128. class OperatorPartial:
  129. """
  130. An "intermediate state" returned by ``BaseOperator.partial()``.
  131. This only exists at DAG-parsing time; the only intended usage is for the
  132. user to call ``.expand()`` on it at some point (usually in a method chain) to
  133. create a ``MappedOperator`` to add into the DAG.
  134. """
  135. operator_class: type[BaseOperator]
  136. kwargs: dict[str, Any]
  137. params: ParamsDict | dict
  138. _expand_called: bool = False # Set when expand() is called to ease user debugging.
  139. def __attrs_post_init__(self):
  140. from airflow.operators.subdag import SubDagOperator
  141. if issubclass(self.operator_class, SubDagOperator):
  142. raise TypeError("Mapping over deprecated SubDagOperator is not supported")
  143. validate_mapping_kwargs(self.operator_class, "partial", self.kwargs)
  144. def __repr__(self) -> str:
  145. args = ", ".join(f"{k}={v!r}" for k, v in self.kwargs.items())
  146. return f"{self.operator_class.__name__}.partial({args})"
  147. def __del__(self):
  148. if not self._expand_called:
  149. try:
  150. task_id = repr(self.kwargs["task_id"])
  151. except KeyError:
  152. task_id = f"at {hex(id(self))}"
  153. warnings.warn(f"Task {task_id} was never mapped!", category=UserWarning, stacklevel=1)
  154. def expand(self, **mapped_kwargs: OperatorExpandArgument) -> MappedOperator:
  155. if not mapped_kwargs:
  156. raise TypeError("no arguments to expand against")
  157. validate_mapping_kwargs(self.operator_class, "expand", mapped_kwargs)
  158. prevent_duplicates(self.kwargs, mapped_kwargs, fail_reason="unmappable or already specified")
  159. # Since the input is already checked at parse time, we can set strict
  160. # to False to skip the checks on execution.
  161. return self._expand(DictOfListsExpandInput(mapped_kwargs), strict=False)
  162. def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> MappedOperator:
  163. from airflow.models.xcom_arg import XComArg
  164. if isinstance(kwargs, collections.abc.Sequence):
  165. for item in kwargs:
  166. if not isinstance(item, (XComArg, collections.abc.Mapping)):
  167. raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
  168. elif not isinstance(kwargs, XComArg):
  169. raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
  170. return self._expand(ListOfDictsExpandInput(kwargs), strict=strict)
  171. def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator:
  172. from airflow.operators.empty import EmptyOperator
  173. self._expand_called = True
  174. ensure_xcomarg_return_value(expand_input.value)
  175. partial_kwargs = self.kwargs.copy()
  176. task_id = partial_kwargs.pop("task_id")
  177. dag = partial_kwargs.pop("dag")
  178. task_group = partial_kwargs.pop("task_group")
  179. start_date = partial_kwargs.pop("start_date")
  180. end_date = partial_kwargs.pop("end_date")
  181. try:
  182. operator_name = self.operator_class.custom_operator_name # type: ignore
  183. except AttributeError:
  184. operator_name = self.operator_class.__name__
  185. op = MappedOperator(
  186. operator_class=self.operator_class,
  187. expand_input=expand_input,
  188. partial_kwargs=partial_kwargs,
  189. task_id=task_id,
  190. params=self.params,
  191. deps=MappedOperator.deps_for(self.operator_class),
  192. operator_extra_links=self.operator_class.operator_extra_links,
  193. template_ext=self.operator_class.template_ext,
  194. template_fields=self.operator_class.template_fields,
  195. template_fields_renderers=self.operator_class.template_fields_renderers,
  196. ui_color=self.operator_class.ui_color,
  197. ui_fgcolor=self.operator_class.ui_fgcolor,
  198. is_empty=issubclass(self.operator_class, EmptyOperator),
  199. task_module=self.operator_class.__module__,
  200. task_type=self.operator_class.__name__,
  201. operator_name=operator_name,
  202. dag=dag,
  203. task_group=task_group,
  204. start_date=start_date,
  205. end_date=end_date,
  206. disallow_kwargs_override=strict,
  207. # For classic operators, this points to expand_input because kwargs
  208. # to BaseOperator.expand() contribute to operator arguments.
  209. expand_input_attr="expand_input",
  210. start_trigger_args=self.operator_class.start_trigger_args,
  211. start_from_trigger=self.operator_class.start_from_trigger,
  212. )
  213. return op
  214. @attr.define(
  215. kw_only=True,
  216. # Disable custom __getstate__ and __setstate__ generation since it interacts
  217. # badly with Airflow's DAG serialization and pickling. When a mapped task is
  218. # deserialized, subclasses are coerced into MappedOperator, but when it goes
  219. # through DAG pickling, all attributes defined in the subclasses are dropped
  220. # by attrs's custom state management. Since attrs does not do anything too
  221. # special here (the logic is only important for slots=True), we use Python's
  222. # built-in implementation, which works (as proven by good old BaseOperator).
  223. getstate_setstate=False,
  224. )
  225. class MappedOperator(AbstractOperator):
  226. """Object representing a mapped operator in a DAG."""
  227. # This attribute serves double purpose. For a "normal" operator instance
  228. # loaded from DAG, this holds the underlying non-mapped operator class that
  229. # can be used to create an unmapped operator for execution. For an operator
  230. # recreated from a serialized DAG, however, this holds the serialized data
  231. # that can be used to unmap this into a SerializedBaseOperator.
  232. operator_class: type[BaseOperator] | dict[str, Any]
  233. expand_input: ExpandInput
  234. partial_kwargs: dict[str, Any]
  235. # Needed for serialization.
  236. task_id: str
  237. params: ParamsDict | dict
  238. deps: frozenset[BaseTIDep]
  239. operator_extra_links: Collection[BaseOperatorLink]
  240. template_ext: Sequence[str]
  241. template_fields: Collection[str]
  242. template_fields_renderers: dict[str, str]
  243. ui_color: str
  244. ui_fgcolor: str
  245. _is_empty: bool
  246. _task_module: str
  247. _task_type: str
  248. _operator_name: str
  249. start_trigger_args: StartTriggerArgs | None
  250. start_from_trigger: bool
  251. _needs_expansion: bool = True
  252. dag: DAG | None
  253. task_group: TaskGroup | None
  254. start_date: pendulum.DateTime | None
  255. end_date: pendulum.DateTime | None
  256. upstream_task_ids: set[str] = attr.ib(factory=set, init=False)
  257. downstream_task_ids: set[str] = attr.ib(factory=set, init=False)
  258. _disallow_kwargs_override: bool
  259. """Whether execution fails if ``expand_input`` has duplicates to ``partial_kwargs``.
  260. If *False*, values from ``expand_input`` under duplicate keys override those
  261. under corresponding keys in ``partial_kwargs``.
  262. """
  263. _expand_input_attr: str
  264. """Where to get kwargs to calculate expansion length against.
  265. This should be a name to call ``getattr()`` on.
  266. """
  267. subdag: None = None # Since we don't support SubDagOperator, this is always None.
  268. supports_lineage: bool = False
  269. HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = AbstractOperator.HIDE_ATTRS_FROM_UI | frozenset(
  270. ("parse_time_mapped_ti_count", "operator_class", "start_trigger_args", "start_from_trigger")
  271. )
  272. def __hash__(self):
  273. return id(self)
  274. def __repr__(self):
  275. return f"<Mapped({self._task_type}): {self.task_id}>"
  276. def __attrs_post_init__(self):
  277. from airflow.models.xcom_arg import XComArg
  278. if self.get_closest_mapped_task_group() is not None:
  279. raise NotImplementedError("operator expansion in an expanded task group is not yet supported")
  280. if self.task_group:
  281. self.task_group.add(self)
  282. if self.dag:
  283. self.dag.add_task(self)
  284. XComArg.apply_upstream_relationship(self, self.expand_input.value)
  285. for k, v in self.partial_kwargs.items():
  286. if k in self.template_fields:
  287. XComArg.apply_upstream_relationship(self, v)
  288. if self.partial_kwargs.get("sla") is not None:
  289. raise AirflowException(
  290. f"SLAs are unsupported with mapped tasks. Please set `sla=None` for task "
  291. f"{self.task_id!r}."
  292. )
  293. @methodtools.lru_cache(maxsize=None)
  294. @classmethod
  295. def get_serialized_fields(cls):
  296. # Not using 'cls' here since we only want to serialize base fields.
  297. return frozenset(attr.fields_dict(MappedOperator)) - {
  298. "dag",
  299. "deps",
  300. "expand_input", # This is needed to be able to accept XComArg.
  301. "subdag",
  302. "task_group",
  303. "upstream_task_ids",
  304. "supports_lineage",
  305. "_is_setup",
  306. "_is_teardown",
  307. "_on_failure_fail_dagrun",
  308. }
  309. @methodtools.lru_cache(maxsize=None)
  310. @staticmethod
  311. def deps_for(operator_class: type[BaseOperator]) -> frozenset[BaseTIDep]:
  312. operator_deps = operator_class.deps
  313. if not isinstance(operator_deps, collections.abc.Set):
  314. raise UnmappableOperator(
  315. f"'deps' must be a set defined as a class-level variable on {operator_class.__name__}, "
  316. f"not a {type(operator_deps).__name__}"
  317. )
  318. return operator_deps | {MappedTaskIsExpanded()}
  319. @property
  320. def task_type(self) -> str:
  321. """Implementing Operator."""
  322. return self._task_type
  323. @property
  324. def operator_name(self) -> str:
  325. return self._operator_name
  326. @property
  327. def inherits_from_empty_operator(self) -> bool:
  328. """Implementing Operator."""
  329. return self._is_empty
  330. @property
  331. def roots(self) -> Sequence[AbstractOperator]:
  332. """Implementing DAGNode."""
  333. return [self]
  334. @property
  335. def leaves(self) -> Sequence[AbstractOperator]:
  336. """Implementing DAGNode."""
  337. return [self]
  338. @property
  339. def task_display_name(self) -> str:
  340. return self.partial_kwargs.get("task_display_name") or self.task_id
  341. @property
  342. def owner(self) -> str: # type: ignore[override]
  343. return self.partial_kwargs.get("owner", DEFAULT_OWNER)
  344. @property
  345. def email(self) -> None | str | Iterable[str]:
  346. return self.partial_kwargs.get("email")
  347. @property
  348. def map_index_template(self) -> None | str:
  349. return self.partial_kwargs.get("map_index_template")
  350. @map_index_template.setter
  351. def map_index_template(self, value: str | None) -> None:
  352. self.partial_kwargs["map_index_template"] = value
  353. @property
  354. def trigger_rule(self) -> TriggerRule:
  355. return self.partial_kwargs.get("trigger_rule", DEFAULT_TRIGGER_RULE)
  356. @trigger_rule.setter
  357. def trigger_rule(self, value):
  358. self.partial_kwargs["trigger_rule"] = value
  359. @property
  360. def is_setup(self) -> bool:
  361. return bool(self.partial_kwargs.get("is_setup"))
  362. @is_setup.setter
  363. def is_setup(self, value: bool) -> None:
  364. self.partial_kwargs["is_setup"] = value
  365. @property
  366. def is_teardown(self) -> bool:
  367. return bool(self.partial_kwargs.get("is_teardown"))
  368. @is_teardown.setter
  369. def is_teardown(self, value: bool) -> None:
  370. self.partial_kwargs["is_teardown"] = value
  371. @property
  372. def depends_on_past(self) -> bool:
  373. return bool(self.partial_kwargs.get("depends_on_past"))
  374. @depends_on_past.setter
  375. def depends_on_past(self, value: bool) -> None:
  376. self.partial_kwargs["depends_on_past"] = value
  377. @property
  378. def ignore_first_depends_on_past(self) -> bool:
  379. value = self.partial_kwargs.get("ignore_first_depends_on_past", DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST)
  380. return bool(value)
  381. @ignore_first_depends_on_past.setter
  382. def ignore_first_depends_on_past(self, value: bool) -> None:
  383. self.partial_kwargs["ignore_first_depends_on_past"] = value
  384. @property
  385. def wait_for_past_depends_before_skipping(self) -> bool:
  386. value = self.partial_kwargs.get(
  387. "wait_for_past_depends_before_skipping", DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING
  388. )
  389. return bool(value)
  390. @wait_for_past_depends_before_skipping.setter
  391. def wait_for_past_depends_before_skipping(self, value: bool) -> None:
  392. self.partial_kwargs["wait_for_past_depends_before_skipping"] = value
  393. @property
  394. def wait_for_downstream(self) -> bool:
  395. return bool(self.partial_kwargs.get("wait_for_downstream"))
  396. @wait_for_downstream.setter
  397. def wait_for_downstream(self, value: bool) -> None:
  398. self.partial_kwargs["wait_for_downstream"] = value
  399. @property
  400. def retries(self) -> int:
  401. return self.partial_kwargs.get("retries", DEFAULT_RETRIES)
  402. @retries.setter
  403. def retries(self, value: int) -> None:
  404. self.partial_kwargs["retries"] = value
  405. @property
  406. def queue(self) -> str:
  407. return self.partial_kwargs.get("queue", DEFAULT_QUEUE)
  408. @queue.setter
  409. def queue(self, value: str) -> None:
  410. self.partial_kwargs["queue"] = value
  411. @property
  412. def pool(self) -> str:
  413. return self.partial_kwargs.get("pool", Pool.DEFAULT_POOL_NAME)
  414. @pool.setter
  415. def pool(self, value: str) -> None:
  416. self.partial_kwargs["pool"] = value
  417. @property
  418. def pool_slots(self) -> int:
  419. return self.partial_kwargs.get("pool_slots", DEFAULT_POOL_SLOTS)
  420. @pool_slots.setter
  421. def pool_slots(self, value: int) -> None:
  422. self.partial_kwargs["pool_slots"] = value
  423. @property
  424. def execution_timeout(self) -> datetime.timedelta | None:
  425. return self.partial_kwargs.get("execution_timeout")
  426. @execution_timeout.setter
  427. def execution_timeout(self, value: datetime.timedelta | None) -> None:
  428. self.partial_kwargs["execution_timeout"] = value
  429. @property
  430. def max_retry_delay(self) -> datetime.timedelta | None:
  431. return self.partial_kwargs.get("max_retry_delay")
  432. @max_retry_delay.setter
  433. def max_retry_delay(self, value: datetime.timedelta | None) -> None:
  434. self.partial_kwargs["max_retry_delay"] = value
  435. @property
  436. def retry_delay(self) -> datetime.timedelta:
  437. return self.partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY)
  438. @retry_delay.setter
  439. def retry_delay(self, value: datetime.timedelta) -> None:
  440. self.partial_kwargs["retry_delay"] = value
  441. @property
  442. def retry_exponential_backoff(self) -> bool:
  443. return bool(self.partial_kwargs.get("retry_exponential_backoff"))
  444. @retry_exponential_backoff.setter
  445. def retry_exponential_backoff(self, value: bool) -> None:
  446. self.partial_kwargs["retry_exponential_backoff"] = value
  447. @property
  448. def priority_weight(self) -> int: # type: ignore[override]
  449. return self.partial_kwargs.get("priority_weight", DEFAULT_PRIORITY_WEIGHT)
  450. @priority_weight.setter
  451. def priority_weight(self, value: int) -> None:
  452. self.partial_kwargs["priority_weight"] = value
  453. @property
  454. def weight_rule(self) -> PriorityWeightStrategy: # type: ignore[override]
  455. return validate_and_load_priority_weight_strategy(
  456. self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE)
  457. )
  458. @weight_rule.setter
  459. def weight_rule(self, value: str | PriorityWeightStrategy) -> None:
  460. self.partial_kwargs["weight_rule"] = validate_and_load_priority_weight_strategy(value)
  461. @property
  462. def sla(self) -> datetime.timedelta | None:
  463. return self.partial_kwargs.get("sla")
  464. @sla.setter
  465. def sla(self, value: datetime.timedelta | None) -> None:
  466. self.partial_kwargs["sla"] = value
  467. @property
  468. def max_active_tis_per_dag(self) -> int | None:
  469. return self.partial_kwargs.get("max_active_tis_per_dag")
  470. @max_active_tis_per_dag.setter
  471. def max_active_tis_per_dag(self, value: int | None) -> None:
  472. self.partial_kwargs["max_active_tis_per_dag"] = value
  473. @property
  474. def max_active_tis_per_dagrun(self) -> int | None:
  475. return self.partial_kwargs.get("max_active_tis_per_dagrun")
  476. @max_active_tis_per_dagrun.setter
  477. def max_active_tis_per_dagrun(self, value: int | None) -> None:
  478. self.partial_kwargs["max_active_tis_per_dagrun"] = value
  479. @property
  480. def resources(self) -> Resources | None:
  481. return self.partial_kwargs.get("resources")
  482. @property
  483. def on_execute_callback(self) -> TaskStateChangeCallbackAttrType:
  484. return self.partial_kwargs.get("on_execute_callback")
  485. @on_execute_callback.setter
  486. def on_execute_callback(self, value: TaskStateChangeCallbackAttrType) -> None:
  487. self.partial_kwargs["on_execute_callback"] = value
  488. @property
  489. def on_failure_callback(self) -> TaskStateChangeCallbackAttrType:
  490. return self.partial_kwargs.get("on_failure_callback")
  491. @on_failure_callback.setter
  492. def on_failure_callback(self, value: TaskStateChangeCallbackAttrType) -> None:
  493. self.partial_kwargs["on_failure_callback"] = value
  494. @property
  495. def on_retry_callback(self) -> TaskStateChangeCallbackAttrType:
  496. return self.partial_kwargs.get("on_retry_callback")
  497. @on_retry_callback.setter
  498. def on_retry_callback(self, value: TaskStateChangeCallbackAttrType) -> None:
  499. self.partial_kwargs["on_retry_callback"] = value
  500. @property
  501. def on_success_callback(self) -> TaskStateChangeCallbackAttrType:
  502. return self.partial_kwargs.get("on_success_callback")
  503. @on_success_callback.setter
  504. def on_success_callback(self, value: TaskStateChangeCallbackAttrType) -> None:
  505. self.partial_kwargs["on_success_callback"] = value
  506. @property
  507. def on_skipped_callback(self) -> TaskStateChangeCallbackAttrType:
  508. return self.partial_kwargs.get("on_skipped_callback")
  509. @on_skipped_callback.setter
  510. def on_skipped_callback(self, value: TaskStateChangeCallbackAttrType) -> None:
  511. self.partial_kwargs["on_skipped_callback"] = value
  512. @property
  513. def run_as_user(self) -> str | None:
  514. return self.partial_kwargs.get("run_as_user")
  515. @property
  516. def executor(self) -> str | None:
  517. return self.partial_kwargs.get("executor", DEFAULT_EXECUTOR)
  518. @property
  519. def executor_config(self) -> dict:
  520. return self.partial_kwargs.get("executor_config", {})
  521. @property # type: ignore[override]
  522. def inlets(self) -> list[Any]: # type: ignore[override]
  523. return self.partial_kwargs.get("inlets", [])
  524. @inlets.setter
  525. def inlets(self, value: list[Any]) -> None: # type: ignore[override]
  526. self.partial_kwargs["inlets"] = value
  527. @property # type: ignore[override]
  528. def outlets(self) -> list[Any]: # type: ignore[override]
  529. return self.partial_kwargs.get("outlets", [])
  530. @outlets.setter
  531. def outlets(self, value: list[Any]) -> None: # type: ignore[override]
  532. self.partial_kwargs["outlets"] = value
  533. @property
  534. def doc(self) -> str | None:
  535. return self.partial_kwargs.get("doc")
  536. @property
  537. def doc_md(self) -> str | None:
  538. return self.partial_kwargs.get("doc_md")
  539. @property
  540. def doc_json(self) -> str | None:
  541. return self.partial_kwargs.get("doc_json")
  542. @property
  543. def doc_yaml(self) -> str | None:
  544. return self.partial_kwargs.get("doc_yaml")
  545. @property
  546. def doc_rst(self) -> str | None:
  547. return self.partial_kwargs.get("doc_rst")
  548. @property
  549. def allow_nested_operators(self) -> bool:
  550. return bool(self.partial_kwargs.get("allow_nested_operators"))
  551. def get_dag(self) -> DAG | None:
  552. """Implement Operator."""
  553. return self.dag
  554. @property
  555. def output(self) -> XComArg:
  556. """Return reference to XCom pushed by current operator."""
  557. from airflow.models.xcom_arg import XComArg
  558. return XComArg(operator=self)
  559. def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
  560. """Implement DAGNode."""
  561. return DagAttributeTypes.OP, self.task_id
  562. def _expand_mapped_kwargs(
  563. self, context: Context, session: Session, *, include_xcom: bool
  564. ) -> tuple[Mapping[str, Any], set[int]]:
  565. """
  566. Get the kwargs to create the unmapped operator.
  567. This exists because taskflow operators expand against op_kwargs, not the
  568. entire operator kwargs dict.
  569. """
  570. return self._get_specified_expand_input().resolve(context, session, include_xcom=include_xcom)
  571. def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> dict[str, Any]:
  572. """
  573. Get init kwargs to unmap the underlying operator class.
  574. :param mapped_kwargs: The dict returned by ``_expand_mapped_kwargs``.
  575. """
  576. if strict:
  577. prevent_duplicates(
  578. self.partial_kwargs,
  579. mapped_kwargs,
  580. fail_reason="unmappable or already specified",
  581. )
  582. # If params appears in the mapped kwargs, we need to merge it into the
  583. # partial params, overriding existing keys.
  584. params = copy.copy(self.params)
  585. with contextlib.suppress(KeyError):
  586. params.update(mapped_kwargs["params"])
  587. # Ordering is significant; mapped kwargs should override partial ones,
  588. # and the specially handled params should be respected.
  589. return {
  590. "task_id": self.task_id,
  591. "dag": self.dag,
  592. "task_group": self.task_group,
  593. "start_date": self.start_date,
  594. "end_date": self.end_date,
  595. **self.partial_kwargs,
  596. **mapped_kwargs,
  597. "params": params,
  598. }
  599. def expand_start_from_trigger(self, *, context: Context, session: Session) -> bool:
  600. """
  601. Get the start_from_trigger value of the current abstract operator.
  602. MappedOperator uses this to unmap start_from_trigger to decide whether to start the task
  603. execution directly from triggerer.
  604. :meta private:
  605. """
  606. # start_from_trigger only makes sense when start_trigger_args exists.
  607. if not self.start_trigger_args:
  608. return False
  609. mapped_kwargs, _ = self._expand_mapped_kwargs(context, session, include_xcom=False)
  610. if self._disallow_kwargs_override:
  611. prevent_duplicates(
  612. self.partial_kwargs,
  613. mapped_kwargs,
  614. fail_reason="unmappable or already specified",
  615. )
  616. # Ordering is significant; mapped kwargs should override partial ones.
  617. return mapped_kwargs.get(
  618. "start_from_trigger", self.partial_kwargs.get("start_from_trigger", self.start_from_trigger)
  619. )
  620. def expand_start_trigger_args(self, *, context: Context, session: Session) -> StartTriggerArgs | None:
  621. """
  622. Get the kwargs to create the unmapped start_trigger_args.
  623. This method is for allowing mapped operator to start execution from triggerer.
  624. """
  625. if not self.start_trigger_args:
  626. return None
  627. mapped_kwargs, _ = self._expand_mapped_kwargs(context, session, include_xcom=False)
  628. if self._disallow_kwargs_override:
  629. prevent_duplicates(
  630. self.partial_kwargs,
  631. mapped_kwargs,
  632. fail_reason="unmappable or already specified",
  633. )
  634. # Ordering is significant; mapped kwargs should override partial ones.
  635. trigger_kwargs = mapped_kwargs.get(
  636. "trigger_kwargs",
  637. self.partial_kwargs.get("trigger_kwargs", self.start_trigger_args.trigger_kwargs),
  638. )
  639. next_kwargs = mapped_kwargs.get(
  640. "next_kwargs",
  641. self.partial_kwargs.get("next_kwargs", self.start_trigger_args.next_kwargs),
  642. )
  643. timeout = mapped_kwargs.get(
  644. "trigger_timeout", self.partial_kwargs.get("trigger_timeout", self.start_trigger_args.timeout)
  645. )
  646. return StartTriggerArgs(
  647. trigger_cls=self.start_trigger_args.trigger_cls,
  648. trigger_kwargs=trigger_kwargs,
  649. next_method=self.start_trigger_args.next_method,
  650. next_kwargs=next_kwargs,
  651. timeout=timeout,
  652. )
  653. def unmap(self, resolve: None | Mapping[str, Any] | tuple[Context, Session]) -> BaseOperator:
  654. """
  655. Get the "normal" Operator after applying the current mapping.
  656. The *resolve* argument is only used if ``operator_class`` is a real
  657. class, i.e. if this operator is not serialized. If ``operator_class`` is
  658. not a class (i.e. this DAG has been deserialized), this returns a
  659. SerializedBaseOperator that "looks like" the actual unmapping result.
  660. If *resolve* is a two-tuple (context, session), the information is used
  661. to resolve the mapped arguments into init arguments. If it is a mapping,
  662. no resolving happens, the mapping directly provides those init arguments
  663. resolved from mapped kwargs.
  664. :meta private:
  665. """
  666. if isinstance(self.operator_class, type):
  667. if isinstance(resolve, collections.abc.Mapping):
  668. kwargs = resolve
  669. elif resolve is not None:
  670. kwargs, _ = self._expand_mapped_kwargs(*resolve, include_xcom=True)
  671. else:
  672. raise RuntimeError("cannot unmap a non-serialized operator without context")
  673. kwargs = self._get_unmap_kwargs(kwargs, strict=self._disallow_kwargs_override)
  674. is_setup = kwargs.pop("is_setup", False)
  675. is_teardown = kwargs.pop("is_teardown", False)
  676. on_failure_fail_dagrun = kwargs.pop("on_failure_fail_dagrun", False)
  677. op = self.operator_class(**kwargs, _airflow_from_mapped=True)
  678. # We need to overwrite task_id here because BaseOperator further
  679. # mangles the task_id based on the task hierarchy (namely, group_id
  680. # is prepended, and '__N' appended to deduplicate). This is hacky,
  681. # but better than duplicating the whole mangling logic.
  682. op.task_id = self.task_id
  683. op.is_setup = is_setup
  684. op.is_teardown = is_teardown
  685. op.on_failure_fail_dagrun = on_failure_fail_dagrun
  686. op.downstream_task_ids = self.downstream_task_ids
  687. op.upstream_task_ids = self.upstream_task_ids
  688. return op
  689. # After a mapped operator is serialized, there's no real way to actually
  690. # unmap it since we've lost access to the underlying operator class.
  691. # This tries its best to simply "forward" all the attributes on this
  692. # mapped operator to a new SerializedBaseOperator instance.
  693. from airflow.serialization.serialized_objects import SerializedBaseOperator
  694. op = SerializedBaseOperator(task_id=self.task_id, params=self.params, _airflow_from_mapped=True)
  695. SerializedBaseOperator.populate_operator(op, self.operator_class)
  696. if self.dag is not None: # For Mypy; we only serialize tasks in a DAG so the check always satisfies.
  697. SerializedBaseOperator.set_task_dag_references(op, self.dag)
  698. return op
  699. def _get_specified_expand_input(self) -> ExpandInput:
  700. """Input received from the expand call on the operator."""
  701. return getattr(self, self._expand_input_attr)
  702. def prepare_for_execution(self) -> MappedOperator:
  703. # Since a mapped operator cannot be used for execution, and an unmapped
  704. # BaseOperator needs to be created later (see render_template_fields),
  705. # we don't need to create a copy of the MappedOperator here.
  706. return self
  707. def iter_mapped_dependencies(self) -> Iterator[Operator]:
  708. """Upstream dependencies that provide XComs used by this task for task mapping."""
  709. from airflow.models.xcom_arg import XComArg
  710. for operator, _ in XComArg.iter_xcom_references(self._get_specified_expand_input()):
  711. yield operator
  712. @methodtools.lru_cache(maxsize=None)
  713. def get_parse_time_mapped_ti_count(self) -> int:
  714. current_count = self._get_specified_expand_input().get_parse_time_mapped_ti_count()
  715. try:
  716. parent_count = super().get_parse_time_mapped_ti_count()
  717. except NotMapped:
  718. return current_count
  719. return parent_count * current_count
  720. def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int:
  721. from airflow.serialization.serialized_objects import _ExpandInputRef
  722. exp_input = self._get_specified_expand_input()
  723. if isinstance(exp_input, _ExpandInputRef):
  724. exp_input = exp_input.deref(self.dag)
  725. current_count = exp_input.get_total_map_length(run_id, session=session)
  726. try:
  727. parent_count = super().get_mapped_ti_count(run_id, session=session)
  728. except NotMapped:
  729. return current_count
  730. return parent_count * current_count
  731. def render_template_fields(
  732. self,
  733. context: Context,
  734. jinja_env: jinja2.Environment | None = None,
  735. ) -> None:
  736. """
  737. Template all attributes listed in *self.template_fields*.
  738. This updates *context* to reference the map-expanded task and relevant
  739. information, without modifying the mapped operator. The expanded task
  740. in *context* is then rendered in-place.
  741. :param context: Context dict with values to apply on content.
  742. :param jinja_env: Jinja environment to use for rendering.
  743. """
  744. if not jinja_env:
  745. jinja_env = self.get_template_env()
  746. # We retrieve the session here, stored by _run_raw_task in set_current_task_session
  747. # context manager - we cannot pass the session via @provide_session because the signature
  748. # of render_template_fields is defined by BaseOperator and there are already many subclasses
  749. # overriding it, so changing the signature is not an option. However render_template_fields is
  750. # always executed within "_run_raw_task" so we make sure that _run_raw_task uses the
  751. # set_current_task_session context manager to store the session in the current task.
  752. session = get_current_task_instance_session()
  753. mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context, session, include_xcom=True)
  754. unmapped_task = self.unmap(mapped_kwargs)
  755. context_update_for_unmapped(context, unmapped_task)
  756. # Since the operators that extend `BaseOperator` are not subclasses of
  757. # `MappedOperator`, we need to call `_do_render_template_fields` from
  758. # the unmapped task in order to call the operator method when we override
  759. # it to customize the parsing of nested fields.
  760. unmapped_task._do_render_template_fields(
  761. parent=unmapped_task,
  762. template_fields=self.template_fields,
  763. context=context,
  764. jinja_env=jinja_env,
  765. seen_oids=seen_oids,
  766. )