123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337 |
- """
- Asynchronous rate limiting strategies
- """
- from __future__ import annotations
- import time
- from abc import ABC, abstractmethod
- from math import floor, inf
- from deprecated.sphinx import deprecated, versionadded
- from ..limits import RateLimitItem
- from ..storage import StorageTypes
- from ..typing import cast
- from ..util import WindowStats
- from .storage import MovingWindowSupport, Storage
- from .storage.base import SlidingWindowCounterSupport
- class RateLimiter(ABC):
- def __init__(self, storage: StorageTypes):
- assert isinstance(storage, Storage)
- self.storage: Storage = storage
- @abstractmethod
- async def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
- """
- Consume the rate limit
- :param item: the rate limit item
- :param identifiers: variable list of strings to uniquely identify the
- limit
- :param cost: The cost of this hit, default 1
- """
- raise NotImplementedError
- @abstractmethod
- async def test(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
- """
- Check if the rate limit can be consumed
- :param item: the rate limit item
- :param identifiers: variable list of strings to uniquely identify the
- limit
- :param cost: The expected cost to be consumed, default 1
- """
- raise NotImplementedError
- @abstractmethod
- async def get_window_stats(
- self, item: RateLimitItem, *identifiers: str
- ) -> WindowStats:
- """
- Query the reset time and remaining amount for the limit
- :param item: the rate limit item
- :param identifiers: variable list of strings to uniquely identify the
- limit
- :return: (reset time, remaining))
- """
- raise NotImplementedError
- async def clear(self, item: RateLimitItem, *identifiers: str) -> None:
- return await self.storage.clear(item.key_for(*identifiers))
- class MovingWindowRateLimiter(RateLimiter):
- """
- Reference: :ref:`strategies:moving window`
- """
- def __init__(self, storage: StorageTypes) -> None:
- if not (
- hasattr(storage, "acquire_entry") or hasattr(storage, "get_moving_window")
- ):
- raise NotImplementedError(
- "MovingWindowRateLimiting is not implemented for storage "
- f"of type {storage.__class__}"
- )
- super().__init__(storage)
- async def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
- """
- Consume the rate limit
- :param item: the rate limit item
- :param identifiers: variable list of strings to uniquely identify the
- limit
- :param cost: The cost of this hit, default 1
- """
- return await cast(MovingWindowSupport, self.storage).acquire_entry(
- item.key_for(*identifiers), item.amount, item.get_expiry(), amount=cost
- )
- async def test(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
- """
- Check if the rate limit can be consumed
- :param item: the rate limit item
- :param identifiers: variable list of strings to uniquely identify the
- limit
- :param cost: The expected cost to be consumed, default 1
- """
- res = await cast(MovingWindowSupport, self.storage).get_moving_window(
- item.key_for(*identifiers),
- item.amount,
- item.get_expiry(),
- )
- amount = res[1]
- return amount <= item.amount - cost
- async def get_window_stats(
- self, item: RateLimitItem, *identifiers: str
- ) -> WindowStats:
- """
- returns the number of requests remaining within this limit.
- :param item: the rate limit item
- :param identifiers: variable list of strings to uniquely identify the
- limit
- :return: (reset time, remaining)
- """
- window_start, window_items = await cast(
- MovingWindowSupport, self.storage
- ).get_moving_window(item.key_for(*identifiers), item.amount, item.get_expiry())
- reset = window_start + item.get_expiry()
- return WindowStats(reset, item.amount - window_items)
- class FixedWindowRateLimiter(RateLimiter):
- """
- Reference: :ref:`strategies:fixed window`
- """
- async def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
- """
- Consume the rate limit
- :param item: the rate limit item
- :param identifiers: variable list of strings to uniquely identify the
- limit
- :param cost: The cost of this hit, default 1
- """
- return (
- await self.storage.incr(
- item.key_for(*identifiers),
- item.get_expiry(),
- elastic_expiry=False,
- amount=cost,
- )
- <= item.amount
- )
- async def test(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
- """
- Check if the rate limit can be consumed
- :param item: the rate limit item
- :param identifiers: variable list of strings to uniquely identify the
- limit
- :param cost: The expected cost to be consumed, default 1
- """
- return (
- await self.storage.get(item.key_for(*identifiers)) < item.amount - cost + 1
- )
- async def get_window_stats(
- self, item: RateLimitItem, *identifiers: str
- ) -> WindowStats:
- """
- Query the reset time and remaining amount for the limit
- :param item: the rate limit item
- :param identifiers: variable list of strings to uniquely identify the
- limit
- :return: reset time, remaining
- """
- remaining = max(
- 0,
- item.amount - await self.storage.get(item.key_for(*identifiers)),
- )
- reset = await self.storage.get_expiry(item.key_for(*identifiers))
- return WindowStats(reset, remaining)
- @versionadded(version="4.1")
- class SlidingWindowCounterRateLimiter(RateLimiter):
- """
- Reference: :ref:`strategies:sliding window counter`
- """
- def __init__(self, storage: StorageTypes):
- if not hasattr(storage, "get_sliding_window") or not hasattr(
- storage, "acquire_sliding_window_entry"
- ):
- raise NotImplementedError(
- "SlidingWindowCounterRateLimiting is not implemented for storage "
- f"of type {storage.__class__}"
- )
- super().__init__(storage)
- def _weighted_count(
- self,
- item: RateLimitItem,
- previous_count: int,
- previous_expires_in: float,
- current_count: int,
- ) -> float:
- """
- Return the approximated by weighting the previous window count and adding the current window count.
- """
- return previous_count * previous_expires_in / item.get_expiry() + current_count
- async def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
- """
- Consume the rate limit
- :param item: The rate limit item
- :param identifiers: variable list of strings to uniquely identify this
- instance of the limit
- :param cost: The cost of this hit, default 1
- """
- return await cast(
- SlidingWindowCounterSupport, self.storage
- ).acquire_sliding_window_entry(
- item.key_for(*identifiers),
- item.amount,
- item.get_expiry(),
- cost,
- )
- async def test(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
- """
- Check if the rate limit can be consumed
- :param item: The rate limit item
- :param identifiers: variable list of strings to uniquely identify this
- instance of the limit
- :param cost: The expected cost to be consumed, default 1
- """
- previous_count, previous_expires_in, current_count, _ = await cast(
- SlidingWindowCounterSupport, self.storage
- ).get_sliding_window(item.key_for(*identifiers), item.get_expiry())
- return (
- self._weighted_count(
- item, previous_count, previous_expires_in, current_count
- )
- < item.amount - cost + 1
- )
- async def get_window_stats(
- self, item: RateLimitItem, *identifiers: str
- ) -> WindowStats:
- """
- Query the reset time and remaining amount for the limit.
- :param item: The rate limit item
- :param identifiers: variable list of strings to uniquely identify this
- instance of the limit
- :return: (reset time, remaining)
- """
- (
- previous_count,
- previous_expires_in,
- current_count,
- current_expires_in,
- ) = await cast(SlidingWindowCounterSupport, self.storage).get_sliding_window(
- item.key_for(*identifiers), item.get_expiry()
- )
- remaining = max(
- 0,
- item.amount
- - floor(
- self._weighted_count(
- item, previous_count, previous_expires_in, current_count
- )
- ),
- )
- now = time.time()
- if not (previous_count or current_count):
- return WindowStats(now, remaining)
- expiry = item.get_expiry()
- previous_reset_in, current_reset_in = inf, inf
- if previous_count:
- previous_reset_in = previous_expires_in % (expiry / previous_count)
- if current_count:
- current_reset_in = current_expires_in % expiry
- return WindowStats(now + min(previous_reset_in, current_reset_in), remaining)
- @deprecated(version="4.1")
- class FixedWindowElasticExpiryRateLimiter(FixedWindowRateLimiter):
- """
- Reference: :ref:`strategies:fixed window with elastic expiry`
- """
- async def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
- """
- Consume the rate limit
- :param item: a :class:`limits.limits.RateLimitItem` instance
- :param identifiers: variable list of strings to uniquely identify the
- limit
- :param cost: The cost of this hit, default 1
- """
- amount = await self.storage.incr(
- item.key_for(*identifiers),
- item.get_expiry(),
- elastic_expiry=True,
- amount=cost,
- )
- return amount <= item.amount
- STRATEGIES = {
- "sliding-window-counter": SlidingWindowCounterRateLimiter,
- "fixed-window": FixedWindowRateLimiter,
- "fixed-window-elastic-expiry": FixedWindowElasticExpiryRateLimiter,
- "moving-window": MovingWindowRateLimiter,
- }
|