123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255 |
- #
- # Licensed to the Apache Software Foundation (ASF) under one
- # or more contributor license agreements. See the NOTICE file
- # distributed with this work for additional information
- # regarding copyright ownership. The ASF licenses this file
- # to you under the Apache License, Version 2.0 (the
- # "License"); you may not use this file except in compliance
- # with the License. You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- # KIND, either express or implied. See the License for the
- # specific language governing permissions and limitations
- # under the License.
- from __future__ import annotations
- from collections.abc import Iterable
- from typing import TYPE_CHECKING
- from sqlalchemy import exc, select
- from sqlalchemy.orm import joinedload
- from airflow.api_internal.internal_api_call import internal_api_call
- from airflow.configuration import conf
- from airflow.datasets import Dataset
- from airflow.listeners.listener import get_listener_manager
- from airflow.models.dagbag import DagPriorityParsingRequest
- from airflow.models.dataset import (
- DagScheduleDatasetAliasReference,
- DagScheduleDatasetReference,
- DatasetAliasModel,
- DatasetDagRunQueue,
- DatasetEvent,
- DatasetModel,
- )
- from airflow.stats import Stats
- from airflow.utils.log.logging_mixin import LoggingMixin
- from airflow.utils.session import NEW_SESSION, provide_session
- if TYPE_CHECKING:
- from sqlalchemy.orm.session import Session
- from airflow.models.dag import DagModel
- from airflow.models.taskinstance import TaskInstance
- class DatasetManager(LoggingMixin):
- """
- A pluggable class that manages operations for datasets.
- The intent is to have one place to handle all Dataset-related operations, so different
- Airflow deployments can use plugins that broadcast dataset events to each other.
- """
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- def create_datasets(self, dataset_models: list[DatasetModel], session: Session) -> None:
- """Create new datasets."""
- for dataset_model in dataset_models:
- session.add(dataset_model)
- session.flush()
- for dataset_model in dataset_models:
- self.notify_dataset_created(dataset=Dataset(uri=dataset_model.uri, extra=dataset_model.extra))
- @classmethod
- @internal_api_call
- @provide_session
- def register_dataset_change(
- cls,
- *,
- task_instance: TaskInstance | None = None,
- dataset: Dataset,
- extra=None,
- session: Session = NEW_SESSION,
- source_alias_names: Iterable[str] | None = None,
- **kwargs,
- ) -> DatasetEvent | None:
- """
- Register dataset related changes.
- For local datasets, look them up, record the dataset event, queue dagruns, and broadcast
- the dataset event
- """
- # todo: add test so that all usages of internal_api_call are added to rpc endpoint
- dataset_model = session.scalar(
- select(DatasetModel)
- .where(DatasetModel.uri == dataset.uri)
- .options(joinedload(DatasetModel.consuming_dags).joinedload(DagScheduleDatasetReference.dag))
- )
- if not dataset_model:
- cls.logger().warning("DatasetModel %s not found", dataset)
- return None
- event_kwargs = {
- "dataset_id": dataset_model.id,
- "extra": extra,
- }
- if task_instance:
- event_kwargs.update(
- {
- "source_task_id": task_instance.task_id,
- "source_dag_id": task_instance.dag_id,
- "source_run_id": task_instance.run_id,
- "source_map_index": task_instance.map_index,
- }
- )
- dataset_event = DatasetEvent(**event_kwargs)
- session.add(dataset_event)
- dags_to_queue_from_dataset = {
- ref.dag for ref in dataset_model.consuming_dags if ref.dag.is_active and not ref.dag.is_paused
- }
- dags_to_queue_from_dataset_alias = set()
- if source_alias_names:
- dataset_alias_models = session.scalars(
- select(DatasetAliasModel)
- .where(DatasetAliasModel.name.in_(source_alias_names))
- .options(
- joinedload(DatasetAliasModel.consuming_dags).joinedload(
- DagScheduleDatasetAliasReference.dag
- )
- )
- ).unique()
- for dsa in dataset_alias_models:
- dsa.dataset_events.append(dataset_event)
- session.add(dsa)
- dags_to_queue_from_dataset_alias |= {
- alias_ref.dag
- for alias_ref in dsa.consuming_dags
- if alias_ref.dag.is_active and not alias_ref.dag.is_paused
- }
- dags_to_reparse = dags_to_queue_from_dataset_alias - dags_to_queue_from_dataset
- if dags_to_reparse:
- file_locs = {dag.fileloc for dag in dags_to_reparse}
- cls._send_dag_priority_parsing_request(file_locs, session)
- session.flush()
- cls.notify_dataset_changed(dataset=dataset)
- Stats.incr("dataset.updates")
- dags_to_queue = dags_to_queue_from_dataset | dags_to_queue_from_dataset_alias
- cls._queue_dagruns(dataset_id=dataset_model.id, dags_to_queue=dags_to_queue, session=session)
- session.flush()
- return dataset_event
- def notify_dataset_created(self, dataset: Dataset):
- """Run applicable notification actions when a dataset is created."""
- get_listener_manager().hook.on_dataset_created(dataset=dataset)
- @classmethod
- def notify_dataset_changed(cls, dataset: Dataset):
- """Run applicable notification actions when a dataset is changed."""
- get_listener_manager().hook.on_dataset_changed(dataset=dataset)
- @classmethod
- def _queue_dagruns(cls, dataset_id: int, dags_to_queue: set[DagModel], session: Session) -> None:
- # Possible race condition: if multiple dags or multiple (usually
- # mapped) tasks update the same dataset, this can fail with a unique
- # constraint violation.
- #
- # If we support it, use ON CONFLICT to do nothing, otherwise
- # "fallback" to running this in a nested transaction. This is needed
- # so that the adding of these rows happens in the same transaction
- # where `ti.state` is changed.
- if not dags_to_queue:
- return
- if session.bind.dialect.name == "postgresql":
- return cls._postgres_queue_dagruns(dataset_id, dags_to_queue, session)
- return cls._slow_path_queue_dagruns(dataset_id, dags_to_queue, session)
- @classmethod
- def _slow_path_queue_dagruns(
- cls, dataset_id: int, dags_to_queue: set[DagModel], session: Session
- ) -> None:
- def _queue_dagrun_if_needed(dag: DagModel) -> str | None:
- item = DatasetDagRunQueue(target_dag_id=dag.dag_id, dataset_id=dataset_id)
- # Don't error whole transaction when a single RunQueue item conflicts.
- # https://docs.sqlalchemy.org/en/14/orm/session_transaction.html#using-savepoint
- try:
- with session.begin_nested():
- session.merge(item)
- except exc.IntegrityError:
- cls.logger().debug("Skipping record %s", item, exc_info=True)
- return dag.dag_id
- queued_results = (_queue_dagrun_if_needed(dag) for dag in dags_to_queue)
- if queued_dag_ids := [r for r in queued_results if r is not None]:
- cls.logger().debug("consuming dag ids %s", queued_dag_ids)
- @classmethod
- def _postgres_queue_dagruns(cls, dataset_id: int, dags_to_queue: set[DagModel], session: Session) -> None:
- from sqlalchemy.dialects.postgresql import insert
- values = [{"target_dag_id": dag.dag_id} for dag in dags_to_queue]
- stmt = insert(DatasetDagRunQueue).values(dataset_id=dataset_id).on_conflict_do_nothing()
- session.execute(stmt, values)
- @classmethod
- def _send_dag_priority_parsing_request(cls, file_locs: Iterable[str], session: Session) -> None:
- if session.bind.dialect.name == "postgresql":
- return cls._postgres_send_dag_priority_parsing_request(file_locs, session)
- return cls._slow_path_send_dag_priority_parsing_request(file_locs, session)
- @classmethod
- def _slow_path_send_dag_priority_parsing_request(cls, file_locs: Iterable[str], session: Session) -> None:
- def _send_dag_priority_parsing_request_if_needed(fileloc: str) -> str | None:
- # Don't error whole transaction when a single DagPriorityParsingRequest item conflicts.
- # https://docs.sqlalchemy.org/en/14/orm/session_transaction.html#using-savepoint
- req = DagPriorityParsingRequest(fileloc=fileloc)
- try:
- with session.begin_nested():
- session.merge(req)
- except exc.IntegrityError:
- cls.logger().debug("Skipping request %s, already present", req, exc_info=True)
- return None
- return req.fileloc
- for fileloc in file_locs:
- _send_dag_priority_parsing_request_if_needed(fileloc)
- @classmethod
- def _postgres_send_dag_priority_parsing_request(cls, file_locs: Iterable[str], session: Session) -> None:
- from sqlalchemy.dialects.postgresql import insert
- stmt = insert(DagPriorityParsingRequest).on_conflict_do_nothing()
- session.execute(stmt, [{"fileloc": fileloc} for fileloc in file_locs])
- def resolve_dataset_manager() -> DatasetManager:
- """Retrieve the dataset manager."""
- _dataset_manager_class = conf.getimport(
- section="core",
- key="dataset_manager_class",
- fallback="airflow.datasets.manager.DatasetManager",
- )
- _dataset_manager_kwargs = conf.getjson(
- section="core",
- key="dataset_manager_kwargs",
- fallback={},
- )
- return _dataset_manager_class(**_dataset_manager_kwargs)
- dataset_manager = resolve_dataset_manager()
|