memory.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. from __future__ import annotations
  2. import threading
  3. import time
  4. from collections import Counter, defaultdict
  5. from math import floor
  6. import limits.typing
  7. from limits.storage.base import (
  8. MovingWindowSupport,
  9. SlidingWindowCounterSupport,
  10. Storage,
  11. TimestampedSlidingWindow,
  12. )
  13. class Entry:
  14. def __init__(self, expiry: float) -> None:
  15. self.atime = time.time()
  16. self.expiry = self.atime + expiry
  17. class MemoryStorage(
  18. Storage, MovingWindowSupport, SlidingWindowCounterSupport, TimestampedSlidingWindow
  19. ):
  20. """
  21. rate limit storage using :class:`collections.Counter`
  22. as an in memory storage for fixed and elastic window strategies,
  23. and a simple list to implement moving window strategy.
  24. """
  25. STORAGE_SCHEME = ["memory"]
  26. def __init__(self, uri: str | None = None, wrap_exceptions: bool = False, **_: str):
  27. self.storage: limits.typing.Counter[str] = Counter()
  28. self.locks: defaultdict[str, threading.RLock] = defaultdict(threading.RLock)
  29. self.expirations: dict[str, float] = {}
  30. self.events: dict[str, list[Entry]] = {}
  31. self.timer: threading.Timer = threading.Timer(0.01, self.__expire_events)
  32. self.timer.start()
  33. super().__init__(uri, wrap_exceptions=wrap_exceptions, **_)
  34. def __getstate__(self) -> dict[str, limits.typing.Any]: # type: ignore[explicit-any]
  35. state = self.__dict__.copy()
  36. del state["timer"]
  37. del state["locks"]
  38. return state
  39. def __setstate__(self, state: dict[str, limits.typing.Any]) -> None: # type: ignore[explicit-any]
  40. self.__dict__.update(state)
  41. self.locks = defaultdict(threading.RLock)
  42. self.timer = threading.Timer(0.01, self.__expire_events)
  43. self.timer.start()
  44. def __expire_events(self) -> None:
  45. for key in list(self.events.keys()):
  46. with self.locks[key]:
  47. for event in list(self.events[key]):
  48. if event.expiry <= time.time() and event in self.events[key]:
  49. self.events[key].remove(event)
  50. if not self.events.get(key, None):
  51. self.locks.pop(key, None)
  52. for key in list(self.expirations.keys()):
  53. if self.expirations[key] <= time.time():
  54. self.storage.pop(key, None)
  55. self.expirations.pop(key, None)
  56. self.locks.pop(key, None)
  57. def __schedule_expiry(self) -> None:
  58. if not self.timer.is_alive():
  59. self.timer = threading.Timer(0.01, self.__expire_events)
  60. self.timer.start()
  61. @property
  62. def base_exceptions(
  63. self,
  64. ) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
  65. return ValueError
  66. def incr(
  67. self, key: str, expiry: float, elastic_expiry: bool = False, amount: int = 1
  68. ) -> int:
  69. """
  70. increments the counter for a given rate limit key
  71. :param key: the key to increment
  72. :param expiry: amount in seconds for the key to expire in
  73. :param elastic_expiry: whether to keep extending the rate limit
  74. window every hit.
  75. :param amount: the number to increment by
  76. """
  77. self.get(key)
  78. self.__schedule_expiry()
  79. with self.locks[key]:
  80. self.storage[key] += amount
  81. if elastic_expiry or self.storage[key] == amount:
  82. self.expirations[key] = time.time() + expiry
  83. return self.storage.get(key, 0)
  84. def decr(self, key: str, amount: int = 1) -> int:
  85. """
  86. decrements the counter for a given rate limit key
  87. :param key: the key to decrement
  88. :param amount: the number to decrement by
  89. """
  90. self.get(key)
  91. self.__schedule_expiry()
  92. with self.locks[key]:
  93. self.storage[key] = max(self.storage[key] - amount, 0)
  94. return self.storage.get(key, 0)
  95. def get(self, key: str) -> int:
  96. """
  97. :param key: the key to get the counter value for
  98. """
  99. if self.expirations.get(key, 0) <= time.time():
  100. self.storage.pop(key, None)
  101. self.expirations.pop(key, None)
  102. self.locks.pop(key, None)
  103. return self.storage.get(key, 0)
  104. def clear(self, key: str) -> None:
  105. """
  106. :param key: the key to clear rate limits for
  107. """
  108. self.storage.pop(key, None)
  109. self.expirations.pop(key, None)
  110. self.events.pop(key, None)
  111. self.locks.pop(key, None)
  112. def acquire_entry(self, key: str, limit: int, expiry: int, amount: int = 1) -> bool:
  113. """
  114. :param key: rate limit key to acquire an entry in
  115. :param limit: amount of entries allowed
  116. :param expiry: expiry of the entry
  117. :param amount: the number of entries to acquire
  118. """
  119. if amount > limit:
  120. return False
  121. self.__schedule_expiry()
  122. with self.locks[key]:
  123. self.events.setdefault(key, [])
  124. timestamp = time.time()
  125. try:
  126. entry = self.events[key][limit - amount]
  127. except IndexError:
  128. entry = None
  129. if entry and entry.atime >= timestamp - expiry:
  130. return False
  131. else:
  132. self.events[key][:0] = [Entry(expiry) for _ in range(amount)]
  133. return True
  134. def get_expiry(self, key: str) -> float:
  135. """
  136. :param key: the key to get the expiry for
  137. """
  138. return self.expirations.get(key, time.time())
  139. def get_num_acquired(self, key: str, expiry: int) -> int:
  140. """
  141. returns the number of entries already acquired
  142. :param key: rate limit key to acquire an entry in
  143. :param expiry: expiry of the entry
  144. """
  145. timestamp = time.time()
  146. return (
  147. len([k for k in self.events.get(key, []) if k.atime >= timestamp - expiry])
  148. if self.events.get(key)
  149. else 0
  150. )
  151. def get_moving_window(self, key: str, limit: int, expiry: int) -> tuple[float, int]:
  152. """
  153. returns the starting point and the number of entries in the moving
  154. window
  155. :param key: rate limit key
  156. :param expiry: expiry of entry
  157. :return: (start of window, number of acquired entries)
  158. """
  159. timestamp = time.time()
  160. acquired = self.get_num_acquired(key, expiry)
  161. for item in self.events.get(key, [])[::-1]:
  162. if item.atime >= timestamp - expiry:
  163. return item.atime, acquired
  164. return timestamp, acquired
  165. def acquire_sliding_window_entry(
  166. self,
  167. key: str,
  168. limit: int,
  169. expiry: int,
  170. amount: int = 1,
  171. ) -> bool:
  172. if amount > limit:
  173. return False
  174. now = time.time()
  175. previous_key, current_key = self.sliding_window_keys(key, expiry, now)
  176. (
  177. previous_count,
  178. previous_ttl,
  179. current_count,
  180. _,
  181. ) = self._get_sliding_window_info(previous_key, current_key, expiry, now)
  182. weighted_count = previous_count * previous_ttl / expiry + current_count
  183. if floor(weighted_count) + amount > limit:
  184. return False
  185. else:
  186. # Hit, increase the current counter.
  187. # If the counter doesn't exist yet, set twice the theorical expiry.
  188. current_count = self.incr(current_key, 2 * expiry, amount=amount)
  189. weighted_count = previous_count * previous_ttl / expiry + current_count
  190. if floor(weighted_count) > limit:
  191. # Another hit won the race condition: revert the incrementation and refuse this hit
  192. # Limitation: during high concurrency at the end of the window,
  193. # the counter is shifted and cannot be decremented, so less requests than expected are allowed.
  194. self.decr(current_key, amount)
  195. return False
  196. return True
  197. def _get_sliding_window_info(
  198. self,
  199. previous_key: str,
  200. current_key: str,
  201. expiry: int,
  202. now: float,
  203. ) -> tuple[int, float, int, float]:
  204. previous_count = self.get(previous_key)
  205. current_count = self.get(current_key)
  206. if previous_count == 0:
  207. previous_ttl = float(0)
  208. else:
  209. previous_ttl = (1 - (((now - expiry) / expiry) % 1)) * expiry
  210. current_ttl = (1 - ((now / expiry) % 1)) * expiry + expiry
  211. return previous_count, previous_ttl, current_count, current_ttl
  212. def get_sliding_window(
  213. self, key: str, expiry: int
  214. ) -> tuple[int, float, int, float]:
  215. now = time.time()
  216. previous_key, current_key = self.sliding_window_keys(key, expiry, now)
  217. return self._get_sliding_window_info(previous_key, current_key, expiry, now)
  218. def check(self) -> bool:
  219. """
  220. check if storage is healthy
  221. """
  222. return True
  223. def reset(self) -> int | None:
  224. num_items = max(len(self.storage), len(self.events))
  225. self.storage.clear()
  226. self.expirations.clear()
  227. self.events.clear()
  228. self.locks.clear()
  229. return num_items