manager.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. #
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing,
  13. # software distributed under the License is distributed on an
  14. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. # KIND, either express or implied. See the License for the
  16. # specific language governing permissions and limitations
  17. # under the License.
  18. from __future__ import annotations
  19. from collections.abc import Iterable
  20. from typing import TYPE_CHECKING
  21. from sqlalchemy import exc, select
  22. from sqlalchemy.orm import joinedload
  23. from airflow.api_internal.internal_api_call import internal_api_call
  24. from airflow.configuration import conf
  25. from airflow.datasets import Dataset
  26. from airflow.listeners.listener import get_listener_manager
  27. from airflow.models.dagbag import DagPriorityParsingRequest
  28. from airflow.models.dataset import (
  29. DagScheduleDatasetAliasReference,
  30. DagScheduleDatasetReference,
  31. DatasetAliasModel,
  32. DatasetDagRunQueue,
  33. DatasetEvent,
  34. DatasetModel,
  35. )
  36. from airflow.stats import Stats
  37. from airflow.utils.log.logging_mixin import LoggingMixin
  38. from airflow.utils.session import NEW_SESSION, provide_session
  39. if TYPE_CHECKING:
  40. from sqlalchemy.orm.session import Session
  41. from airflow.models.dag import DagModel
  42. from airflow.models.taskinstance import TaskInstance
  43. class DatasetManager(LoggingMixin):
  44. """
  45. A pluggable class that manages operations for datasets.
  46. The intent is to have one place to handle all Dataset-related operations, so different
  47. Airflow deployments can use plugins that broadcast dataset events to each other.
  48. """
  49. def __init__(self, **kwargs):
  50. super().__init__(**kwargs)
  51. def create_datasets(self, dataset_models: list[DatasetModel], session: Session) -> None:
  52. """Create new datasets."""
  53. for dataset_model in dataset_models:
  54. session.add(dataset_model)
  55. session.flush()
  56. for dataset_model in dataset_models:
  57. self.notify_dataset_created(dataset=Dataset(uri=dataset_model.uri, extra=dataset_model.extra))
  58. @classmethod
  59. @internal_api_call
  60. @provide_session
  61. def register_dataset_change(
  62. cls,
  63. *,
  64. task_instance: TaskInstance | None = None,
  65. dataset: Dataset,
  66. extra=None,
  67. session: Session = NEW_SESSION,
  68. source_alias_names: Iterable[str] | None = None,
  69. **kwargs,
  70. ) -> DatasetEvent | None:
  71. """
  72. Register dataset related changes.
  73. For local datasets, look them up, record the dataset event, queue dagruns, and broadcast
  74. the dataset event
  75. """
  76. # todo: add test so that all usages of internal_api_call are added to rpc endpoint
  77. dataset_model = session.scalar(
  78. select(DatasetModel)
  79. .where(DatasetModel.uri == dataset.uri)
  80. .options(joinedload(DatasetModel.consuming_dags).joinedload(DagScheduleDatasetReference.dag))
  81. )
  82. if not dataset_model:
  83. cls.logger().warning("DatasetModel %s not found", dataset)
  84. return None
  85. event_kwargs = {
  86. "dataset_id": dataset_model.id,
  87. "extra": extra,
  88. }
  89. if task_instance:
  90. event_kwargs.update(
  91. {
  92. "source_task_id": task_instance.task_id,
  93. "source_dag_id": task_instance.dag_id,
  94. "source_run_id": task_instance.run_id,
  95. "source_map_index": task_instance.map_index,
  96. }
  97. )
  98. dataset_event = DatasetEvent(**event_kwargs)
  99. session.add(dataset_event)
  100. dags_to_queue_from_dataset = {
  101. ref.dag for ref in dataset_model.consuming_dags if ref.dag.is_active and not ref.dag.is_paused
  102. }
  103. dags_to_queue_from_dataset_alias = set()
  104. if source_alias_names:
  105. dataset_alias_models = session.scalars(
  106. select(DatasetAliasModel)
  107. .where(DatasetAliasModel.name.in_(source_alias_names))
  108. .options(
  109. joinedload(DatasetAliasModel.consuming_dags).joinedload(
  110. DagScheduleDatasetAliasReference.dag
  111. )
  112. )
  113. ).unique()
  114. for dsa in dataset_alias_models:
  115. dsa.dataset_events.append(dataset_event)
  116. session.add(dsa)
  117. dags_to_queue_from_dataset_alias |= {
  118. alias_ref.dag
  119. for alias_ref in dsa.consuming_dags
  120. if alias_ref.dag.is_active and not alias_ref.dag.is_paused
  121. }
  122. dags_to_reparse = dags_to_queue_from_dataset_alias - dags_to_queue_from_dataset
  123. if dags_to_reparse:
  124. file_locs = {dag.fileloc for dag in dags_to_reparse}
  125. cls._send_dag_priority_parsing_request(file_locs, session)
  126. session.flush()
  127. cls.notify_dataset_changed(dataset=dataset)
  128. Stats.incr("dataset.updates")
  129. dags_to_queue = dags_to_queue_from_dataset | dags_to_queue_from_dataset_alias
  130. cls._queue_dagruns(dataset_id=dataset_model.id, dags_to_queue=dags_to_queue, session=session)
  131. session.flush()
  132. return dataset_event
  133. def notify_dataset_created(self, dataset: Dataset):
  134. """Run applicable notification actions when a dataset is created."""
  135. get_listener_manager().hook.on_dataset_created(dataset=dataset)
  136. @classmethod
  137. def notify_dataset_changed(cls, dataset: Dataset):
  138. """Run applicable notification actions when a dataset is changed."""
  139. get_listener_manager().hook.on_dataset_changed(dataset=dataset)
  140. @classmethod
  141. def _queue_dagruns(cls, dataset_id: int, dags_to_queue: set[DagModel], session: Session) -> None:
  142. # Possible race condition: if multiple dags or multiple (usually
  143. # mapped) tasks update the same dataset, this can fail with a unique
  144. # constraint violation.
  145. #
  146. # If we support it, use ON CONFLICT to do nothing, otherwise
  147. # "fallback" to running this in a nested transaction. This is needed
  148. # so that the adding of these rows happens in the same transaction
  149. # where `ti.state` is changed.
  150. if not dags_to_queue:
  151. return
  152. if session.bind.dialect.name == "postgresql":
  153. return cls._postgres_queue_dagruns(dataset_id, dags_to_queue, session)
  154. return cls._slow_path_queue_dagruns(dataset_id, dags_to_queue, session)
  155. @classmethod
  156. def _slow_path_queue_dagruns(
  157. cls, dataset_id: int, dags_to_queue: set[DagModel], session: Session
  158. ) -> None:
  159. def _queue_dagrun_if_needed(dag: DagModel) -> str | None:
  160. item = DatasetDagRunQueue(target_dag_id=dag.dag_id, dataset_id=dataset_id)
  161. # Don't error whole transaction when a single RunQueue item conflicts.
  162. # https://docs.sqlalchemy.org/en/14/orm/session_transaction.html#using-savepoint
  163. try:
  164. with session.begin_nested():
  165. session.merge(item)
  166. except exc.IntegrityError:
  167. cls.logger().debug("Skipping record %s", item, exc_info=True)
  168. return dag.dag_id
  169. queued_results = (_queue_dagrun_if_needed(dag) for dag in dags_to_queue)
  170. if queued_dag_ids := [r for r in queued_results if r is not None]:
  171. cls.logger().debug("consuming dag ids %s", queued_dag_ids)
  172. @classmethod
  173. def _postgres_queue_dagruns(cls, dataset_id: int, dags_to_queue: set[DagModel], session: Session) -> None:
  174. from sqlalchemy.dialects.postgresql import insert
  175. values = [{"target_dag_id": dag.dag_id} for dag in dags_to_queue]
  176. stmt = insert(DatasetDagRunQueue).values(dataset_id=dataset_id).on_conflict_do_nothing()
  177. session.execute(stmt, values)
  178. @classmethod
  179. def _send_dag_priority_parsing_request(cls, file_locs: Iterable[str], session: Session) -> None:
  180. if session.bind.dialect.name == "postgresql":
  181. return cls._postgres_send_dag_priority_parsing_request(file_locs, session)
  182. return cls._slow_path_send_dag_priority_parsing_request(file_locs, session)
  183. @classmethod
  184. def _slow_path_send_dag_priority_parsing_request(cls, file_locs: Iterable[str], session: Session) -> None:
  185. def _send_dag_priority_parsing_request_if_needed(fileloc: str) -> str | None:
  186. # Don't error whole transaction when a single DagPriorityParsingRequest item conflicts.
  187. # https://docs.sqlalchemy.org/en/14/orm/session_transaction.html#using-savepoint
  188. req = DagPriorityParsingRequest(fileloc=fileloc)
  189. try:
  190. with session.begin_nested():
  191. session.merge(req)
  192. except exc.IntegrityError:
  193. cls.logger().debug("Skipping request %s, already present", req, exc_info=True)
  194. return None
  195. return req.fileloc
  196. for fileloc in file_locs:
  197. _send_dag_priority_parsing_request_if_needed(fileloc)
  198. @classmethod
  199. def _postgres_send_dag_priority_parsing_request(cls, file_locs: Iterable[str], session: Session) -> None:
  200. from sqlalchemy.dialects.postgresql import insert
  201. stmt = insert(DagPriorityParsingRequest).on_conflict_do_nothing()
  202. session.execute(stmt, [{"fileloc": fileloc} for fileloc in file_locs])
  203. def resolve_dataset_manager() -> DatasetManager:
  204. """Retrieve the dataset manager."""
  205. _dataset_manager_class = conf.getimport(
  206. section="core",
  207. key="dataset_manager_class",
  208. fallback="airflow.datasets.manager.DatasetManager",
  209. )
  210. _dataset_manager_kwargs = conf.getjson(
  211. section="core",
  212. key="dataset_manager_kwargs",
  213. fallback={},
  214. )
  215. return _dataset_manager_class(**_dataset_manager_kwargs)
  216. dataset_manager = resolve_dataset_manager()