mongodb.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522
  1. from __future__ import annotations
  2. import asyncio
  3. import datetime
  4. import time
  5. from deprecated.sphinx import versionadded, versionchanged
  6. from limits.aio.storage.base import (
  7. MovingWindowSupport,
  8. SlidingWindowCounterSupport,
  9. Storage,
  10. )
  11. from limits.typing import (
  12. ParamSpec,
  13. TypeVar,
  14. cast,
  15. )
  16. from limits.util import get_dependency
  17. P = ParamSpec("P")
  18. R = TypeVar("R")
  19. @versionadded(version="2.1")
  20. @versionchanged(
  21. version="3.14.0",
  22. reason="Added option to select custom collection names for windows & counters",
  23. )
  24. class MongoDBStorage(Storage, MovingWindowSupport, SlidingWindowCounterSupport):
  25. """
  26. Rate limit storage with MongoDB as backend.
  27. Depends on :pypi:`motor`
  28. """
  29. STORAGE_SCHEME = ["async+mongodb", "async+mongodb+srv"]
  30. """
  31. The storage scheme for MongoDB for use in an async context
  32. """
  33. DEPENDENCIES = ["motor.motor_asyncio", "pymongo"]
  34. def __init__(
  35. self,
  36. uri: str,
  37. database_name: str = "limits",
  38. counter_collection_name: str = "counters",
  39. window_collection_name: str = "windows",
  40. wrap_exceptions: bool = False,
  41. **options: float | str | bool,
  42. ) -> None:
  43. """
  44. :param uri: uri of the form ``async+mongodb://[user:password]@host:port?...``,
  45. This uri is passed directly to :class:`~motor.motor_asyncio.AsyncIOMotorClient`
  46. :param database_name: The database to use for storing the rate limit
  47. collections.
  48. :param counter_collection_name: The collection name to use for individual counters
  49. used in fixed window strategies
  50. :param window_collection_name: The collection name to use for sliding & moving window
  51. storage
  52. :param wrap_exceptions: Whether to wrap storage exceptions in
  53. :exc:`limits.errors.StorageError` before raising it.
  54. :param options: all remaining keyword arguments are passed
  55. to the constructor of :class:`~motor.motor_asyncio.AsyncIOMotorClient`
  56. :raise ConfigurationError: when the :pypi:`motor` or :pypi:`pymongo` are
  57. not available
  58. """
  59. uri = uri.replace("async+mongodb", "mongodb", 1)
  60. super().__init__(uri, wrap_exceptions=wrap_exceptions, **options)
  61. self.dependency = self.dependencies["motor.motor_asyncio"]
  62. self.proxy_dependency = self.dependencies["pymongo"]
  63. self.lib_errors, _ = get_dependency("pymongo.errors")
  64. self.storage = self.dependency.module.AsyncIOMotorClient(uri, **options)
  65. # TODO: Fix this hack. It was noticed when running a benchmark
  66. # with FastAPI - however - doesn't appear in unit tests or in an isolated
  67. # use. Reference: https://jira.mongodb.org/browse/MOTOR-822
  68. self.storage.get_io_loop = asyncio.get_running_loop
  69. self.__database_name = database_name
  70. self.__collection_mapping = {
  71. "counters": counter_collection_name,
  72. "windows": window_collection_name,
  73. }
  74. self.__indices_created = False
  75. @property
  76. def base_exceptions(
  77. self,
  78. ) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
  79. return self.lib_errors.PyMongoError # type: ignore
  80. @property
  81. def database(self): # type: ignore
  82. return self.storage.get_database(self.__database_name)
  83. async def create_indices(self) -> None:
  84. if not self.__indices_created:
  85. await asyncio.gather(
  86. self.database[self.__collection_mapping["counters"]].create_index(
  87. "expireAt", expireAfterSeconds=0
  88. ),
  89. self.database[self.__collection_mapping["windows"]].create_index(
  90. "expireAt", expireAfterSeconds=0
  91. ),
  92. )
  93. self.__indices_created = True
  94. async def reset(self) -> int | None:
  95. """
  96. Delete all rate limit keys in the rate limit collections (counters, windows)
  97. """
  98. num_keys = sum(
  99. await asyncio.gather(
  100. self.database[self.__collection_mapping["counters"]].count_documents(
  101. {}
  102. ),
  103. self.database[self.__collection_mapping["windows"]].count_documents({}),
  104. )
  105. )
  106. await asyncio.gather(
  107. self.database[self.__collection_mapping["counters"]].drop(),
  108. self.database[self.__collection_mapping["windows"]].drop(),
  109. )
  110. return cast(int, num_keys)
  111. async def clear(self, key: str) -> None:
  112. """
  113. :param key: the key to clear rate limits for
  114. """
  115. await asyncio.gather(
  116. self.database[self.__collection_mapping["counters"]].find_one_and_delete(
  117. {"_id": key}
  118. ),
  119. self.database[self.__collection_mapping["windows"]].find_one_and_delete(
  120. {"_id": key}
  121. ),
  122. )
  123. async def get_expiry(self, key: str) -> float:
  124. """
  125. :param key: the key to get the expiry for
  126. """
  127. counter = await self.database[self.__collection_mapping["counters"]].find_one(
  128. {"_id": key}
  129. )
  130. return (
  131. (counter["expireAt"] if counter else datetime.datetime.now())
  132. .replace(tzinfo=datetime.timezone.utc)
  133. .timestamp()
  134. )
  135. async def get(self, key: str) -> int:
  136. """
  137. :param key: the key to get the counter value for
  138. """
  139. counter = await self.database[self.__collection_mapping["counters"]].find_one(
  140. {
  141. "_id": key,
  142. "expireAt": {"$gte": datetime.datetime.now(datetime.timezone.utc)},
  143. },
  144. projection=["count"],
  145. )
  146. return counter and counter["count"] or 0
  147. async def incr(
  148. self, key: str, expiry: int, elastic_expiry: bool = False, amount: int = 1
  149. ) -> int:
  150. """
  151. increments the counter for a given rate limit key
  152. :param key: the key to increment
  153. :param expiry: amount in seconds for the key to expire in
  154. :param elastic_expiry: whether to keep extending the rate limit
  155. window every hit.
  156. :param amount: the number to increment by
  157. """
  158. await self.create_indices()
  159. expiration = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(
  160. seconds=expiry
  161. )
  162. response = await self.database[
  163. self.__collection_mapping["counters"]
  164. ].find_one_and_update(
  165. {"_id": key},
  166. [
  167. {
  168. "$set": {
  169. "count": {
  170. "$cond": {
  171. "if": {"$lt": ["$expireAt", "$$NOW"]},
  172. "then": amount,
  173. "else": {"$add": ["$count", amount]},
  174. }
  175. },
  176. "expireAt": {
  177. "$cond": {
  178. "if": {"$lt": ["$expireAt", "$$NOW"]},
  179. "then": expiration,
  180. "else": (expiration if elastic_expiry else "$expireAt"),
  181. }
  182. },
  183. }
  184. },
  185. ],
  186. upsert=True,
  187. projection=["count"],
  188. return_document=self.proxy_dependency.module.ReturnDocument.AFTER,
  189. )
  190. return int(response["count"])
  191. async def check(self) -> bool:
  192. """
  193. Check if storage is healthy by calling
  194. :meth:`motor.motor_asyncio.AsyncIOMotorClient.server_info`
  195. """
  196. try:
  197. await self.storage.server_info()
  198. return True
  199. except: # noqa: E722
  200. return False
  201. async def get_moving_window(
  202. self, key: str, limit: int, expiry: int
  203. ) -> tuple[float, int]:
  204. """
  205. returns the starting point and the number of entries in the moving
  206. window
  207. :param str key: rate limit key
  208. :param int expiry: expiry of entry
  209. :return: (start of window, number of acquired entries)
  210. """
  211. timestamp = time.time()
  212. if result := (
  213. await self.database[self.__collection_mapping["windows"]]
  214. .aggregate(
  215. [
  216. {"$match": {"_id": key}},
  217. {
  218. "$project": {
  219. "entries": {
  220. "$filter": {
  221. "input": "$entries",
  222. "as": "entry",
  223. "cond": {"$gte": ["$$entry", timestamp - expiry]},
  224. }
  225. }
  226. }
  227. },
  228. {"$unwind": "$entries"},
  229. {
  230. "$group": {
  231. "_id": "$_id",
  232. "min": {"$min": "$entries"},
  233. "count": {"$sum": 1},
  234. }
  235. },
  236. ]
  237. )
  238. .to_list(length=1)
  239. ):
  240. return result[0]["min"], result[0]["count"]
  241. return timestamp, 0
  242. async def acquire_entry(
  243. self, key: str, limit: int, expiry: int, amount: int = 1
  244. ) -> bool:
  245. """
  246. :param key: rate limit key to acquire an entry in
  247. :param limit: amount of entries allowed
  248. :param expiry: expiry of the entry
  249. :param amount: the number of entries to acquire
  250. """
  251. await self.create_indices()
  252. if amount > limit:
  253. return False
  254. timestamp = time.time()
  255. try:
  256. updates: dict[
  257. str,
  258. dict[str, datetime.datetime | dict[str, list[float] | int]],
  259. ] = {
  260. "$push": {
  261. "entries": {
  262. "$each": [timestamp] * amount,
  263. "$position": 0,
  264. "$slice": limit,
  265. }
  266. },
  267. "$set": {
  268. "expireAt": (
  269. datetime.datetime.now(datetime.timezone.utc)
  270. + datetime.timedelta(seconds=expiry)
  271. )
  272. },
  273. }
  274. await self.database[self.__collection_mapping["windows"]].update_one(
  275. {
  276. "_id": key,
  277. f"entries.{limit - amount}": {"$not": {"$gte": timestamp - expiry}},
  278. },
  279. updates,
  280. upsert=True,
  281. )
  282. return True
  283. except self.proxy_dependency.module.errors.DuplicateKeyError:
  284. return False
  285. async def acquire_sliding_window_entry(
  286. self, key: str, limit: int, expiry: int, amount: int = 1
  287. ) -> bool:
  288. await self.create_indices()
  289. expiry_ms = expiry * 1000
  290. result = await self.database[
  291. self.__collection_mapping["windows"]
  292. ].find_one_and_update(
  293. {"_id": key},
  294. [
  295. {
  296. "$set": {
  297. "previousCount": {
  298. "$cond": {
  299. "if": {
  300. "$lte": [
  301. {"$subtract": ["$expiresAt", "$$NOW"]},
  302. expiry_ms,
  303. ]
  304. },
  305. "then": {"$ifNull": ["$currentCount", 0]},
  306. "else": {"$ifNull": ["$previousCount", 0]},
  307. }
  308. },
  309. }
  310. },
  311. {
  312. "$set": {
  313. "currentCount": {
  314. "$cond": {
  315. "if": {
  316. "$lte": [
  317. {"$subtract": ["$expiresAt", "$$NOW"]},
  318. expiry_ms,
  319. ]
  320. },
  321. "then": 0,
  322. "else": {"$ifNull": ["$currentCount", 0]},
  323. }
  324. },
  325. "expiresAt": {
  326. "$cond": {
  327. "if": {
  328. "$lte": [
  329. {"$subtract": ["$expiresAt", "$$NOW"]},
  330. expiry_ms,
  331. ]
  332. },
  333. "then": {
  334. "$cond": {
  335. "if": {"$gt": ["$expiresAt", 0]},
  336. "then": {"$add": ["$expiresAt", expiry_ms]},
  337. "else": {"$add": ["$$NOW", 2 * expiry_ms]},
  338. }
  339. },
  340. "else": "$expiresAt",
  341. }
  342. },
  343. }
  344. },
  345. {
  346. "$set": {
  347. "curWeightedCount": {
  348. "$floor": {
  349. "$add": [
  350. {
  351. "$multiply": [
  352. "$previousCount",
  353. {
  354. "$divide": [
  355. {
  356. "$max": [
  357. 0,
  358. {
  359. "$subtract": [
  360. "$expiresAt",
  361. {
  362. "$add": [
  363. "$$NOW",
  364. expiry_ms,
  365. ]
  366. },
  367. ]
  368. },
  369. ]
  370. },
  371. expiry_ms,
  372. ]
  373. },
  374. ]
  375. },
  376. "$currentCount",
  377. ]
  378. }
  379. }
  380. }
  381. },
  382. {
  383. "$set": {
  384. "currentCount": {
  385. "$cond": {
  386. "if": {
  387. "$lte": [
  388. {"$add": ["$curWeightedCount", amount]},
  389. limit,
  390. ]
  391. },
  392. "then": {"$add": ["$currentCount", amount]},
  393. "else": "$currentCount",
  394. }
  395. }
  396. }
  397. },
  398. {
  399. "$set": {
  400. "_acquired": {
  401. "$lte": [{"$add": ["$curWeightedCount", amount]}, limit]
  402. }
  403. }
  404. },
  405. {"$unset": ["curWeightedCount"]},
  406. ],
  407. return_document=self.proxy_dependency.module.ReturnDocument.AFTER,
  408. upsert=True,
  409. )
  410. return cast(bool, result["_acquired"])
  411. async def get_sliding_window(
  412. self, key: str, expiry: int
  413. ) -> tuple[int, float, int, float]:
  414. expiry_ms = expiry * 1000
  415. if result := await self.database[
  416. self.__collection_mapping["windows"]
  417. ].find_one_and_update(
  418. {"_id": key},
  419. [
  420. {
  421. "$set": {
  422. "previousCount": {
  423. "$cond": {
  424. "if": {
  425. "$lte": [
  426. {"$subtract": ["$expiresAt", "$$NOW"]},
  427. expiry_ms,
  428. ]
  429. },
  430. "then": {"$ifNull": ["$currentCount", 0]},
  431. "else": {"$ifNull": ["$previousCount", 0]},
  432. }
  433. },
  434. "currentCount": {
  435. "$cond": {
  436. "if": {
  437. "$lte": [
  438. {"$subtract": ["$expiresAt", "$$NOW"]},
  439. expiry_ms,
  440. ]
  441. },
  442. "then": 0,
  443. "else": {"$ifNull": ["$currentCount", 0]},
  444. }
  445. },
  446. "expiresAt": {
  447. "$cond": {
  448. "if": {
  449. "$lte": [
  450. {"$subtract": ["$expiresAt", "$$NOW"]},
  451. expiry_ms,
  452. ]
  453. },
  454. "then": {"$add": ["$expiresAt", expiry_ms]},
  455. "else": "$expiresAt",
  456. }
  457. },
  458. }
  459. }
  460. ],
  461. return_document=self.proxy_dependency.module.ReturnDocument.AFTER,
  462. projection=["currentCount", "previousCount", "expiresAt"],
  463. ):
  464. expires_at = (
  465. (result["expiresAt"].replace(tzinfo=datetime.timezone.utc).timestamp())
  466. if result.get("expiresAt")
  467. else time.time()
  468. )
  469. current_ttl = max(0, expires_at - time.time())
  470. prev_ttl = max(0, current_ttl - expiry if result["previousCount"] else 0)
  471. return (
  472. result["previousCount"],
  473. prev_ttl,
  474. result["currentCount"],
  475. current_ttl,
  476. )
  477. return 0, 0.0, 0, 0.0
  478. def __del__(self) -> None:
  479. self.storage and self.storage.close()