simple.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import typing as _t
  2. from time import time
  3. from cachelib.base import BaseCache
  4. from cachelib.serializers import SimpleSerializer
  5. class SimpleCache(BaseCache):
  6. """Simple memory cache for single process environments. This class exists
  7. mainly for the development server and is not 100% thread safe. It tries
  8. to use as many atomic operations as possible and no locks for simplicity
  9. but it could happen under heavy load that keys are added multiple times.
  10. :param threshold: the maximum number of items the cache stores before
  11. it starts deleting some.
  12. :param default_timeout: the default timeout that is used if no timeout is
  13. specified on :meth:`~BaseCache.set`. A timeout of
  14. 0 indicates that the cache never expires.
  15. """
  16. serializer = SimpleSerializer()
  17. def __init__(
  18. self,
  19. threshold: int = 500,
  20. default_timeout: int = 300,
  21. ):
  22. BaseCache.__init__(self, default_timeout)
  23. self._cache: _t.Dict[str, _t.Any] = {}
  24. self._threshold = threshold or 500 # threshold = 0
  25. def _over_threshold(self) -> bool:
  26. return len(self._cache) > self._threshold
  27. def _remove_expired(self, now: float) -> None:
  28. toremove = [k for k, (expires, _) in self._cache.items() if expires < now]
  29. for k in toremove:
  30. self._cache.pop(k, None)
  31. def _remove_older(self) -> None:
  32. k_ordered = (
  33. k for k, v in sorted(self._cache.items(), key=lambda item: item[1][0])
  34. )
  35. for k in k_ordered:
  36. self._cache.pop(k, None)
  37. if not self._over_threshold():
  38. break
  39. def _prune(self) -> None:
  40. if self._over_threshold():
  41. now = time()
  42. self._remove_expired(now)
  43. # remove older items if still over threshold
  44. if self._over_threshold():
  45. self._remove_older()
  46. def _normalize_timeout(self, timeout: _t.Optional[int]) -> int:
  47. timeout = BaseCache._normalize_timeout(self, timeout)
  48. if timeout > 0:
  49. timeout = int(time()) + timeout
  50. return timeout
  51. def get(self, key: str) -> _t.Any:
  52. try:
  53. expires, value = self._cache[key]
  54. if expires == 0 or expires > time():
  55. return self.serializer.loads(value)
  56. except KeyError:
  57. return None
  58. def set(
  59. self, key: str, value: _t.Any, timeout: _t.Optional[int] = None
  60. ) -> _t.Optional[bool]:
  61. expires = self._normalize_timeout(timeout)
  62. self._prune()
  63. self._cache[key] = (expires, self.serializer.dumps(value))
  64. return True
  65. def add(self, key: str, value: _t.Any, timeout: _t.Optional[int] = None) -> bool:
  66. expires = self._normalize_timeout(timeout)
  67. self._prune()
  68. item = (expires, self.serializer.dumps(value))
  69. if key in self._cache:
  70. return False
  71. self._cache.setdefault(key, item)
  72. return True
  73. def delete(self, key: str) -> bool:
  74. return self._cache.pop(key, None) is not None
  75. def has(self, key: str) -> bool:
  76. try:
  77. expires, value = self._cache[key]
  78. return bool(expires == 0 or expires > time())
  79. except KeyError:
  80. return False
  81. def clear(self) -> bool:
  82. self._cache.clear()
  83. return not bool(self._cache)