base.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. from __future__ import annotations
  2. import functools
  3. from abc import ABC, abstractmethod
  4. from limits import errors
  5. from limits.storage.registry import StorageRegistry
  6. from limits.typing import (
  7. Any,
  8. Callable,
  9. P,
  10. R,
  11. cast,
  12. )
  13. from limits.util import LazyDependency
  14. def _wrap_errors(
  15. fn: Callable[P, R],
  16. ) -> Callable[P, R]:
  17. @functools.wraps(fn)
  18. def inner(*args: P.args, **kwargs: P.kwargs) -> R:
  19. instance = cast(Storage, args[0])
  20. try:
  21. return fn(*args, **kwargs)
  22. except instance.base_exceptions as exc:
  23. if instance.wrap_exceptions:
  24. raise errors.StorageError(exc) from exc
  25. raise
  26. return inner
  27. class Storage(LazyDependency, metaclass=StorageRegistry):
  28. """
  29. Base class to extend when implementing a storage backend.
  30. """
  31. STORAGE_SCHEME: list[str] | None
  32. """The storage schemes to register against this implementation"""
  33. def __init_subclass__(cls, **kwargs: Any) -> None: # type: ignore[explicit-any]
  34. for method in {
  35. "incr",
  36. "get",
  37. "get_expiry",
  38. "check",
  39. "reset",
  40. "clear",
  41. }:
  42. setattr(cls, method, _wrap_errors(getattr(cls, method)))
  43. super().__init_subclass__(**kwargs)
  44. def __init__(
  45. self,
  46. uri: str | None = None,
  47. wrap_exceptions: bool = False,
  48. **options: float | str | bool,
  49. ):
  50. """
  51. :param wrap_exceptions: Whether to wrap storage exceptions in
  52. :exc:`limits.errors.StorageError` before raising it.
  53. """
  54. super().__init__()
  55. self.wrap_exceptions = wrap_exceptions
  56. @property
  57. @abstractmethod
  58. def base_exceptions(self) -> type[Exception] | tuple[type[Exception], ...]:
  59. raise NotImplementedError
  60. @abstractmethod
  61. def incr(
  62. self, key: str, expiry: int, elastic_expiry: bool = False, amount: int = 1
  63. ) -> int:
  64. """
  65. increments the counter for a given rate limit key
  66. :param key: the key to increment
  67. :param expiry: amount in seconds for the key to expire in
  68. :param elastic_expiry: whether to keep extending the rate limit
  69. window every hit.
  70. :param amount: the number to increment by
  71. """
  72. raise NotImplementedError
  73. @abstractmethod
  74. def get(self, key: str) -> int:
  75. """
  76. :param key: the key to get the counter value for
  77. """
  78. raise NotImplementedError
  79. @abstractmethod
  80. def get_expiry(self, key: str) -> float:
  81. """
  82. :param key: the key to get the expiry for
  83. """
  84. raise NotImplementedError
  85. @abstractmethod
  86. def check(self) -> bool:
  87. """
  88. check if storage is healthy
  89. """
  90. raise NotImplementedError
  91. @abstractmethod
  92. def reset(self) -> int | None:
  93. """
  94. reset storage to clear limits
  95. """
  96. raise NotImplementedError
  97. @abstractmethod
  98. def clear(self, key: str) -> None:
  99. """
  100. resets the rate limit key
  101. :param key: the key to clear rate limits for
  102. """
  103. raise NotImplementedError
  104. class MovingWindowSupport(ABC):
  105. """
  106. Abstract base class for storages that support
  107. the :ref:`strategies:moving window` strategy
  108. """
  109. def __init_subclass__(cls, **kwargs: Any) -> None: # type: ignore[explicit-any]
  110. for method in {
  111. "acquire_entry",
  112. "get_moving_window",
  113. }:
  114. setattr(
  115. cls,
  116. method,
  117. _wrap_errors(getattr(cls, method)),
  118. )
  119. super().__init_subclass__(**kwargs)
  120. @abstractmethod
  121. def acquire_entry(self, key: str, limit: int, expiry: int, amount: int = 1) -> bool:
  122. """
  123. :param key: rate limit key to acquire an entry in
  124. :param limit: amount of entries allowed
  125. :param expiry: expiry of the entry
  126. :param amount: the number of entries to acquire
  127. """
  128. raise NotImplementedError
  129. @abstractmethod
  130. def get_moving_window(self, key: str, limit: int, expiry: int) -> tuple[float, int]:
  131. """
  132. returns the starting point and the number of entries in the moving
  133. window
  134. :param key: rate limit key
  135. :param expiry: expiry of entry
  136. :return: (start of window, number of acquired entries)
  137. """
  138. raise NotImplementedError
  139. class SlidingWindowCounterSupport(ABC):
  140. """
  141. Abstract base class for storages that support
  142. the :ref:`strategies:sliding window counter` strategy.
  143. """
  144. def __init_subclass__(cls, **kwargs: Any) -> None: # type: ignore[explicit-any]
  145. for method in {"acquire_sliding_window_entry", "get_sliding_window"}:
  146. setattr(
  147. cls,
  148. method,
  149. _wrap_errors(getattr(cls, method)),
  150. )
  151. super().__init_subclass__(**kwargs)
  152. @abstractmethod
  153. def acquire_sliding_window_entry(
  154. self, key: str, limit: int, expiry: int, amount: int = 1
  155. ) -> bool:
  156. """
  157. Acquire an entry if the weighted count of the current and previous
  158. windows is less than or equal to the limit
  159. :param key: rate limit key to acquire an entry in
  160. :param limit: amount of entries allowed
  161. :param expiry: expiry of the entry
  162. :param amount: the number of entries to acquire
  163. """
  164. raise NotImplementedError
  165. @abstractmethod
  166. def get_sliding_window(
  167. self, key: str, expiry: int
  168. ) -> tuple[int, float, int, float]:
  169. """
  170. Return the previous and current window information.
  171. :param key: the rate limit key
  172. :param expiry: the rate limit expiry, needed to compute the key in some implementations
  173. :return: a tuple of (int, float, int, float) with the following information:
  174. - previous window counter
  175. - previous window TTL
  176. - current window counter
  177. - current window TTL
  178. """
  179. raise NotImplementedError
  180. class TimestampedSlidingWindow:
  181. """Helper class for storage that support the sliding window counter, with timestamp based keys."""
  182. @classmethod
  183. def sliding_window_keys(cls, key: str, expiry: int, at: float) -> tuple[str, str]:
  184. """
  185. returns the previous and the current window's keys.
  186. :param key: the key to get the window's keys from
  187. :param expiry: the expiry of the limit item, in seconds
  188. :param at: the timestamp to get the keys from. Default to now, ie ``time.time()``
  189. Returns a tuple with the previous and the current key: (previous, current).
  190. Example:
  191. - key = "mykey"
  192. - expiry = 60
  193. - at = 1738576292.6631825
  194. The return value will be the tuple ``("mykey/28976271", "mykey/28976270")``.
  195. """
  196. return f"{key}/{int((at - expiry) / expiry)}", f"{key}/{int(at / expiry)}"