base.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. from __future__ import annotations
  2. import functools
  3. from abc import ABC, abstractmethod
  4. from deprecated.sphinx import versionadded
  5. from limits import errors
  6. from limits.storage.registry import StorageRegistry
  7. from limits.typing import (
  8. Any,
  9. Awaitable,
  10. Callable,
  11. P,
  12. R,
  13. cast,
  14. )
  15. from limits.util import LazyDependency
  16. def _wrap_errors(
  17. fn: Callable[P, Awaitable[R]],
  18. ) -> Callable[P, Awaitable[R]]:
  19. @functools.wraps(fn)
  20. async def inner(*args: P.args, **kwargs: P.kwargs) -> R: # type: ignore[misc]
  21. instance = cast(Storage, args[0])
  22. try:
  23. return await fn(*args, **kwargs)
  24. except instance.base_exceptions as exc:
  25. if instance.wrap_exceptions:
  26. raise errors.StorageError(exc) from exc
  27. raise
  28. return inner
  29. @versionadded(version="2.1")
  30. class Storage(LazyDependency, metaclass=StorageRegistry):
  31. """
  32. Base class to extend when implementing an async storage backend.
  33. """
  34. STORAGE_SCHEME: list[str] | None
  35. """The storage schemes to register against this implementation"""
  36. def __init_subclass__(cls, **kwargs: Any) -> None: # type:ignore[explicit-any]
  37. super().__init_subclass__(**kwargs)
  38. for method in {
  39. "incr",
  40. "get",
  41. "get_expiry",
  42. "check",
  43. "reset",
  44. "clear",
  45. }:
  46. setattr(cls, method, _wrap_errors(getattr(cls, method)))
  47. super().__init_subclass__(**kwargs)
  48. def __init__(
  49. self,
  50. uri: str | None = None,
  51. wrap_exceptions: bool = False,
  52. **options: float | str | bool,
  53. ) -> None:
  54. """
  55. :param wrap_exceptions: Whether to wrap storage exceptions in
  56. :exc:`limits.errors.StorageError` before raising it.
  57. """
  58. super().__init__()
  59. self.wrap_exceptions = wrap_exceptions
  60. @property
  61. @abstractmethod
  62. def base_exceptions(self) -> type[Exception] | tuple[type[Exception], ...]:
  63. raise NotImplementedError
  64. @abstractmethod
  65. async def incr(
  66. self, key: str, expiry: int, elastic_expiry: bool = False, amount: int = 1
  67. ) -> int:
  68. """
  69. increments the counter for a given rate limit key
  70. :param key: the key to increment
  71. :param expiry: amount in seconds for the key to expire in
  72. :param elastic_expiry: whether to keep extending the rate limit
  73. window every hit.
  74. :param amount: the number to increment by
  75. """
  76. raise NotImplementedError
  77. @abstractmethod
  78. async def get(self, key: str) -> int:
  79. """
  80. :param key: the key to get the counter value for
  81. """
  82. raise NotImplementedError
  83. @abstractmethod
  84. async def get_expiry(self, key: str) -> float:
  85. """
  86. :param key: the key to get the expiry for
  87. """
  88. raise NotImplementedError
  89. @abstractmethod
  90. async def check(self) -> bool:
  91. """
  92. check if storage is healthy
  93. """
  94. raise NotImplementedError
  95. @abstractmethod
  96. async def reset(self) -> int | None:
  97. """
  98. reset storage to clear limits
  99. """
  100. raise NotImplementedError
  101. @abstractmethod
  102. async def clear(self, key: str) -> None:
  103. """
  104. resets the rate limit key
  105. :param key: the key to clear rate limits for
  106. """
  107. raise NotImplementedError
  108. class MovingWindowSupport(ABC):
  109. """
  110. Abstract base class for async storages that support
  111. the :ref:`strategies:moving window` strategy
  112. """
  113. def __init_subclass__(cls, **kwargs: Any) -> None: # type: ignore[explicit-any]
  114. for method in {
  115. "acquire_entry",
  116. "get_moving_window",
  117. }:
  118. setattr(
  119. cls,
  120. method,
  121. _wrap_errors(getattr(cls, method)),
  122. )
  123. super().__init_subclass__(**kwargs)
  124. @abstractmethod
  125. async def acquire_entry(
  126. self, key: str, limit: int, expiry: int, amount: int = 1
  127. ) -> bool:
  128. """
  129. :param key: rate limit key to acquire an entry in
  130. :param limit: amount of entries allowed
  131. :param expiry: expiry of the entry
  132. :param amount: the number of entries to acquire
  133. """
  134. raise NotImplementedError
  135. @abstractmethod
  136. async def get_moving_window(
  137. self, key: str, limit: int, expiry: int
  138. ) -> tuple[float, int]:
  139. """
  140. returns the starting point and the number of entries in the moving
  141. window
  142. :param key: rate limit key
  143. :param expiry: expiry of entry
  144. :return: (start of window, number of acquired entries)
  145. """
  146. raise NotImplementedError
  147. class SlidingWindowCounterSupport(ABC):
  148. """
  149. Abstract base class for async storages that support
  150. the :ref:`strategies:sliding window counter` strategy
  151. """
  152. def __init_subclass__(cls, **kwargs: Any) -> None: # type: ignore[explicit-any]
  153. for method in {"acquire_sliding_window_entry", "get_sliding_window"}:
  154. setattr(
  155. cls,
  156. method,
  157. _wrap_errors(getattr(cls, method)),
  158. )
  159. super().__init_subclass__(**kwargs)
  160. @abstractmethod
  161. async def acquire_sliding_window_entry(
  162. self,
  163. key: str,
  164. limit: int,
  165. expiry: int,
  166. amount: int = 1,
  167. ) -> bool:
  168. """
  169. Acquire an entry if the weighted count of the current and previous
  170. windows is less than or equal to the limit
  171. :param key: rate limit key to acquire an entry in
  172. :param limit: amount of entries allowed
  173. :param expiry: expiry of the entry
  174. :param amount: the number of entries to acquire
  175. """
  176. raise NotImplementedError
  177. @abstractmethod
  178. async def get_sliding_window(
  179. self, key: str, expiry: int
  180. ) -> tuple[int, float, int, float]:
  181. """
  182. Return the previous and current window information.
  183. :param key: the rate limit key
  184. :param expiry: the rate limit expiry, needed to compute the key in some implementations
  185. :return: a tuple of (int, float, int, float) with the following information:
  186. - previous window counter
  187. - previous window TTL
  188. - current window counter
  189. - current window TTL
  190. """
  191. raise NotImplementedError