123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226 |
- # Licensed to the Apache Software Foundation (ASF) under one
- # or more contributor license agreements. See the NOTICE file
- # distributed with this work for additional information
- # regarding copyright ownership. The ASF licenses this file
- # to you under the Apache License, Version 2.0 (the
- # "License"); you may not use this file except in compliance
- # with the License. You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- # KIND, either express or implied. See the License for the
- # specific language governing permissions and limitations
- # under the License.
- from __future__ import annotations
- import logging
- import os
- import struct
- from datetime import datetime
- from typing import TYPE_CHECKING, Collection, Iterable
- from sqlalchemy import BigInteger, Column, String, Text, delete, select
- from sqlalchemy.dialects.mysql import MEDIUMTEXT
- from sqlalchemy.sql.expression import literal
- from airflow.api_internal.internal_api_call import internal_api_call
- from airflow.exceptions import AirflowException, DagCodeNotFound
- from airflow.models.base import Base
- from airflow.utils import timezone
- from airflow.utils.file import correct_maybe_zipped, open_maybe_zipped
- from airflow.utils.session import NEW_SESSION, provide_session
- from airflow.utils.sqlalchemy import UtcDateTime
- if TYPE_CHECKING:
- from sqlalchemy.orm import Session
- log = logging.getLogger(__name__)
- class DagCode(Base):
- """
- A table for DAGs code.
- dag_code table contains code of DAG files synchronized by scheduler.
- For details on dag serialization see SerializedDagModel
- """
- __tablename__ = "dag_code"
- fileloc_hash = Column(BigInteger, nullable=False, primary_key=True, autoincrement=False)
- fileloc = Column(String(2000), nullable=False)
- # The max length of fileloc exceeds the limit of indexing.
- last_updated = Column(UtcDateTime, nullable=False)
- source_code = Column(Text().with_variant(MEDIUMTEXT(), "mysql"), nullable=False)
- def __init__(self, full_filepath: str, source_code: str | None = None):
- self.fileloc = full_filepath
- self.fileloc_hash = DagCode.dag_fileloc_hash(self.fileloc)
- self.last_updated = timezone.utcnow()
- self.source_code = source_code or DagCode.code(self.fileloc)
- @provide_session
- def sync_to_db(self, session: Session = NEW_SESSION) -> None:
- """
- Write code into database.
- :param session: ORM Session
- """
- self.bulk_sync_to_db([self.fileloc], session)
- @classmethod
- @provide_session
- def bulk_sync_to_db(cls, filelocs: Iterable[str], session: Session = NEW_SESSION) -> None:
- """
- Write code in bulk into database.
- :param filelocs: file paths of DAGs to sync
- :param session: ORM Session
- """
- filelocs = set(filelocs)
- filelocs_to_hashes = {fileloc: DagCode.dag_fileloc_hash(fileloc) for fileloc in filelocs}
- existing_orm_dag_codes = session.scalars(
- select(DagCode)
- .filter(DagCode.fileloc_hash.in_(filelocs_to_hashes.values()))
- .with_for_update(of=DagCode)
- ).all()
- if existing_orm_dag_codes:
- existing_orm_dag_codes_map = {
- orm_dag_code.fileloc: orm_dag_code for orm_dag_code in existing_orm_dag_codes
- }
- else:
- existing_orm_dag_codes_map = {}
- existing_orm_dag_codes_by_fileloc_hashes = {orm.fileloc_hash: orm for orm in existing_orm_dag_codes}
- existing_orm_filelocs = {orm.fileloc for orm in existing_orm_dag_codes_by_fileloc_hashes.values()}
- if not existing_orm_filelocs.issubset(filelocs):
- conflicting_filelocs = existing_orm_filelocs.difference(filelocs)
- hashes_to_filelocs = {DagCode.dag_fileloc_hash(fileloc): fileloc for fileloc in filelocs}
- message = ""
- for fileloc in conflicting_filelocs:
- filename = hashes_to_filelocs[DagCode.dag_fileloc_hash(fileloc)]
- message += (
- f"Filename '{filename}' causes a hash collision in the "
- f"database with '{fileloc}'. Please rename the file."
- )
- raise AirflowException(message)
- existing_filelocs = {dag_code.fileloc for dag_code in existing_orm_dag_codes}
- missing_filelocs = filelocs.difference(existing_filelocs)
- for fileloc in missing_filelocs:
- orm_dag_code = DagCode(fileloc, cls._get_code_from_file(fileloc))
- session.add(orm_dag_code)
- for fileloc in existing_filelocs:
- current_version = existing_orm_dag_codes_by_fileloc_hashes[filelocs_to_hashes[fileloc]]
- file_mod_time = datetime.fromtimestamp(
- os.path.getmtime(correct_maybe_zipped(fileloc)), tz=timezone.utc
- )
- if file_mod_time > current_version.last_updated:
- orm_dag_code = existing_orm_dag_codes_map[fileloc]
- orm_dag_code.last_updated = file_mod_time
- orm_dag_code.source_code = cls._get_code_from_file(orm_dag_code.fileloc)
- session.merge(orm_dag_code)
- @classmethod
- @internal_api_call
- @provide_session
- def remove_deleted_code(
- cls,
- alive_dag_filelocs: Collection[str],
- processor_subdir: str,
- session: Session = NEW_SESSION,
- ) -> None:
- """
- Delete code not included in alive_dag_filelocs.
- :param alive_dag_filelocs: file paths of alive DAGs
- :param processor_subdir: dag processor subdir
- :param session: ORM Session
- """
- alive_fileloc_hashes = [cls.dag_fileloc_hash(fileloc) for fileloc in alive_dag_filelocs]
- log.debug("Deleting code from %s table ", cls.__tablename__)
- session.execute(
- delete(cls)
- .where(
- cls.fileloc_hash.notin_(alive_fileloc_hashes),
- cls.fileloc.notin_(alive_dag_filelocs),
- cls.fileloc.contains(processor_subdir),
- )
- .execution_options(synchronize_session="fetch")
- )
- @classmethod
- @provide_session
- def has_dag(cls, fileloc: str, session: Session = NEW_SESSION) -> bool:
- """
- Check a file exist in dag_code table.
- :param fileloc: the file to check
- :param session: ORM Session
- """
- fileloc_hash = cls.dag_fileloc_hash(fileloc)
- return (
- session.scalars(select(literal(True)).where(cls.fileloc_hash == fileloc_hash)).one_or_none()
- is not None
- )
- @classmethod
- def get_code_by_fileloc(cls, fileloc: str) -> str:
- """
- Return source code for a given fileloc.
- :param fileloc: file path of a DAG
- :return: source code as string
- """
- return cls.code(fileloc)
- @classmethod
- @provide_session
- def code(cls, fileloc, session: Session = NEW_SESSION) -> str:
- """
- Return source code for this DagCode object.
- :return: source code as string
- """
- return cls._get_code_from_db(fileloc, session)
- @staticmethod
- def _get_code_from_file(fileloc):
- with open_maybe_zipped(fileloc, "r") as f:
- code = f.read()
- return code
- @classmethod
- @provide_session
- def _get_code_from_db(cls, fileloc, session: Session = NEW_SESSION) -> str:
- dag_code = session.scalar(select(cls).where(cls.fileloc_hash == cls.dag_fileloc_hash(fileloc)))
- if not dag_code:
- raise DagCodeNotFound()
- else:
- code = dag_code.source_code
- return code
- @staticmethod
- def dag_fileloc_hash(full_filepath: str) -> int:
- """
- Hashing file location for indexing.
- :param full_filepath: full filepath of DAG file
- :return: hashed full_filepath
- """
- # Hashing is needed because the length of fileloc is 2000 as an Airflow convention,
- # which is over the limit of indexing.
- import hashlib
- # Only 7 bytes because MySQL BigInteger can hold only 8 bytes (signed).
- return struct.unpack(">Q", hashlib.sha1(full_filepath.encode("utf-8")).digest()[-8:])[0] >> 8
|