session.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. # Licensed to the Apache Software Foundation (ASF) under one
  2. # or more contributor license agreements. See the NOTICE file
  3. # distributed with this work for additional information
  4. # regarding copyright ownership. The ASF licenses this file
  5. # to you under the Apache License, Version 2.0 (the
  6. # "License"); you may not use this file except in compliance
  7. # with the License. You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing,
  12. # software distributed under the License is distributed on an
  13. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  14. # KIND, either express or implied. See the License for the
  15. # specific language governing permissions and limitations
  16. # under the License.
  17. from __future__ import annotations
  18. import contextlib
  19. import os
  20. from functools import wraps
  21. from inspect import signature
  22. from typing import Callable, Generator, TypeVar, cast
  23. from sqlalchemy.orm import Session as SASession
  24. from airflow import settings
  25. from airflow.api_internal.internal_api_call import InternalApiConfig
  26. from airflow.settings import TracebackSession, TracebackSessionForTests
  27. from airflow.typing_compat import ParamSpec
  28. @contextlib.contextmanager
  29. def create_session() -> Generator[SASession, None, None]:
  30. """Contextmanager that will create and teardown a session."""
  31. if InternalApiConfig.get_use_internal_api():
  32. if os.environ.get("RUN_TESTS_WITH_DATABASE_ISOLATION", "false").lower() == "true":
  33. traceback_session_for_tests = TracebackSessionForTests()
  34. try:
  35. yield traceback_session_for_tests
  36. if traceback_session_for_tests.current_db_session:
  37. traceback_session_for_tests.current_db_session.commit()
  38. except Exception:
  39. traceback_session_for_tests.current_db_session.rollback()
  40. raise
  41. finally:
  42. traceback_session_for_tests.current_db_session.close()
  43. else:
  44. yield TracebackSession()
  45. return
  46. Session = getattr(settings, "Session", None)
  47. if Session is None:
  48. raise RuntimeError("Session must be set before!")
  49. session = Session()
  50. try:
  51. yield session
  52. session.commit()
  53. except Exception:
  54. session.rollback()
  55. raise
  56. finally:
  57. session.close()
  58. PS = ParamSpec("PS")
  59. RT = TypeVar("RT")
  60. def find_session_idx(func: Callable[PS, RT]) -> int:
  61. """Find session index in function call parameter."""
  62. func_params = signature(func).parameters
  63. try:
  64. # func_params is an ordered dict -- this is the "recommended" way of getting the position
  65. session_args_idx = tuple(func_params).index("session")
  66. except ValueError:
  67. raise ValueError(f"Function {func.__qualname__} has no `session` argument") from None
  68. return session_args_idx
  69. def provide_session(func: Callable[PS, RT]) -> Callable[PS, RT]:
  70. """
  71. Provide a session if it isn't provided.
  72. If you want to reuse a session or run the function as part of a
  73. database transaction, you pass it to the function, if not this wrapper
  74. will create one and close it for you.
  75. """
  76. session_args_idx = find_session_idx(func)
  77. @wraps(func)
  78. def wrapper(*args, **kwargs) -> RT:
  79. if "session" in kwargs or session_args_idx < len(args):
  80. return func(*args, **kwargs)
  81. else:
  82. with create_session() as session:
  83. return func(*args, session=session, **kwargs)
  84. return wrapper
  85. # A fake session to use in functions decorated by provide_session. This allows
  86. # the 'session' argument to be of type Session instead of Session | None,
  87. # making it easier to type hint the function body without dealing with the None
  88. # case that can never happen at runtime.
  89. NEW_SESSION: SASession = cast(SASession, None)