datasets.py 3.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. # Licensed to the Apache Software Foundation (ASF) under one
  2. # or more contributor license agreements. See the NOTICE file
  3. # distributed with this work for additional information
  4. # regarding copyright ownership. The ASF licenses this file
  5. # to you under the Apache License, Version 2.0 (the
  6. # "License"); you may not use this file except in compliance
  7. # with the License. You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing,
  12. # software distributed under the License is distributed on an
  13. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  14. # KIND, either express or implied. See the License for the
  15. # specific language governing permissions and limitations
  16. # under the License.
  17. from __future__ import annotations
  18. import typing
  19. from airflow.datasets import BaseDataset, DatasetAll
  20. from airflow.exceptions import AirflowTimetableInvalid
  21. from airflow.timetables.simple import DatasetTriggeredTimetable as DatasetTriggeredSchedule
  22. from airflow.utils.types import DagRunType
  23. if typing.TYPE_CHECKING:
  24. from collections.abc import Collection
  25. import pendulum
  26. from airflow.datasets import Dataset
  27. from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable
  28. class DatasetOrTimeSchedule(DatasetTriggeredSchedule):
  29. """Combine time-based scheduling with event-based scheduling."""
  30. def __init__(
  31. self,
  32. *,
  33. timetable: Timetable,
  34. datasets: Collection[Dataset] | BaseDataset,
  35. ) -> None:
  36. self.timetable = timetable
  37. if isinstance(datasets, BaseDataset):
  38. self.dataset_condition = datasets
  39. else:
  40. self.dataset_condition = DatasetAll(*datasets)
  41. self.description = f"Triggered by datasets or {timetable.description}"
  42. self.periodic = timetable.periodic
  43. self._can_be_scheduled = timetable._can_be_scheduled
  44. self.active_runs_limit = timetable.active_runs_limit
  45. @classmethod
  46. def deserialize(cls, data: dict[str, typing.Any]) -> Timetable:
  47. from airflow.serialization.serialized_objects import decode_dataset_condition, decode_timetable
  48. return cls(
  49. datasets=decode_dataset_condition(data["dataset_condition"]),
  50. timetable=decode_timetable(data["timetable"]),
  51. )
  52. def serialize(self) -> dict[str, typing.Any]:
  53. from airflow.serialization.serialized_objects import encode_dataset_condition, encode_timetable
  54. return {
  55. "dataset_condition": encode_dataset_condition(self.dataset_condition),
  56. "timetable": encode_timetable(self.timetable),
  57. }
  58. def validate(self) -> None:
  59. if isinstance(self.timetable, DatasetTriggeredSchedule):
  60. raise AirflowTimetableInvalid("cannot nest dataset timetables")
  61. if not isinstance(self.dataset_condition, BaseDataset):
  62. raise AirflowTimetableInvalid("all elements in 'datasets' must be datasets")
  63. @property
  64. def summary(self) -> str:
  65. return f"Dataset or {self.timetable.summary}"
  66. def infer_manual_data_interval(self, *, run_after: pendulum.DateTime) -> DataInterval:
  67. return self.timetable.infer_manual_data_interval(run_after=run_after)
  68. def next_dagrun_info(
  69. self, *, last_automated_data_interval: DataInterval | None, restriction: TimeRestriction
  70. ) -> DagRunInfo | None:
  71. return self.timetable.next_dagrun_info(
  72. last_automated_data_interval=last_automated_data_interval,
  73. restriction=restriction,
  74. )
  75. def generate_run_id(self, *, run_type: DagRunType, **kwargs: typing.Any) -> str:
  76. if run_type != DagRunType.DATASET_TRIGGERED:
  77. return self.timetable.generate_run_id(run_type=run_type, **kwargs)
  78. return super().generate_run_id(run_type=run_type, **kwargs)