dataset.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492
  1. #
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing,
  13. # software distributed under the License is distributed on an
  14. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. # KIND, either express or implied. See the License for the
  16. # specific language governing permissions and limitations
  17. # under the License.
  18. from __future__ import annotations
  19. from urllib.parse import urlsplit
  20. import sqlalchemy_jsonfield
  21. from sqlalchemy import (
  22. Boolean,
  23. Column,
  24. ForeignKey,
  25. ForeignKeyConstraint,
  26. Index,
  27. Integer,
  28. PrimaryKeyConstraint,
  29. String,
  30. Table,
  31. text,
  32. )
  33. from sqlalchemy.orm import relationship
  34. from airflow.datasets import Dataset, DatasetAlias
  35. from airflow.models.base import Base, StringID
  36. from airflow.settings import json
  37. from airflow.utils import timezone
  38. from airflow.utils.sqlalchemy import UtcDateTime
  39. alias_association_table = Table(
  40. "dataset_alias_dataset",
  41. Base.metadata,
  42. Column("alias_id", ForeignKey("dataset_alias.id", ondelete="CASCADE"), primary_key=True),
  43. Column("dataset_id", ForeignKey("dataset.id", ondelete="CASCADE"), primary_key=True),
  44. Index("idx_dataset_alias_dataset_alias_id", "alias_id"),
  45. Index("idx_dataset_alias_dataset_alias_dataset_id", "dataset_id"),
  46. ForeignKeyConstraint(
  47. ("alias_id",),
  48. ["dataset_alias.id"],
  49. name="ds_dsa_alias_id",
  50. ondelete="CASCADE",
  51. ),
  52. ForeignKeyConstraint(
  53. ("dataset_id",),
  54. ["dataset.id"],
  55. name="ds_dsa_dataset_id",
  56. ondelete="CASCADE",
  57. ),
  58. )
  59. dataset_alias_dataset_event_assocation_table = Table(
  60. "dataset_alias_dataset_event",
  61. Base.metadata,
  62. Column("alias_id", ForeignKey("dataset_alias.id", ondelete="CASCADE"), primary_key=True),
  63. Column("event_id", ForeignKey("dataset_event.id", ondelete="CASCADE"), primary_key=True),
  64. Index("idx_dataset_alias_dataset_event_alias_id", "alias_id"),
  65. Index("idx_dataset_alias_dataset_event_event_id", "event_id"),
  66. ForeignKeyConstraint(
  67. ("alias_id",),
  68. ["dataset_alias.id"],
  69. name="dss_de_alias_id",
  70. ondelete="CASCADE",
  71. ),
  72. ForeignKeyConstraint(
  73. ("event_id",),
  74. ["dataset_event.id"],
  75. name="dss_de_event_id",
  76. ondelete="CASCADE",
  77. ),
  78. )
  79. class DatasetAliasModel(Base):
  80. """
  81. A table to store dataset alias.
  82. :param uri: a string that uniquely identifies the dataset alias
  83. """
  84. id = Column(Integer, primary_key=True, autoincrement=True)
  85. name = Column(
  86. String(length=3000).with_variant(
  87. String(
  88. length=3000,
  89. # latin1 allows for more indexed length in mysql
  90. # and this field should only be ascii chars
  91. collation="latin1_general_cs",
  92. ),
  93. "mysql",
  94. ),
  95. nullable=False,
  96. )
  97. __tablename__ = "dataset_alias"
  98. __table_args__ = (
  99. Index("idx_name_unique", name, unique=True),
  100. {"sqlite_autoincrement": True}, # ensures PK values not reused
  101. )
  102. datasets = relationship(
  103. "DatasetModel",
  104. secondary=alias_association_table,
  105. backref="aliases",
  106. )
  107. dataset_events = relationship(
  108. "DatasetEvent",
  109. secondary=dataset_alias_dataset_event_assocation_table,
  110. back_populates="source_aliases",
  111. )
  112. consuming_dags = relationship("DagScheduleDatasetAliasReference", back_populates="dataset_alias")
  113. @classmethod
  114. def from_public(cls, obj: DatasetAlias) -> DatasetAliasModel:
  115. return cls(name=obj.name)
  116. def __repr__(self):
  117. return f"{self.__class__.__name__}(name={self.name!r})"
  118. def __hash__(self):
  119. return hash(self.name)
  120. def __eq__(self, other):
  121. if isinstance(other, (self.__class__, DatasetAlias)):
  122. return self.name == other.name
  123. else:
  124. return NotImplemented
  125. class DatasetModel(Base):
  126. """
  127. A table to store datasets.
  128. :param uri: a string that uniquely identifies the dataset
  129. :param extra: JSON field for arbitrary extra info
  130. """
  131. id = Column(Integer, primary_key=True, autoincrement=True)
  132. uri = Column(
  133. String(length=3000).with_variant(
  134. String(
  135. length=3000,
  136. # latin1 allows for more indexed length in mysql
  137. # and this field should only be ascii chars
  138. collation="latin1_general_cs",
  139. ),
  140. "mysql",
  141. ),
  142. nullable=False,
  143. )
  144. extra = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=False, default={})
  145. created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
  146. updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False)
  147. is_orphaned = Column(Boolean, default=False, nullable=False, server_default="0")
  148. consuming_dags = relationship("DagScheduleDatasetReference", back_populates="dataset")
  149. producing_tasks = relationship("TaskOutletDatasetReference", back_populates="dataset")
  150. __tablename__ = "dataset"
  151. __table_args__ = (
  152. Index("idx_uri_unique", uri, unique=True),
  153. {"sqlite_autoincrement": True}, # ensures PK values not reused
  154. )
  155. @classmethod
  156. def from_public(cls, obj: Dataset) -> DatasetModel:
  157. return cls(uri=obj.uri, extra=obj.extra)
  158. def __init__(self, uri: str, **kwargs):
  159. try:
  160. uri.encode("ascii")
  161. except UnicodeEncodeError:
  162. raise ValueError("URI must be ascii")
  163. parsed = urlsplit(uri)
  164. if parsed.scheme and parsed.scheme.lower() == "airflow":
  165. raise ValueError("Scheme `airflow` is reserved.")
  166. super().__init__(uri=uri, **kwargs)
  167. def __eq__(self, other):
  168. if isinstance(other, (self.__class__, Dataset)):
  169. return self.uri == other.uri
  170. else:
  171. return NotImplemented
  172. def __hash__(self):
  173. return hash(self.uri)
  174. def __repr__(self):
  175. return f"{self.__class__.__name__}(uri={self.uri!r}, extra={self.extra!r})"
  176. class DagScheduleDatasetAliasReference(Base):
  177. """References from a DAG to a dataset alias of which it is a consumer."""
  178. alias_id = Column(Integer, primary_key=True, nullable=False)
  179. dag_id = Column(StringID(), primary_key=True, nullable=False)
  180. created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
  181. updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False)
  182. dataset_alias = relationship("DatasetAliasModel", back_populates="consuming_dags")
  183. dag = relationship("DagModel", back_populates="schedule_dataset_alias_references")
  184. __tablename__ = "dag_schedule_dataset_alias_reference"
  185. __table_args__ = (
  186. PrimaryKeyConstraint(alias_id, dag_id, name="dsdar_pkey"),
  187. ForeignKeyConstraint(
  188. (alias_id,),
  189. ["dataset_alias.id"],
  190. name="dsdar_dataset_alias_fkey",
  191. ondelete="CASCADE",
  192. ),
  193. ForeignKeyConstraint(
  194. columns=(dag_id,),
  195. refcolumns=["dag.dag_id"],
  196. name="dsdar_dag_id_fkey",
  197. ondelete="CASCADE",
  198. ),
  199. Index("idx_dag_schedule_dataset_alias_reference_dag_id", dag_id),
  200. )
  201. def __eq__(self, other):
  202. if isinstance(other, self.__class__):
  203. return self.alias_id == other.alias_id and self.dag_id == other.dag_id
  204. return NotImplemented
  205. def __hash__(self):
  206. return hash(self.__mapper__.primary_key)
  207. def __repr__(self):
  208. args = []
  209. for attr in [x.name for x in self.__mapper__.primary_key]:
  210. args.append(f"{attr}={getattr(self, attr)!r}")
  211. return f"{self.__class__.__name__}({', '.join(args)})"
  212. class DagScheduleDatasetReference(Base):
  213. """References from a DAG to a dataset of which it is a consumer."""
  214. dataset_id = Column(Integer, primary_key=True, nullable=False)
  215. dag_id = Column(StringID(), primary_key=True, nullable=False)
  216. created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
  217. updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False)
  218. dataset = relationship("DatasetModel", back_populates="consuming_dags")
  219. dag = relationship("DagModel", back_populates="schedule_dataset_references")
  220. queue_records = relationship(
  221. "DatasetDagRunQueue",
  222. primaryjoin="""and_(
  223. DagScheduleDatasetReference.dataset_id == foreign(DatasetDagRunQueue.dataset_id),
  224. DagScheduleDatasetReference.dag_id == foreign(DatasetDagRunQueue.target_dag_id),
  225. )""",
  226. cascade="all, delete, delete-orphan",
  227. )
  228. __tablename__ = "dag_schedule_dataset_reference"
  229. __table_args__ = (
  230. PrimaryKeyConstraint(dataset_id, dag_id, name="dsdr_pkey"),
  231. ForeignKeyConstraint(
  232. (dataset_id,),
  233. ["dataset.id"],
  234. name="dsdr_dataset_fkey",
  235. ondelete="CASCADE",
  236. ),
  237. ForeignKeyConstraint(
  238. columns=(dag_id,),
  239. refcolumns=["dag.dag_id"],
  240. name="dsdr_dag_id_fkey",
  241. ondelete="CASCADE",
  242. ),
  243. Index("idx_dag_schedule_dataset_reference_dag_id", dag_id),
  244. )
  245. def __eq__(self, other):
  246. if isinstance(other, self.__class__):
  247. return self.dataset_id == other.dataset_id and self.dag_id == other.dag_id
  248. else:
  249. return NotImplemented
  250. def __hash__(self):
  251. return hash(self.__mapper__.primary_key)
  252. def __repr__(self):
  253. args = []
  254. for attr in [x.name for x in self.__mapper__.primary_key]:
  255. args.append(f"{attr}={getattr(self, attr)!r}")
  256. return f"{self.__class__.__name__}({', '.join(args)})"
  257. class TaskOutletDatasetReference(Base):
  258. """References from a task to a dataset that it updates / produces."""
  259. dataset_id = Column(Integer, primary_key=True, nullable=False)
  260. dag_id = Column(StringID(), primary_key=True, nullable=False)
  261. task_id = Column(StringID(), primary_key=True, nullable=False)
  262. created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
  263. updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False)
  264. dataset = relationship("DatasetModel", back_populates="producing_tasks")
  265. __tablename__ = "task_outlet_dataset_reference"
  266. __table_args__ = (
  267. ForeignKeyConstraint(
  268. (dataset_id,),
  269. ["dataset.id"],
  270. name="todr_dataset_fkey",
  271. ondelete="CASCADE",
  272. ),
  273. PrimaryKeyConstraint(dataset_id, dag_id, task_id, name="todr_pkey"),
  274. ForeignKeyConstraint(
  275. columns=(dag_id,),
  276. refcolumns=["dag.dag_id"],
  277. name="todr_dag_id_fkey",
  278. ondelete="CASCADE",
  279. ),
  280. Index("idx_task_outlet_dataset_reference_dag_id", dag_id),
  281. )
  282. def __eq__(self, other):
  283. if isinstance(other, self.__class__):
  284. return (
  285. self.dataset_id == other.dataset_id
  286. and self.dag_id == other.dag_id
  287. and self.task_id == other.task_id
  288. )
  289. else:
  290. return NotImplemented
  291. def __hash__(self):
  292. return hash(self.__mapper__.primary_key)
  293. def __repr__(self):
  294. args = []
  295. for attr in [x.name for x in self.__mapper__.primary_key]:
  296. args.append(f"{attr}={getattr(self, attr)!r}")
  297. return f"{self.__class__.__name__}({', '.join(args)})"
  298. class DatasetDagRunQueue(Base):
  299. """Model for storing dataset events that need processing."""
  300. dataset_id = Column(Integer, primary_key=True, nullable=False)
  301. target_dag_id = Column(StringID(), primary_key=True, nullable=False)
  302. created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
  303. dataset = relationship("DatasetModel", viewonly=True)
  304. __tablename__ = "dataset_dag_run_queue"
  305. __table_args__ = (
  306. PrimaryKeyConstraint(dataset_id, target_dag_id, name="datasetdagrunqueue_pkey"),
  307. ForeignKeyConstraint(
  308. (dataset_id,),
  309. ["dataset.id"],
  310. name="ddrq_dataset_fkey",
  311. ondelete="CASCADE",
  312. ),
  313. ForeignKeyConstraint(
  314. (target_dag_id,),
  315. ["dag.dag_id"],
  316. name="ddrq_dag_fkey",
  317. ondelete="CASCADE",
  318. ),
  319. Index("idx_dataset_dag_run_queue_target_dag_id", target_dag_id),
  320. )
  321. def __eq__(self, other):
  322. if isinstance(other, self.__class__):
  323. return self.dataset_id == other.dataset_id and self.target_dag_id == other.target_dag_id
  324. else:
  325. return NotImplemented
  326. def __hash__(self):
  327. return hash(self.__mapper__.primary_key)
  328. def __repr__(self):
  329. args = []
  330. for attr in [x.name for x in self.__mapper__.primary_key]:
  331. args.append(f"{attr}={getattr(self, attr)!r}")
  332. return f"{self.__class__.__name__}({', '.join(args)})"
  333. association_table = Table(
  334. "dagrun_dataset_event",
  335. Base.metadata,
  336. Column("dag_run_id", ForeignKey("dag_run.id", ondelete="CASCADE"), primary_key=True),
  337. Column("event_id", ForeignKey("dataset_event.id", ondelete="CASCADE"), primary_key=True),
  338. Index("idx_dagrun_dataset_events_dag_run_id", "dag_run_id"),
  339. Index("idx_dagrun_dataset_events_event_id", "event_id"),
  340. )
  341. class DatasetEvent(Base):
  342. """
  343. A table to store datasets events.
  344. :param dataset_id: reference to DatasetModel record
  345. :param extra: JSON field for arbitrary extra info
  346. :param source_task_id: the task_id of the TI which updated the dataset
  347. :param source_dag_id: the dag_id of the TI which updated the dataset
  348. :param source_run_id: the run_id of the TI which updated the dataset
  349. :param source_map_index: the map_index of the TI which updated the dataset
  350. :param timestamp: the time the event was logged
  351. We use relationships instead of foreign keys so that dataset events are not deleted even
  352. if the foreign key object is.
  353. """
  354. id = Column(Integer, primary_key=True, autoincrement=True)
  355. dataset_id = Column(Integer, nullable=False)
  356. extra = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=False, default={})
  357. source_task_id = Column(StringID(), nullable=True)
  358. source_dag_id = Column(StringID(), nullable=True)
  359. source_run_id = Column(StringID(), nullable=True)
  360. source_map_index = Column(Integer, nullable=True, server_default=text("-1"))
  361. timestamp = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
  362. __tablename__ = "dataset_event"
  363. __table_args__ = (
  364. Index("idx_dataset_id_timestamp", dataset_id, timestamp),
  365. {"sqlite_autoincrement": True}, # ensures PK values not reused
  366. )
  367. created_dagruns = relationship(
  368. "DagRun",
  369. secondary=association_table,
  370. backref="consumed_dataset_events",
  371. )
  372. source_aliases = relationship(
  373. "DatasetAliasModel",
  374. secondary=dataset_alias_dataset_event_assocation_table,
  375. back_populates="dataset_events",
  376. )
  377. source_task_instance = relationship(
  378. "TaskInstance",
  379. primaryjoin="""and_(
  380. DatasetEvent.source_dag_id == foreign(TaskInstance.dag_id),
  381. DatasetEvent.source_run_id == foreign(TaskInstance.run_id),
  382. DatasetEvent.source_task_id == foreign(TaskInstance.task_id),
  383. DatasetEvent.source_map_index == foreign(TaskInstance.map_index),
  384. )""",
  385. viewonly=True,
  386. lazy="select",
  387. uselist=False,
  388. )
  389. source_dag_run = relationship(
  390. "DagRun",
  391. primaryjoin="""and_(
  392. DatasetEvent.source_dag_id == foreign(DagRun.dag_id),
  393. DatasetEvent.source_run_id == foreign(DagRun.run_id),
  394. )""",
  395. viewonly=True,
  396. lazy="select",
  397. uselist=False,
  398. )
  399. dataset = relationship(
  400. DatasetModel,
  401. primaryjoin="DatasetEvent.dataset_id == foreign(DatasetModel.id)",
  402. viewonly=True,
  403. lazy="select",
  404. uselist=False,
  405. )
  406. @property
  407. def uri(self):
  408. return self.dataset.uri
  409. def __repr__(self) -> str:
  410. args = []
  411. for attr in [
  412. "id",
  413. "dataset_id",
  414. "extra",
  415. "source_task_id",
  416. "source_dag_id",
  417. "source_run_id",
  418. "source_map_index",
  419. "source_aliases",
  420. ]:
  421. args.append(f"{attr}={getattr(self, attr)!r}")
  422. return f"{self.__class__.__name__}({', '.join(args)})"