xcom_arg.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737
  1. # Licensed to the Apache Software Foundation (ASF) under one
  2. # or more contributor license agreements. See the NOTICE file
  3. # distributed with this work for additional information
  4. # regarding copyright ownership. The ASF licenses this file
  5. # to you under the Apache License, Version 2.0 (the
  6. # "License"); you may not use this file except in compliance
  7. # with the License. You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing,
  12. # software distributed under the License is distributed on an
  13. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  14. # KIND, either express or implied. See the License for the
  15. # specific language governing permissions and limitations
  16. # under the License.
  17. from __future__ import annotations
  18. import contextlib
  19. import inspect
  20. import itertools
  21. from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Mapping, Sequence, Union, overload
  22. from sqlalchemy import func, or_, select
  23. from airflow.api_internal.internal_api_call import internal_api_call
  24. from airflow.exceptions import AirflowException, XComNotFound
  25. from airflow.models import MappedOperator, TaskInstance
  26. from airflow.models.abstractoperator import AbstractOperator
  27. from airflow.models.taskmixin import DependencyMixin
  28. from airflow.utils.db import exists_query
  29. from airflow.utils.mixins import ResolveMixin
  30. from airflow.utils.session import NEW_SESSION, provide_session
  31. from airflow.utils.setup_teardown import SetupTeardownContext
  32. from airflow.utils.state import State
  33. from airflow.utils.trigger_rule import TriggerRule
  34. from airflow.utils.types import NOTSET, ArgNotSet
  35. from airflow.utils.xcom import XCOM_RETURN_KEY
  36. if TYPE_CHECKING:
  37. from sqlalchemy.orm import Session
  38. from airflow.models.baseoperator import BaseOperator
  39. from airflow.models.dag import DAG
  40. from airflow.models.operator import Operator
  41. from airflow.models.taskmixin import DAGNode
  42. from airflow.utils.context import Context
  43. from airflow.utils.edgemodifier import EdgeModifier
  44. # Callable objects contained by MapXComArg. We only accept callables from
  45. # the user, but deserialize them into strings in a serialized XComArg for
  46. # safety (those callables are arbitrary user code).
  47. MapCallables = Sequence[Union[Callable[[Any], Any], str]]
  48. class XComArg(ResolveMixin, DependencyMixin):
  49. """
  50. Reference to an XCom value pushed from another operator.
  51. The implementation supports::
  52. xcomarg >> op
  53. xcomarg << op
  54. op >> xcomarg # By BaseOperator code
  55. op << xcomarg # By BaseOperator code
  56. **Example**: The moment you get a result from any operator (decorated or regular) you can ::
  57. any_op = AnyOperator()
  58. xcomarg = XComArg(any_op)
  59. # or equivalently
  60. xcomarg = any_op.output
  61. my_op = MyOperator()
  62. my_op >> xcomarg
  63. This object can be used in legacy Operators via Jinja.
  64. **Example**: You can make this result to be part of any generated string::
  65. any_op = AnyOperator()
  66. xcomarg = any_op.output
  67. op1 = MyOperator(my_text_message=f"the value is {xcomarg}")
  68. op2 = MyOperator(my_text_message=f"the value is {xcomarg['topic']}")
  69. :param operator: Operator instance to which the XComArg references.
  70. :param key: Key used to pull the XCom value. Defaults to *XCOM_RETURN_KEY*,
  71. i.e. the referenced operator's return value.
  72. """
  73. @overload
  74. def __new__(cls: type[XComArg], operator: Operator, key: str = XCOM_RETURN_KEY) -> XComArg:
  75. """Execute when the user writes ``XComArg(...)`` directly."""
  76. @overload
  77. def __new__(cls: type[XComArg]) -> XComArg:
  78. """Execute by Python internals from subclasses."""
  79. def __new__(cls, *args, **kwargs) -> XComArg:
  80. if cls is XComArg:
  81. return PlainXComArg(*args, **kwargs)
  82. return super().__new__(cls)
  83. @staticmethod
  84. def iter_xcom_references(arg: Any) -> Iterator[tuple[Operator, str]]:
  85. """
  86. Return XCom references in an arbitrary value.
  87. Recursively traverse ``arg`` and look for XComArg instances in any
  88. collection objects, and instances with ``template_fields`` set.
  89. """
  90. if isinstance(arg, ResolveMixin):
  91. yield from arg.iter_references()
  92. elif isinstance(arg, (tuple, set, list)):
  93. for elem in arg:
  94. yield from XComArg.iter_xcom_references(elem)
  95. elif isinstance(arg, dict):
  96. for elem in arg.values():
  97. yield from XComArg.iter_xcom_references(elem)
  98. elif isinstance(arg, AbstractOperator):
  99. for attr in arg.template_fields:
  100. yield from XComArg.iter_xcom_references(getattr(arg, attr))
  101. @staticmethod
  102. def apply_upstream_relationship(op: Operator, arg: Any):
  103. """
  104. Set dependency for XComArgs.
  105. This looks for XComArg objects in ``arg`` "deeply" (looking inside
  106. collections objects and classes decorated with ``template_fields``), and
  107. sets the relationship to ``op`` on any found.
  108. """
  109. for operator, _ in XComArg.iter_xcom_references(arg):
  110. op.set_upstream(operator)
  111. @property
  112. def roots(self) -> list[DAGNode]:
  113. """Required by TaskMixin."""
  114. return [op for op, _ in self.iter_references()]
  115. @property
  116. def leaves(self) -> list[DAGNode]:
  117. """Required by TaskMixin."""
  118. return [op for op, _ in self.iter_references()]
  119. def set_upstream(
  120. self,
  121. task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
  122. edge_modifier: EdgeModifier | None = None,
  123. ):
  124. """Proxy to underlying operator set_upstream method. Required by TaskMixin."""
  125. for operator, _ in self.iter_references():
  126. operator.set_upstream(task_or_task_list, edge_modifier)
  127. def set_downstream(
  128. self,
  129. task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
  130. edge_modifier: EdgeModifier | None = None,
  131. ):
  132. """Proxy to underlying operator set_downstream method. Required by TaskMixin."""
  133. for operator, _ in self.iter_references():
  134. operator.set_downstream(task_or_task_list, edge_modifier)
  135. def _serialize(self) -> dict[str, Any]:
  136. """
  137. Serialize an XComArg.
  138. The implementation should be the inverse function to ``deserialize``,
  139. returning a data dict converted from this XComArg derivative. DAG
  140. serialization does not call this directly, but ``serialize_xcom_arg``
  141. instead, which adds additional information to dispatch deserialization
  142. to the correct class.
  143. """
  144. raise NotImplementedError()
  145. @classmethod
  146. def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg:
  147. """
  148. Deserialize an XComArg.
  149. The implementation should be the inverse function to ``serialize``,
  150. implementing given a data dict converted from this XComArg derivative,
  151. how the original XComArg should be created. DAG serialization relies on
  152. additional information added in ``serialize_xcom_arg`` to dispatch data
  153. dicts to the correct ``_deserialize`` information, so this function does
  154. not need to validate whether the incoming data contains correct keys.
  155. """
  156. raise NotImplementedError()
  157. def map(self, f: Callable[[Any], Any]) -> MapXComArg:
  158. return MapXComArg(self, [f])
  159. def zip(self, *others: XComArg, fillvalue: Any = NOTSET) -> ZipXComArg:
  160. return ZipXComArg([self, *others], fillvalue=fillvalue)
  161. def concat(self, *others: XComArg) -> ConcatXComArg:
  162. return ConcatXComArg([self, *others])
  163. def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
  164. """
  165. Inspect length of pushed value for task-mapping.
  166. This is used to determine how many task instances the scheduler should
  167. create for a downstream using this XComArg for task-mapping.
  168. *None* may be returned if the depended XCom has not been pushed.
  169. """
  170. raise NotImplementedError()
  171. @provide_session
  172. def resolve(self, context: Context, session: Session = NEW_SESSION, *, include_xcom: bool = True) -> Any:
  173. """
  174. Pull XCom value.
  175. This should only be called during ``op.execute()`` with an appropriate
  176. context (e.g. generated from ``TaskInstance.get_template_context()``).
  177. Although the ``ResolveMixin`` parent mixin also has a ``resolve``
  178. protocol, this adds the optional ``session`` argument that some of the
  179. subclasses need.
  180. :meta private:
  181. """
  182. raise NotImplementedError()
  183. def __enter__(self):
  184. if not self.operator.is_setup and not self.operator.is_teardown:
  185. raise AirflowException("Only setup/teardown tasks can be used as context managers.")
  186. SetupTeardownContext.push_setup_teardown_task(self.operator)
  187. return SetupTeardownContext
  188. def __exit__(self, exc_type, exc_val, exc_tb):
  189. SetupTeardownContext.set_work_task_roots_and_leaves()
  190. @internal_api_call
  191. @provide_session
  192. def _get_task_map_length(
  193. *,
  194. dag_id: str,
  195. task_id: str,
  196. run_id: str,
  197. is_mapped: bool,
  198. session: Session = NEW_SESSION,
  199. ) -> int | None:
  200. from airflow.models.taskinstance import TaskInstance
  201. from airflow.models.taskmap import TaskMap
  202. from airflow.models.xcom import XCom
  203. if is_mapped:
  204. unfinished_ti_exists = exists_query(
  205. TaskInstance.dag_id == dag_id,
  206. TaskInstance.run_id == run_id,
  207. TaskInstance.task_id == task_id,
  208. # Special NULL treatment is needed because 'state' can be NULL.
  209. # The "IN" part would produce "NULL NOT IN ..." and eventually
  210. # "NULl = NULL", which is a big no-no in SQL.
  211. or_(
  212. TaskInstance.state.is_(None),
  213. TaskInstance.state.in_(s.value for s in State.unfinished if s is not None),
  214. ),
  215. session=session,
  216. )
  217. if unfinished_ti_exists:
  218. return None # Not all of the expanded tis are done yet.
  219. query = select(func.count(XCom.map_index)).where(
  220. XCom.dag_id == dag_id,
  221. XCom.run_id == run_id,
  222. XCom.task_id == task_id,
  223. XCom.map_index >= 0,
  224. XCom.key == XCOM_RETURN_KEY,
  225. )
  226. else:
  227. query = select(TaskMap.length).where(
  228. TaskMap.dag_id == dag_id,
  229. TaskMap.run_id == run_id,
  230. TaskMap.task_id == task_id,
  231. TaskMap.map_index < 0,
  232. )
  233. return session.scalar(query)
  234. class PlainXComArg(XComArg):
  235. """
  236. Reference to one single XCom without any additional semantics.
  237. This class should not be accessed directly, but only through XComArg. The
  238. class inheritance chain and ``__new__`` is implemented in this slightly
  239. convoluted way because we want to
  240. a. Allow the user to continue using XComArg directly for the simple
  241. semantics (see documentation of the base class for details).
  242. b. Make ``isinstance(thing, XComArg)`` be able to detect all kinds of XCom
  243. references.
  244. c. Not allow many properties of PlainXComArg (including ``__getitem__`` and
  245. ``__str__``) to exist on other kinds of XComArg implementations since
  246. they don't make sense.
  247. :meta private:
  248. """
  249. def __init__(self, operator: Operator, key: str = XCOM_RETURN_KEY):
  250. self.operator = operator
  251. self.key = key
  252. def __eq__(self, other: Any) -> bool:
  253. if not isinstance(other, PlainXComArg):
  254. return NotImplemented
  255. return self.operator == other.operator and self.key == other.key
  256. def __getitem__(self, item: str) -> XComArg:
  257. """Implement xcomresult['some_result_key']."""
  258. if not isinstance(item, str):
  259. raise ValueError(f"XComArg only supports str lookup, received {type(item).__name__}")
  260. return PlainXComArg(operator=self.operator, key=item)
  261. def __iter__(self):
  262. """
  263. Override iterable protocol to raise error explicitly.
  264. The default ``__iter__`` implementation in Python calls ``__getitem__``
  265. with 0, 1, 2, etc. until it hits an ``IndexError``. This does not work
  266. well with our custom ``__getitem__`` implementation, and results in poor
  267. DAG-writing experience since a misplaced ``*`` expansion would create an
  268. infinite loop consuming the entire DAG parser.
  269. This override catches the error eagerly, so an incorrectly implemented
  270. DAG fails fast and avoids wasting resources on nonsensical iterating.
  271. """
  272. raise TypeError("'XComArg' object is not iterable")
  273. def __repr__(self) -> str:
  274. if self.key == XCOM_RETURN_KEY:
  275. return f"XComArg({self.operator!r})"
  276. return f"XComArg({self.operator!r}, {self.key!r})"
  277. def __str__(self) -> str:
  278. """
  279. Backward compatibility for old-style jinja used in Airflow Operators.
  280. **Example**: to use XComArg at BashOperator::
  281. BashOperator(cmd=f"... { xcomarg } ...")
  282. :return:
  283. """
  284. xcom_pull_kwargs = [
  285. f"task_ids='{self.operator.task_id}'",
  286. f"dag_id='{self.operator.dag_id}'",
  287. ]
  288. if self.key is not None:
  289. xcom_pull_kwargs.append(f"key='{self.key}'")
  290. xcom_pull_str = ", ".join(xcom_pull_kwargs)
  291. # {{{{ are required for escape {{ in f-string
  292. xcom_pull = f"{{{{ task_instance.xcom_pull({xcom_pull_str}) }}}}"
  293. return xcom_pull
  294. def _serialize(self) -> dict[str, Any]:
  295. return {"task_id": self.operator.task_id, "key": self.key}
  296. @classmethod
  297. def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg:
  298. return cls(dag.get_task(data["task_id"]), data["key"])
  299. @property
  300. def is_setup(self) -> bool:
  301. return self.operator.is_setup
  302. @is_setup.setter
  303. def is_setup(self, val: bool):
  304. self.operator.is_setup = val
  305. @property
  306. def is_teardown(self) -> bool:
  307. return self.operator.is_teardown
  308. @is_teardown.setter
  309. def is_teardown(self, val: bool):
  310. self.operator.is_teardown = val
  311. @property
  312. def on_failure_fail_dagrun(self) -> bool:
  313. return self.operator.on_failure_fail_dagrun
  314. @on_failure_fail_dagrun.setter
  315. def on_failure_fail_dagrun(self, val: bool):
  316. self.operator.on_failure_fail_dagrun = val
  317. def as_setup(self) -> DependencyMixin:
  318. for operator, _ in self.iter_references():
  319. operator.is_setup = True
  320. return self
  321. def as_teardown(
  322. self,
  323. *,
  324. setups: BaseOperator | Iterable[BaseOperator] | ArgNotSet = NOTSET,
  325. on_failure_fail_dagrun=NOTSET,
  326. ):
  327. for operator, _ in self.iter_references():
  328. operator.is_teardown = True
  329. operator.trigger_rule = TriggerRule.ALL_DONE_SETUP_SUCCESS
  330. if on_failure_fail_dagrun is not NOTSET:
  331. operator.on_failure_fail_dagrun = on_failure_fail_dagrun
  332. if not isinstance(setups, ArgNotSet):
  333. setups = [setups] if isinstance(setups, DependencyMixin) else setups
  334. for s in setups:
  335. s.is_setup = True
  336. s >> operator
  337. return self
  338. def iter_references(self) -> Iterator[tuple[Operator, str]]:
  339. yield self.operator, self.key
  340. def map(self, f: Callable[[Any], Any]) -> MapXComArg:
  341. if self.key != XCOM_RETURN_KEY:
  342. raise ValueError("cannot map against non-return XCom")
  343. return super().map(f)
  344. def zip(self, *others: XComArg, fillvalue: Any = NOTSET) -> ZipXComArg:
  345. if self.key != XCOM_RETURN_KEY:
  346. raise ValueError("cannot map against non-return XCom")
  347. return super().zip(*others, fillvalue=fillvalue)
  348. def concat(self, *others: XComArg) -> ConcatXComArg:
  349. if self.key != XCOM_RETURN_KEY:
  350. raise ValueError("cannot concatenate non-return XCom")
  351. return super().concat(*others)
  352. def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
  353. return _get_task_map_length(
  354. dag_id=self.operator.dag_id,
  355. task_id=self.operator.task_id,
  356. is_mapped=isinstance(self.operator, MappedOperator),
  357. run_id=run_id,
  358. session=session,
  359. )
  360. @provide_session
  361. def resolve(self, context: Context, session: Session = NEW_SESSION, *, include_xcom: bool = True) -> Any:
  362. ti = context["ti"]
  363. if TYPE_CHECKING:
  364. assert isinstance(ti, TaskInstance)
  365. task_id = self.operator.task_id
  366. map_indexes = ti.get_relevant_upstream_map_indexes(
  367. self.operator,
  368. context["expanded_ti_count"],
  369. session=session,
  370. )
  371. result = ti.xcom_pull(
  372. task_ids=task_id,
  373. map_indexes=map_indexes,
  374. key=self.key,
  375. default=NOTSET,
  376. session=session,
  377. )
  378. if not isinstance(result, ArgNotSet):
  379. return result
  380. if self.key == XCOM_RETURN_KEY:
  381. return None
  382. if getattr(self.operator, "multiple_outputs", False):
  383. # If the operator is set to have multiple outputs and it was not executed,
  384. # we should return "None" instead of showing an error. This is because when
  385. # multiple outputs XComs are created, the XCom keys associated with them will have
  386. # different names than the predefined "XCOM_RETURN_KEY" and won't be found.
  387. # Therefore, it's better to return "None" like we did above where self.key==XCOM_RETURN_KEY.
  388. return None
  389. raise XComNotFound(ti.dag_id, task_id, self.key)
  390. def _get_callable_name(f: Callable | str) -> str:
  391. """Try to "describe" a callable by getting its name."""
  392. if callable(f):
  393. return f.__name__
  394. # Parse the source to find whatever is behind "def". For safety, we don't
  395. # want to evaluate the code in any meaningful way!
  396. with contextlib.suppress(Exception):
  397. kw, name, _ = f.lstrip().split(None, 2)
  398. if kw == "def":
  399. return name
  400. return "<function>"
  401. class _MapResult(Sequence):
  402. def __init__(self, value: Sequence | dict, callables: MapCallables) -> None:
  403. self.value = value
  404. self.callables = callables
  405. def __getitem__(self, index: Any) -> Any:
  406. value = self.value[index]
  407. # In the worker, we can access all actual callables. Call them.
  408. callables = [f for f in self.callables if callable(f)]
  409. if len(callables) == len(self.callables):
  410. for f in callables:
  411. value = f(value)
  412. return value
  413. # In the scheduler, we don't have access to the actual callables, nor do
  414. # we want to run it since it's arbitrary code. This builds a string to
  415. # represent the call chain in the UI or logs instead.
  416. for v in self.callables:
  417. value = f"{_get_callable_name(v)}({value})"
  418. return value
  419. def __len__(self) -> int:
  420. return len(self.value)
  421. class MapXComArg(XComArg):
  422. """
  423. An XCom reference with ``map()`` call(s) applied.
  424. This is based on an XComArg, but also applies a series of "transforms" that
  425. convert the pulled XCom value.
  426. :meta private:
  427. """
  428. def __init__(self, arg: XComArg, callables: MapCallables) -> None:
  429. for c in callables:
  430. if getattr(c, "_airflow_is_task_decorator", False):
  431. raise ValueError("map() argument must be a plain function, not a @task operator")
  432. self.arg = arg
  433. self.callables = callables
  434. def __repr__(self) -> str:
  435. map_calls = "".join(f".map({_get_callable_name(f)})" for f in self.callables)
  436. return f"{self.arg!r}{map_calls}"
  437. def _serialize(self) -> dict[str, Any]:
  438. return {
  439. "arg": serialize_xcom_arg(self.arg),
  440. "callables": [inspect.getsource(c) if callable(c) else c for c in self.callables],
  441. }
  442. @classmethod
  443. def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg:
  444. # We are deliberately NOT deserializing the callables. These are shown
  445. # in the UI, and displaying a function object is useless.
  446. return cls(deserialize_xcom_arg(data["arg"], dag), data["callables"])
  447. def iter_references(self) -> Iterator[tuple[Operator, str]]:
  448. yield from self.arg.iter_references()
  449. def map(self, f: Callable[[Any], Any]) -> MapXComArg:
  450. # Flatten arg.map(f1).map(f2) into one MapXComArg.
  451. return MapXComArg(self.arg, [*self.callables, f])
  452. def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
  453. return self.arg.get_task_map_length(run_id, session=session)
  454. @provide_session
  455. def resolve(self, context: Context, session: Session = NEW_SESSION, *, include_xcom: bool = True) -> Any:
  456. value = self.arg.resolve(context, session=session, include_xcom=include_xcom)
  457. if not isinstance(value, (Sequence, dict)):
  458. raise ValueError(f"XCom map expects sequence or dict, not {type(value).__name__}")
  459. return _MapResult(value, self.callables)
  460. class _ZipResult(Sequence):
  461. def __init__(self, values: Sequence[Sequence | dict], *, fillvalue: Any = NOTSET) -> None:
  462. self.values = values
  463. self.fillvalue = fillvalue
  464. @staticmethod
  465. def _get_or_fill(container: Sequence | dict, index: Any, fillvalue: Any) -> Any:
  466. try:
  467. return container[index]
  468. except (IndexError, KeyError):
  469. return fillvalue
  470. def __getitem__(self, index: Any) -> Any:
  471. if index >= len(self):
  472. raise IndexError(index)
  473. return tuple(self._get_or_fill(value, index, self.fillvalue) for value in self.values)
  474. def __len__(self) -> int:
  475. lengths = (len(v) for v in self.values)
  476. if isinstance(self.fillvalue, ArgNotSet):
  477. return min(lengths)
  478. return max(lengths)
  479. class ZipXComArg(XComArg):
  480. """
  481. An XCom reference with ``zip()`` applied.
  482. This is constructed from multiple XComArg instances, and presents an
  483. iterable that "zips" them together like the built-in ``zip()`` (and
  484. ``itertools.zip_longest()`` if ``fillvalue`` is provided).
  485. """
  486. def __init__(self, args: Sequence[XComArg], *, fillvalue: Any = NOTSET) -> None:
  487. if not args:
  488. raise ValueError("At least one input is required")
  489. self.args = args
  490. self.fillvalue = fillvalue
  491. def __repr__(self) -> str:
  492. args_iter = iter(self.args)
  493. first = repr(next(args_iter))
  494. rest = ", ".join(repr(arg) for arg in args_iter)
  495. if isinstance(self.fillvalue, ArgNotSet):
  496. return f"{first}.zip({rest})"
  497. return f"{first}.zip({rest}, fillvalue={self.fillvalue!r})"
  498. def _serialize(self) -> dict[str, Any]:
  499. args = [serialize_xcom_arg(arg) for arg in self.args]
  500. if isinstance(self.fillvalue, ArgNotSet):
  501. return {"args": args}
  502. return {"args": args, "fillvalue": self.fillvalue}
  503. @classmethod
  504. def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg:
  505. return cls(
  506. [deserialize_xcom_arg(arg, dag) for arg in data["args"]],
  507. fillvalue=data.get("fillvalue", NOTSET),
  508. )
  509. def iter_references(self) -> Iterator[tuple[Operator, str]]:
  510. for arg in self.args:
  511. yield from arg.iter_references()
  512. def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
  513. all_lengths = (arg.get_task_map_length(run_id, session=session) for arg in self.args)
  514. ready_lengths = [length for length in all_lengths if length is not None]
  515. if len(ready_lengths) != len(self.args):
  516. return None # If any of the referenced XComs is not ready, we are not ready either.
  517. if isinstance(self.fillvalue, ArgNotSet):
  518. return min(ready_lengths)
  519. return max(ready_lengths)
  520. @provide_session
  521. def resolve(self, context: Context, session: Session = NEW_SESSION, *, include_xcom: bool = True) -> Any:
  522. values = [arg.resolve(context, session=session, include_xcom=include_xcom) for arg in self.args]
  523. for value in values:
  524. if not isinstance(value, (Sequence, dict)):
  525. raise ValueError(f"XCom zip expects sequence or dict, not {type(value).__name__}")
  526. return _ZipResult(values, fillvalue=self.fillvalue)
  527. class _ConcatResult(Sequence):
  528. def __init__(self, values: Sequence[Sequence | dict]) -> None:
  529. self.values = values
  530. def __getitem__(self, index: Any) -> Any:
  531. if index >= 0:
  532. i = index
  533. else:
  534. i = len(self) + index
  535. for value in self.values:
  536. if i < 0:
  537. break
  538. elif i >= (curlen := len(value)):
  539. i -= curlen
  540. elif isinstance(value, Sequence):
  541. return value[i]
  542. else:
  543. return next(itertools.islice(iter(value), i, None))
  544. raise IndexError("list index out of range")
  545. def __len__(self) -> int:
  546. return sum(len(v) for v in self.values)
  547. class ConcatXComArg(XComArg):
  548. """
  549. Concatenating multiple XCom references into one.
  550. This is done by calling ``concat()`` on an XComArg to combine it with
  551. others. The effect is similar to Python's :func:`itertools.chain`, but the
  552. return value also supports index access.
  553. """
  554. def __init__(self, args: Sequence[XComArg]) -> None:
  555. if not args:
  556. raise ValueError("At least one input is required")
  557. self.args = args
  558. def __repr__(self) -> str:
  559. args_iter = iter(self.args)
  560. first = repr(next(args_iter))
  561. rest = ", ".join(repr(arg) for arg in args_iter)
  562. return f"{first}.concat({rest})"
  563. def _serialize(self) -> dict[str, Any]:
  564. return {"args": [serialize_xcom_arg(arg) for arg in self.args]}
  565. @classmethod
  566. def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg:
  567. return cls([deserialize_xcom_arg(arg, dag) for arg in data["args"]])
  568. def iter_references(self) -> Iterator[tuple[Operator, str]]:
  569. for arg in self.args:
  570. yield from arg.iter_references()
  571. def concat(self, *others: XComArg) -> ConcatXComArg:
  572. # Flatten foo.concat(x).concat(y) into one call.
  573. return ConcatXComArg([*self.args, *others])
  574. def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
  575. all_lengths = (arg.get_task_map_length(run_id, session=session) for arg in self.args)
  576. ready_lengths = [length for length in all_lengths if length is not None]
  577. if len(ready_lengths) != len(self.args):
  578. return None # If any of the referenced XComs is not ready, we are not ready either.
  579. return sum(ready_lengths)
  580. @provide_session
  581. def resolve(self, context: Context, session: Session = NEW_SESSION, *, include_xcom: bool = True) -> Any:
  582. values = [arg.resolve(context, session=session, include_xcom=include_xcom) for arg in self.args]
  583. for value in values:
  584. if not isinstance(value, (Sequence, dict)):
  585. raise ValueError(f"XCom concat expects sequence or dict, not {type(value).__name__}")
  586. return _ConcatResult(values)
  587. _XCOM_ARG_TYPES: Mapping[str, type[XComArg]] = {
  588. "": PlainXComArg,
  589. "concat": ConcatXComArg,
  590. "map": MapXComArg,
  591. "zip": ZipXComArg,
  592. }
  593. def serialize_xcom_arg(value: XComArg) -> dict[str, Any]:
  594. """DAG serialization interface."""
  595. key = next(k for k, v in _XCOM_ARG_TYPES.items() if isinstance(value, v))
  596. if key:
  597. return {"type": key, **value._serialize()}
  598. return value._serialize()
  599. def deserialize_xcom_arg(data: dict[str, Any], dag: DAG) -> XComArg:
  600. """DAG serialization interface."""
  601. klass = _XCOM_ARG_TYPES[data.get("type", "")]
  602. return klass._deserialize(data, dag)