mongodb.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. import datetime
  2. import logging
  3. import typing as _t
  4. from cachelib.base import BaseCache
  5. from cachelib.serializers import BaseSerializer
  6. class MongoDbCache(BaseCache):
  7. """
  8. Implementation of cachelib.BaseCache that uses mongodb collection
  9. as the backend.
  10. Limitations: maximum MongoDB document size is 16mb
  11. :param client: mongodb client or connection string
  12. :param db: mongodb database name
  13. :param collection: mongodb collection name
  14. :param default_timeout: Set the timeout in seconds after which cache entries
  15. expire
  16. :param key_prefix: A prefix that should be added to all keys.
  17. """
  18. serializer = BaseSerializer()
  19. def __init__(
  20. self,
  21. client: _t.Any = None,
  22. db: _t.Optional[str] = "cache-db",
  23. collection: _t.Optional[str] = "cache-collection",
  24. default_timeout: int = 300,
  25. key_prefix: _t.Optional[str] = None,
  26. **kwargs: _t.Any
  27. ):
  28. super().__init__(default_timeout)
  29. try:
  30. import pymongo # type: ignore
  31. except ImportError:
  32. logging.warning("no pymongo module found")
  33. if client is None or isinstance(client, str):
  34. client = pymongo.MongoClient(host=client)
  35. self.client = client[db][collection]
  36. index_info = self.client.index_information()
  37. all_keys = {
  38. subkey[0] for value in index_info.values() for subkey in value["key"]
  39. }
  40. if "id" not in all_keys:
  41. self.client.create_index("id", unique=True)
  42. self.key_prefix = key_prefix or ""
  43. self.collection = collection
  44. def _utcnow(self) -> _t.Any:
  45. """Return a tz-aware UTC datetime representing the current time"""
  46. return datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc)
  47. def _expire_records(self) -> _t.Any:
  48. res = self.client.delete_many({"expiration": {"$lte": self._utcnow()}})
  49. return res
  50. def get(self, key: str) -> _t.Any:
  51. """
  52. Get a cache item
  53. :param key: The cache key of the item to fetch
  54. :return: cache value if not expired, else None
  55. """
  56. self._expire_records()
  57. record = self.client.find_one({"id": self.key_prefix + key})
  58. value = None
  59. if record:
  60. value = self.serializer.loads(record["val"])
  61. return value
  62. def delete(self, key: str) -> bool:
  63. """
  64. Deletes an item from the cache. This is a no-op if the item doesn't
  65. exist
  66. :param key: Key of the item to delete.
  67. :return: True if the key existed and was deleted
  68. """
  69. res = self.client.delete_one({"id": self.key_prefix + key})
  70. deleted = bool(res.deleted_count > 0)
  71. return deleted
  72. def _set(
  73. self,
  74. key: str,
  75. value: _t.Any,
  76. timeout: _t.Optional[int] = None,
  77. overwrite: _t.Optional[bool] = True,
  78. ) -> _t.Any:
  79. """
  80. Store a cache item, with the option to not overwrite existing items
  81. :param key: Cache key to use
  82. :param value: a serializable object
  83. :param timeout: The timeout in seconds for the cached item, to override
  84. the default
  85. :param overwrite: If true, overwrite any existing cache item with key.
  86. If false, the new value will only be stored if no
  87. non-expired cache item exists with key.
  88. :return: True if the new item was stored.
  89. """
  90. timeout = self._normalize_timeout(timeout)
  91. now = self._utcnow()
  92. if not overwrite:
  93. # fail if a non-expired item with this key
  94. # already exists
  95. if self.has(key):
  96. return False
  97. dump = self.serializer.dumps(value)
  98. record = {"id": self.key_prefix + key, "val": dump}
  99. if timeout > 0:
  100. record["expiration"] = now + datetime.timedelta(seconds=timeout)
  101. self.client.update_one({"id": self.key_prefix + key}, {"$set": record}, True)
  102. return True
  103. def set(self, key: str, value: _t.Any, timeout: _t.Optional[int] = None) -> _t.Any:
  104. self._expire_records()
  105. return self._set(key, value, timeout=timeout, overwrite=True)
  106. def set_many(
  107. self, mapping: _t.Dict[str, _t.Any], timeout: _t.Optional[int] = None
  108. ) -> _t.List[_t.Any]:
  109. self._expire_records()
  110. from pymongo import UpdateOne
  111. operations = []
  112. now = self._utcnow()
  113. timeout = self._normalize_timeout(timeout)
  114. for key, val in mapping.items():
  115. dump = self.serializer.dumps(val)
  116. record = {"id": self.key_prefix + key, "val": dump}
  117. if timeout > 0:
  118. record["expiration"] = now + datetime.timedelta(seconds=timeout)
  119. operations.append(
  120. UpdateOne({"id": self.key_prefix + key}, {"$set": record}, upsert=True),
  121. )
  122. result = self.client.bulk_write(operations)
  123. keys = list(mapping.keys())
  124. if result.bulk_api_result["nUpserted"] != len(keys):
  125. query = self.client.find(
  126. {"id": {"$in": [self.key_prefix + key for key in keys]}}
  127. )
  128. keys = []
  129. for item in query:
  130. keys.append(item["id"])
  131. return keys
  132. def get_many(self, *keys: str) -> _t.List[_t.Any]:
  133. results = self.get_dict(*keys)
  134. values = []
  135. for key in keys:
  136. values.append(results.get(key, None))
  137. return values
  138. def get_dict(self, *keys: str) -> _t.Dict[str, _t.Any]:
  139. self._expire_records()
  140. query = self.client.find(
  141. {"id": {"$in": [self.key_prefix + key for key in keys]}}
  142. )
  143. results = dict.fromkeys(keys, None)
  144. for item in query:
  145. value = self.serializer.loads(item["val"])
  146. results[item["id"][len(self.key_prefix) :]] = value
  147. return results
  148. def add(self, key: str, value: _t.Any, timeout: _t.Optional[int] = None) -> _t.Any:
  149. self._expire_records()
  150. return self._set(key, value, timeout=timeout, overwrite=False)
  151. def has(self, key: str) -> bool:
  152. self._expire_records()
  153. record = self.get(key)
  154. return record is not None
  155. def delete_many(self, *keys: str) -> _t.List[_t.Any]:
  156. self._expire_records()
  157. res = list(keys)
  158. filter = {"id": {"$in": [self.key_prefix + key for key in keys]}}
  159. result = self.client.delete_many(filter)
  160. if result.deleted_count != len(keys):
  161. existing_keys = [
  162. item["id"][len(self.key_prefix) :] for item in self.client.find(filter)
  163. ]
  164. res = [item for item in keys if item not in existing_keys]
  165. return res
  166. def clear(self) -> bool:
  167. self.client.drop()
  168. return True