123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558 |
- #
- # Licensed to the Apache Software Foundation (ASF) under one
- # or more contributor license agreements. See the NOTICE file
- # distributed with this work for additional information
- # regarding copyright ownership. The ASF licenses this file
- # to you under the Apache License, Version 2.0 (the
- # "License"); you may not use this file except in compliance
- # with the License. You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- # KIND, either express or implied. See the License for the
- # specific language governing permissions and limitations
- # under the License.
- from __future__ import annotations
- import contextlib
- import copy
- import datetime
- import json
- import logging
- from importlib import metadata
- from typing import TYPE_CHECKING, Any, Generator, Iterable, overload
- from dateutil import relativedelta
- from packaging import version
- from sqlalchemy import TIMESTAMP, PickleType, event, nullsfirst, tuple_
- from sqlalchemy.dialects import mysql
- from sqlalchemy.types import JSON, Text, TypeDecorator
- from airflow.configuration import conf
- from airflow.serialization.enums import Encoding
- from airflow.utils.timezone import make_naive, utc
- if TYPE_CHECKING:
- from kubernetes.client.models.v1_pod import V1Pod
- from sqlalchemy.exc import OperationalError
- from sqlalchemy.orm import Query, Session
- from sqlalchemy.sql import ColumnElement, Select
- from sqlalchemy.sql.expression import ColumnOperators
- from sqlalchemy.types import TypeEngine
- log = logging.getLogger(__name__)
- class UtcDateTime(TypeDecorator):
- """
- Similar to :class:`~sqlalchemy.types.TIMESTAMP` with ``timezone=True`` option, with some differences.
- - Never silently take naive :class:`~datetime.datetime`, instead it
- always raise :exc:`ValueError` unless time zone aware value.
- - :class:`~datetime.datetime` value's :attr:`~datetime.datetime.tzinfo`
- is always converted to UTC.
- - Unlike SQLAlchemy's built-in :class:`~sqlalchemy.types.TIMESTAMP`,
- it never return naive :class:`~datetime.datetime`, but time zone
- aware value, even with SQLite or MySQL.
- - Always returns TIMESTAMP in UTC.
- """
- impl = TIMESTAMP(timezone=True)
- cache_ok = True
- def process_bind_param(self, value, dialect):
- if not isinstance(value, datetime.datetime):
- if value is None:
- return None
- raise TypeError(f"expected datetime.datetime, not {value!r}")
- elif value.tzinfo is None:
- raise ValueError("naive datetime is disallowed")
- elif dialect.name == "mysql":
- # For mysql versions prior 8.0.19 we should send timestamps as naive values in UTC
- # see: https://dev.mysql.com/doc/refman/8.0/en/date-and-time-literals.html
- return make_naive(value, timezone=utc)
- return value.astimezone(utc)
- def process_result_value(self, value, dialect):
- """
- Process DateTimes from the DB making sure to always return UTC.
- Not using timezone.convert_to_utc as that converts to configured TIMEZONE
- while the DB might be running with some other setting. We assume UTC
- datetimes in the database.
- """
- if value is not None:
- if value.tzinfo is None:
- value = value.replace(tzinfo=utc)
- else:
- value = value.astimezone(utc)
- return value
- def load_dialect_impl(self, dialect):
- if dialect.name == "mysql":
- return mysql.TIMESTAMP(fsp=6)
- return super().load_dialect_impl(dialect)
- class ExtendedJSON(TypeDecorator):
- """
- A version of the JSON column that uses the Airflow extended JSON serialization.
- See airflow.serialization.
- """
- impl = Text
- cache_ok = True
- should_evaluate_none = True
- def load_dialect_impl(self, dialect) -> TypeEngine:
- return dialect.type_descriptor(JSON)
- def process_bind_param(self, value, dialect):
- from airflow.serialization.serialized_objects import BaseSerialization
- if value is None:
- return None
- return BaseSerialization.serialize(value)
- def process_result_value(self, value, dialect):
- from airflow.serialization.serialized_objects import BaseSerialization
- if value is None:
- return None
- return BaseSerialization.deserialize(value)
- def sanitize_for_serialization(obj: V1Pod):
- """
- Convert pod to dict.... but *safely*.
- When pod objects created with one k8s version are unpickled in a python
- env with a more recent k8s version (in which the object attrs may have
- changed) the unpickled obj may throw an error because the attr
- expected on new obj may not be there on the unpickled obj.
- This function still converts the pod to a dict; the only difference is
- it populates missing attrs with None. You may compare with
- https://github.com/kubernetes-client/python/blob/5a96bbcbe21a552cc1f9cda13e0522fafb0dbac8/kubernetes/client/api_client.py#L202
- If obj is None, return None.
- If obj is str, int, long, float, bool, return directly.
- If obj is datetime.datetime, datetime.date
- convert to string in iso8601 format.
- If obj is list, sanitize each element in the list.
- If obj is dict, return the dict.
- If obj is OpenAPI model, return the properties dict.
- :param obj: The data to serialize.
- :return: The serialized form of data.
- :meta private:
- """
- if obj is None:
- return None
- elif isinstance(obj, (float, bool, bytes, str, int)):
- return obj
- elif isinstance(obj, list):
- return [sanitize_for_serialization(sub_obj) for sub_obj in obj]
- elif isinstance(obj, tuple):
- return tuple(sanitize_for_serialization(sub_obj) for sub_obj in obj)
- elif isinstance(obj, (datetime.datetime, datetime.date)):
- return obj.isoformat()
- if isinstance(obj, dict):
- obj_dict = obj
- else:
- obj_dict = {
- obj.attribute_map[attr]: getattr(obj, attr)
- for attr, _ in obj.openapi_types.items()
- # below is the only line we change, and we just add default=None for getattr
- if getattr(obj, attr, None) is not None
- }
- return {key: sanitize_for_serialization(val) for key, val in obj_dict.items()}
- def ensure_pod_is_valid_after_unpickling(pod: V1Pod) -> V1Pod | None:
- """
- Convert pod to json and back so that pod is safe.
- The pod_override in executor_config is a V1Pod object.
- Such objects created with one k8s version, when unpickled in
- an env with upgraded k8s version, may blow up when
- `to_dict` is called, because openapi client code gen calls
- getattr on all attrs in openapi_types for each object, and when
- new attrs are added to that list, getattr will fail.
- Here we re-serialize it to ensure it is not going to blow up.
- :meta private:
- """
- try:
- # if to_dict works, the pod is fine
- pod.to_dict()
- return pod
- except AttributeError:
- pass
- try:
- from kubernetes.client.models.v1_pod import V1Pod
- except ImportError:
- return None
- if not isinstance(pod, V1Pod):
- return None
- try:
- try:
- from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator
- except ImportError:
- from airflow.kubernetes.pre_7_4_0_compatibility.pod_generator import ( # type: ignore[assignment]
- PodGenerator,
- )
- # now we actually reserialize / deserialize the pod
- pod_dict = sanitize_for_serialization(pod)
- return PodGenerator.deserialize_model_dict(pod_dict)
- except Exception:
- return None
- class ExecutorConfigType(PickleType):
- """
- Adds special handling for K8s executor config.
- If we unpickle a k8s object that was pickled under an earlier k8s library version, then
- the unpickled object may throw an error when to_dict is called. To be more tolerant of
- version changes we convert to JSON using Airflow's serializer before pickling.
- """
- cache_ok = True
- def bind_processor(self, dialect):
- from airflow.serialization.serialized_objects import BaseSerialization
- super_process = super().bind_processor(dialect)
- def process(value):
- val_copy = copy.copy(value)
- if isinstance(val_copy, dict) and "pod_override" in val_copy:
- val_copy["pod_override"] = BaseSerialization.serialize(val_copy["pod_override"])
- return super_process(val_copy)
- return process
- def result_processor(self, dialect, coltype):
- from airflow.serialization.serialized_objects import BaseSerialization
- super_process = super().result_processor(dialect, coltype)
- def process(value):
- value = super_process(value) # unpickle
- if isinstance(value, dict) and "pod_override" in value:
- pod_override = value["pod_override"]
- if isinstance(pod_override, dict) and pod_override.get(Encoding.TYPE):
- # If pod_override was serialized with Airflow's BaseSerialization, deserialize it
- value["pod_override"] = BaseSerialization.deserialize(pod_override)
- else:
- # backcompat path
- # we no longer pickle raw pods but this code may be reached
- # when accessing executor configs created in a prior version
- new_pod = ensure_pod_is_valid_after_unpickling(pod_override)
- if new_pod:
- value["pod_override"] = new_pod
- return value
- return process
- def compare_values(self, x, y):
- """
- Compare x and y using self.comparator if available. Else, use __eq__.
- The TaskInstance.executor_config attribute is a pickled object that may contain kubernetes objects.
- If the installed library version has changed since the object was originally pickled,
- due to the underlying ``__eq__`` method on these objects (which converts them to JSON),
- we may encounter attribute errors. In this case we should replace the stored object.
- From https://github.com/apache/airflow/pull/24356 we use our serializer to store
- k8s objects, but there could still be raw pickled k8s objects in the database,
- stored from earlier version, so we still compare them defensively here.
- """
- if self.comparator:
- return self.comparator(x, y)
- else:
- try:
- return x == y
- except AttributeError:
- return False
- class Interval(TypeDecorator):
- """Base class representing a time interval."""
- impl = Text
- cache_ok = True
- attr_keys = {
- datetime.timedelta: ("days", "seconds", "microseconds"),
- relativedelta.relativedelta: (
- "years",
- "months",
- "days",
- "leapdays",
- "hours",
- "minutes",
- "seconds",
- "microseconds",
- "year",
- "month",
- "day",
- "hour",
- "minute",
- "second",
- "microsecond",
- ),
- }
- def process_bind_param(self, value, dialect):
- if isinstance(value, tuple(self.attr_keys)):
- attrs = {key: getattr(value, key) for key in self.attr_keys[type(value)]}
- return json.dumps({"type": type(value).__name__, "attrs": attrs})
- return json.dumps(value)
- def process_result_value(self, value, dialect):
- if not value:
- return value
- data = json.loads(value)
- if isinstance(data, dict):
- type_map = {key.__name__: key for key in self.attr_keys}
- return type_map[data["type"]](**data["attrs"])
- return data
- def nulls_first(col, session: Session) -> dict[str, Any]:
- """
- Specify *NULLS FIRST* to the column ordering.
- This is only done to Postgres, currently the only backend that supports it.
- Other databases do not need it since NULL values are considered lower than
- any other values, and appear first when the order is ASC (ascending).
- """
- if session.bind.dialect.name == "postgresql":
- return nullsfirst(col)
- else:
- return col
- USE_ROW_LEVEL_LOCKING: bool = conf.getboolean("scheduler", "use_row_level_locking", fallback=True)
- def with_row_locks(
- query: Query,
- session: Session,
- *,
- nowait: bool = False,
- skip_locked: bool = False,
- **kwargs,
- ) -> Query:
- """
- Apply with_for_update to the SQLAlchemy query if row level locking is in use.
- This wrapper is needed so we don't use the syntax on unsupported database
- engines. In particular, MySQL (prior to 8.0) and MariaDB do not support
- row locking, where we do not support nor recommend running HA scheduler. If
- a user ignores this and tries anyway, everything will still work, just
- slightly slower in some circumstances.
- See https://jira.mariadb.org/browse/MDEV-13115
- :param query: An SQLAlchemy Query object
- :param session: ORM Session
- :param nowait: If set to True, will pass NOWAIT to supported database backends.
- :param skip_locked: If set to True, will pass SKIP LOCKED to supported database backends.
- :param kwargs: Extra kwargs to pass to with_for_update (of, nowait, skip_locked, etc)
- :return: updated query
- """
- dialect = session.bind.dialect
- # Don't use row level locks if the MySQL dialect (Mariadb & MySQL < 8) does not support it.
- if not USE_ROW_LEVEL_LOCKING:
- return query
- if dialect.name == "mysql" and not dialect.supports_for_update_of:
- return query
- if nowait:
- kwargs["nowait"] = True
- if skip_locked:
- kwargs["skip_locked"] = True
- return query.with_for_update(**kwargs)
- @contextlib.contextmanager
- def lock_rows(query: Query, session: Session) -> Generator[None, None, None]:
- """
- Lock database rows during the context manager block.
- This is a convenient method for ``with_row_locks`` when we don't need the
- locked rows.
- :meta private:
- """
- locked_rows = with_row_locks(query, session)
- yield
- del locked_rows
- class CommitProhibitorGuard:
- """Context manager class that powers prohibit_commit."""
- expected_commit = False
- def __init__(self, session: Session):
- self.session = session
- def _validate_commit(self, _):
- if self.expected_commit:
- self.expected_commit = False
- return
- raise RuntimeError("UNEXPECTED COMMIT - THIS WILL BREAK HA LOCKS!")
- def __enter__(self):
- event.listen(self.session, "before_commit", self._validate_commit)
- return self
- def __exit__(self, *exc_info):
- event.remove(self.session, "before_commit", self._validate_commit)
- def commit(self):
- """
- Commit the session.
- This is the required way to commit when the guard is in scope
- """
- self.expected_commit = True
- self.session.commit()
- def prohibit_commit(session):
- """
- Return a context manager that will disallow any commit that isn't done via the context manager.
- The aim of this is to ensure that transaction lifetime is strictly controlled which is especially
- important in the core scheduler loop. Any commit on the session that is _not_ via this context manager
- will result in RuntimeError
- Example usage:
- .. code:: python
- with prohibit_commit(session) as guard:
- # ... do something with session
- guard.commit()
- # This would throw an error
- # session.commit()
- """
- return CommitProhibitorGuard(session)
- def is_lock_not_available_error(error: OperationalError):
- """Check if the Error is about not being able to acquire lock."""
- # DB specific error codes:
- # Postgres: 55P03
- # MySQL: 3572, 'Statement aborted because lock(s) could not be acquired immediately and NOWAIT
- # is set.'
- # MySQL: 1205, 'Lock wait timeout exceeded; try restarting transaction
- # (when NOWAIT isn't available)
- db_err_code = getattr(error.orig, "pgcode", None) or error.orig.args[0]
- # We could test if error.orig is an instance of
- # psycopg2.errors.LockNotAvailable/_mysql_exceptions.OperationalError, but that involves
- # importing it. This doesn't
- if db_err_code in ("55P03", 1205, 3572):
- return True
- return False
- @overload
- def tuple_in_condition(
- columns: tuple[ColumnElement, ...],
- collection: Iterable[Any],
- ) -> ColumnOperators: ...
- @overload
- def tuple_in_condition(
- columns: tuple[ColumnElement, ...],
- collection: Select,
- *,
- session: Session,
- ) -> ColumnOperators: ...
- def tuple_in_condition(
- columns: tuple[ColumnElement, ...],
- collection: Iterable[Any] | Select,
- *,
- session: Session | None = None,
- ) -> ColumnOperators:
- """
- Generate a tuple-in-collection operator to use in ``.where()``.
- For most SQL backends, this generates a simple ``([col, ...]) IN [condition]``
- clause.
- :meta private:
- """
- return tuple_(*columns).in_(collection)
- @overload
- def tuple_not_in_condition(
- columns: tuple[ColumnElement, ...],
- collection: Iterable[Any],
- ) -> ColumnOperators: ...
- @overload
- def tuple_not_in_condition(
- columns: tuple[ColumnElement, ...],
- collection: Select,
- *,
- session: Session,
- ) -> ColumnOperators: ...
- def tuple_not_in_condition(
- columns: tuple[ColumnElement, ...],
- collection: Iterable[Any] | Select,
- *,
- session: Session | None = None,
- ) -> ColumnOperators:
- """
- Generate a tuple-not-in-collection operator to use in ``.where()``.
- This is similar to ``tuple_in_condition`` except generating ``NOT IN``.
- :meta private:
- """
- return tuple_(*columns).not_in(collection)
- def get_orm_mapper():
- """Get the correct ORM mapper for the installed SQLAlchemy version."""
- import sqlalchemy.orm.mapper
- return sqlalchemy.orm.mapper if is_sqlalchemy_v1() else sqlalchemy.orm.Mapper
- def is_sqlalchemy_v1() -> bool:
- return version.parse(metadata.version("sqlalchemy")).major == 1
|