123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106 |
- # Licensed to the Apache Software Foundation (ASF) under one
- # or more contributor license agreements. See the NOTICE file
- # distributed with this work for additional information
- # regarding copyright ownership. The ASF licenses this file
- # to you under the Apache License, Version 2.0 (the
- # "License"); you may not use this file except in compliance
- # with the License. You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- # KIND, either express or implied. See the License for the
- # specific language governing permissions and limitations
- # under the License.
- from __future__ import annotations
- import contextlib
- import os
- from functools import wraps
- from inspect import signature
- from typing import Callable, Generator, TypeVar, cast
- from sqlalchemy.orm import Session as SASession
- from airflow import settings
- from airflow.api_internal.internal_api_call import InternalApiConfig
- from airflow.settings import TracebackSession, TracebackSessionForTests
- from airflow.typing_compat import ParamSpec
- @contextlib.contextmanager
- def create_session() -> Generator[SASession, None, None]:
- """Contextmanager that will create and teardown a session."""
- if InternalApiConfig.get_use_internal_api():
- if os.environ.get("RUN_TESTS_WITH_DATABASE_ISOLATION", "false").lower() == "true":
- traceback_session_for_tests = TracebackSessionForTests()
- try:
- yield traceback_session_for_tests
- if traceback_session_for_tests.current_db_session:
- traceback_session_for_tests.current_db_session.commit()
- except Exception:
- traceback_session_for_tests.current_db_session.rollback()
- raise
- finally:
- traceback_session_for_tests.current_db_session.close()
- else:
- yield TracebackSession()
- return
- Session = getattr(settings, "Session", None)
- if Session is None:
- raise RuntimeError("Session must be set before!")
- session = Session()
- try:
- yield session
- session.commit()
- except Exception:
- session.rollback()
- raise
- finally:
- session.close()
- PS = ParamSpec("PS")
- RT = TypeVar("RT")
- def find_session_idx(func: Callable[PS, RT]) -> int:
- """Find session index in function call parameter."""
- func_params = signature(func).parameters
- try:
- # func_params is an ordered dict -- this is the "recommended" way of getting the position
- session_args_idx = tuple(func_params).index("session")
- except ValueError:
- raise ValueError(f"Function {func.__qualname__} has no `session` argument") from None
- return session_args_idx
- def provide_session(func: Callable[PS, RT]) -> Callable[PS, RT]:
- """
- Provide a session if it isn't provided.
- If you want to reuse a session or run the function as part of a
- database transaction, you pass it to the function, if not this wrapper
- will create one and close it for you.
- """
- session_args_idx = find_session_idx(func)
- @wraps(func)
- def wrapper(*args, **kwargs) -> RT:
- if "session" in kwargs or session_args_idx < len(args):
- return func(*args, **kwargs)
- else:
- with create_session() as session:
- return func(*args, session=session, **kwargs)
- return wrapper
- # A fake session to use in functions decorated by provide_session. This allows
- # the 'session' argument to be of type Session instead of Session | None,
- # making it easier to type hint the function body without dealing with the None
- # case that can never happen at runtime.
- NEW_SESSION: SASession = cast(SASession, None)
|