hook.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. #
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing,
  13. # software distributed under the License is distributed on an
  14. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. # KIND, either express or implied. See the License for the
  16. # specific language governing permissions and limitations
  17. # under the License.
  18. from __future__ import annotations
  19. import hashlib
  20. import json
  21. from collections import defaultdict
  22. from typing import TYPE_CHECKING, Union
  23. import attr
  24. from airflow.datasets import Dataset
  25. from airflow.providers_manager import ProvidersManager
  26. from airflow.utils.log.logging_mixin import LoggingMixin
  27. if TYPE_CHECKING:
  28. from airflow.hooks.base import BaseHook
  29. from airflow.io.path import ObjectStoragePath
  30. # Store context what sent lineage.
  31. LineageContext = Union[BaseHook, ObjectStoragePath]
  32. _hook_lineage_collector: HookLineageCollector | None = None
  33. @attr.define
  34. class DatasetLineageInfo:
  35. """
  36. Holds lineage information for a single dataset.
  37. This class represents the lineage information for a single dataset, including the dataset itself,
  38. the count of how many times it has been encountered, and the context in which it was encountered.
  39. """
  40. dataset: Dataset
  41. count: int
  42. context: LineageContext
  43. @attr.define
  44. class HookLineage:
  45. """
  46. Holds lineage collected by HookLineageCollector.
  47. This class represents the lineage information collected by the `HookLineageCollector`. It stores
  48. the input and output datasets, each with an associated count indicating how many times the dataset
  49. has been encountered during the hook execution.
  50. """
  51. inputs: list[DatasetLineageInfo] = attr.ib(factory=list)
  52. outputs: list[DatasetLineageInfo] = attr.ib(factory=list)
  53. class HookLineageCollector(LoggingMixin):
  54. """
  55. HookLineageCollector is a base class for collecting hook lineage information.
  56. It is used to collect the input and output datasets of a hook execution.
  57. """
  58. def __init__(self, **kwargs):
  59. super().__init__(**kwargs)
  60. # Dictionary to store input datasets, counted by unique key (dataset URI, MD5 hash of extra
  61. # dictionary, and LineageContext's unique identifier)
  62. self._inputs: dict[str, tuple[Dataset, LineageContext]] = {}
  63. self._outputs: dict[str, tuple[Dataset, LineageContext]] = {}
  64. self._input_counts: dict[str, int] = defaultdict(int)
  65. self._output_counts: dict[str, int] = defaultdict(int)
  66. def _generate_key(self, dataset: Dataset, context: LineageContext) -> str:
  67. """
  68. Generate a unique key for the given dataset and context.
  69. This method creates a unique key by combining the dataset URI, the MD5 hash of the dataset's extra
  70. dictionary, and the LineageContext's unique identifier. This ensures that the generated key is
  71. unique for each combination of dataset and context.
  72. """
  73. extra_str = json.dumps(dataset.extra, sort_keys=True)
  74. extra_hash = hashlib.md5(extra_str.encode()).hexdigest()
  75. return f"{dataset.uri}_{extra_hash}_{id(context)}"
  76. def create_dataset(
  77. self, scheme: str | None, uri: str | None, dataset_kwargs: dict | None, dataset_extra: dict | None
  78. ) -> Dataset | None:
  79. """
  80. Create a Dataset instance using the provided parameters.
  81. This method attempts to create a Dataset instance using the given parameters.
  82. It first checks if a URI is provided and falls back to using the default dataset factory
  83. with the given URI if no other information is available.
  84. If a scheme is provided but no URI, it attempts to find a dataset factory that matches
  85. the given scheme. If no such factory is found, it logs an error message and returns None.
  86. If dataset_kwargs is provided, it is used to pass additional parameters to the Dataset
  87. factory. The dataset_extra parameter is also passed to the factory as an ``extra`` parameter.
  88. """
  89. if uri:
  90. # Fallback to default factory using the provided URI
  91. return Dataset(uri=uri, extra=dataset_extra)
  92. if not scheme:
  93. self.log.debug(
  94. "Missing required parameter: either 'uri' or 'scheme' must be provided to create a Dataset."
  95. )
  96. return None
  97. dataset_factory = ProvidersManager().dataset_factories.get(scheme)
  98. if not dataset_factory:
  99. self.log.debug("Unsupported scheme: %s. Please provide a valid URI to create a Dataset.", scheme)
  100. return None
  101. dataset_kwargs = dataset_kwargs or {}
  102. try:
  103. return dataset_factory(**dataset_kwargs, extra=dataset_extra)
  104. except Exception as e:
  105. self.log.debug("Failed to create dataset. Skipping. Error: %s", e)
  106. return None
  107. def add_input_dataset(
  108. self,
  109. context: LineageContext,
  110. scheme: str | None = None,
  111. uri: str | None = None,
  112. dataset_kwargs: dict | None = None,
  113. dataset_extra: dict | None = None,
  114. ):
  115. """Add the input dataset and its corresponding hook execution context to the collector."""
  116. dataset = self.create_dataset(
  117. scheme=scheme, uri=uri, dataset_kwargs=dataset_kwargs, dataset_extra=dataset_extra
  118. )
  119. if dataset:
  120. key = self._generate_key(dataset, context)
  121. if key not in self._inputs:
  122. self._inputs[key] = (dataset, context)
  123. self._input_counts[key] += 1
  124. def add_output_dataset(
  125. self,
  126. context: LineageContext,
  127. scheme: str | None = None,
  128. uri: str | None = None,
  129. dataset_kwargs: dict | None = None,
  130. dataset_extra: dict | None = None,
  131. ):
  132. """Add the output dataset and its corresponding hook execution context to the collector."""
  133. dataset = self.create_dataset(
  134. scheme=scheme, uri=uri, dataset_kwargs=dataset_kwargs, dataset_extra=dataset_extra
  135. )
  136. if dataset:
  137. key = self._generate_key(dataset, context)
  138. if key not in self._outputs:
  139. self._outputs[key] = (dataset, context)
  140. self._output_counts[key] += 1
  141. @property
  142. def collected_datasets(self) -> HookLineage:
  143. """Get the collected hook lineage information."""
  144. return HookLineage(
  145. [
  146. DatasetLineageInfo(dataset=dataset, count=self._input_counts[key], context=context)
  147. for key, (dataset, context) in self._inputs.items()
  148. ],
  149. [
  150. DatasetLineageInfo(dataset=dataset, count=self._output_counts[key], context=context)
  151. for key, (dataset, context) in self._outputs.items()
  152. ],
  153. )
  154. @property
  155. def has_collected(self) -> bool:
  156. """Check if any datasets have been collected."""
  157. return len(self._inputs) != 0 or len(self._outputs) != 0
  158. class NoOpCollector(HookLineageCollector):
  159. """
  160. NoOpCollector is a hook lineage collector that does nothing.
  161. It is used when you want to disable lineage collection.
  162. """
  163. def add_input_dataset(self, *_, **__):
  164. pass
  165. def add_output_dataset(self, *_, **__):
  166. pass
  167. @property
  168. def collected_datasets(
  169. self,
  170. ) -> HookLineage:
  171. self.log.warning(
  172. "Data lineage tracking is disabled. Register a hook lineage reader to start tracking hook lineage."
  173. )
  174. return HookLineage([], [])
  175. class HookLineageReader(LoggingMixin):
  176. """Class used to retrieve the hook lineage information collected by HookLineageCollector."""
  177. def __init__(self, **kwargs):
  178. self.lineage_collector = get_hook_lineage_collector()
  179. def retrieve_hook_lineage(self) -> HookLineage:
  180. """Retrieve hook lineage from HookLineageCollector."""
  181. hook_lineage = self.lineage_collector.collected_datasets
  182. return hook_lineage
  183. def get_hook_lineage_collector() -> HookLineageCollector:
  184. """Get singleton lineage collector."""
  185. global _hook_lineage_collector
  186. if not _hook_lineage_collector:
  187. from airflow import plugins_manager
  188. plugins_manager.initialize_hook_lineage_readers_plugins()
  189. if plugins_manager.hook_lineage_reader_classes:
  190. _hook_lineage_collector = HookLineageCollector()
  191. else:
  192. _hook_lineage_collector = NoOpCollector()
  193. return _hook_lineage_collector