memcached.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. from __future__ import annotations
  2. import inspect
  3. import threading
  4. import time
  5. import urllib.parse
  6. from collections.abc import Iterable
  7. from math import ceil, floor
  8. from types import ModuleType
  9. from limits.errors import ConfigurationError
  10. from limits.storage.base import (
  11. SlidingWindowCounterSupport,
  12. Storage,
  13. TimestampedSlidingWindow,
  14. )
  15. from limits.typing import (
  16. Any,
  17. Callable,
  18. MemcachedClientP,
  19. P,
  20. R,
  21. cast,
  22. )
  23. from limits.util import get_dependency
  24. class MemcachedStorage(Storage, SlidingWindowCounterSupport, TimestampedSlidingWindow):
  25. """
  26. Rate limit storage with memcached as backend.
  27. Depends on :pypi:`pymemcache`.
  28. """
  29. STORAGE_SCHEME = ["memcached"]
  30. """The storage scheme for memcached"""
  31. DEPENDENCIES = ["pymemcache"]
  32. def __init__(
  33. self,
  34. uri: str,
  35. wrap_exceptions: bool = False,
  36. **options: str | Callable[[], MemcachedClientP],
  37. ) -> None:
  38. """
  39. :param uri: memcached location of the form
  40. ``memcached://host:port,host:port``,
  41. ``memcached:///var/tmp/path/to/sock``
  42. :param wrap_exceptions: Whether to wrap storage exceptions in
  43. :exc:`limits.errors.StorageError` before raising it.
  44. :param options: all remaining keyword arguments are passed
  45. directly to the constructor of :class:`pymemcache.client.base.PooledClient`
  46. or :class:`pymemcache.client.hash.HashClient` (if there are more than
  47. one hosts specified)
  48. :raise ConfigurationError: when :pypi:`pymemcache` is not available
  49. """
  50. parsed = urllib.parse.urlparse(uri)
  51. self.hosts = []
  52. for loc in parsed.netloc.strip().split(","):
  53. if not loc:
  54. continue
  55. host, port = loc.split(":")
  56. self.hosts.append((host, int(port)))
  57. else:
  58. # filesystem path to UDS
  59. if parsed.path and not parsed.netloc and not parsed.port:
  60. self.hosts = [parsed.path] # type: ignore
  61. self.dependency = self.dependencies["pymemcache"].module
  62. self.library = str(options.pop("library", "pymemcache.client"))
  63. self.cluster_library = str(
  64. options.pop("cluster_library", "pymemcache.client.hash")
  65. )
  66. self.client_getter = cast(
  67. Callable[[ModuleType, list[tuple[str, int]]], MemcachedClientP],
  68. options.pop("client_getter", self.get_client),
  69. )
  70. self.options = options
  71. if not get_dependency(self.library):
  72. raise ConfigurationError(
  73. f"memcached prerequisite not available. please install {self.library}"
  74. ) # pragma: no cover
  75. self.local_storage = threading.local()
  76. self.local_storage.storage = None
  77. super().__init__(uri, wrap_exceptions=wrap_exceptions)
  78. @property
  79. def base_exceptions(
  80. self,
  81. ) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
  82. return self.dependency.MemcacheError # type: ignore[no-any-return]
  83. def get_client(
  84. self, module: ModuleType, hosts: list[tuple[str, int]], **kwargs: str
  85. ) -> MemcachedClientP:
  86. """
  87. returns a memcached client.
  88. :param module: the memcached module
  89. :param hosts: list of memcached hosts
  90. """
  91. return cast(
  92. MemcachedClientP,
  93. (
  94. module.HashClient(hosts, **kwargs)
  95. if len(hosts) > 1
  96. else module.PooledClient(*hosts, **kwargs)
  97. ),
  98. )
  99. def call_memcached_func(
  100. self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs
  101. ) -> R:
  102. if "noreply" in kwargs:
  103. argspec = inspect.getfullargspec(func)
  104. if not ("noreply" in argspec.args or argspec.varkw):
  105. kwargs.pop("noreply")
  106. return func(*args, **kwargs)
  107. @property
  108. def storage(self) -> MemcachedClientP:
  109. """
  110. lazily creates a memcached client instance using a thread local
  111. """
  112. if not (hasattr(self.local_storage, "storage") and self.local_storage.storage):
  113. dependency = get_dependency(
  114. self.cluster_library if len(self.hosts) > 1 else self.library
  115. )[0]
  116. if not dependency:
  117. raise ConfigurationError(f"Unable to import {self.cluster_library}")
  118. self.local_storage.storage = self.client_getter(
  119. dependency, self.hosts, **self.options
  120. )
  121. return cast(MemcachedClientP, self.local_storage.storage)
  122. def get(self, key: str) -> int:
  123. """
  124. :param key: the key to get the counter value for
  125. """
  126. return int(self.storage.get(key, "0"))
  127. def get_many(self, keys: Iterable[str]) -> dict[str, Any]: # type:ignore[explicit-any]
  128. """
  129. Return multiple counters at once
  130. :param keys: the keys to get the counter values for
  131. """
  132. return self.storage.get_many(keys)
  133. def clear(self, key: str) -> None:
  134. """
  135. :param key: the key to clear rate limits for
  136. """
  137. self.storage.delete(key)
  138. def incr(
  139. self,
  140. key: str,
  141. expiry: float,
  142. elastic_expiry: bool = False,
  143. amount: int = 1,
  144. set_expiration_key: bool = True,
  145. ) -> int:
  146. """
  147. increments the counter for a given rate limit key
  148. :param key: the key to increment
  149. :param expiry: amount in seconds for the key to expire in
  150. :param elastic_expiry: whether to keep extending the rate limit
  151. window every hit.
  152. :param amount: the number to increment by
  153. :param set_expiration_key: set the expiration key with the expiration time if needed. If set to False, the key will still expire, but memcached cannot provide the expiration time.
  154. """
  155. value = self.call_memcached_func(self.storage.incr, key, amount, noreply=False)
  156. if value is not None:
  157. if elastic_expiry:
  158. self.call_memcached_func(self.storage.touch, key, ceil(expiry))
  159. if set_expiration_key:
  160. self.call_memcached_func(
  161. self.storage.set,
  162. self._expiration_key(key),
  163. expiry + time.time(),
  164. expire=ceil(expiry),
  165. noreply=False,
  166. )
  167. return value
  168. else:
  169. if not self.call_memcached_func(
  170. self.storage.add, key, amount, ceil(expiry), noreply=False
  171. ):
  172. value = self.storage.incr(key, amount) or amount
  173. if elastic_expiry:
  174. self.call_memcached_func(self.storage.touch, key, ceil(expiry))
  175. if set_expiration_key:
  176. self.call_memcached_func(
  177. self.storage.set,
  178. self._expiration_key(key),
  179. expiry + time.time(),
  180. expire=ceil(expiry),
  181. noreply=False,
  182. )
  183. return value
  184. else:
  185. if set_expiration_key:
  186. self.call_memcached_func(
  187. self.storage.set,
  188. self._expiration_key(key),
  189. expiry + time.time(),
  190. expire=ceil(expiry),
  191. noreply=False,
  192. )
  193. return amount
  194. def get_expiry(self, key: str) -> float:
  195. """
  196. :param key: the key to get the expiry for
  197. """
  198. return float(self.storage.get(self._expiration_key(key)) or time.time())
  199. def _expiration_key(self, key: str) -> str:
  200. """
  201. Return the expiration key for the given counter key.
  202. Memcached doesn't natively return the expiration time or TTL for a given key,
  203. so we implement the expiration time on a separate key.
  204. """
  205. return key + "/expires"
  206. def check(self) -> bool:
  207. """
  208. Check if storage is healthy by calling the ``get`` command
  209. on the key ``limiter-check``
  210. """
  211. try:
  212. self.call_memcached_func(self.storage.get, "limiter-check")
  213. return True
  214. except: # noqa
  215. return False
  216. def reset(self) -> int | None:
  217. raise NotImplementedError
  218. def acquire_sliding_window_entry(
  219. self,
  220. key: str,
  221. limit: int,
  222. expiry: int,
  223. amount: int = 1,
  224. ) -> bool:
  225. if amount > limit:
  226. return False
  227. now = time.time()
  228. previous_key, current_key = self.sliding_window_keys(key, expiry, now)
  229. previous_count, previous_ttl, current_count, _ = self._get_sliding_window_info(
  230. previous_key, current_key, expiry, now=now
  231. )
  232. weighted_count = previous_count * previous_ttl / expiry + current_count
  233. if floor(weighted_count) + amount > limit:
  234. return False
  235. else:
  236. # Hit, increase the current counter.
  237. # If the counter doesn't exist yet, set twice the theorical expiry.
  238. # We don't need the expiration key as it is estimated with the timestamps directly.
  239. current_count = self.incr(
  240. current_key, 2 * expiry, amount=amount, set_expiration_key=False
  241. )
  242. actualised_previous_ttl = min(0, previous_ttl - (time.time() - now))
  243. weighted_count = (
  244. previous_count * actualised_previous_ttl / expiry + current_count
  245. )
  246. if floor(weighted_count) > limit:
  247. # Another hit won the race condition: revert the incrementation and refuse this hit
  248. # Limitation: during high concurrency at the end of the window,
  249. # the counter is shifted and cannot be decremented, so less requests than expected are allowed.
  250. self.call_memcached_func(
  251. self.storage.decr,
  252. current_key,
  253. amount,
  254. noreply=True,
  255. )
  256. return False
  257. return True
  258. def get_sliding_window(
  259. self, key: str, expiry: int
  260. ) -> tuple[int, float, int, float]:
  261. now = time.time()
  262. previous_key, current_key = self.sliding_window_keys(key, expiry, now)
  263. return self._get_sliding_window_info(previous_key, current_key, expiry, now)
  264. def _get_sliding_window_info(
  265. self, previous_key: str, current_key: str, expiry: int, now: float
  266. ) -> tuple[int, float, int, float]:
  267. result = self.get_many([previous_key, current_key])
  268. previous_count, current_count = (
  269. int(result.get(previous_key, 0)),
  270. int(result.get(current_key, 0)),
  271. )
  272. if previous_count == 0:
  273. previous_ttl = float(0)
  274. else:
  275. previous_ttl = (1 - (((now - expiry) / expiry) % 1)) * expiry
  276. current_ttl = (1 - ((now / expiry) % 1)) * expiry + expiry
  277. return previous_count, previous_ttl, current_count, current_ttl