sqlalchemy.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558
  1. #
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing,
  13. # software distributed under the License is distributed on an
  14. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. # KIND, either express or implied. See the License for the
  16. # specific language governing permissions and limitations
  17. # under the License.
  18. from __future__ import annotations
  19. import contextlib
  20. import copy
  21. import datetime
  22. import json
  23. import logging
  24. from importlib import metadata
  25. from typing import TYPE_CHECKING, Any, Generator, Iterable, overload
  26. from dateutil import relativedelta
  27. from packaging import version
  28. from sqlalchemy import TIMESTAMP, PickleType, event, nullsfirst, tuple_
  29. from sqlalchemy.dialects import mysql
  30. from sqlalchemy.types import JSON, Text, TypeDecorator
  31. from airflow.configuration import conf
  32. from airflow.serialization.enums import Encoding
  33. from airflow.utils.timezone import make_naive, utc
  34. if TYPE_CHECKING:
  35. from kubernetes.client.models.v1_pod import V1Pod
  36. from sqlalchemy.exc import OperationalError
  37. from sqlalchemy.orm import Query, Session
  38. from sqlalchemy.sql import ColumnElement, Select
  39. from sqlalchemy.sql.expression import ColumnOperators
  40. from sqlalchemy.types import TypeEngine
  41. log = logging.getLogger(__name__)
  42. class UtcDateTime(TypeDecorator):
  43. """
  44. Similar to :class:`~sqlalchemy.types.TIMESTAMP` with ``timezone=True`` option, with some differences.
  45. - Never silently take naive :class:`~datetime.datetime`, instead it
  46. always raise :exc:`ValueError` unless time zone aware value.
  47. - :class:`~datetime.datetime` value's :attr:`~datetime.datetime.tzinfo`
  48. is always converted to UTC.
  49. - Unlike SQLAlchemy's built-in :class:`~sqlalchemy.types.TIMESTAMP`,
  50. it never return naive :class:`~datetime.datetime`, but time zone
  51. aware value, even with SQLite or MySQL.
  52. - Always returns TIMESTAMP in UTC.
  53. """
  54. impl = TIMESTAMP(timezone=True)
  55. cache_ok = True
  56. def process_bind_param(self, value, dialect):
  57. if not isinstance(value, datetime.datetime):
  58. if value is None:
  59. return None
  60. raise TypeError(f"expected datetime.datetime, not {value!r}")
  61. elif value.tzinfo is None:
  62. raise ValueError("naive datetime is disallowed")
  63. elif dialect.name == "mysql":
  64. # For mysql versions prior 8.0.19 we should send timestamps as naive values in UTC
  65. # see: https://dev.mysql.com/doc/refman/8.0/en/date-and-time-literals.html
  66. return make_naive(value, timezone=utc)
  67. return value.astimezone(utc)
  68. def process_result_value(self, value, dialect):
  69. """
  70. Process DateTimes from the DB making sure to always return UTC.
  71. Not using timezone.convert_to_utc as that converts to configured TIMEZONE
  72. while the DB might be running with some other setting. We assume UTC
  73. datetimes in the database.
  74. """
  75. if value is not None:
  76. if value.tzinfo is None:
  77. value = value.replace(tzinfo=utc)
  78. else:
  79. value = value.astimezone(utc)
  80. return value
  81. def load_dialect_impl(self, dialect):
  82. if dialect.name == "mysql":
  83. return mysql.TIMESTAMP(fsp=6)
  84. return super().load_dialect_impl(dialect)
  85. class ExtendedJSON(TypeDecorator):
  86. """
  87. A version of the JSON column that uses the Airflow extended JSON serialization.
  88. See airflow.serialization.
  89. """
  90. impl = Text
  91. cache_ok = True
  92. should_evaluate_none = True
  93. def load_dialect_impl(self, dialect) -> TypeEngine:
  94. return dialect.type_descriptor(JSON)
  95. def process_bind_param(self, value, dialect):
  96. from airflow.serialization.serialized_objects import BaseSerialization
  97. if value is None:
  98. return None
  99. return BaseSerialization.serialize(value)
  100. def process_result_value(self, value, dialect):
  101. from airflow.serialization.serialized_objects import BaseSerialization
  102. if value is None:
  103. return None
  104. return BaseSerialization.deserialize(value)
  105. def sanitize_for_serialization(obj: V1Pod):
  106. """
  107. Convert pod to dict.... but *safely*.
  108. When pod objects created with one k8s version are unpickled in a python
  109. env with a more recent k8s version (in which the object attrs may have
  110. changed) the unpickled obj may throw an error because the attr
  111. expected on new obj may not be there on the unpickled obj.
  112. This function still converts the pod to a dict; the only difference is
  113. it populates missing attrs with None. You may compare with
  114. https://github.com/kubernetes-client/python/blob/5a96bbcbe21a552cc1f9cda13e0522fafb0dbac8/kubernetes/client/api_client.py#L202
  115. If obj is None, return None.
  116. If obj is str, int, long, float, bool, return directly.
  117. If obj is datetime.datetime, datetime.date
  118. convert to string in iso8601 format.
  119. If obj is list, sanitize each element in the list.
  120. If obj is dict, return the dict.
  121. If obj is OpenAPI model, return the properties dict.
  122. :param obj: The data to serialize.
  123. :return: The serialized form of data.
  124. :meta private:
  125. """
  126. if obj is None:
  127. return None
  128. elif isinstance(obj, (float, bool, bytes, str, int)):
  129. return obj
  130. elif isinstance(obj, list):
  131. return [sanitize_for_serialization(sub_obj) for sub_obj in obj]
  132. elif isinstance(obj, tuple):
  133. return tuple(sanitize_for_serialization(sub_obj) for sub_obj in obj)
  134. elif isinstance(obj, (datetime.datetime, datetime.date)):
  135. return obj.isoformat()
  136. if isinstance(obj, dict):
  137. obj_dict = obj
  138. else:
  139. obj_dict = {
  140. obj.attribute_map[attr]: getattr(obj, attr)
  141. for attr, _ in obj.openapi_types.items()
  142. # below is the only line we change, and we just add default=None for getattr
  143. if getattr(obj, attr, None) is not None
  144. }
  145. return {key: sanitize_for_serialization(val) for key, val in obj_dict.items()}
  146. def ensure_pod_is_valid_after_unpickling(pod: V1Pod) -> V1Pod | None:
  147. """
  148. Convert pod to json and back so that pod is safe.
  149. The pod_override in executor_config is a V1Pod object.
  150. Such objects created with one k8s version, when unpickled in
  151. an env with upgraded k8s version, may blow up when
  152. `to_dict` is called, because openapi client code gen calls
  153. getattr on all attrs in openapi_types for each object, and when
  154. new attrs are added to that list, getattr will fail.
  155. Here we re-serialize it to ensure it is not going to blow up.
  156. :meta private:
  157. """
  158. try:
  159. # if to_dict works, the pod is fine
  160. pod.to_dict()
  161. return pod
  162. except AttributeError:
  163. pass
  164. try:
  165. from kubernetes.client.models.v1_pod import V1Pod
  166. except ImportError:
  167. return None
  168. if not isinstance(pod, V1Pod):
  169. return None
  170. try:
  171. try:
  172. from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator
  173. except ImportError:
  174. from airflow.kubernetes.pre_7_4_0_compatibility.pod_generator import ( # type: ignore[assignment]
  175. PodGenerator,
  176. )
  177. # now we actually reserialize / deserialize the pod
  178. pod_dict = sanitize_for_serialization(pod)
  179. return PodGenerator.deserialize_model_dict(pod_dict)
  180. except Exception:
  181. return None
  182. class ExecutorConfigType(PickleType):
  183. """
  184. Adds special handling for K8s executor config.
  185. If we unpickle a k8s object that was pickled under an earlier k8s library version, then
  186. the unpickled object may throw an error when to_dict is called. To be more tolerant of
  187. version changes we convert to JSON using Airflow's serializer before pickling.
  188. """
  189. cache_ok = True
  190. def bind_processor(self, dialect):
  191. from airflow.serialization.serialized_objects import BaseSerialization
  192. super_process = super().bind_processor(dialect)
  193. def process(value):
  194. val_copy = copy.copy(value)
  195. if isinstance(val_copy, dict) and "pod_override" in val_copy:
  196. val_copy["pod_override"] = BaseSerialization.serialize(val_copy["pod_override"])
  197. return super_process(val_copy)
  198. return process
  199. def result_processor(self, dialect, coltype):
  200. from airflow.serialization.serialized_objects import BaseSerialization
  201. super_process = super().result_processor(dialect, coltype)
  202. def process(value):
  203. value = super_process(value) # unpickle
  204. if isinstance(value, dict) and "pod_override" in value:
  205. pod_override = value["pod_override"]
  206. if isinstance(pod_override, dict) and pod_override.get(Encoding.TYPE):
  207. # If pod_override was serialized with Airflow's BaseSerialization, deserialize it
  208. value["pod_override"] = BaseSerialization.deserialize(pod_override)
  209. else:
  210. # backcompat path
  211. # we no longer pickle raw pods but this code may be reached
  212. # when accessing executor configs created in a prior version
  213. new_pod = ensure_pod_is_valid_after_unpickling(pod_override)
  214. if new_pod:
  215. value["pod_override"] = new_pod
  216. return value
  217. return process
  218. def compare_values(self, x, y):
  219. """
  220. Compare x and y using self.comparator if available. Else, use __eq__.
  221. The TaskInstance.executor_config attribute is a pickled object that may contain kubernetes objects.
  222. If the installed library version has changed since the object was originally pickled,
  223. due to the underlying ``__eq__`` method on these objects (which converts them to JSON),
  224. we may encounter attribute errors. In this case we should replace the stored object.
  225. From https://github.com/apache/airflow/pull/24356 we use our serializer to store
  226. k8s objects, but there could still be raw pickled k8s objects in the database,
  227. stored from earlier version, so we still compare them defensively here.
  228. """
  229. if self.comparator:
  230. return self.comparator(x, y)
  231. else:
  232. try:
  233. return x == y
  234. except AttributeError:
  235. return False
  236. class Interval(TypeDecorator):
  237. """Base class representing a time interval."""
  238. impl = Text
  239. cache_ok = True
  240. attr_keys = {
  241. datetime.timedelta: ("days", "seconds", "microseconds"),
  242. relativedelta.relativedelta: (
  243. "years",
  244. "months",
  245. "days",
  246. "leapdays",
  247. "hours",
  248. "minutes",
  249. "seconds",
  250. "microseconds",
  251. "year",
  252. "month",
  253. "day",
  254. "hour",
  255. "minute",
  256. "second",
  257. "microsecond",
  258. ),
  259. }
  260. def process_bind_param(self, value, dialect):
  261. if isinstance(value, tuple(self.attr_keys)):
  262. attrs = {key: getattr(value, key) for key in self.attr_keys[type(value)]}
  263. return json.dumps({"type": type(value).__name__, "attrs": attrs})
  264. return json.dumps(value)
  265. def process_result_value(self, value, dialect):
  266. if not value:
  267. return value
  268. data = json.loads(value)
  269. if isinstance(data, dict):
  270. type_map = {key.__name__: key for key in self.attr_keys}
  271. return type_map[data["type"]](**data["attrs"])
  272. return data
  273. def nulls_first(col, session: Session) -> dict[str, Any]:
  274. """
  275. Specify *NULLS FIRST* to the column ordering.
  276. This is only done to Postgres, currently the only backend that supports it.
  277. Other databases do not need it since NULL values are considered lower than
  278. any other values, and appear first when the order is ASC (ascending).
  279. """
  280. if session.bind.dialect.name == "postgresql":
  281. return nullsfirst(col)
  282. else:
  283. return col
  284. USE_ROW_LEVEL_LOCKING: bool = conf.getboolean("scheduler", "use_row_level_locking", fallback=True)
  285. def with_row_locks(
  286. query: Query,
  287. session: Session,
  288. *,
  289. nowait: bool = False,
  290. skip_locked: bool = False,
  291. **kwargs,
  292. ) -> Query:
  293. """
  294. Apply with_for_update to the SQLAlchemy query if row level locking is in use.
  295. This wrapper is needed so we don't use the syntax on unsupported database
  296. engines. In particular, MySQL (prior to 8.0) and MariaDB do not support
  297. row locking, where we do not support nor recommend running HA scheduler. If
  298. a user ignores this and tries anyway, everything will still work, just
  299. slightly slower in some circumstances.
  300. See https://jira.mariadb.org/browse/MDEV-13115
  301. :param query: An SQLAlchemy Query object
  302. :param session: ORM Session
  303. :param nowait: If set to True, will pass NOWAIT to supported database backends.
  304. :param skip_locked: If set to True, will pass SKIP LOCKED to supported database backends.
  305. :param kwargs: Extra kwargs to pass to with_for_update (of, nowait, skip_locked, etc)
  306. :return: updated query
  307. """
  308. dialect = session.bind.dialect
  309. # Don't use row level locks if the MySQL dialect (Mariadb & MySQL < 8) does not support it.
  310. if not USE_ROW_LEVEL_LOCKING:
  311. return query
  312. if dialect.name == "mysql" and not dialect.supports_for_update_of:
  313. return query
  314. if nowait:
  315. kwargs["nowait"] = True
  316. if skip_locked:
  317. kwargs["skip_locked"] = True
  318. return query.with_for_update(**kwargs)
  319. @contextlib.contextmanager
  320. def lock_rows(query: Query, session: Session) -> Generator[None, None, None]:
  321. """
  322. Lock database rows during the context manager block.
  323. This is a convenient method for ``with_row_locks`` when we don't need the
  324. locked rows.
  325. :meta private:
  326. """
  327. locked_rows = with_row_locks(query, session)
  328. yield
  329. del locked_rows
  330. class CommitProhibitorGuard:
  331. """Context manager class that powers prohibit_commit."""
  332. expected_commit = False
  333. def __init__(self, session: Session):
  334. self.session = session
  335. def _validate_commit(self, _):
  336. if self.expected_commit:
  337. self.expected_commit = False
  338. return
  339. raise RuntimeError("UNEXPECTED COMMIT - THIS WILL BREAK HA LOCKS!")
  340. def __enter__(self):
  341. event.listen(self.session, "before_commit", self._validate_commit)
  342. return self
  343. def __exit__(self, *exc_info):
  344. event.remove(self.session, "before_commit", self._validate_commit)
  345. def commit(self):
  346. """
  347. Commit the session.
  348. This is the required way to commit when the guard is in scope
  349. """
  350. self.expected_commit = True
  351. self.session.commit()
  352. def prohibit_commit(session):
  353. """
  354. Return a context manager that will disallow any commit that isn't done via the context manager.
  355. The aim of this is to ensure that transaction lifetime is strictly controlled which is especially
  356. important in the core scheduler loop. Any commit on the session that is _not_ via this context manager
  357. will result in RuntimeError
  358. Example usage:
  359. .. code:: python
  360. with prohibit_commit(session) as guard:
  361. # ... do something with session
  362. guard.commit()
  363. # This would throw an error
  364. # session.commit()
  365. """
  366. return CommitProhibitorGuard(session)
  367. def is_lock_not_available_error(error: OperationalError):
  368. """Check if the Error is about not being able to acquire lock."""
  369. # DB specific error codes:
  370. # Postgres: 55P03
  371. # MySQL: 3572, 'Statement aborted because lock(s) could not be acquired immediately and NOWAIT
  372. # is set.'
  373. # MySQL: 1205, 'Lock wait timeout exceeded; try restarting transaction
  374. # (when NOWAIT isn't available)
  375. db_err_code = getattr(error.orig, "pgcode", None) or error.orig.args[0]
  376. # We could test if error.orig is an instance of
  377. # psycopg2.errors.LockNotAvailable/_mysql_exceptions.OperationalError, but that involves
  378. # importing it. This doesn't
  379. if db_err_code in ("55P03", 1205, 3572):
  380. return True
  381. return False
  382. @overload
  383. def tuple_in_condition(
  384. columns: tuple[ColumnElement, ...],
  385. collection: Iterable[Any],
  386. ) -> ColumnOperators: ...
  387. @overload
  388. def tuple_in_condition(
  389. columns: tuple[ColumnElement, ...],
  390. collection: Select,
  391. *,
  392. session: Session,
  393. ) -> ColumnOperators: ...
  394. def tuple_in_condition(
  395. columns: tuple[ColumnElement, ...],
  396. collection: Iterable[Any] | Select,
  397. *,
  398. session: Session | None = None,
  399. ) -> ColumnOperators:
  400. """
  401. Generate a tuple-in-collection operator to use in ``.where()``.
  402. For most SQL backends, this generates a simple ``([col, ...]) IN [condition]``
  403. clause.
  404. :meta private:
  405. """
  406. return tuple_(*columns).in_(collection)
  407. @overload
  408. def tuple_not_in_condition(
  409. columns: tuple[ColumnElement, ...],
  410. collection: Iterable[Any],
  411. ) -> ColumnOperators: ...
  412. @overload
  413. def tuple_not_in_condition(
  414. columns: tuple[ColumnElement, ...],
  415. collection: Select,
  416. *,
  417. session: Session,
  418. ) -> ColumnOperators: ...
  419. def tuple_not_in_condition(
  420. columns: tuple[ColumnElement, ...],
  421. collection: Iterable[Any] | Select,
  422. *,
  423. session: Session | None = None,
  424. ) -> ColumnOperators:
  425. """
  426. Generate a tuple-not-in-collection operator to use in ``.where()``.
  427. This is similar to ``tuple_in_condition`` except generating ``NOT IN``.
  428. :meta private:
  429. """
  430. return tuple_(*columns).not_in(collection)
  431. def get_orm_mapper():
  432. """Get the correct ORM mapper for the installed SQLAlchemy version."""
  433. import sqlalchemy.orm.mapper
  434. return sqlalchemy.orm.mapper if is_sqlalchemy_v1() else sqlalchemy.orm.Mapper
  435. def is_sqlalchemy_v1() -> bool:
  436. return version.parse(metadata.version("sqlalchemy")).major == 1