lowlevel.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. from __future__ import annotations
  2. import enum
  3. from dataclasses import dataclass
  4. from typing import Any, Generic, Literal, TypeVar, overload
  5. from weakref import WeakKeyDictionary
  6. from ._core._eventloop import get_async_backend
  7. T = TypeVar("T")
  8. D = TypeVar("D")
  9. async def checkpoint() -> None:
  10. """
  11. Check for cancellation and allow the scheduler to switch to another task.
  12. Equivalent to (but more efficient than)::
  13. await checkpoint_if_cancelled()
  14. await cancel_shielded_checkpoint()
  15. .. versionadded:: 3.0
  16. """
  17. await get_async_backend().checkpoint()
  18. async def checkpoint_if_cancelled() -> None:
  19. """
  20. Enter a checkpoint if the enclosing cancel scope has been cancelled.
  21. This does not allow the scheduler to switch to a different task.
  22. .. versionadded:: 3.0
  23. """
  24. await get_async_backend().checkpoint_if_cancelled()
  25. async def cancel_shielded_checkpoint() -> None:
  26. """
  27. Allow the scheduler to switch to another task but without checking for cancellation.
  28. Equivalent to (but potentially more efficient than)::
  29. with CancelScope(shield=True):
  30. await checkpoint()
  31. .. versionadded:: 3.0
  32. """
  33. await get_async_backend().cancel_shielded_checkpoint()
  34. def current_token() -> object:
  35. """
  36. Return a backend specific token object that can be used to get back to the event
  37. loop.
  38. """
  39. return get_async_backend().current_token()
  40. _run_vars: WeakKeyDictionary[Any, dict[str, Any]] = WeakKeyDictionary()
  41. _token_wrappers: dict[Any, _TokenWrapper] = {}
  42. @dataclass(frozen=True)
  43. class _TokenWrapper:
  44. __slots__ = "_token", "__weakref__"
  45. _token: object
  46. class _NoValueSet(enum.Enum):
  47. NO_VALUE_SET = enum.auto()
  48. class RunvarToken(Generic[T]):
  49. __slots__ = "_var", "_value", "_redeemed"
  50. def __init__(self, var: RunVar[T], value: T | Literal[_NoValueSet.NO_VALUE_SET]):
  51. self._var = var
  52. self._value: T | Literal[_NoValueSet.NO_VALUE_SET] = value
  53. self._redeemed = False
  54. class RunVar(Generic[T]):
  55. """
  56. Like a :class:`~contextvars.ContextVar`, except scoped to the running event loop.
  57. """
  58. __slots__ = "_name", "_default"
  59. NO_VALUE_SET: Literal[_NoValueSet.NO_VALUE_SET] = _NoValueSet.NO_VALUE_SET
  60. _token_wrappers: set[_TokenWrapper] = set()
  61. def __init__(
  62. self, name: str, default: T | Literal[_NoValueSet.NO_VALUE_SET] = NO_VALUE_SET
  63. ):
  64. self._name = name
  65. self._default = default
  66. @property
  67. def _current_vars(self) -> dict[str, T]:
  68. token = current_token()
  69. try:
  70. return _run_vars[token]
  71. except KeyError:
  72. run_vars = _run_vars[token] = {}
  73. return run_vars
  74. @overload
  75. def get(self, default: D) -> T | D: ...
  76. @overload
  77. def get(self) -> T: ...
  78. def get(
  79. self, default: D | Literal[_NoValueSet.NO_VALUE_SET] = NO_VALUE_SET
  80. ) -> T | D:
  81. try:
  82. return self._current_vars[self._name]
  83. except KeyError:
  84. if default is not RunVar.NO_VALUE_SET:
  85. return default
  86. elif self._default is not RunVar.NO_VALUE_SET:
  87. return self._default
  88. raise LookupError(
  89. f'Run variable "{self._name}" has no value and no default set'
  90. )
  91. def set(self, value: T) -> RunvarToken[T]:
  92. current_vars = self._current_vars
  93. token = RunvarToken(self, current_vars.get(self._name, RunVar.NO_VALUE_SET))
  94. current_vars[self._name] = value
  95. return token
  96. def reset(self, token: RunvarToken[T]) -> None:
  97. if token._var is not self:
  98. raise ValueError("This token does not belong to this RunVar")
  99. if token._redeemed:
  100. raise ValueError("This token has already been used")
  101. if token._value is _NoValueSet.NO_VALUE_SET:
  102. try:
  103. del self._current_vars[self._name]
  104. except KeyError:
  105. pass
  106. else:
  107. self._current_vars[self._name] = token._value
  108. token._redeemed = True
  109. def __repr__(self) -> str:
  110. return f"<RunVar name={self._name!r}>"