strategies.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. """
  2. Asynchronous rate limiting strategies
  3. """
  4. from __future__ import annotations
  5. import time
  6. from abc import ABC, abstractmethod
  7. from math import floor, inf
  8. from deprecated.sphinx import deprecated, versionadded
  9. from ..limits import RateLimitItem
  10. from ..storage import StorageTypes
  11. from ..typing import cast
  12. from ..util import WindowStats
  13. from .storage import MovingWindowSupport, Storage
  14. from .storage.base import SlidingWindowCounterSupport
  15. class RateLimiter(ABC):
  16. def __init__(self, storage: StorageTypes):
  17. assert isinstance(storage, Storage)
  18. self.storage: Storage = storage
  19. @abstractmethod
  20. async def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
  21. """
  22. Consume the rate limit
  23. :param item: the rate limit item
  24. :param identifiers: variable list of strings to uniquely identify the
  25. limit
  26. :param cost: The cost of this hit, default 1
  27. """
  28. raise NotImplementedError
  29. @abstractmethod
  30. async def test(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
  31. """
  32. Check if the rate limit can be consumed
  33. :param item: the rate limit item
  34. :param identifiers: variable list of strings to uniquely identify the
  35. limit
  36. :param cost: The expected cost to be consumed, default 1
  37. """
  38. raise NotImplementedError
  39. @abstractmethod
  40. async def get_window_stats(
  41. self, item: RateLimitItem, *identifiers: str
  42. ) -> WindowStats:
  43. """
  44. Query the reset time and remaining amount for the limit
  45. :param item: the rate limit item
  46. :param identifiers: variable list of strings to uniquely identify the
  47. limit
  48. :return: (reset time, remaining))
  49. """
  50. raise NotImplementedError
  51. async def clear(self, item: RateLimitItem, *identifiers: str) -> None:
  52. return await self.storage.clear(item.key_for(*identifiers))
  53. class MovingWindowRateLimiter(RateLimiter):
  54. """
  55. Reference: :ref:`strategies:moving window`
  56. """
  57. def __init__(self, storage: StorageTypes) -> None:
  58. if not (
  59. hasattr(storage, "acquire_entry") or hasattr(storage, "get_moving_window")
  60. ):
  61. raise NotImplementedError(
  62. "MovingWindowRateLimiting is not implemented for storage "
  63. f"of type {storage.__class__}"
  64. )
  65. super().__init__(storage)
  66. async def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
  67. """
  68. Consume the rate limit
  69. :param item: the rate limit item
  70. :param identifiers: variable list of strings to uniquely identify the
  71. limit
  72. :param cost: The cost of this hit, default 1
  73. """
  74. return await cast(MovingWindowSupport, self.storage).acquire_entry(
  75. item.key_for(*identifiers), item.amount, item.get_expiry(), amount=cost
  76. )
  77. async def test(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
  78. """
  79. Check if the rate limit can be consumed
  80. :param item: the rate limit item
  81. :param identifiers: variable list of strings to uniquely identify the
  82. limit
  83. :param cost: The expected cost to be consumed, default 1
  84. """
  85. res = await cast(MovingWindowSupport, self.storage).get_moving_window(
  86. item.key_for(*identifiers),
  87. item.amount,
  88. item.get_expiry(),
  89. )
  90. amount = res[1]
  91. return amount <= item.amount - cost
  92. async def get_window_stats(
  93. self, item: RateLimitItem, *identifiers: str
  94. ) -> WindowStats:
  95. """
  96. returns the number of requests remaining within this limit.
  97. :param item: the rate limit item
  98. :param identifiers: variable list of strings to uniquely identify the
  99. limit
  100. :return: (reset time, remaining)
  101. """
  102. window_start, window_items = await cast(
  103. MovingWindowSupport, self.storage
  104. ).get_moving_window(item.key_for(*identifiers), item.amount, item.get_expiry())
  105. reset = window_start + item.get_expiry()
  106. return WindowStats(reset, item.amount - window_items)
  107. class FixedWindowRateLimiter(RateLimiter):
  108. """
  109. Reference: :ref:`strategies:fixed window`
  110. """
  111. async def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
  112. """
  113. Consume the rate limit
  114. :param item: the rate limit item
  115. :param identifiers: variable list of strings to uniquely identify the
  116. limit
  117. :param cost: The cost of this hit, default 1
  118. """
  119. return (
  120. await self.storage.incr(
  121. item.key_for(*identifiers),
  122. item.get_expiry(),
  123. elastic_expiry=False,
  124. amount=cost,
  125. )
  126. <= item.amount
  127. )
  128. async def test(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
  129. """
  130. Check if the rate limit can be consumed
  131. :param item: the rate limit item
  132. :param identifiers: variable list of strings to uniquely identify the
  133. limit
  134. :param cost: The expected cost to be consumed, default 1
  135. """
  136. return (
  137. await self.storage.get(item.key_for(*identifiers)) < item.amount - cost + 1
  138. )
  139. async def get_window_stats(
  140. self, item: RateLimitItem, *identifiers: str
  141. ) -> WindowStats:
  142. """
  143. Query the reset time and remaining amount for the limit
  144. :param item: the rate limit item
  145. :param identifiers: variable list of strings to uniquely identify the
  146. limit
  147. :return: reset time, remaining
  148. """
  149. remaining = max(
  150. 0,
  151. item.amount - await self.storage.get(item.key_for(*identifiers)),
  152. )
  153. reset = await self.storage.get_expiry(item.key_for(*identifiers))
  154. return WindowStats(reset, remaining)
  155. @versionadded(version="4.1")
  156. class SlidingWindowCounterRateLimiter(RateLimiter):
  157. """
  158. Reference: :ref:`strategies:sliding window counter`
  159. """
  160. def __init__(self, storage: StorageTypes):
  161. if not hasattr(storage, "get_sliding_window") or not hasattr(
  162. storage, "acquire_sliding_window_entry"
  163. ):
  164. raise NotImplementedError(
  165. "SlidingWindowCounterRateLimiting is not implemented for storage "
  166. f"of type {storage.__class__}"
  167. )
  168. super().__init__(storage)
  169. def _weighted_count(
  170. self,
  171. item: RateLimitItem,
  172. previous_count: int,
  173. previous_expires_in: float,
  174. current_count: int,
  175. ) -> float:
  176. """
  177. Return the approximated by weighting the previous window count and adding the current window count.
  178. """
  179. return previous_count * previous_expires_in / item.get_expiry() + current_count
  180. async def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
  181. """
  182. Consume the rate limit
  183. :param item: The rate limit item
  184. :param identifiers: variable list of strings to uniquely identify this
  185. instance of the limit
  186. :param cost: The cost of this hit, default 1
  187. """
  188. return await cast(
  189. SlidingWindowCounterSupport, self.storage
  190. ).acquire_sliding_window_entry(
  191. item.key_for(*identifiers),
  192. item.amount,
  193. item.get_expiry(),
  194. cost,
  195. )
  196. async def test(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
  197. """
  198. Check if the rate limit can be consumed
  199. :param item: The rate limit item
  200. :param identifiers: variable list of strings to uniquely identify this
  201. instance of the limit
  202. :param cost: The expected cost to be consumed, default 1
  203. """
  204. previous_count, previous_expires_in, current_count, _ = await cast(
  205. SlidingWindowCounterSupport, self.storage
  206. ).get_sliding_window(item.key_for(*identifiers), item.get_expiry())
  207. return (
  208. self._weighted_count(
  209. item, previous_count, previous_expires_in, current_count
  210. )
  211. < item.amount - cost + 1
  212. )
  213. async def get_window_stats(
  214. self, item: RateLimitItem, *identifiers: str
  215. ) -> WindowStats:
  216. """
  217. Query the reset time and remaining amount for the limit.
  218. :param item: The rate limit item
  219. :param identifiers: variable list of strings to uniquely identify this
  220. instance of the limit
  221. :return: (reset time, remaining)
  222. """
  223. (
  224. previous_count,
  225. previous_expires_in,
  226. current_count,
  227. current_expires_in,
  228. ) = await cast(SlidingWindowCounterSupport, self.storage).get_sliding_window(
  229. item.key_for(*identifiers), item.get_expiry()
  230. )
  231. remaining = max(
  232. 0,
  233. item.amount
  234. - floor(
  235. self._weighted_count(
  236. item, previous_count, previous_expires_in, current_count
  237. )
  238. ),
  239. )
  240. now = time.time()
  241. if not (previous_count or current_count):
  242. return WindowStats(now, remaining)
  243. expiry = item.get_expiry()
  244. previous_reset_in, current_reset_in = inf, inf
  245. if previous_count:
  246. previous_reset_in = previous_expires_in % (expiry / previous_count)
  247. if current_count:
  248. current_reset_in = current_expires_in % expiry
  249. return WindowStats(now + min(previous_reset_in, current_reset_in), remaining)
  250. @deprecated(version="4.1")
  251. class FixedWindowElasticExpiryRateLimiter(FixedWindowRateLimiter):
  252. """
  253. Reference: :ref:`strategies:fixed window with elastic expiry`
  254. """
  255. async def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
  256. """
  257. Consume the rate limit
  258. :param item: a :class:`limits.limits.RateLimitItem` instance
  259. :param identifiers: variable list of strings to uniquely identify the
  260. limit
  261. :param cost: The cost of this hit, default 1
  262. """
  263. amount = await self.storage.incr(
  264. item.key_for(*identifiers),
  265. item.get_expiry(),
  266. elastic_expiry=True,
  267. amount=cost,
  268. )
  269. return amount <= item.amount
  270. STRATEGIES = {
  271. "sliding-window-counter": SlidingWindowCounterRateLimiter,
  272. "fixed-window": FixedWindowRateLimiter,
  273. "fixed-window-elastic-expiry": FixedWindowElasticExpiryRateLimiter,
  274. "moving-window": MovingWindowRateLimiter,
  275. }