| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158 |
- #
- # Licensed to the Apache Software Foundation (ASF) under one
- # or more contributor license agreements. See the NOTICE file
- # distributed with this work for additional information
- # regarding copyright ownership. The ASF licenses this file
- # to you under the Apache License, Version 2.0 (the
- # "License"); you may not use this file except in compliance
- # with the License. You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- # KIND, either express or implied. See the License for the
- # specific language governing permissions and limitations
- # under the License.
- """Provides lineage support functions."""
- from __future__ import annotations
- import logging
- from functools import wraps
- from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast
- from airflow.configuration import conf
- from airflow.lineage.backend import LineageBackend
- from airflow.utils.session import create_session
- if TYPE_CHECKING:
- from airflow.utils.context import Context
- PIPELINE_OUTLETS = "pipeline_outlets"
- PIPELINE_INLETS = "pipeline_inlets"
- AUTO = "auto"
- log = logging.getLogger(__name__)
- def get_backend() -> LineageBackend | None:
- """Get the lineage backend if defined in the configs."""
- clazz = conf.getimport("lineage", "backend", fallback=None)
- if clazz:
- if not issubclass(clazz, LineageBackend):
- raise TypeError(
- f"Your custom Lineage class `{clazz.__name__}` "
- f"is not a subclass of `{LineageBackend.__name__}`."
- )
- else:
- return clazz()
- return None
- def _render_object(obj: Any, context: Context) -> dict:
- ti = context["ti"]
- if TYPE_CHECKING:
- assert ti.task
- return ti.task.render_template(obj, context)
- T = TypeVar("T", bound=Callable)
- def apply_lineage(func: T) -> T:
- """
- Conditionally send lineage to the backend.
- Saves the lineage to XCom and if configured to do so sends it
- to the backend.
- """
- _backend = get_backend()
- @wraps(func)
- def wrapper(self, context, *args, **kwargs):
- self.log.debug("Lineage called with inlets: %s, outlets: %s", self.inlets, self.outlets)
- ret_val = func(self, context, *args, **kwargs)
- outlets = list(self.outlets)
- inlets = list(self.inlets)
- if outlets:
- self.xcom_push(context, key=PIPELINE_OUTLETS, value=outlets)
- if inlets:
- self.xcom_push(context, key=PIPELINE_INLETS, value=inlets)
- if _backend:
- _backend.send_lineage(operator=self, inlets=self.inlets, outlets=self.outlets, context=context)
- return ret_val
- return cast(T, wrapper)
- def prepare_lineage(func: T) -> T:
- """
- Prepare the lineage inlets and outlets.
- Inlets can be:
- * "auto" -> picks up any outlets from direct upstream tasks that have outlets defined, as such that
- if A -> B -> C and B does not have outlets but A does, these are provided as inlets.
- * "list of task_ids" -> picks up outlets from the upstream task_ids
- * "list of datasets" -> manually defined list of data
- """
- @wraps(func)
- def wrapper(self, context, *args, **kwargs):
- from airflow.models.abstractoperator import AbstractOperator
- self.log.debug("Preparing lineage inlets and outlets")
- if isinstance(self.inlets, (str, AbstractOperator)):
- self.inlets = [self.inlets]
- if self.inlets and isinstance(self.inlets, list):
- # get task_ids that are specified as parameter and make sure they are upstream
- task_ids = {o for o in self.inlets if isinstance(o, str)}.union(
- op.task_id for op in self.inlets if isinstance(op, AbstractOperator)
- ).intersection(self.get_flat_relative_ids(upstream=True))
- # pick up unique direct upstream task_ids if AUTO is specified
- if AUTO.upper() in self.inlets or AUTO.lower() in self.inlets:
- task_ids = task_ids.union(task_ids.symmetric_difference(self.upstream_task_ids))
- # Remove auto and task_ids
- self.inlets = [i for i in self.inlets if not isinstance(i, str)]
- # We manually create a session here since xcom_pull returns a
- # LazySelectSequence proxy. If we do not pass a session, a new one
- # will be created, but that session will not be properly closed.
- # After we are done iterating, we can safely close this session.
- with create_session() as session:
- _inlets = self.xcom_pull(
- context, task_ids=task_ids, dag_id=self.dag_id, key=PIPELINE_OUTLETS, session=session
- )
- self.inlets.extend(i for it in _inlets for i in it)
- elif self.inlets:
- raise AttributeError("inlets is not a list, operator, string or attr annotated object")
- if not isinstance(self.outlets, list):
- self.outlets = [self.outlets]
- # render inlets and outlets
- self.inlets = [_render_object(i, context) for i in self.inlets]
- self.outlets = [_render_object(i, context) for i in self.outlets]
- self.log.debug("inlets: %s, outlets: %s", self.inlets, self.outlets)
- return func(self, context, *args, **kwargs)
- return cast(T, wrapper)
|