mongodb.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498
  1. from __future__ import annotations
  2. import datetime
  3. import time
  4. from abc import ABC, abstractmethod
  5. from deprecated.sphinx import versionadded, versionchanged
  6. from limits.typing import (
  7. MongoClient,
  8. MongoCollection,
  9. MongoDatabase,
  10. cast,
  11. )
  12. from ..util import get_dependency
  13. from .base import MovingWindowSupport, SlidingWindowCounterSupport, Storage
  14. class MongoDBStorageBase(
  15. Storage, MovingWindowSupport, SlidingWindowCounterSupport, ABC
  16. ):
  17. """
  18. Rate limit storage with MongoDB as backend.
  19. Depends on :pypi:`pymongo`.
  20. """
  21. DEPENDENCIES = ["pymongo"]
  22. def __init__(
  23. self,
  24. uri: str,
  25. database_name: str = "limits",
  26. counter_collection_name: str = "counters",
  27. window_collection_name: str = "windows",
  28. wrap_exceptions: bool = False,
  29. **options: int | str | bool,
  30. ) -> None:
  31. """
  32. :param uri: uri of the form ``mongodb://[user:password]@host:port?...``,
  33. This uri is passed directly to :class:`~pymongo.mongo_client.MongoClient`
  34. :param database_name: The database to use for storing the rate limit
  35. collections.
  36. :param counter_collection_name: The collection name to use for individual counters
  37. used in fixed window strategies
  38. :param window_collection_name: The collection name to use for sliding & moving window
  39. storage
  40. :param wrap_exceptions: Whether to wrap storage exceptions in
  41. :exc:`limits.errors.StorageError` before raising it.
  42. :param options: all remaining keyword arguments are passed to the
  43. constructor of :class:`~pymongo.mongo_client.MongoClient`
  44. :raise ConfigurationError: when the :pypi:`pymongo` library is not available
  45. """
  46. super().__init__(uri, wrap_exceptions=wrap_exceptions, **options)
  47. self._database_name = database_name
  48. self._collection_mapping = {
  49. "counters": counter_collection_name,
  50. "windows": window_collection_name,
  51. }
  52. self.lib = self.dependencies["pymongo"].module
  53. self.lib_errors, _ = get_dependency("pymongo.errors")
  54. self._storage_uri = uri
  55. self._storage_options = options
  56. self._storage: MongoClient | None = None
  57. @property
  58. def storage(self) -> MongoClient:
  59. if self._storage is None:
  60. self._storage = self._init_mongo_client(
  61. self._storage_uri, **self._storage_options
  62. )
  63. self.__initialize_database()
  64. return self._storage
  65. @property
  66. def _database(self) -> MongoDatabase:
  67. return self.storage[self._database_name]
  68. @property
  69. def counters(self) -> MongoCollection:
  70. return self._database[self._collection_mapping["counters"]]
  71. @property
  72. def windows(self) -> MongoCollection:
  73. return self._database[self._collection_mapping["windows"]]
  74. @abstractmethod
  75. def _init_mongo_client(
  76. self, uri: str | None, **options: int | str | bool
  77. ) -> MongoClient:
  78. raise NotImplementedError()
  79. @property
  80. def base_exceptions(
  81. self,
  82. ) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
  83. return self.lib_errors.PyMongoError # type: ignore
  84. def __initialize_database(self) -> None:
  85. self.counters.create_index("expireAt", expireAfterSeconds=0)
  86. self.windows.create_index("expireAt", expireAfterSeconds=0)
  87. def reset(self) -> int | None:
  88. """
  89. Delete all rate limit keys in the rate limit collections (counters, windows)
  90. """
  91. num_keys = self.counters.count_documents({}) + self.windows.count_documents({})
  92. self.counters.drop()
  93. self.windows.drop()
  94. return int(num_keys)
  95. def clear(self, key: str) -> None:
  96. """
  97. :param key: the key to clear rate limits for
  98. """
  99. self.counters.find_one_and_delete({"_id": key})
  100. self.windows.find_one_and_delete({"_id": key})
  101. def get_expiry(self, key: str) -> float:
  102. """
  103. :param key: the key to get the expiry for
  104. """
  105. counter = self.counters.find_one({"_id": key})
  106. return (
  107. (counter["expireAt"] if counter else datetime.datetime.now())
  108. .replace(tzinfo=datetime.timezone.utc)
  109. .timestamp()
  110. )
  111. def get(self, key: str) -> int:
  112. """
  113. :param key: the key to get the counter value for
  114. """
  115. counter = self.counters.find_one(
  116. {
  117. "_id": key,
  118. "expireAt": {"$gte": datetime.datetime.now(datetime.timezone.utc)},
  119. },
  120. projection=["count"],
  121. )
  122. return counter and counter["count"] or 0
  123. def incr(
  124. self, key: str, expiry: int, elastic_expiry: bool = False, amount: int = 1
  125. ) -> int:
  126. """
  127. increments the counter for a given rate limit key
  128. :param key: the key to increment
  129. :param expiry: amount in seconds for the key to expire in
  130. :param amount: the number to increment by
  131. """
  132. expiration = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(
  133. seconds=expiry
  134. )
  135. return int(
  136. self.counters.find_one_and_update(
  137. {"_id": key},
  138. [
  139. {
  140. "$set": {
  141. "count": {
  142. "$cond": {
  143. "if": {"$lt": ["$expireAt", "$$NOW"]},
  144. "then": amount,
  145. "else": {"$add": ["$count", amount]},
  146. }
  147. },
  148. "expireAt": {
  149. "$cond": {
  150. "if": {"$lt": ["$expireAt", "$$NOW"]},
  151. "then": expiration,
  152. "else": (
  153. expiration if elastic_expiry else "$expireAt"
  154. ),
  155. }
  156. },
  157. }
  158. },
  159. ],
  160. upsert=True,
  161. projection=["count"],
  162. return_document=self.lib.ReturnDocument.AFTER,
  163. )["count"]
  164. )
  165. def check(self) -> bool:
  166. """
  167. Check if storage is healthy by calling :meth:`pymongo.mongo_client.MongoClient.server_info`
  168. """
  169. try:
  170. self.storage.server_info()
  171. return True
  172. except: # noqa: E722
  173. return False
  174. def get_moving_window(self, key: str, limit: int, expiry: int) -> tuple[float, int]:
  175. """
  176. returns the starting point and the number of entries in the moving
  177. window
  178. :param key: rate limit key
  179. :param expiry: expiry of entry
  180. :return: (start of window, number of acquired entries)
  181. """
  182. timestamp = time.time()
  183. result = list(
  184. self.windows.aggregate(
  185. [
  186. {"$match": {"_id": key}},
  187. {
  188. "$project": {
  189. "entries": {
  190. "$filter": {
  191. "input": "$entries",
  192. "as": "entry",
  193. "cond": {"$gte": ["$$entry", timestamp - expiry]},
  194. }
  195. }
  196. }
  197. },
  198. {"$unwind": "$entries"},
  199. {
  200. "$group": {
  201. "_id": "$_id",
  202. "min": {"$min": "$entries"},
  203. "count": {"$sum": 1},
  204. }
  205. },
  206. ]
  207. )
  208. )
  209. if result:
  210. return result[0]["min"], result[0]["count"]
  211. return timestamp, 0
  212. def acquire_entry(self, key: str, limit: int, expiry: int, amount: int = 1) -> bool:
  213. """
  214. :param key: rate limit key to acquire an entry in
  215. :param limit: amount of entries allowed
  216. :param expiry: expiry of the entry
  217. :param amount: the number of entries to acquire
  218. """
  219. if amount > limit:
  220. return False
  221. timestamp = time.time()
  222. try:
  223. updates: dict[
  224. str,
  225. dict[str, datetime.datetime | dict[str, list[float] | int]],
  226. ] = {
  227. "$push": {
  228. "entries": {
  229. "$each": [timestamp] * amount,
  230. "$position": 0,
  231. "$slice": limit,
  232. }
  233. },
  234. "$set": {
  235. "expireAt": (
  236. datetime.datetime.now(datetime.timezone.utc)
  237. + datetime.timedelta(seconds=expiry)
  238. )
  239. },
  240. }
  241. self.windows.update_one(
  242. {
  243. "_id": key,
  244. f"entries.{limit - amount}": {"$not": {"$gte": timestamp - expiry}},
  245. },
  246. updates,
  247. upsert=True,
  248. )
  249. return True
  250. except self.lib.errors.DuplicateKeyError:
  251. return False
  252. def get_sliding_window(
  253. self, key: str, expiry: int
  254. ) -> tuple[int, float, int, float]:
  255. expiry_ms = expiry * 1000
  256. if result := self.windows.find_one_and_update(
  257. {"_id": key},
  258. [
  259. {
  260. "$set": {
  261. "previousCount": {
  262. "$cond": {
  263. "if": {
  264. "$lte": [
  265. {"$subtract": ["$expiresAt", "$$NOW"]},
  266. expiry_ms,
  267. ]
  268. },
  269. "then": {"$ifNull": ["$currentCount", 0]},
  270. "else": {"$ifNull": ["$previousCount", 0]},
  271. }
  272. },
  273. "currentCount": {
  274. "$cond": {
  275. "if": {
  276. "$lte": [
  277. {"$subtract": ["$expiresAt", "$$NOW"]},
  278. expiry_ms,
  279. ]
  280. },
  281. "then": 0,
  282. "else": {"$ifNull": ["$currentCount", 0]},
  283. }
  284. },
  285. "expiresAt": {
  286. "$cond": {
  287. "if": {
  288. "$lte": [
  289. {"$subtract": ["$expiresAt", "$$NOW"]},
  290. expiry_ms,
  291. ]
  292. },
  293. "then": {
  294. "$add": ["$expiresAt", expiry_ms],
  295. },
  296. "else": "$expiresAt",
  297. }
  298. },
  299. }
  300. }
  301. ],
  302. return_document=self.lib.ReturnDocument.AFTER,
  303. projection=["currentCount", "previousCount", "expiresAt"],
  304. ):
  305. expires_at = (
  306. (result["expiresAt"].replace(tzinfo=datetime.timezone.utc).timestamp())
  307. if result.get("expiresAt")
  308. else time.time()
  309. )
  310. current_ttl = max(0, expires_at - time.time())
  311. prev_ttl = max(0, current_ttl - expiry if result["previousCount"] else 0)
  312. return (
  313. result["previousCount"],
  314. prev_ttl,
  315. result["currentCount"],
  316. current_ttl,
  317. )
  318. return 0, 0.0, 0, 0.0
  319. def acquire_sliding_window_entry(
  320. self, key: str, limit: int, expiry: int, amount: int = 1
  321. ) -> bool:
  322. expiry_ms = expiry * 1000
  323. result = self.windows.find_one_and_update(
  324. {"_id": key},
  325. [
  326. {
  327. "$set": {
  328. "previousCount": {
  329. "$cond": {
  330. "if": {
  331. "$lte": [
  332. {"$subtract": ["$expiresAt", "$$NOW"]},
  333. expiry_ms,
  334. ]
  335. },
  336. "then": {"$ifNull": ["$currentCount", 0]},
  337. "else": {"$ifNull": ["$previousCount", 0]},
  338. }
  339. },
  340. }
  341. },
  342. {
  343. "$set": {
  344. "currentCount": {
  345. "$cond": {
  346. "if": {
  347. "$lte": [
  348. {"$subtract": ["$expiresAt", "$$NOW"]},
  349. expiry_ms,
  350. ]
  351. },
  352. "then": 0,
  353. "else": {"$ifNull": ["$currentCount", 0]},
  354. }
  355. },
  356. "expiresAt": {
  357. "$cond": {
  358. "if": {
  359. "$lte": [
  360. {"$subtract": ["$expiresAt", "$$NOW"]},
  361. expiry_ms,
  362. ]
  363. },
  364. "then": {
  365. "$cond": {
  366. "if": {"$gt": ["$expiresAt", 0]},
  367. "then": {"$add": ["$expiresAt", expiry_ms]},
  368. "else": {"$add": ["$$NOW", 2 * expiry_ms]},
  369. }
  370. },
  371. "else": "$expiresAt",
  372. }
  373. },
  374. }
  375. },
  376. {
  377. "$set": {
  378. "curWeightedCount": {
  379. "$floor": {
  380. "$add": [
  381. {
  382. "$multiply": [
  383. "$previousCount",
  384. {
  385. "$divide": [
  386. {
  387. "$max": [
  388. 0,
  389. {
  390. "$subtract": [
  391. "$expiresAt",
  392. {
  393. "$add": [
  394. "$$NOW",
  395. expiry_ms,
  396. ]
  397. },
  398. ]
  399. },
  400. ]
  401. },
  402. expiry_ms,
  403. ]
  404. },
  405. ]
  406. },
  407. "$currentCount",
  408. ]
  409. }
  410. }
  411. }
  412. },
  413. {
  414. "$set": {
  415. "currentCount": {
  416. "$cond": {
  417. "if": {
  418. "$lte": [
  419. {"$add": ["$curWeightedCount", amount]},
  420. limit,
  421. ]
  422. },
  423. "then": {"$add": ["$currentCount", amount]},
  424. "else": "$currentCount",
  425. }
  426. }
  427. }
  428. },
  429. {
  430. "$set": {
  431. "_acquired": {
  432. "$lte": [{"$add": ["$curWeightedCount", amount]}, limit]
  433. }
  434. }
  435. },
  436. {"$unset": ["curWeightedCount"]},
  437. ],
  438. return_document=self.lib.ReturnDocument.AFTER,
  439. upsert=True,
  440. )
  441. return cast(bool, result["_acquired"])
  442. def __del__(self) -> None:
  443. if self.storage:
  444. self.storage.close()
  445. @versionadded(version="2.1")
  446. @versionchanged(
  447. version="3.14.0",
  448. reason="Added option to select custom collection names for windows & counters",
  449. )
  450. class MongoDBStorage(MongoDBStorageBase):
  451. STORAGE_SCHEME = ["mongodb", "mongodb+srv"]
  452. def _init_mongo_client(
  453. self, uri: str | None, **options: int | str | bool
  454. ) -> MongoClient:
  455. return cast(MongoClient, self.lib.MongoClient(uri, **options))