123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522 |
- from __future__ import annotations
- import asyncio
- import datetime
- import time
- from deprecated.sphinx import versionadded, versionchanged
- from limits.aio.storage.base import (
- MovingWindowSupport,
- SlidingWindowCounterSupport,
- Storage,
- )
- from limits.typing import (
- ParamSpec,
- TypeVar,
- cast,
- )
- from limits.util import get_dependency
- P = ParamSpec("P")
- R = TypeVar("R")
- @versionadded(version="2.1")
- @versionchanged(
- version="3.14.0",
- reason="Added option to select custom collection names for windows & counters",
- )
- class MongoDBStorage(Storage, MovingWindowSupport, SlidingWindowCounterSupport):
- """
- Rate limit storage with MongoDB as backend.
- Depends on :pypi:`motor`
- """
- STORAGE_SCHEME = ["async+mongodb", "async+mongodb+srv"]
- """
- The storage scheme for MongoDB for use in an async context
- """
- DEPENDENCIES = ["motor.motor_asyncio", "pymongo"]
- def __init__(
- self,
- uri: str,
- database_name: str = "limits",
- counter_collection_name: str = "counters",
- window_collection_name: str = "windows",
- wrap_exceptions: bool = False,
- **options: float | str | bool,
- ) -> None:
- """
- :param uri: uri of the form ``async+mongodb://[user:password]@host:port?...``,
- This uri is passed directly to :class:`~motor.motor_asyncio.AsyncIOMotorClient`
- :param database_name: The database to use for storing the rate limit
- collections.
- :param counter_collection_name: The collection name to use for individual counters
- used in fixed window strategies
- :param window_collection_name: The collection name to use for sliding & moving window
- storage
- :param wrap_exceptions: Whether to wrap storage exceptions in
- :exc:`limits.errors.StorageError` before raising it.
- :param options: all remaining keyword arguments are passed
- to the constructor of :class:`~motor.motor_asyncio.AsyncIOMotorClient`
- :raise ConfigurationError: when the :pypi:`motor` or :pypi:`pymongo` are
- not available
- """
- uri = uri.replace("async+mongodb", "mongodb", 1)
- super().__init__(uri, wrap_exceptions=wrap_exceptions, **options)
- self.dependency = self.dependencies["motor.motor_asyncio"]
- self.proxy_dependency = self.dependencies["pymongo"]
- self.lib_errors, _ = get_dependency("pymongo.errors")
- self.storage = self.dependency.module.AsyncIOMotorClient(uri, **options)
- # TODO: Fix this hack. It was noticed when running a benchmark
- # with FastAPI - however - doesn't appear in unit tests or in an isolated
- # use. Reference: https://jira.mongodb.org/browse/MOTOR-822
- self.storage.get_io_loop = asyncio.get_running_loop
- self.__database_name = database_name
- self.__collection_mapping = {
- "counters": counter_collection_name,
- "windows": window_collection_name,
- }
- self.__indices_created = False
- @property
- def base_exceptions(
- self,
- ) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
- return self.lib_errors.PyMongoError # type: ignore
- @property
- def database(self): # type: ignore
- return self.storage.get_database(self.__database_name)
- async def create_indices(self) -> None:
- if not self.__indices_created:
- await asyncio.gather(
- self.database[self.__collection_mapping["counters"]].create_index(
- "expireAt", expireAfterSeconds=0
- ),
- self.database[self.__collection_mapping["windows"]].create_index(
- "expireAt", expireAfterSeconds=0
- ),
- )
- self.__indices_created = True
- async def reset(self) -> int | None:
- """
- Delete all rate limit keys in the rate limit collections (counters, windows)
- """
- num_keys = sum(
- await asyncio.gather(
- self.database[self.__collection_mapping["counters"]].count_documents(
- {}
- ),
- self.database[self.__collection_mapping["windows"]].count_documents({}),
- )
- )
- await asyncio.gather(
- self.database[self.__collection_mapping["counters"]].drop(),
- self.database[self.__collection_mapping["windows"]].drop(),
- )
- return cast(int, num_keys)
- async def clear(self, key: str) -> None:
- """
- :param key: the key to clear rate limits for
- """
- await asyncio.gather(
- self.database[self.__collection_mapping["counters"]].find_one_and_delete(
- {"_id": key}
- ),
- self.database[self.__collection_mapping["windows"]].find_one_and_delete(
- {"_id": key}
- ),
- )
- async def get_expiry(self, key: str) -> float:
- """
- :param key: the key to get the expiry for
- """
- counter = await self.database[self.__collection_mapping["counters"]].find_one(
- {"_id": key}
- )
- return (
- (counter["expireAt"] if counter else datetime.datetime.now())
- .replace(tzinfo=datetime.timezone.utc)
- .timestamp()
- )
- async def get(self, key: str) -> int:
- """
- :param key: the key to get the counter value for
- """
- counter = await self.database[self.__collection_mapping["counters"]].find_one(
- {
- "_id": key,
- "expireAt": {"$gte": datetime.datetime.now(datetime.timezone.utc)},
- },
- projection=["count"],
- )
- return counter and counter["count"] or 0
- async def incr(
- self, key: str, expiry: int, elastic_expiry: bool = False, amount: int = 1
- ) -> int:
- """
- increments the counter for a given rate limit key
- :param key: the key to increment
- :param expiry: amount in seconds for the key to expire in
- :param elastic_expiry: whether to keep extending the rate limit
- window every hit.
- :param amount: the number to increment by
- """
- await self.create_indices()
- expiration = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(
- seconds=expiry
- )
- response = await self.database[
- self.__collection_mapping["counters"]
- ].find_one_and_update(
- {"_id": key},
- [
- {
- "$set": {
- "count": {
- "$cond": {
- "if": {"$lt": ["$expireAt", "$$NOW"]},
- "then": amount,
- "else": {"$add": ["$count", amount]},
- }
- },
- "expireAt": {
- "$cond": {
- "if": {"$lt": ["$expireAt", "$$NOW"]},
- "then": expiration,
- "else": (expiration if elastic_expiry else "$expireAt"),
- }
- },
- }
- },
- ],
- upsert=True,
- projection=["count"],
- return_document=self.proxy_dependency.module.ReturnDocument.AFTER,
- )
- return int(response["count"])
- async def check(self) -> bool:
- """
- Check if storage is healthy by calling
- :meth:`motor.motor_asyncio.AsyncIOMotorClient.server_info`
- """
- try:
- await self.storage.server_info()
- return True
- except: # noqa: E722
- return False
- async def get_moving_window(
- self, key: str, limit: int, expiry: int
- ) -> tuple[float, int]:
- """
- returns the starting point and the number of entries in the moving
- window
- :param str key: rate limit key
- :param int expiry: expiry of entry
- :return: (start of window, number of acquired entries)
- """
- timestamp = time.time()
- if result := (
- await self.database[self.__collection_mapping["windows"]]
- .aggregate(
- [
- {"$match": {"_id": key}},
- {
- "$project": {
- "entries": {
- "$filter": {
- "input": "$entries",
- "as": "entry",
- "cond": {"$gte": ["$$entry", timestamp - expiry]},
- }
- }
- }
- },
- {"$unwind": "$entries"},
- {
- "$group": {
- "_id": "$_id",
- "min": {"$min": "$entries"},
- "count": {"$sum": 1},
- }
- },
- ]
- )
- .to_list(length=1)
- ):
- return result[0]["min"], result[0]["count"]
- return timestamp, 0
- async def acquire_entry(
- self, key: str, limit: int, expiry: int, amount: int = 1
- ) -> bool:
- """
- :param key: rate limit key to acquire an entry in
- :param limit: amount of entries allowed
- :param expiry: expiry of the entry
- :param amount: the number of entries to acquire
- """
- await self.create_indices()
- if amount > limit:
- return False
- timestamp = time.time()
- try:
- updates: dict[
- str,
- dict[str, datetime.datetime | dict[str, list[float] | int]],
- ] = {
- "$push": {
- "entries": {
- "$each": [timestamp] * amount,
- "$position": 0,
- "$slice": limit,
- }
- },
- "$set": {
- "expireAt": (
- datetime.datetime.now(datetime.timezone.utc)
- + datetime.timedelta(seconds=expiry)
- )
- },
- }
- await self.database[self.__collection_mapping["windows"]].update_one(
- {
- "_id": key,
- f"entries.{limit - amount}": {"$not": {"$gte": timestamp - expiry}},
- },
- updates,
- upsert=True,
- )
- return True
- except self.proxy_dependency.module.errors.DuplicateKeyError:
- return False
- async def acquire_sliding_window_entry(
- self, key: str, limit: int, expiry: int, amount: int = 1
- ) -> bool:
- await self.create_indices()
- expiry_ms = expiry * 1000
- result = await self.database[
- self.__collection_mapping["windows"]
- ].find_one_and_update(
- {"_id": key},
- [
- {
- "$set": {
- "previousCount": {
- "$cond": {
- "if": {
- "$lte": [
- {"$subtract": ["$expiresAt", "$$NOW"]},
- expiry_ms,
- ]
- },
- "then": {"$ifNull": ["$currentCount", 0]},
- "else": {"$ifNull": ["$previousCount", 0]},
- }
- },
- }
- },
- {
- "$set": {
- "currentCount": {
- "$cond": {
- "if": {
- "$lte": [
- {"$subtract": ["$expiresAt", "$$NOW"]},
- expiry_ms,
- ]
- },
- "then": 0,
- "else": {"$ifNull": ["$currentCount", 0]},
- }
- },
- "expiresAt": {
- "$cond": {
- "if": {
- "$lte": [
- {"$subtract": ["$expiresAt", "$$NOW"]},
- expiry_ms,
- ]
- },
- "then": {
- "$cond": {
- "if": {"$gt": ["$expiresAt", 0]},
- "then": {"$add": ["$expiresAt", expiry_ms]},
- "else": {"$add": ["$$NOW", 2 * expiry_ms]},
- }
- },
- "else": "$expiresAt",
- }
- },
- }
- },
- {
- "$set": {
- "curWeightedCount": {
- "$floor": {
- "$add": [
- {
- "$multiply": [
- "$previousCount",
- {
- "$divide": [
- {
- "$max": [
- 0,
- {
- "$subtract": [
- "$expiresAt",
- {
- "$add": [
- "$$NOW",
- expiry_ms,
- ]
- },
- ]
- },
- ]
- },
- expiry_ms,
- ]
- },
- ]
- },
- "$currentCount",
- ]
- }
- }
- }
- },
- {
- "$set": {
- "currentCount": {
- "$cond": {
- "if": {
- "$lte": [
- {"$add": ["$curWeightedCount", amount]},
- limit,
- ]
- },
- "then": {"$add": ["$currentCount", amount]},
- "else": "$currentCount",
- }
- }
- }
- },
- {
- "$set": {
- "_acquired": {
- "$lte": [{"$add": ["$curWeightedCount", amount]}, limit]
- }
- }
- },
- {"$unset": ["curWeightedCount"]},
- ],
- return_document=self.proxy_dependency.module.ReturnDocument.AFTER,
- upsert=True,
- )
- return cast(bool, result["_acquired"])
- async def get_sliding_window(
- self, key: str, expiry: int
- ) -> tuple[int, float, int, float]:
- expiry_ms = expiry * 1000
- if result := await self.database[
- self.__collection_mapping["windows"]
- ].find_one_and_update(
- {"_id": key},
- [
- {
- "$set": {
- "previousCount": {
- "$cond": {
- "if": {
- "$lte": [
- {"$subtract": ["$expiresAt", "$$NOW"]},
- expiry_ms,
- ]
- },
- "then": {"$ifNull": ["$currentCount", 0]},
- "else": {"$ifNull": ["$previousCount", 0]},
- }
- },
- "currentCount": {
- "$cond": {
- "if": {
- "$lte": [
- {"$subtract": ["$expiresAt", "$$NOW"]},
- expiry_ms,
- ]
- },
- "then": 0,
- "else": {"$ifNull": ["$currentCount", 0]},
- }
- },
- "expiresAt": {
- "$cond": {
- "if": {
- "$lte": [
- {"$subtract": ["$expiresAt", "$$NOW"]},
- expiry_ms,
- ]
- },
- "then": {"$add": ["$expiresAt", expiry_ms]},
- "else": "$expiresAt",
- }
- },
- }
- }
- ],
- return_document=self.proxy_dependency.module.ReturnDocument.AFTER,
- projection=["currentCount", "previousCount", "expiresAt"],
- ):
- expires_at = (
- (result["expiresAt"].replace(tzinfo=datetime.timezone.utc).timestamp())
- if result.get("expiresAt")
- else time.time()
- )
- current_ttl = max(0, expires_at - time.time())
- prev_ttl = max(0, current_ttl - expiry if result["previousCount"] else 0)
- return (
- result["previousCount"],
- prev_ttl,
- result["currentCount"],
- current_ttl,
- )
- return 0, 0.0, 0, 0.0
- def __del__(self) -> None:
- self.storage and self.storage.close()
|