123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272 |
- #
- # Licensed to the Apache Software Foundation (ASF) under one
- # or more contributor license agreements. See the NOTICE file
- # distributed with this work for additional information
- # regarding copyright ownership. The ASF licenses this file
- # to you under the Apache License, Version 2.0 (the
- # "License"); you may not use this file except in compliance
- # with the License. You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- # KIND, either express or implied. See the License for the
- # specific language governing permissions and limitations
- # under the License.
- from __future__ import annotations
- import logging
- from datetime import datetime
- from typing import TYPE_CHECKING, Any, Callable, Collection, Mapping, TypeVar
- from airflow import settings
- from airflow.utils.context import Context, lazy_mapping_from_context
- if TYPE_CHECKING:
- from airflow.utils.context import OutletEventAccessors
- R = TypeVar("R")
- DEFAULT_FORMAT_PREFIX = "airflow.ctx."
- ENV_VAR_FORMAT_PREFIX = "AIRFLOW_CTX_"
- AIRFLOW_VAR_NAME_FORMAT_MAPPING = {
- "AIRFLOW_CONTEXT_DAG_ID": {
- "default": f"{DEFAULT_FORMAT_PREFIX}dag_id",
- "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}DAG_ID",
- },
- "AIRFLOW_CONTEXT_TASK_ID": {
- "default": f"{DEFAULT_FORMAT_PREFIX}task_id",
- "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}TASK_ID",
- },
- "AIRFLOW_CONTEXT_EXECUTION_DATE": {
- "default": f"{DEFAULT_FORMAT_PREFIX}execution_date",
- "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}EXECUTION_DATE",
- },
- "AIRFLOW_CONTEXT_TRY_NUMBER": {
- "default": f"{DEFAULT_FORMAT_PREFIX}try_number",
- "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}TRY_NUMBER",
- },
- "AIRFLOW_CONTEXT_DAG_RUN_ID": {
- "default": f"{DEFAULT_FORMAT_PREFIX}dag_run_id",
- "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}DAG_RUN_ID",
- },
- "AIRFLOW_CONTEXT_DAG_OWNER": {
- "default": f"{DEFAULT_FORMAT_PREFIX}dag_owner",
- "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}DAG_OWNER",
- },
- "AIRFLOW_CONTEXT_DAG_EMAIL": {
- "default": f"{DEFAULT_FORMAT_PREFIX}dag_email",
- "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}DAG_EMAIL",
- },
- }
- def context_to_airflow_vars(context: Mapping[str, Any], in_env_var_format: bool = False) -> dict[str, str]:
- """
- Return values used to externally reconstruct relations between dags, dag_runs, tasks and task_instances.
- Given a context, this function provides a dictionary of values that can be used to
- externally reconstruct relations between dags, dag_runs, tasks and task_instances.
- Default to abc.def.ghi format and can be made to ABC_DEF_GHI format if
- in_env_var_format is set to True.
- :param context: The context for the task_instance of interest.
- :param in_env_var_format: If returned vars should be in ABC_DEF_GHI format.
- :return: task_instance context as dict.
- """
- params = {}
- if in_env_var_format:
- name_format = "env_var_format"
- else:
- name_format = "default"
- task = context.get("task")
- task_instance = context.get("task_instance")
- dag_run = context.get("dag_run")
- ops = [
- (task, "email", "AIRFLOW_CONTEXT_DAG_EMAIL"),
- (task, "owner", "AIRFLOW_CONTEXT_DAG_OWNER"),
- (task_instance, "dag_id", "AIRFLOW_CONTEXT_DAG_ID"),
- (task_instance, "task_id", "AIRFLOW_CONTEXT_TASK_ID"),
- (task_instance, "execution_date", "AIRFLOW_CONTEXT_EXECUTION_DATE"),
- (task_instance, "try_number", "AIRFLOW_CONTEXT_TRY_NUMBER"),
- (dag_run, "run_id", "AIRFLOW_CONTEXT_DAG_RUN_ID"),
- ]
- context_params = settings.get_airflow_context_vars(context)
- for key, value in context_params.items():
- if not isinstance(key, str):
- raise TypeError(f"key <{key}> must be string")
- if not isinstance(value, str):
- raise TypeError(f"value of key <{key}> must be string, not {type(value)}")
- if in_env_var_format:
- if not key.startswith(ENV_VAR_FORMAT_PREFIX):
- key = ENV_VAR_FORMAT_PREFIX + key.upper()
- else:
- if not key.startswith(DEFAULT_FORMAT_PREFIX):
- key = DEFAULT_FORMAT_PREFIX + key
- params[key] = value
- for subject, attr, mapping_key in ops:
- _attr = getattr(subject, attr, None)
- if subject and _attr:
- mapping_value = AIRFLOW_VAR_NAME_FORMAT_MAPPING[mapping_key][name_format]
- if isinstance(_attr, str):
- params[mapping_value] = _attr
- elif isinstance(_attr, datetime):
- params[mapping_value] = _attr.isoformat()
- elif isinstance(_attr, list):
- # os env variable value needs to be string
- params[mapping_value] = ",".join(_attr)
- else:
- params[mapping_value] = str(_attr)
- return params
- class KeywordParameters:
- """
- Wrapper representing ``**kwargs`` to a callable.
- The actual ``kwargs`` can be obtained by calling either ``unpacking()`` or
- ``serializing()``. They behave almost the same and are only different if
- the containing ``kwargs`` is an Airflow Context object, and the calling
- function uses ``**kwargs`` in the argument list.
- In this particular case, ``unpacking()`` uses ``lazy-object-proxy`` to
- prevent the Context from emitting deprecation warnings too eagerly when it's
- unpacked by ``**``. ``serializing()`` does not do this, and will allow the
- warnings to be emitted eagerly, which is useful when you want to dump the
- content and use it somewhere else without needing ``lazy-object-proxy``.
- """
- def __init__(self, kwargs: Mapping[str, Any], *, wildcard: bool) -> None:
- self._kwargs = kwargs
- self._wildcard = wildcard
- @classmethod
- def determine(
- cls,
- func: Callable[..., Any],
- args: Collection[Any],
- kwargs: Mapping[str, Any],
- ) -> KeywordParameters:
- import inspect
- import itertools
- signature = inspect.signature(func)
- has_wildcard_kwargs = any(p.kind == p.VAR_KEYWORD for p in signature.parameters.values())
- for name in itertools.islice(signature.parameters.keys(), len(args)):
- # Check if args conflict with names in kwargs.
- if name in kwargs:
- raise ValueError(f"The key {name!r} in args is a part of kwargs and therefore reserved.")
- if has_wildcard_kwargs:
- # If the callable has a **kwargs argument, it's ready to accept all the kwargs.
- return cls(kwargs, wildcard=True)
- # If the callable has no **kwargs argument, it only wants the arguments it requested.
- kwargs = {key: kwargs[key] for key in signature.parameters if key in kwargs}
- return cls(kwargs, wildcard=False)
- def unpacking(self) -> Mapping[str, Any]:
- """Dump the kwargs mapping to unpack with ``**`` in a function call."""
- if self._wildcard and isinstance(self._kwargs, Context): # type: ignore[misc]
- return lazy_mapping_from_context(self._kwargs)
- return self._kwargs
- def serializing(self) -> Mapping[str, Any]:
- """Dump the kwargs mapping for serialization purposes."""
- return self._kwargs
- def determine_kwargs(
- func: Callable[..., Any],
- args: Collection[Any],
- kwargs: Mapping[str, Any],
- ) -> Mapping[str, Any]:
- """
- Inspect the signature of a callable to determine which kwargs need to be passed to the callable.
- :param func: The callable that you want to invoke
- :param args: The positional arguments that need to be passed to the callable, so we know how many to skip.
- :param kwargs: The keyword arguments that need to be filtered before passing to the callable.
- :return: A dictionary which contains the keyword arguments that are compatible with the callable.
- """
- return KeywordParameters.determine(func, args, kwargs).unpacking()
- def make_kwargs_callable(func: Callable[..., R]) -> Callable[..., R]:
- """
- Create a new callable that only forwards necessary arguments from any provided input.
- Make a new callable that can accept any number of positional or keyword arguments
- but only forwards those required by the given callable func.
- """
- import functools
- @functools.wraps(func)
- def kwargs_func(*args, **kwargs):
- kwargs = determine_kwargs(func, args, kwargs)
- return func(*args, **kwargs)
- return kwargs_func
- class ExecutionCallableRunner:
- """
- Run an execution callable against a task context and given arguments.
- If the callable is a simple function, this simply calls it with the supplied
- arguments (including the context). If the callable is a generator function,
- the generator is exhausted here, with the yielded values getting fed back
- into the task context automatically for execution.
- :meta private:
- """
- def __init__(
- self,
- func: Callable,
- outlet_events: OutletEventAccessors,
- *,
- logger: logging.Logger | None,
- ) -> None:
- self.func = func
- self.outlet_events = outlet_events
- self.logger = logger or logging.getLogger(__name__)
- def run(self, *args, **kwargs) -> Any:
- import inspect
- from airflow.datasets.metadata import Metadata
- from airflow.utils.types import NOTSET
- if not inspect.isgeneratorfunction(self.func):
- return self.func(*args, **kwargs)
- result: Any = NOTSET
- def _run():
- nonlocal result
- result = yield from self.func(*args, **kwargs)
- for metadata in _run():
- if isinstance(metadata, Metadata):
- self.outlet_events[metadata.uri].extra.update(metadata.extra)
- if metadata.alias_name:
- self.outlet_events[metadata.alias_name].add(metadata.uri, extra=metadata.extra)
- continue
- self.logger.warning("Ignoring unknown data of %r received from task", type(metadata))
- if self.logger.isEnabledFor(logging.DEBUG):
- self.logger.debug("Full yielded value: %r", metadata)
- return result
|