__init__.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. #
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing,
  13. # software distributed under the License is distributed on an
  14. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. # KIND, either express or implied. See the License for the
  16. # specific language governing permissions and limitations
  17. # under the License.
  18. """Provides lineage support functions."""
  19. from __future__ import annotations
  20. import logging
  21. from functools import wraps
  22. from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast
  23. from airflow.configuration import conf
  24. from airflow.lineage.backend import LineageBackend
  25. from airflow.utils.session import create_session
  26. if TYPE_CHECKING:
  27. from airflow.utils.context import Context
  28. PIPELINE_OUTLETS = "pipeline_outlets"
  29. PIPELINE_INLETS = "pipeline_inlets"
  30. AUTO = "auto"
  31. log = logging.getLogger(__name__)
  32. def get_backend() -> LineageBackend | None:
  33. """Get the lineage backend if defined in the configs."""
  34. clazz = conf.getimport("lineage", "backend", fallback=None)
  35. if clazz:
  36. if not issubclass(clazz, LineageBackend):
  37. raise TypeError(
  38. f"Your custom Lineage class `{clazz.__name__}` "
  39. f"is not a subclass of `{LineageBackend.__name__}`."
  40. )
  41. else:
  42. return clazz()
  43. return None
  44. def _render_object(obj: Any, context: Context) -> dict:
  45. ti = context["ti"]
  46. if TYPE_CHECKING:
  47. assert ti.task
  48. return ti.task.render_template(obj, context)
  49. T = TypeVar("T", bound=Callable)
  50. def apply_lineage(func: T) -> T:
  51. """
  52. Conditionally send lineage to the backend.
  53. Saves the lineage to XCom and if configured to do so sends it
  54. to the backend.
  55. """
  56. _backend = get_backend()
  57. @wraps(func)
  58. def wrapper(self, context, *args, **kwargs):
  59. self.log.debug("Lineage called with inlets: %s, outlets: %s", self.inlets, self.outlets)
  60. ret_val = func(self, context, *args, **kwargs)
  61. outlets = list(self.outlets)
  62. inlets = list(self.inlets)
  63. if outlets:
  64. self.xcom_push(context, key=PIPELINE_OUTLETS, value=outlets)
  65. if inlets:
  66. self.xcom_push(context, key=PIPELINE_INLETS, value=inlets)
  67. if _backend:
  68. _backend.send_lineage(operator=self, inlets=self.inlets, outlets=self.outlets, context=context)
  69. return ret_val
  70. return cast(T, wrapper)
  71. def prepare_lineage(func: T) -> T:
  72. """
  73. Prepare the lineage inlets and outlets.
  74. Inlets can be:
  75. * "auto" -> picks up any outlets from direct upstream tasks that have outlets defined, as such that
  76. if A -> B -> C and B does not have outlets but A does, these are provided as inlets.
  77. * "list of task_ids" -> picks up outlets from the upstream task_ids
  78. * "list of datasets" -> manually defined list of data
  79. """
  80. @wraps(func)
  81. def wrapper(self, context, *args, **kwargs):
  82. from airflow.models.abstractoperator import AbstractOperator
  83. self.log.debug("Preparing lineage inlets and outlets")
  84. if isinstance(self.inlets, (str, AbstractOperator)):
  85. self.inlets = [self.inlets]
  86. if self.inlets and isinstance(self.inlets, list):
  87. # get task_ids that are specified as parameter and make sure they are upstream
  88. task_ids = {o for o in self.inlets if isinstance(o, str)}.union(
  89. op.task_id for op in self.inlets if isinstance(op, AbstractOperator)
  90. ).intersection(self.get_flat_relative_ids(upstream=True))
  91. # pick up unique direct upstream task_ids if AUTO is specified
  92. if AUTO.upper() in self.inlets or AUTO.lower() in self.inlets:
  93. task_ids = task_ids.union(task_ids.symmetric_difference(self.upstream_task_ids))
  94. # Remove auto and task_ids
  95. self.inlets = [i for i in self.inlets if not isinstance(i, str)]
  96. # We manually create a session here since xcom_pull returns a
  97. # LazySelectSequence proxy. If we do not pass a session, a new one
  98. # will be created, but that session will not be properly closed.
  99. # After we are done iterating, we can safely close this session.
  100. with create_session() as session:
  101. _inlets = self.xcom_pull(
  102. context, task_ids=task_ids, dag_id=self.dag_id, key=PIPELINE_OUTLETS, session=session
  103. )
  104. self.inlets.extend(i for it in _inlets for i in it)
  105. elif self.inlets:
  106. raise AttributeError("inlets is not a list, operator, string or attr annotated object")
  107. if not isinstance(self.outlets, list):
  108. self.outlets = [self.outlets]
  109. # render inlets and outlets
  110. self.inlets = [_render_object(i, context) for i in self.inlets]
  111. self.outlets = [_render_object(i, context) for i in self.outlets]
  112. self.log.debug("inlets: %s, outlets: %s", self.inlets, self.outlets)
  113. return func(self, context, *args, **kwargs)
  114. return cast(T, wrapper)