expandinput.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. #
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing,
  13. # software distributed under the License is distributed on an
  14. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. # KIND, either express or implied. See the License for the
  16. # specific language governing permissions and limitations
  17. # under the License.
  18. from __future__ import annotations
  19. import collections.abc
  20. import functools
  21. import operator
  22. from collections.abc import Sized
  23. from typing import TYPE_CHECKING, Any, Dict, Iterable, Mapping, NamedTuple, Sequence, Union
  24. import attr
  25. from airflow.utils.mixins import ResolveMixin
  26. from airflow.utils.session import NEW_SESSION, provide_session
  27. if TYPE_CHECKING:
  28. from sqlalchemy.orm import Session
  29. from airflow.models.operator import Operator
  30. from airflow.models.xcom_arg import XComArg
  31. from airflow.serialization.serialized_objects import _ExpandInputRef
  32. from airflow.typing_compat import TypeGuard
  33. from airflow.utils.context import Context
  34. ExpandInput = Union["DictOfListsExpandInput", "ListOfDictsExpandInput"]
  35. # Each keyword argument to expand() can be an XComArg, sequence, or dict (not
  36. # any mapping since we need the value to be ordered).
  37. OperatorExpandArgument = Union["MappedArgument", "XComArg", Sequence, Dict[str, Any]]
  38. # The single argument of expand_kwargs() can be an XComArg, or a list with each
  39. # element being either an XComArg or a dict.
  40. OperatorExpandKwargsArgument = Union["XComArg", Sequence[Union["XComArg", Mapping[str, Any]]]]
  41. @attr.define(kw_only=True)
  42. class MappedArgument(ResolveMixin):
  43. """
  44. Stand-in stub for task-group-mapping arguments.
  45. This is very similar to an XComArg, but resolved differently. Declared here
  46. (instead of in the task group module) to avoid import cycles.
  47. """
  48. _input: ExpandInput
  49. _key: str
  50. def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
  51. # TODO (AIP-42): Implement run-time task map length inspection. This is
  52. # needed when we implement task mapping inside a mapped task group.
  53. raise NotImplementedError()
  54. def iter_references(self) -> Iterable[tuple[Operator, str]]:
  55. yield from self._input.iter_references()
  56. @provide_session
  57. def resolve(self, context: Context, *, include_xcom: bool = True, session: Session = NEW_SESSION) -> Any:
  58. data, _ = self._input.resolve(context, session=session, include_xcom=include_xcom)
  59. return data[self._key]
  60. # To replace tedious isinstance() checks.
  61. def is_mappable(v: Any) -> TypeGuard[OperatorExpandArgument]:
  62. from airflow.models.xcom_arg import XComArg
  63. return isinstance(v, (MappedArgument, XComArg, Mapping, Sequence)) and not isinstance(v, str)
  64. # To replace tedious isinstance() checks.
  65. def _is_parse_time_mappable(v: OperatorExpandArgument) -> TypeGuard[Mapping | Sequence]:
  66. from airflow.models.xcom_arg import XComArg
  67. return not isinstance(v, (MappedArgument, XComArg))
  68. # To replace tedious isinstance() checks.
  69. def _needs_run_time_resolution(v: OperatorExpandArgument) -> TypeGuard[MappedArgument | XComArg]:
  70. from airflow.models.xcom_arg import XComArg
  71. return isinstance(v, (MappedArgument, XComArg))
  72. class NotFullyPopulated(RuntimeError):
  73. """
  74. Raise when ``get_map_lengths`` cannot populate all mapping metadata.
  75. This is generally due to not all upstream tasks have finished when the
  76. function is called.
  77. """
  78. def __init__(self, missing: set[str]) -> None:
  79. self.missing = missing
  80. def __str__(self) -> str:
  81. keys = ", ".join(repr(k) for k in sorted(self.missing))
  82. return f"Failed to populate all mapping metadata; missing: {keys}"
  83. class DictOfListsExpandInput(NamedTuple):
  84. """
  85. Storage type of a mapped operator's mapped kwargs.
  86. This is created from ``expand(**kwargs)``.
  87. """
  88. value: dict[str, OperatorExpandArgument]
  89. def _iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]:
  90. """Generate kwargs with values available on parse-time."""
  91. return ((k, v) for k, v in self.value.items() if _is_parse_time_mappable(v))
  92. def get_parse_time_mapped_ti_count(self) -> int:
  93. if not self.value:
  94. return 0
  95. literal_values = [len(v) for _, v in self._iter_parse_time_resolved_kwargs()]
  96. if len(literal_values) != len(self.value):
  97. literal_keys = (k for k, _ in self._iter_parse_time_resolved_kwargs())
  98. raise NotFullyPopulated(set(self.value).difference(literal_keys))
  99. return functools.reduce(operator.mul, literal_values, 1)
  100. def _get_map_lengths(self, run_id: str, *, session: Session) -> dict[str, int]:
  101. """
  102. Return dict of argument name to map length.
  103. If any arguments are not known right now (upstream task not finished),
  104. they will not be present in the dict.
  105. """
  106. # TODO: This initiates one database call for each XComArg. Would it be
  107. # more efficient to do one single db call and unpack the value here?
  108. def _get_length(v: OperatorExpandArgument) -> int | None:
  109. if _needs_run_time_resolution(v):
  110. return v.get_task_map_length(run_id, session=session)
  111. # Unfortunately a user-defined TypeGuard cannot apply negative type
  112. # narrowing. https://github.com/python/typing/discussions/1013
  113. if TYPE_CHECKING:
  114. assert isinstance(v, Sized)
  115. return len(v)
  116. map_lengths_iterator = ((k, _get_length(v)) for k, v in self.value.items())
  117. map_lengths = {k: v for k, v in map_lengths_iterator if v is not None}
  118. if len(map_lengths) < len(self.value):
  119. raise NotFullyPopulated(set(self.value).difference(map_lengths))
  120. return map_lengths
  121. def get_total_map_length(self, run_id: str, *, session: Session) -> int:
  122. if not self.value:
  123. return 0
  124. lengths = self._get_map_lengths(run_id, session=session)
  125. return functools.reduce(operator.mul, (lengths[name] for name in self.value), 1)
  126. def _expand_mapped_field(
  127. self, key: str, value: Any, context: Context, *, session: Session, include_xcom: bool
  128. ) -> Any:
  129. if _needs_run_time_resolution(value):
  130. value = (
  131. value.resolve(context, session=session, include_xcom=include_xcom)
  132. if include_xcom
  133. else str(value)
  134. )
  135. map_index = context["ti"].map_index
  136. if map_index < 0:
  137. raise RuntimeError("can't resolve task-mapping argument without expanding")
  138. all_lengths = self._get_map_lengths(context["run_id"], session=session)
  139. def _find_index_for_this_field(index: int) -> int:
  140. # Need to use the original user input to retain argument order.
  141. for mapped_key in reversed(self.value):
  142. mapped_length = all_lengths[mapped_key]
  143. if mapped_length < 1:
  144. raise RuntimeError(f"cannot expand field mapped to length {mapped_length!r}")
  145. if mapped_key == key:
  146. return index % mapped_length
  147. index //= mapped_length
  148. return -1
  149. found_index = _find_index_for_this_field(map_index)
  150. if found_index < 0:
  151. return value
  152. if isinstance(value, collections.abc.Sequence):
  153. return value[found_index]
  154. if not isinstance(value, dict):
  155. raise TypeError(f"can't map over value of type {type(value)}")
  156. for i, (k, v) in enumerate(value.items()):
  157. if i == found_index:
  158. return k, v
  159. raise IndexError(f"index {map_index} is over mapped length")
  160. def iter_references(self) -> Iterable[tuple[Operator, str]]:
  161. from airflow.models.xcom_arg import XComArg
  162. for x in self.value.values():
  163. if isinstance(x, XComArg):
  164. yield from x.iter_references()
  165. def resolve(
  166. self, context: Context, session: Session, *, include_xcom: bool = True
  167. ) -> tuple[Mapping[str, Any], set[int]]:
  168. data = {
  169. k: self._expand_mapped_field(k, v, context, session=session, include_xcom=include_xcom)
  170. for k, v in self.value.items()
  171. }
  172. literal_keys = {k for k, _ in self._iter_parse_time_resolved_kwargs()}
  173. resolved_oids = {id(v) for k, v in data.items() if k not in literal_keys}
  174. return data, resolved_oids
  175. def _describe_type(value: Any) -> str:
  176. if value is None:
  177. return "None"
  178. return type(value).__name__
  179. class ListOfDictsExpandInput(NamedTuple):
  180. """
  181. Storage type of a mapped operator's mapped kwargs.
  182. This is created from ``expand_kwargs(xcom_arg)``.
  183. """
  184. value: OperatorExpandKwargsArgument
  185. def get_parse_time_mapped_ti_count(self) -> int:
  186. if isinstance(self.value, collections.abc.Sized):
  187. return len(self.value)
  188. raise NotFullyPopulated({"expand_kwargs() argument"})
  189. def get_total_map_length(self, run_id: str, *, session: Session) -> int:
  190. if isinstance(self.value, collections.abc.Sized):
  191. return len(self.value)
  192. length = self.value.get_task_map_length(run_id, session=session)
  193. if length is None:
  194. raise NotFullyPopulated({"expand_kwargs() argument"})
  195. return length
  196. def iter_references(self) -> Iterable[tuple[Operator, str]]:
  197. from airflow.models.xcom_arg import XComArg
  198. if isinstance(self.value, XComArg):
  199. yield from self.value.iter_references()
  200. else:
  201. for x in self.value:
  202. if isinstance(x, XComArg):
  203. yield from x.iter_references()
  204. def resolve(
  205. self, context: Context, session: Session, *, include_xcom: bool = True
  206. ) -> tuple[Mapping[str, Any], set[int]]:
  207. map_index = context["ti"].map_index
  208. if map_index < 0:
  209. raise RuntimeError("can't resolve task-mapping argument without expanding")
  210. mapping: Any
  211. if isinstance(self.value, collections.abc.Sized):
  212. mapping = self.value[map_index]
  213. if not isinstance(mapping, collections.abc.Mapping):
  214. mapping = mapping.resolve(context, session, include_xcom=include_xcom)
  215. elif include_xcom:
  216. mappings = self.value.resolve(context, session, include_xcom=include_xcom)
  217. if not isinstance(mappings, collections.abc.Sequence):
  218. raise ValueError(f"expand_kwargs() expects a list[dict], not {_describe_type(mappings)}")
  219. mapping = mappings[map_index]
  220. if not isinstance(mapping, collections.abc.Mapping):
  221. raise ValueError(f"expand_kwargs() expects a list[dict], not list[{_describe_type(mapping)}]")
  222. for key in mapping:
  223. if not isinstance(key, str):
  224. raise ValueError(
  225. f"expand_kwargs() input dict keys must all be str, "
  226. f"but {key!r} is of type {_describe_type(key)}"
  227. )
  228. # filter out parse time resolved values from the resolved_oids
  229. resolved_oids = {id(v) for k, v in mapping.items() if not _is_parse_time_mappable(v)}
  230. return mapping, resolved_oids
  231. EXPAND_INPUT_EMPTY = DictOfListsExpandInput({}) # Sentinel value.
  232. _EXPAND_INPUT_TYPES = {
  233. "dict-of-lists": DictOfListsExpandInput,
  234. "list-of-dicts": ListOfDictsExpandInput,
  235. }
  236. def get_map_type_key(expand_input: ExpandInput | _ExpandInputRef) -> str:
  237. from airflow.serialization.serialized_objects import _ExpandInputRef
  238. if isinstance(expand_input, _ExpandInputRef):
  239. return expand_input.key
  240. return next(k for k, v in _EXPAND_INPUT_TYPES.items() if isinstance(expand_input, v))
  241. def create_expand_input(kind: str, value: Any) -> ExpandInput:
  242. return _EXPAND_INPUT_TYPES[kind](value)