context.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. # Copyright Amethyst Reese
  2. # Licensed under the MIT license
  3. from collections.abc import Coroutine, Generator
  4. from contextlib import AbstractAsyncContextManager
  5. from functools import wraps
  6. from typing import Any, Callable, TypeVar
  7. from .cursor import Cursor
  8. _T = TypeVar("_T")
  9. class Result(AbstractAsyncContextManager[_T], Coroutine[Any, Any, _T]):
  10. __slots__ = ("_coro", "_obj")
  11. def __init__(self, coro: Coroutine[Any, Any, _T]):
  12. self._coro = coro
  13. self._obj: _T
  14. def send(self, value) -> None:
  15. return self._coro.send(value)
  16. def throw(self, typ, val=None, tb=None) -> None:
  17. if val is None:
  18. return self._coro.throw(typ)
  19. if tb is None:
  20. return self._coro.throw(typ, val)
  21. return self._coro.throw(typ, val, tb)
  22. def close(self) -> None:
  23. return self._coro.close()
  24. def __await__(self) -> Generator[Any, None, _T]:
  25. return self._coro.__await__()
  26. async def __aenter__(self) -> _T:
  27. self._obj = await self._coro
  28. return self._obj
  29. async def __aexit__(self, exc_type, exc, tb) -> None:
  30. if isinstance(self._obj, Cursor):
  31. await self._obj.close()
  32. def contextmanager(
  33. method: Callable[..., Coroutine[Any, Any, _T]]
  34. ) -> Callable[..., Result[_T]]:
  35. @wraps(method)
  36. def wrapper(self, *args, **kwargs) -> Result[_T]:
  37. return Result(method(self, *args, **kwargs))
  38. return wrapper