123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113 |
- # Licensed to the Apache Software Foundation (ASF) under one
- # or more contributor license agreements. See the NOTICE file
- # distributed with this work for additional information
- # regarding copyright ownership. The ASF licenses this file
- # to you under the Apache License, Version 2.0 (the
- # "License"); you may not use this file except in compliance
- # with the License. You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- # KIND, either express or implied. See the License for the
- # specific language governing permissions and limitations
- # under the License.
- from __future__ import annotations
- import functools
- import logging
- from inspect import signature
- from typing import Callable, TypeVar, overload
- from sqlalchemy.exc import DBAPIError
- from airflow.configuration import conf
- F = TypeVar("F", bound=Callable)
- MAX_DB_RETRIES = conf.getint("database", "max_db_retries", fallback=3)
- def run_with_db_retries(max_retries: int = MAX_DB_RETRIES, logger: logging.Logger | None = None, **kwargs):
- """Return Tenacity Retrying object with project specific default."""
- import tenacity
- # Default kwargs
- retry_kwargs = dict(
- retry=tenacity.retry_if_exception_type(exception_types=(DBAPIError)),
- wait=tenacity.wait_random_exponential(multiplier=0.5, max=5),
- stop=tenacity.stop_after_attempt(max_retries),
- reraise=True,
- **kwargs,
- )
- if logger and isinstance(logger, logging.Logger):
- retry_kwargs["before_sleep"] = tenacity.before_sleep_log(logger, logging.DEBUG, True)
- return tenacity.Retrying(**retry_kwargs)
- @overload
- def retry_db_transaction(*, retries: int = MAX_DB_RETRIES) -> Callable[[F], F]: ...
- @overload
- def retry_db_transaction(_func: F) -> F: ...
- def retry_db_transaction(_func: Callable | None = None, *, retries: int = MAX_DB_RETRIES, **retry_kwargs):
- """
- Retry functions in case of ``DBAPIError`` from DB.
- It should not be used with ``@provide_session``.
- """
- def retry_decorator(func: Callable) -> Callable:
- # Get Positional argument for 'session'
- func_params = signature(func).parameters
- try:
- # func_params is an ordered dict -- this is the "recommended" way of getting the position
- session_args_idx = tuple(func_params).index("session")
- except ValueError:
- raise ValueError(f"Function {func.__qualname__} has no `session` argument")
- # We don't need this anymore -- ensure we don't keep a reference to it by mistake
- del func_params
- @functools.wraps(func)
- def wrapped_function(*args, **kwargs):
- if args and hasattr(args[0], "logger"):
- logger = args[0].logger()
- elif args and hasattr(args[0], "log"):
- logger = args[0].log
- else:
- logger = logging.getLogger(func.__module__)
- # Get session from args or kwargs
- if "session" in kwargs:
- session = kwargs["session"]
- elif len(args) > session_args_idx:
- session = args[session_args_idx]
- else:
- raise TypeError(f"session is a required argument for {func.__qualname__}")
- for attempt in run_with_db_retries(max_retries=retries, logger=logger, **retry_kwargs):
- with attempt:
- logger.debug(
- "Running %s with retries. Try %d of %d",
- func.__qualname__,
- attempt.retry_state.attempt_number,
- retries,
- )
- try:
- return func(*args, **kwargs)
- except DBAPIError:
- session.rollback()
- raise
- return wrapped_function
- # Allow using decorator with and without arguments
- if _func is None:
- return retry_decorator
- else:
- return retry_decorator(_func)
|