123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320 |
- from __future__ import annotations
- import inspect
- import threading
- import time
- import urllib.parse
- from collections.abc import Iterable
- from math import ceil, floor
- from types import ModuleType
- from limits.errors import ConfigurationError
- from limits.storage.base import (
- SlidingWindowCounterSupport,
- Storage,
- TimestampedSlidingWindow,
- )
- from limits.typing import (
- Any,
- Callable,
- MemcachedClientP,
- P,
- R,
- cast,
- )
- from limits.util import get_dependency
- class MemcachedStorage(Storage, SlidingWindowCounterSupport, TimestampedSlidingWindow):
- """
- Rate limit storage with memcached as backend.
- Depends on :pypi:`pymemcache`.
- """
- STORAGE_SCHEME = ["memcached"]
- """The storage scheme for memcached"""
- DEPENDENCIES = ["pymemcache"]
- def __init__(
- self,
- uri: str,
- wrap_exceptions: bool = False,
- **options: str | Callable[[], MemcachedClientP],
- ) -> None:
- """
- :param uri: memcached location of the form
- ``memcached://host:port,host:port``,
- ``memcached:///var/tmp/path/to/sock``
- :param wrap_exceptions: Whether to wrap storage exceptions in
- :exc:`limits.errors.StorageError` before raising it.
- :param options: all remaining keyword arguments are passed
- directly to the constructor of :class:`pymemcache.client.base.PooledClient`
- or :class:`pymemcache.client.hash.HashClient` (if there are more than
- one hosts specified)
- :raise ConfigurationError: when :pypi:`pymemcache` is not available
- """
- parsed = urllib.parse.urlparse(uri)
- self.hosts = []
- for loc in parsed.netloc.strip().split(","):
- if not loc:
- continue
- host, port = loc.split(":")
- self.hosts.append((host, int(port)))
- else:
- # filesystem path to UDS
- if parsed.path and not parsed.netloc and not parsed.port:
- self.hosts = [parsed.path] # type: ignore
- self.dependency = self.dependencies["pymemcache"].module
- self.library = str(options.pop("library", "pymemcache.client"))
- self.cluster_library = str(
- options.pop("cluster_library", "pymemcache.client.hash")
- )
- self.client_getter = cast(
- Callable[[ModuleType, list[tuple[str, int]]], MemcachedClientP],
- options.pop("client_getter", self.get_client),
- )
- self.options = options
- if not get_dependency(self.library):
- raise ConfigurationError(
- f"memcached prerequisite not available. please install {self.library}"
- ) # pragma: no cover
- self.local_storage = threading.local()
- self.local_storage.storage = None
- super().__init__(uri, wrap_exceptions=wrap_exceptions)
- @property
- def base_exceptions(
- self,
- ) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
- return self.dependency.MemcacheError # type: ignore[no-any-return]
- def get_client(
- self, module: ModuleType, hosts: list[tuple[str, int]], **kwargs: str
- ) -> MemcachedClientP:
- """
- returns a memcached client.
- :param module: the memcached module
- :param hosts: list of memcached hosts
- """
- return cast(
- MemcachedClientP,
- (
- module.HashClient(hosts, **kwargs)
- if len(hosts) > 1
- else module.PooledClient(*hosts, **kwargs)
- ),
- )
- def call_memcached_func(
- self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs
- ) -> R:
- if "noreply" in kwargs:
- argspec = inspect.getfullargspec(func)
- if not ("noreply" in argspec.args or argspec.varkw):
- kwargs.pop("noreply")
- return func(*args, **kwargs)
- @property
- def storage(self) -> MemcachedClientP:
- """
- lazily creates a memcached client instance using a thread local
- """
- if not (hasattr(self.local_storage, "storage") and self.local_storage.storage):
- dependency = get_dependency(
- self.cluster_library if len(self.hosts) > 1 else self.library
- )[0]
- if not dependency:
- raise ConfigurationError(f"Unable to import {self.cluster_library}")
- self.local_storage.storage = self.client_getter(
- dependency, self.hosts, **self.options
- )
- return cast(MemcachedClientP, self.local_storage.storage)
- def get(self, key: str) -> int:
- """
- :param key: the key to get the counter value for
- """
- return int(self.storage.get(key, "0"))
- def get_many(self, keys: Iterable[str]) -> dict[str, Any]: # type:ignore[explicit-any]
- """
- Return multiple counters at once
- :param keys: the keys to get the counter values for
- """
- return self.storage.get_many(keys)
- def clear(self, key: str) -> None:
- """
- :param key: the key to clear rate limits for
- """
- self.storage.delete(key)
- def incr(
- self,
- key: str,
- expiry: float,
- elastic_expiry: bool = False,
- amount: int = 1,
- set_expiration_key: bool = True,
- ) -> int:
- """
- increments the counter for a given rate limit key
- :param key: the key to increment
- :param expiry: amount in seconds for the key to expire in
- :param elastic_expiry: whether to keep extending the rate limit
- window every hit.
- :param amount: the number to increment by
- :param set_expiration_key: set the expiration key with the expiration time if needed. If set to False, the key will still expire, but memcached cannot provide the expiration time.
- """
- value = self.call_memcached_func(self.storage.incr, key, amount, noreply=False)
- if value is not None:
- if elastic_expiry:
- self.call_memcached_func(self.storage.touch, key, ceil(expiry))
- if set_expiration_key:
- self.call_memcached_func(
- self.storage.set,
- self._expiration_key(key),
- expiry + time.time(),
- expire=ceil(expiry),
- noreply=False,
- )
- return value
- else:
- if not self.call_memcached_func(
- self.storage.add, key, amount, ceil(expiry), noreply=False
- ):
- value = self.storage.incr(key, amount) or amount
- if elastic_expiry:
- self.call_memcached_func(self.storage.touch, key, ceil(expiry))
- if set_expiration_key:
- self.call_memcached_func(
- self.storage.set,
- self._expiration_key(key),
- expiry + time.time(),
- expire=ceil(expiry),
- noreply=False,
- )
- return value
- else:
- if set_expiration_key:
- self.call_memcached_func(
- self.storage.set,
- self._expiration_key(key),
- expiry + time.time(),
- expire=ceil(expiry),
- noreply=False,
- )
- return amount
- def get_expiry(self, key: str) -> float:
- """
- :param key: the key to get the expiry for
- """
- return float(self.storage.get(self._expiration_key(key)) or time.time())
- def _expiration_key(self, key: str) -> str:
- """
- Return the expiration key for the given counter key.
- Memcached doesn't natively return the expiration time or TTL for a given key,
- so we implement the expiration time on a separate key.
- """
- return key + "/expires"
- def check(self) -> bool:
- """
- Check if storage is healthy by calling the ``get`` command
- on the key ``limiter-check``
- """
- try:
- self.call_memcached_func(self.storage.get, "limiter-check")
- return True
- except: # noqa
- return False
- def reset(self) -> int | None:
- raise NotImplementedError
- def acquire_sliding_window_entry(
- self,
- key: str,
- limit: int,
- expiry: int,
- amount: int = 1,
- ) -> bool:
- if amount > limit:
- return False
- now = time.time()
- previous_key, current_key = self.sliding_window_keys(key, expiry, now)
- previous_count, previous_ttl, current_count, _ = self._get_sliding_window_info(
- previous_key, current_key, expiry, now=now
- )
- weighted_count = previous_count * previous_ttl / expiry + current_count
- if floor(weighted_count) + amount > limit:
- return False
- else:
- # Hit, increase the current counter.
- # If the counter doesn't exist yet, set twice the theorical expiry.
- # We don't need the expiration key as it is estimated with the timestamps directly.
- current_count = self.incr(
- current_key, 2 * expiry, amount=amount, set_expiration_key=False
- )
- actualised_previous_ttl = min(0, previous_ttl - (time.time() - now))
- weighted_count = (
- previous_count * actualised_previous_ttl / expiry + current_count
- )
- if floor(weighted_count) > limit:
- # Another hit won the race condition: revert the incrementation and refuse this hit
- # Limitation: during high concurrency at the end of the window,
- # the counter is shifted and cannot be decremented, so less requests than expected are allowed.
- self.call_memcached_func(
- self.storage.decr,
- current_key,
- amount,
- noreply=True,
- )
- return False
- return True
- def get_sliding_window(
- self, key: str, expiry: int
- ) -> tuple[int, float, int, float]:
- now = time.time()
- previous_key, current_key = self.sliding_window_keys(key, expiry, now)
- return self._get_sliding_window_info(previous_key, current_key, expiry, now)
- def _get_sliding_window_info(
- self, previous_key: str, current_key: str, expiry: int, now: float
- ) -> tuple[int, float, int, float]:
- result = self.get_many([previous_key, current_key])
- previous_count, current_count = (
- int(result.get(previous_key, 0)),
- int(result.get(current_key, 0)),
- )
- if previous_count == 0:
- previous_ttl = float(0)
- else:
- previous_ttl = (1 - (((now - expiry) / expiry) % 1)) * expiry
- current_ttl = (1 - ((now / expiry) % 1)) * expiry + expiry
- return previous_count, previous_ttl, current_count, current_ttl
|