operator_helpers.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. #
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing,
  13. # software distributed under the License is distributed on an
  14. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. # KIND, either express or implied. See the License for the
  16. # specific language governing permissions and limitations
  17. # under the License.
  18. from __future__ import annotations
  19. import logging
  20. from datetime import datetime
  21. from typing import TYPE_CHECKING, Any, Callable, Collection, Mapping, TypeVar
  22. from airflow import settings
  23. from airflow.utils.context import Context, lazy_mapping_from_context
  24. if TYPE_CHECKING:
  25. from airflow.utils.context import OutletEventAccessors
  26. R = TypeVar("R")
  27. DEFAULT_FORMAT_PREFIX = "airflow.ctx."
  28. ENV_VAR_FORMAT_PREFIX = "AIRFLOW_CTX_"
  29. AIRFLOW_VAR_NAME_FORMAT_MAPPING = {
  30. "AIRFLOW_CONTEXT_DAG_ID": {
  31. "default": f"{DEFAULT_FORMAT_PREFIX}dag_id",
  32. "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}DAG_ID",
  33. },
  34. "AIRFLOW_CONTEXT_TASK_ID": {
  35. "default": f"{DEFAULT_FORMAT_PREFIX}task_id",
  36. "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}TASK_ID",
  37. },
  38. "AIRFLOW_CONTEXT_EXECUTION_DATE": {
  39. "default": f"{DEFAULT_FORMAT_PREFIX}execution_date",
  40. "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}EXECUTION_DATE",
  41. },
  42. "AIRFLOW_CONTEXT_TRY_NUMBER": {
  43. "default": f"{DEFAULT_FORMAT_PREFIX}try_number",
  44. "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}TRY_NUMBER",
  45. },
  46. "AIRFLOW_CONTEXT_DAG_RUN_ID": {
  47. "default": f"{DEFAULT_FORMAT_PREFIX}dag_run_id",
  48. "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}DAG_RUN_ID",
  49. },
  50. "AIRFLOW_CONTEXT_DAG_OWNER": {
  51. "default": f"{DEFAULT_FORMAT_PREFIX}dag_owner",
  52. "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}DAG_OWNER",
  53. },
  54. "AIRFLOW_CONTEXT_DAG_EMAIL": {
  55. "default": f"{DEFAULT_FORMAT_PREFIX}dag_email",
  56. "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}DAG_EMAIL",
  57. },
  58. }
  59. def context_to_airflow_vars(context: Mapping[str, Any], in_env_var_format: bool = False) -> dict[str, str]:
  60. """
  61. Return values used to externally reconstruct relations between dags, dag_runs, tasks and task_instances.
  62. Given a context, this function provides a dictionary of values that can be used to
  63. externally reconstruct relations between dags, dag_runs, tasks and task_instances.
  64. Default to abc.def.ghi format and can be made to ABC_DEF_GHI format if
  65. in_env_var_format is set to True.
  66. :param context: The context for the task_instance of interest.
  67. :param in_env_var_format: If returned vars should be in ABC_DEF_GHI format.
  68. :return: task_instance context as dict.
  69. """
  70. params = {}
  71. if in_env_var_format:
  72. name_format = "env_var_format"
  73. else:
  74. name_format = "default"
  75. task = context.get("task")
  76. task_instance = context.get("task_instance")
  77. dag_run = context.get("dag_run")
  78. ops = [
  79. (task, "email", "AIRFLOW_CONTEXT_DAG_EMAIL"),
  80. (task, "owner", "AIRFLOW_CONTEXT_DAG_OWNER"),
  81. (task_instance, "dag_id", "AIRFLOW_CONTEXT_DAG_ID"),
  82. (task_instance, "task_id", "AIRFLOW_CONTEXT_TASK_ID"),
  83. (task_instance, "execution_date", "AIRFLOW_CONTEXT_EXECUTION_DATE"),
  84. (task_instance, "try_number", "AIRFLOW_CONTEXT_TRY_NUMBER"),
  85. (dag_run, "run_id", "AIRFLOW_CONTEXT_DAG_RUN_ID"),
  86. ]
  87. context_params = settings.get_airflow_context_vars(context)
  88. for key, value in context_params.items():
  89. if not isinstance(key, str):
  90. raise TypeError(f"key <{key}> must be string")
  91. if not isinstance(value, str):
  92. raise TypeError(f"value of key <{key}> must be string, not {type(value)}")
  93. if in_env_var_format:
  94. if not key.startswith(ENV_VAR_FORMAT_PREFIX):
  95. key = ENV_VAR_FORMAT_PREFIX + key.upper()
  96. else:
  97. if not key.startswith(DEFAULT_FORMAT_PREFIX):
  98. key = DEFAULT_FORMAT_PREFIX + key
  99. params[key] = value
  100. for subject, attr, mapping_key in ops:
  101. _attr = getattr(subject, attr, None)
  102. if subject and _attr:
  103. mapping_value = AIRFLOW_VAR_NAME_FORMAT_MAPPING[mapping_key][name_format]
  104. if isinstance(_attr, str):
  105. params[mapping_value] = _attr
  106. elif isinstance(_attr, datetime):
  107. params[mapping_value] = _attr.isoformat()
  108. elif isinstance(_attr, list):
  109. # os env variable value needs to be string
  110. params[mapping_value] = ",".join(_attr)
  111. else:
  112. params[mapping_value] = str(_attr)
  113. return params
  114. class KeywordParameters:
  115. """
  116. Wrapper representing ``**kwargs`` to a callable.
  117. The actual ``kwargs`` can be obtained by calling either ``unpacking()`` or
  118. ``serializing()``. They behave almost the same and are only different if
  119. the containing ``kwargs`` is an Airflow Context object, and the calling
  120. function uses ``**kwargs`` in the argument list.
  121. In this particular case, ``unpacking()`` uses ``lazy-object-proxy`` to
  122. prevent the Context from emitting deprecation warnings too eagerly when it's
  123. unpacked by ``**``. ``serializing()`` does not do this, and will allow the
  124. warnings to be emitted eagerly, which is useful when you want to dump the
  125. content and use it somewhere else without needing ``lazy-object-proxy``.
  126. """
  127. def __init__(self, kwargs: Mapping[str, Any], *, wildcard: bool) -> None:
  128. self._kwargs = kwargs
  129. self._wildcard = wildcard
  130. @classmethod
  131. def determine(
  132. cls,
  133. func: Callable[..., Any],
  134. args: Collection[Any],
  135. kwargs: Mapping[str, Any],
  136. ) -> KeywordParameters:
  137. import inspect
  138. import itertools
  139. signature = inspect.signature(func)
  140. has_wildcard_kwargs = any(p.kind == p.VAR_KEYWORD for p in signature.parameters.values())
  141. for name in itertools.islice(signature.parameters.keys(), len(args)):
  142. # Check if args conflict with names in kwargs.
  143. if name in kwargs:
  144. raise ValueError(f"The key {name!r} in args is a part of kwargs and therefore reserved.")
  145. if has_wildcard_kwargs:
  146. # If the callable has a **kwargs argument, it's ready to accept all the kwargs.
  147. return cls(kwargs, wildcard=True)
  148. # If the callable has no **kwargs argument, it only wants the arguments it requested.
  149. kwargs = {key: kwargs[key] for key in signature.parameters if key in kwargs}
  150. return cls(kwargs, wildcard=False)
  151. def unpacking(self) -> Mapping[str, Any]:
  152. """Dump the kwargs mapping to unpack with ``**`` in a function call."""
  153. if self._wildcard and isinstance(self._kwargs, Context): # type: ignore[misc]
  154. return lazy_mapping_from_context(self._kwargs)
  155. return self._kwargs
  156. def serializing(self) -> Mapping[str, Any]:
  157. """Dump the kwargs mapping for serialization purposes."""
  158. return self._kwargs
  159. def determine_kwargs(
  160. func: Callable[..., Any],
  161. args: Collection[Any],
  162. kwargs: Mapping[str, Any],
  163. ) -> Mapping[str, Any]:
  164. """
  165. Inspect the signature of a callable to determine which kwargs need to be passed to the callable.
  166. :param func: The callable that you want to invoke
  167. :param args: The positional arguments that need to be passed to the callable, so we know how many to skip.
  168. :param kwargs: The keyword arguments that need to be filtered before passing to the callable.
  169. :return: A dictionary which contains the keyword arguments that are compatible with the callable.
  170. """
  171. return KeywordParameters.determine(func, args, kwargs).unpacking()
  172. def make_kwargs_callable(func: Callable[..., R]) -> Callable[..., R]:
  173. """
  174. Create a new callable that only forwards necessary arguments from any provided input.
  175. Make a new callable that can accept any number of positional or keyword arguments
  176. but only forwards those required by the given callable func.
  177. """
  178. import functools
  179. @functools.wraps(func)
  180. def kwargs_func(*args, **kwargs):
  181. kwargs = determine_kwargs(func, args, kwargs)
  182. return func(*args, **kwargs)
  183. return kwargs_func
  184. class ExecutionCallableRunner:
  185. """
  186. Run an execution callable against a task context and given arguments.
  187. If the callable is a simple function, this simply calls it with the supplied
  188. arguments (including the context). If the callable is a generator function,
  189. the generator is exhausted here, with the yielded values getting fed back
  190. into the task context automatically for execution.
  191. :meta private:
  192. """
  193. def __init__(
  194. self,
  195. func: Callable,
  196. outlet_events: OutletEventAccessors,
  197. *,
  198. logger: logging.Logger | None,
  199. ) -> None:
  200. self.func = func
  201. self.outlet_events = outlet_events
  202. self.logger = logger or logging.getLogger(__name__)
  203. def run(self, *args, **kwargs) -> Any:
  204. import inspect
  205. from airflow.datasets.metadata import Metadata
  206. from airflow.utils.types import NOTSET
  207. if not inspect.isgeneratorfunction(self.func):
  208. return self.func(*args, **kwargs)
  209. result: Any = NOTSET
  210. def _run():
  211. nonlocal result
  212. result = yield from self.func(*args, **kwargs)
  213. for metadata in _run():
  214. if isinstance(metadata, Metadata):
  215. self.outlet_events[metadata.uri].extra.update(metadata.extra)
  216. if metadata.alias_name:
  217. self.outlet_events[metadata.alias_name].add(metadata.uri, extra=metadata.extra)
  218. continue
  219. self.logger.warning("Ignoring unknown data of %r received from task", type(metadata))
  220. if self.logger.isEnabledFor(logging.DEBUG):
  221. self.logger.debug("Full yielded value: %r", metadata)
  222. return result