__init__.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462
  1. # Licensed to the Apache Software Foundation (ASF) under one
  2. # or more contributor license agreements. See the NOTICE file
  3. # distributed with this work for additional information
  4. # regarding copyright ownership. The ASF licenses this file
  5. # to you under the Apache License, Version 2.0 (the
  6. # "License"); you may not use this file except in compliance
  7. # with the License. You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing,
  12. # software distributed under the License is distributed on an
  13. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  14. # KIND, either express or implied. See the License for the
  15. # specific language governing permissions and limitations
  16. # under the License.
  17. from __future__ import annotations
  18. import os
  19. import urllib.parse
  20. import warnings
  21. from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Iterator, cast
  22. import attr
  23. from sqlalchemy import select
  24. from airflow.api_internal.internal_api_call import internal_api_call
  25. from airflow.serialization.dag_dependency import DagDependency
  26. from airflow.typing_compat import TypedDict
  27. from airflow.utils.session import NEW_SESSION, provide_session
  28. if TYPE_CHECKING:
  29. from urllib.parse import SplitResult
  30. from sqlalchemy.orm.session import Session
  31. from airflow.configuration import conf
  32. __all__ = ["Dataset", "DatasetAll", "DatasetAny"]
  33. def normalize_noop(parts: SplitResult) -> SplitResult:
  34. """
  35. Place-hold a :class:`~urllib.parse.SplitResult`` normalizer.
  36. :meta private:
  37. """
  38. return parts
  39. def _get_uri_normalizer(scheme: str) -> Callable[[SplitResult], SplitResult] | None:
  40. if scheme == "file":
  41. return normalize_noop
  42. from airflow.providers_manager import ProvidersManager
  43. return ProvidersManager().dataset_uri_handlers.get(scheme)
  44. def _get_normalized_scheme(uri: str) -> str:
  45. parsed = urllib.parse.urlsplit(uri)
  46. return parsed.scheme.lower()
  47. def _sanitize_uri(uri: str) -> str:
  48. """
  49. Sanitize a dataset URI.
  50. This checks for URI validity, and normalizes the URI if needed. A fully
  51. normalized URI is returned.
  52. """
  53. if not uri:
  54. raise ValueError("Dataset URI cannot be empty")
  55. if uri.isspace():
  56. raise ValueError("Dataset URI cannot be just whitespace")
  57. if not uri.isascii():
  58. raise ValueError("Dataset URI must only consist of ASCII characters")
  59. parsed = urllib.parse.urlsplit(uri)
  60. if not parsed.scheme and not parsed.netloc: # Does not look like a URI.
  61. return uri
  62. if not (normalized_scheme := _get_normalized_scheme(uri)):
  63. return uri
  64. if normalized_scheme.startswith("x-"):
  65. return uri
  66. if normalized_scheme == "airflow":
  67. raise ValueError("Dataset scheme 'airflow' is reserved")
  68. _, auth_exists, normalized_netloc = parsed.netloc.rpartition("@")
  69. if auth_exists:
  70. # TODO: Collect this into a DagWarning.
  71. warnings.warn(
  72. "A dataset URI should not contain auth info (e.g. username or "
  73. "password). It has been automatically dropped.",
  74. UserWarning,
  75. stacklevel=3,
  76. )
  77. if parsed.query:
  78. normalized_query = urllib.parse.urlencode(sorted(urllib.parse.parse_qsl(parsed.query)))
  79. else:
  80. normalized_query = ""
  81. parsed = parsed._replace(
  82. scheme=normalized_scheme,
  83. netloc=normalized_netloc,
  84. path=parsed.path.rstrip("/") or "/", # Remove all trailing slashes.
  85. query=normalized_query,
  86. fragment="", # Ignore any fragments.
  87. )
  88. if (normalizer := _get_uri_normalizer(normalized_scheme)) is not None:
  89. try:
  90. parsed = normalizer(parsed)
  91. except ValueError as exception:
  92. if conf.getboolean("core", "strict_dataset_uri_validation", fallback=False):
  93. raise
  94. warnings.warn(
  95. f"The dataset URI {uri} is not AIP-60 compliant: {exception}. "
  96. f"In Airflow 3, this will raise an exception.",
  97. UserWarning,
  98. stacklevel=3,
  99. )
  100. return urllib.parse.urlunsplit(parsed)
  101. def extract_event_key(value: str | Dataset | DatasetAlias) -> str:
  102. """
  103. Extract the key of an inlet or an outlet event.
  104. If the input value is a string, it is treated as a URI and sanitized. If the
  105. input is a :class:`Dataset`, the URI it contains is considered sanitized and
  106. returned directly. If the input is a :class:`DatasetAlias`, the name it contains
  107. will be returned directly.
  108. :meta private:
  109. """
  110. if isinstance(value, DatasetAlias):
  111. return value.name
  112. if isinstance(value, Dataset):
  113. return value.uri
  114. return _sanitize_uri(str(value))
  115. @internal_api_call
  116. @provide_session
  117. def expand_alias_to_datasets(
  118. alias: str | DatasetAlias, *, session: Session = NEW_SESSION
  119. ) -> list[BaseDataset]:
  120. """Expand dataset alias to resolved datasets."""
  121. from airflow.models.dataset import DatasetAliasModel
  122. alias_name = alias.name if isinstance(alias, DatasetAlias) else alias
  123. dataset_alias_obj = session.scalar(
  124. select(DatasetAliasModel).where(DatasetAliasModel.name == alias_name).limit(1)
  125. )
  126. if dataset_alias_obj:
  127. return [Dataset(uri=dataset.uri, extra=dataset.extra) for dataset in dataset_alias_obj.datasets]
  128. return []
  129. class BaseDataset:
  130. """
  131. Protocol for all dataset triggers to use in ``DAG(schedule=...)``.
  132. :meta private:
  133. """
  134. def __bool__(self) -> bool:
  135. return True
  136. def __or__(self, other: BaseDataset) -> BaseDataset:
  137. if not isinstance(other, BaseDataset):
  138. return NotImplemented
  139. return DatasetAny(self, other)
  140. def __and__(self, other: BaseDataset) -> BaseDataset:
  141. if not isinstance(other, BaseDataset):
  142. return NotImplemented
  143. return DatasetAll(self, other)
  144. def as_expression(self) -> Any:
  145. """
  146. Serialize the dataset into its scheduling expression.
  147. The return value is stored in DagModel for display purposes. It must be
  148. JSON-compatible.
  149. :meta private:
  150. """
  151. raise NotImplementedError
  152. def evaluate(self, statuses: dict[str, bool]) -> bool:
  153. raise NotImplementedError
  154. def iter_datasets(self) -> Iterator[tuple[str, Dataset]]:
  155. raise NotImplementedError
  156. def iter_dataset_aliases(self) -> Iterator[DatasetAlias]:
  157. raise NotImplementedError
  158. def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDependency]:
  159. """
  160. Iterate a base dataset as dag dependency.
  161. :meta private:
  162. """
  163. raise NotImplementedError
  164. @attr.define()
  165. class DatasetAlias(BaseDataset):
  166. """A represeation of dataset alias which is used to create dataset during the runtime."""
  167. name: str
  168. def __eq__(self, other: Any) -> bool:
  169. if isinstance(other, DatasetAlias):
  170. return self.name == other.name
  171. return NotImplemented
  172. def __hash__(self) -> int:
  173. return hash(self.name)
  174. def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDependency]:
  175. """
  176. Iterate a dataset alias as dag dependency.
  177. :meta private:
  178. """
  179. yield DagDependency(
  180. source=source or "dataset-alias",
  181. target=target or "dataset-alias",
  182. dependency_type="dataset-alias",
  183. dependency_id=self.name,
  184. )
  185. class DatasetAliasEvent(TypedDict):
  186. """A represeation of dataset event to be triggered by a dataset alias."""
  187. source_alias_name: str
  188. dest_dataset_uri: str
  189. extra: dict[str, Any]
  190. @attr.define()
  191. class Dataset(os.PathLike, BaseDataset):
  192. """A representation of data dependencies between workflows."""
  193. uri: str = attr.field(
  194. converter=_sanitize_uri,
  195. validator=[attr.validators.min_len(1), attr.validators.max_len(3000)],
  196. )
  197. extra: dict[str, Any] | None = None
  198. __version__: ClassVar[int] = 1
  199. def __fspath__(self) -> str:
  200. return self.uri
  201. def __eq__(self, other: Any) -> bool:
  202. if isinstance(other, self.__class__):
  203. return self.uri == other.uri
  204. return NotImplemented
  205. def __hash__(self) -> int:
  206. return hash(self.uri)
  207. @property
  208. def normalized_uri(self) -> str | None:
  209. """
  210. Returns the normalized and AIP-60 compliant URI whenever possible.
  211. If we can't retrieve the scheme from URI or no normalizer is provided or if parsing fails,
  212. it returns None.
  213. If a normalizer for the scheme exists and parsing is successful we return the normalizer result.
  214. """
  215. if not (normalized_scheme := _get_normalized_scheme(self.uri)):
  216. return None
  217. if (normalizer := _get_uri_normalizer(normalized_scheme)) is None:
  218. return None
  219. parsed = urllib.parse.urlsplit(self.uri)
  220. try:
  221. normalized_uri = normalizer(parsed)
  222. return urllib.parse.urlunsplit(normalized_uri)
  223. except ValueError:
  224. return None
  225. def as_expression(self) -> Any:
  226. """
  227. Serialize the dataset into its scheduling expression.
  228. :meta private:
  229. """
  230. return self.uri
  231. def iter_datasets(self) -> Iterator[tuple[str, Dataset]]:
  232. yield self.uri, self
  233. def iter_dataset_aliases(self) -> Iterator[DatasetAlias]:
  234. return iter(())
  235. def evaluate(self, statuses: dict[str, bool]) -> bool:
  236. return statuses.get(self.uri, False)
  237. def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDependency]:
  238. """
  239. Iterate a dataset as dag dependency.
  240. :meta private:
  241. """
  242. yield DagDependency(
  243. source=source or "dataset",
  244. target=target or "dataset",
  245. dependency_type="dataset",
  246. dependency_id=self.uri,
  247. )
  248. class _DatasetBooleanCondition(BaseDataset):
  249. """Base class for dataset boolean logic."""
  250. agg_func: Callable[[Iterable], bool]
  251. def __init__(self, *objects: BaseDataset) -> None:
  252. if not all(isinstance(o, BaseDataset) for o in objects):
  253. raise TypeError("expect dataset expressions in condition")
  254. self.objects = [
  255. _DatasetAliasCondition(obj.name) if isinstance(obj, DatasetAlias) else obj for obj in objects
  256. ]
  257. def evaluate(self, statuses: dict[str, bool]) -> bool:
  258. return self.agg_func(x.evaluate(statuses=statuses) for x in self.objects)
  259. def iter_datasets(self) -> Iterator[tuple[str, Dataset]]:
  260. seen = set() # We want to keep the first instance.
  261. for o in self.objects:
  262. for k, v in o.iter_datasets():
  263. if k in seen:
  264. continue
  265. yield k, v
  266. seen.add(k)
  267. def iter_dataset_aliases(self) -> Iterator[DatasetAlias]:
  268. """Filter dataest aliases in the condition."""
  269. for o in self.objects:
  270. yield from o.iter_dataset_aliases()
  271. def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDependency]:
  272. """
  273. Iterate dataset, dataset aliases and their resolved datasets as dag dependency.
  274. :meta private:
  275. """
  276. for obj in self.objects:
  277. yield from obj.iter_dag_dependencies(source=source, target=target)
  278. class DatasetAny(_DatasetBooleanCondition):
  279. """Use to combine datasets schedule references in an "and" relationship."""
  280. agg_func = any
  281. def __or__(self, other: BaseDataset) -> BaseDataset:
  282. if not isinstance(other, BaseDataset):
  283. return NotImplemented
  284. # Optimization: X | (Y | Z) is equivalent to X | Y | Z.
  285. return DatasetAny(*self.objects, other)
  286. def __repr__(self) -> str:
  287. return f"DatasetAny({', '.join(map(str, self.objects))})"
  288. def as_expression(self) -> dict[str, Any]:
  289. """
  290. Serialize the dataset into its scheduling expression.
  291. :meta private:
  292. """
  293. return {"any": [o.as_expression() for o in self.objects]}
  294. class _DatasetAliasCondition(DatasetAny):
  295. """
  296. Use to expand DataAlias as DatasetAny of its resolved Datasets.
  297. :meta private:
  298. """
  299. def __init__(self, name: str) -> None:
  300. self.name = name
  301. self.objects = expand_alias_to_datasets(name)
  302. def __repr__(self) -> str:
  303. return f"_DatasetAliasCondition({', '.join(map(str, self.objects))})"
  304. def as_expression(self) -> Any:
  305. """
  306. Serialize the dataset into its scheduling expression.
  307. :meta private:
  308. """
  309. return {"alias": self.name}
  310. def iter_dataset_aliases(self) -> Iterator[DatasetAlias]:
  311. yield DatasetAlias(self.name)
  312. def iter_dag_dependencies(self, *, source: str = "", target: str = "") -> Iterator[DagDependency]:
  313. """
  314. Iterate a dataset alias and its resolved datasets as dag dependency.
  315. :meta private:
  316. """
  317. if self.objects:
  318. for obj in self.objects:
  319. dataset = cast(Dataset, obj)
  320. uri = dataset.uri
  321. # dataset
  322. yield DagDependency(
  323. source=f"dataset-alias:{self.name}" if source else "dataset",
  324. target="dataset" if source else f"dataset-alias:{self.name}",
  325. dependency_type="dataset",
  326. dependency_id=uri,
  327. )
  328. # dataset alias
  329. yield DagDependency(
  330. source=source or f"dataset:{uri}",
  331. target=target or f"dataset:{uri}",
  332. dependency_type="dataset-alias",
  333. dependency_id=self.name,
  334. )
  335. else:
  336. yield DagDependency(
  337. source=source or "dataset-alias",
  338. target=target or "dataset-alias",
  339. dependency_type="dataset-alias",
  340. dependency_id=self.name,
  341. )
  342. class DatasetAll(_DatasetBooleanCondition):
  343. """Use to combine datasets schedule references in an "or" relationship."""
  344. agg_func = all
  345. def __and__(self, other: BaseDataset) -> BaseDataset:
  346. if not isinstance(other, BaseDataset):
  347. return NotImplemented
  348. # Optimization: X & (Y & Z) is equivalent to X & Y & Z.
  349. return DatasetAll(*self.objects, other)
  350. def __repr__(self) -> str:
  351. return f"DatasetAll({', '.join(map(str, self.objects))})"
  352. def as_expression(self) -> Any:
  353. """
  354. Serialize the dataset into its scheduling expression.
  355. :meta private:
  356. """
  357. return {"all": [o.as_expression() for o in self.objects]}