serialized_dag.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432
  1. #
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing,
  13. # software distributed under the License is distributed on an
  14. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. # KIND, either express or implied. See the License for the
  16. # specific language governing permissions and limitations
  17. # under the License.
  18. """Serialized DAG table in database."""
  19. from __future__ import annotations
  20. import logging
  21. import zlib
  22. from datetime import timedelta
  23. from typing import TYPE_CHECKING, Collection
  24. import sqlalchemy_jsonfield
  25. from sqlalchemy import BigInteger, Column, Index, LargeBinary, String, and_, exc, or_, select
  26. from sqlalchemy.orm import backref, foreign, relationship
  27. from sqlalchemy.sql.expression import func, literal
  28. from airflow.api_internal.internal_api_call import internal_api_call
  29. from airflow.exceptions import TaskNotFound
  30. from airflow.models.base import ID_LEN, Base
  31. from airflow.models.dag import DagModel
  32. from airflow.models.dagcode import DagCode
  33. from airflow.models.dagrun import DagRun
  34. from airflow.serialization.dag_dependency import DagDependency
  35. from airflow.serialization.serialized_objects import SerializedDAG
  36. from airflow.settings import COMPRESS_SERIALIZED_DAGS, MIN_SERIALIZED_DAG_UPDATE_INTERVAL, json
  37. from airflow.utils import timezone
  38. from airflow.utils.hashlib_wrapper import md5
  39. from airflow.utils.session import NEW_SESSION, provide_session
  40. from airflow.utils.sqlalchemy import UtcDateTime
  41. if TYPE_CHECKING:
  42. from datetime import datetime
  43. from sqlalchemy.orm import Session
  44. from airflow.models import Operator
  45. from airflow.models.dag import DAG
  46. log = logging.getLogger(__name__)
  47. class SerializedDagModel(Base):
  48. """
  49. A table for serialized DAGs.
  50. serialized_dag table is a snapshot of DAG files synchronized by scheduler.
  51. This feature is controlled by:
  52. * ``[core] min_serialized_dag_update_interval = 30`` (s):
  53. serialized DAGs are updated in DB when a file gets processed by scheduler,
  54. to reduce DB write rate, there is a minimal interval of updating serialized DAGs.
  55. * ``[scheduler] dag_dir_list_interval = 300`` (s):
  56. interval of deleting serialized DAGs in DB when the files are deleted, suggest
  57. to use a smaller interval such as 60
  58. * ``[core] compress_serialized_dags``:
  59. whether compressing the dag data to the Database.
  60. It is used by webserver to load dags
  61. because reading from database is lightweight compared to importing from files,
  62. it solves the webserver scalability issue.
  63. """
  64. __tablename__ = "serialized_dag"
  65. dag_id = Column(String(ID_LEN), primary_key=True)
  66. fileloc = Column(String(2000), nullable=False)
  67. # The max length of fileloc exceeds the limit of indexing.
  68. fileloc_hash = Column(BigInteger(), nullable=False)
  69. _data = Column("data", sqlalchemy_jsonfield.JSONField(json=json), nullable=True)
  70. _data_compressed = Column("data_compressed", LargeBinary, nullable=True)
  71. last_updated = Column(UtcDateTime, nullable=False)
  72. dag_hash = Column(String(32), nullable=False)
  73. processor_subdir = Column(String(2000), nullable=True)
  74. __table_args__ = (Index("idx_fileloc_hash", fileloc_hash, unique=False),)
  75. dag_runs = relationship(
  76. DagRun,
  77. primaryjoin=dag_id == foreign(DagRun.dag_id), # type: ignore
  78. backref=backref("serialized_dag", uselist=False, innerjoin=True),
  79. )
  80. dag_model = relationship(
  81. DagModel,
  82. primaryjoin=dag_id == DagModel.dag_id, # type: ignore
  83. foreign_keys=dag_id,
  84. uselist=False,
  85. innerjoin=True,
  86. backref=backref("serialized_dag", uselist=False, innerjoin=True),
  87. )
  88. load_op_links = True
  89. def __init__(self, dag: DAG, processor_subdir: str | None = None) -> None:
  90. self.dag_id = dag.dag_id
  91. self.fileloc = dag.fileloc
  92. self.fileloc_hash = DagCode.dag_fileloc_hash(self.fileloc)
  93. self.last_updated = timezone.utcnow()
  94. self.processor_subdir = processor_subdir
  95. dag_data = SerializedDAG.to_dict(dag)
  96. dag_data_json = json.dumps(dag_data, sort_keys=True).encode("utf-8")
  97. self.dag_hash = md5(dag_data_json).hexdigest()
  98. if COMPRESS_SERIALIZED_DAGS:
  99. self._data = None
  100. self._data_compressed = zlib.compress(dag_data_json)
  101. else:
  102. self._data = dag_data
  103. self._data_compressed = None
  104. # serve as cache so no need to decompress and load, when accessing data field
  105. # when COMPRESS_SERIALIZED_DAGS is True
  106. self.__data_cache = dag_data
  107. def __repr__(self) -> str:
  108. return f"<SerializedDag: {self.dag_id}>"
  109. @classmethod
  110. @provide_session
  111. def write_dag(
  112. cls,
  113. dag: DAG,
  114. min_update_interval: int | None = None,
  115. processor_subdir: str | None = None,
  116. session: Session = NEW_SESSION,
  117. ) -> bool:
  118. """
  119. Serialize a DAG and writes it into database.
  120. If the record already exists, it checks if the Serialized DAG changed or not. If it is
  121. changed, it updates the record, ignores otherwise.
  122. :param dag: a DAG to be written into database
  123. :param min_update_interval: minimal interval in seconds to update serialized DAG
  124. :param session: ORM Session
  125. :returns: Boolean indicating if the DAG was written to the DB
  126. """
  127. # Checks if (Current Time - Time when the DAG was written to DB) < min_update_interval
  128. # If Yes, does nothing
  129. # If No or the DAG does not exists, updates / writes Serialized DAG to DB
  130. if min_update_interval is not None:
  131. if session.scalar(
  132. select(literal(True)).where(
  133. cls.dag_id == dag.dag_id,
  134. (timezone.utcnow() - timedelta(seconds=min_update_interval)) < cls.last_updated,
  135. )
  136. ):
  137. return False
  138. log.debug("Checking if DAG (%s) changed", dag.dag_id)
  139. new_serialized_dag = cls(dag, processor_subdir)
  140. serialized_dag_db = session.execute(
  141. select(cls.dag_hash, cls.processor_subdir).where(cls.dag_id == dag.dag_id)
  142. ).first()
  143. if (
  144. serialized_dag_db is not None
  145. and serialized_dag_db.dag_hash == new_serialized_dag.dag_hash
  146. and serialized_dag_db.processor_subdir == new_serialized_dag.processor_subdir
  147. ):
  148. log.debug("Serialized DAG (%s) is unchanged. Skipping writing to DB", dag.dag_id)
  149. return False
  150. log.debug("Writing Serialized DAG: %s to the DB", dag.dag_id)
  151. session.merge(new_serialized_dag)
  152. log.debug("DAG: %s written to the DB", dag.dag_id)
  153. return True
  154. @classmethod
  155. @provide_session
  156. def read_all_dags(cls, session: Session = NEW_SESSION) -> dict[str, SerializedDAG]:
  157. """
  158. Read all DAGs in serialized_dag table.
  159. :param session: ORM Session
  160. :returns: a dict of DAGs read from database
  161. """
  162. serialized_dags = session.scalars(select(cls))
  163. dags = {}
  164. for row in serialized_dags:
  165. log.debug("Deserializing DAG: %s", row.dag_id)
  166. dag = row.dag
  167. # Coherence check
  168. if dag.dag_id == row.dag_id:
  169. dags[row.dag_id] = dag
  170. else:
  171. log.warning(
  172. "dag_id Mismatch in DB: Row with dag_id '%s' has Serialised DAG with '%s' dag_id",
  173. row.dag_id,
  174. dag.dag_id,
  175. )
  176. return dags
  177. @property
  178. def data(self) -> dict | None:
  179. # use __data_cache to avoid decompress and loads
  180. if not hasattr(self, "__data_cache") or self.__data_cache is None:
  181. if self._data_compressed:
  182. self.__data_cache = json.loads(zlib.decompress(self._data_compressed))
  183. else:
  184. self.__data_cache = self._data
  185. return self.__data_cache
  186. @property
  187. def dag(self) -> SerializedDAG:
  188. """The DAG deserialized from the ``data`` column."""
  189. SerializedDAG._load_operator_extra_links = self.load_op_links
  190. if isinstance(self.data, dict):
  191. data = self.data
  192. elif isinstance(self.data, str):
  193. data = json.loads(self.data)
  194. else:
  195. raise ValueError("invalid or missing serialized DAG data")
  196. return SerializedDAG.from_dict(data)
  197. @classmethod
  198. @provide_session
  199. def remove_dag(cls, dag_id: str, session: Session = NEW_SESSION) -> None:
  200. """
  201. Delete a DAG with given dag_id.
  202. :param dag_id: dag_id to be deleted
  203. :param session: ORM Session.
  204. """
  205. session.execute(cls.__table__.delete().where(cls.dag_id == dag_id))
  206. @classmethod
  207. @internal_api_call
  208. @provide_session
  209. def remove_deleted_dags(
  210. cls,
  211. alive_dag_filelocs: Collection[str],
  212. processor_subdir: str | None = None,
  213. session: Session = NEW_SESSION,
  214. ) -> None:
  215. """
  216. Delete DAGs not included in alive_dag_filelocs.
  217. :param alive_dag_filelocs: file paths of alive DAGs
  218. :param processor_subdir: dag processor subdir
  219. :param session: ORM Session
  220. """
  221. alive_fileloc_hashes = [DagCode.dag_fileloc_hash(fileloc) for fileloc in alive_dag_filelocs]
  222. log.debug(
  223. "Deleting Serialized DAGs (for which DAG files are deleted) from %s table ", cls.__tablename__
  224. )
  225. session.execute(
  226. cls.__table__.delete().where(
  227. and_(
  228. cls.fileloc_hash.notin_(alive_fileloc_hashes),
  229. cls.fileloc.notin_(alive_dag_filelocs),
  230. or_(
  231. cls.processor_subdir.is_(None),
  232. cls.processor_subdir == processor_subdir,
  233. ),
  234. )
  235. )
  236. )
  237. @classmethod
  238. @provide_session
  239. def has_dag(cls, dag_id: str, session: Session = NEW_SESSION) -> bool:
  240. """
  241. Check a DAG exist in serialized_dag table.
  242. :param dag_id: the DAG to check
  243. :param session: ORM Session
  244. """
  245. return session.scalar(select(literal(True)).where(cls.dag_id == dag_id).limit(1)) is not None
  246. @classmethod
  247. @provide_session
  248. def get_dag(cls, dag_id: str, session: Session = NEW_SESSION) -> SerializedDAG | None:
  249. row = cls.get(dag_id, session=session)
  250. if row:
  251. return row.dag
  252. return None
  253. @classmethod
  254. @provide_session
  255. def get(cls, dag_id: str, session: Session = NEW_SESSION) -> SerializedDagModel | None:
  256. """
  257. Get the SerializedDAG for the given dag ID.
  258. It will cope with being passed the ID of a subdag by looking up the root dag_id from the DAG table.
  259. :param dag_id: the DAG to fetch
  260. :param session: ORM Session
  261. """
  262. row = session.scalar(select(cls).where(cls.dag_id == dag_id))
  263. if row:
  264. return row
  265. # If we didn't find a matching DAG id then ask the DAG table to find
  266. # out the root dag
  267. root_dag_id = session.scalar(select(DagModel.root_dag_id).where(DagModel.dag_id == dag_id))
  268. return session.scalar(select(cls).where(cls.dag_id == root_dag_id))
  269. @staticmethod
  270. @provide_session
  271. def bulk_sync_to_db(
  272. dags: list[DAG],
  273. processor_subdir: str | None = None,
  274. session: Session = NEW_SESSION,
  275. ) -> None:
  276. """
  277. Save DAGs as Serialized DAG objects in the database.
  278. Each DAG is saved in a separate database query.
  279. :param dags: the DAG objects to save to the DB
  280. :param session: ORM Session
  281. :return: None
  282. """
  283. for dag in dags:
  284. if not dag.is_subdag:
  285. SerializedDagModel.write_dag(
  286. dag=dag,
  287. min_update_interval=MIN_SERIALIZED_DAG_UPDATE_INTERVAL,
  288. processor_subdir=processor_subdir,
  289. session=session,
  290. )
  291. @classmethod
  292. @provide_session
  293. def get_last_updated_datetime(cls, dag_id: str, session: Session = NEW_SESSION) -> datetime | None:
  294. """
  295. Get the date when the Serialized DAG associated to DAG was last updated in serialized_dag table.
  296. :param dag_id: DAG ID
  297. :param session: ORM Session
  298. """
  299. return session.scalar(select(cls.last_updated).where(cls.dag_id == dag_id))
  300. @classmethod
  301. @provide_session
  302. def get_max_last_updated_datetime(cls, session: Session = NEW_SESSION) -> datetime | None:
  303. """
  304. Get the maximum date when any DAG was last updated in serialized_dag table.
  305. :param session: ORM Session
  306. """
  307. return session.scalar(select(func.max(cls.last_updated)))
  308. @classmethod
  309. @provide_session
  310. def get_latest_version_hash(cls, dag_id: str, session: Session = NEW_SESSION) -> str | None:
  311. """
  312. Get the latest DAG version for a given DAG ID.
  313. :param dag_id: DAG ID
  314. :param session: ORM Session
  315. :return: DAG Hash, or None if the DAG is not found
  316. """
  317. return session.scalar(select(cls.dag_hash).where(cls.dag_id == dag_id))
  318. @classmethod
  319. def get_latest_version_hash_and_updated_datetime(
  320. cls,
  321. dag_id: str,
  322. *,
  323. session: Session,
  324. ) -> tuple[str, datetime] | None:
  325. """
  326. Get the latest version for a DAG ID and the date it was last updated in serialized_dag table.
  327. :meta private:
  328. :param dag_id: DAG ID
  329. :param session: ORM Session
  330. :return: A tuple of DAG Hash and last updated datetime, or None if the DAG is not found
  331. """
  332. return session.execute(
  333. select(cls.dag_hash, cls.last_updated).where(cls.dag_id == dag_id)
  334. ).one_or_none()
  335. @classmethod
  336. @provide_session
  337. def get_dag_dependencies(cls, session: Session = NEW_SESSION) -> dict[str, list[DagDependency]]:
  338. """
  339. Get the dependencies between DAGs.
  340. :param session: ORM Session
  341. """
  342. if session.bind.dialect.name in ["sqlite", "mysql"]:
  343. query = session.execute(
  344. select(cls.dag_id, func.json_extract(cls._data, "$.dag.dag_dependencies"))
  345. )
  346. iterator = ((dag_id, json.loads(deps_data) if deps_data else []) for dag_id, deps_data in query)
  347. else:
  348. iterator = session.execute(
  349. select(cls.dag_id, func.json_extract_path(cls._data, "dag", "dag_dependencies"))
  350. )
  351. return {dag_id: [DagDependency(**d) for d in (deps_data or [])] for dag_id, deps_data in iterator}
  352. @staticmethod
  353. @internal_api_call
  354. @provide_session
  355. def get_serialized_dag(dag_id: str, task_id: str, session: Session = NEW_SESSION) -> Operator | None:
  356. from airflow.models.serialized_dag import SerializedDagModel
  357. try:
  358. model = session.get(SerializedDagModel, dag_id)
  359. if model:
  360. return model.dag.get_task(task_id)
  361. except (exc.NoResultFound, TaskNotFound):
  362. return None
  363. return None