dagcode.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. # Licensed to the Apache Software Foundation (ASF) under one
  2. # or more contributor license agreements. See the NOTICE file
  3. # distributed with this work for additional information
  4. # regarding copyright ownership. The ASF licenses this file
  5. # to you under the Apache License, Version 2.0 (the
  6. # "License"); you may not use this file except in compliance
  7. # with the License. You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing,
  12. # software distributed under the License is distributed on an
  13. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  14. # KIND, either express or implied. See the License for the
  15. # specific language governing permissions and limitations
  16. # under the License.
  17. from __future__ import annotations
  18. import logging
  19. import os
  20. import struct
  21. from datetime import datetime
  22. from typing import TYPE_CHECKING, Collection, Iterable
  23. from sqlalchemy import BigInteger, Column, String, Text, delete, select
  24. from sqlalchemy.dialects.mysql import MEDIUMTEXT
  25. from sqlalchemy.sql.expression import literal
  26. from airflow.api_internal.internal_api_call import internal_api_call
  27. from airflow.exceptions import AirflowException, DagCodeNotFound
  28. from airflow.models.base import Base
  29. from airflow.utils import timezone
  30. from airflow.utils.file import correct_maybe_zipped, open_maybe_zipped
  31. from airflow.utils.session import NEW_SESSION, provide_session
  32. from airflow.utils.sqlalchemy import UtcDateTime
  33. if TYPE_CHECKING:
  34. from sqlalchemy.orm import Session
  35. log = logging.getLogger(__name__)
  36. class DagCode(Base):
  37. """
  38. A table for DAGs code.
  39. dag_code table contains code of DAG files synchronized by scheduler.
  40. For details on dag serialization see SerializedDagModel
  41. """
  42. __tablename__ = "dag_code"
  43. fileloc_hash = Column(BigInteger, nullable=False, primary_key=True, autoincrement=False)
  44. fileloc = Column(String(2000), nullable=False)
  45. # The max length of fileloc exceeds the limit of indexing.
  46. last_updated = Column(UtcDateTime, nullable=False)
  47. source_code = Column(Text().with_variant(MEDIUMTEXT(), "mysql"), nullable=False)
  48. def __init__(self, full_filepath: str, source_code: str | None = None):
  49. self.fileloc = full_filepath
  50. self.fileloc_hash = DagCode.dag_fileloc_hash(self.fileloc)
  51. self.last_updated = timezone.utcnow()
  52. self.source_code = source_code or DagCode.code(self.fileloc)
  53. @provide_session
  54. def sync_to_db(self, session: Session = NEW_SESSION) -> None:
  55. """
  56. Write code into database.
  57. :param session: ORM Session
  58. """
  59. self.bulk_sync_to_db([self.fileloc], session)
  60. @classmethod
  61. @provide_session
  62. def bulk_sync_to_db(cls, filelocs: Iterable[str], session: Session = NEW_SESSION) -> None:
  63. """
  64. Write code in bulk into database.
  65. :param filelocs: file paths of DAGs to sync
  66. :param session: ORM Session
  67. """
  68. filelocs = set(filelocs)
  69. filelocs_to_hashes = {fileloc: DagCode.dag_fileloc_hash(fileloc) for fileloc in filelocs}
  70. existing_orm_dag_codes = session.scalars(
  71. select(DagCode)
  72. .filter(DagCode.fileloc_hash.in_(filelocs_to_hashes.values()))
  73. .with_for_update(of=DagCode)
  74. ).all()
  75. if existing_orm_dag_codes:
  76. existing_orm_dag_codes_map = {
  77. orm_dag_code.fileloc: orm_dag_code for orm_dag_code in existing_orm_dag_codes
  78. }
  79. else:
  80. existing_orm_dag_codes_map = {}
  81. existing_orm_dag_codes_by_fileloc_hashes = {orm.fileloc_hash: orm for orm in existing_orm_dag_codes}
  82. existing_orm_filelocs = {orm.fileloc for orm in existing_orm_dag_codes_by_fileloc_hashes.values()}
  83. if not existing_orm_filelocs.issubset(filelocs):
  84. conflicting_filelocs = existing_orm_filelocs.difference(filelocs)
  85. hashes_to_filelocs = {DagCode.dag_fileloc_hash(fileloc): fileloc for fileloc in filelocs}
  86. message = ""
  87. for fileloc in conflicting_filelocs:
  88. filename = hashes_to_filelocs[DagCode.dag_fileloc_hash(fileloc)]
  89. message += (
  90. f"Filename '{filename}' causes a hash collision in the "
  91. f"database with '{fileloc}'. Please rename the file."
  92. )
  93. raise AirflowException(message)
  94. existing_filelocs = {dag_code.fileloc for dag_code in existing_orm_dag_codes}
  95. missing_filelocs = filelocs.difference(existing_filelocs)
  96. for fileloc in missing_filelocs:
  97. orm_dag_code = DagCode(fileloc, cls._get_code_from_file(fileloc))
  98. session.add(orm_dag_code)
  99. for fileloc in existing_filelocs:
  100. current_version = existing_orm_dag_codes_by_fileloc_hashes[filelocs_to_hashes[fileloc]]
  101. file_mod_time = datetime.fromtimestamp(
  102. os.path.getmtime(correct_maybe_zipped(fileloc)), tz=timezone.utc
  103. )
  104. if file_mod_time > current_version.last_updated:
  105. orm_dag_code = existing_orm_dag_codes_map[fileloc]
  106. orm_dag_code.last_updated = file_mod_time
  107. orm_dag_code.source_code = cls._get_code_from_file(orm_dag_code.fileloc)
  108. session.merge(orm_dag_code)
  109. @classmethod
  110. @internal_api_call
  111. @provide_session
  112. def remove_deleted_code(
  113. cls,
  114. alive_dag_filelocs: Collection[str],
  115. processor_subdir: str,
  116. session: Session = NEW_SESSION,
  117. ) -> None:
  118. """
  119. Delete code not included in alive_dag_filelocs.
  120. :param alive_dag_filelocs: file paths of alive DAGs
  121. :param processor_subdir: dag processor subdir
  122. :param session: ORM Session
  123. """
  124. alive_fileloc_hashes = [cls.dag_fileloc_hash(fileloc) for fileloc in alive_dag_filelocs]
  125. log.debug("Deleting code from %s table ", cls.__tablename__)
  126. session.execute(
  127. delete(cls)
  128. .where(
  129. cls.fileloc_hash.notin_(alive_fileloc_hashes),
  130. cls.fileloc.notin_(alive_dag_filelocs),
  131. cls.fileloc.contains(processor_subdir),
  132. )
  133. .execution_options(synchronize_session="fetch")
  134. )
  135. @classmethod
  136. @provide_session
  137. def has_dag(cls, fileloc: str, session: Session = NEW_SESSION) -> bool:
  138. """
  139. Check a file exist in dag_code table.
  140. :param fileloc: the file to check
  141. :param session: ORM Session
  142. """
  143. fileloc_hash = cls.dag_fileloc_hash(fileloc)
  144. return (
  145. session.scalars(select(literal(True)).where(cls.fileloc_hash == fileloc_hash)).one_or_none()
  146. is not None
  147. )
  148. @classmethod
  149. def get_code_by_fileloc(cls, fileloc: str) -> str:
  150. """
  151. Return source code for a given fileloc.
  152. :param fileloc: file path of a DAG
  153. :return: source code as string
  154. """
  155. return cls.code(fileloc)
  156. @classmethod
  157. @provide_session
  158. def code(cls, fileloc, session: Session = NEW_SESSION) -> str:
  159. """
  160. Return source code for this DagCode object.
  161. :return: source code as string
  162. """
  163. return cls._get_code_from_db(fileloc, session)
  164. @staticmethod
  165. def _get_code_from_file(fileloc):
  166. with open_maybe_zipped(fileloc, "r") as f:
  167. code = f.read()
  168. return code
  169. @classmethod
  170. @provide_session
  171. def _get_code_from_db(cls, fileloc, session: Session = NEW_SESSION) -> str:
  172. dag_code = session.scalar(select(cls).where(cls.fileloc_hash == cls.dag_fileloc_hash(fileloc)))
  173. if not dag_code:
  174. raise DagCodeNotFound()
  175. else:
  176. code = dag_code.source_code
  177. return code
  178. @staticmethod
  179. def dag_fileloc_hash(full_filepath: str) -> int:
  180. """
  181. Hashing file location for indexing.
  182. :param full_filepath: full filepath of DAG file
  183. :return: hashed full_filepath
  184. """
  185. # Hashing is needed because the length of fileloc is 2000 as an Airflow convention,
  186. # which is over the limit of indexing.
  187. import hashlib
  188. # Only 7 bytes because MySQL BigInteger can hold only 8 bytes (signed).
  189. return struct.unpack(">Q", hashlib.sha1(full_filepath.encode("utf-8")).digest()[-8:])[0] >> 8