redis.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. import typing as _t
  2. from cachelib.base import BaseCache
  3. from cachelib.serializers import RedisSerializer
  4. class RedisCache(BaseCache):
  5. """Uses the Redis key-value store as a cache backend.
  6. The first argument can be either a string denoting address of the Redis
  7. server or an object resembling an instance of a redis.Redis class.
  8. Note: Python Redis API already takes care of encoding unicode strings on
  9. the fly.
  10. :param host: address of the Redis server or an object which API is
  11. compatible with the official Python Redis client (redis-py).
  12. :param port: port number on which Redis server listens for connections.
  13. :param password: password authentication for the Redis server.
  14. :param db: db (zero-based numeric index) on Redis Server to connect.
  15. :param default_timeout: the default timeout that is used if no timeout is
  16. specified on :meth:`~BaseCache.set`. A timeout of
  17. 0 indicates that the cache never expires.
  18. :param key_prefix: A prefix that should be added to all keys.
  19. Any additional keyword arguments will be passed to ``redis.Redis``.
  20. """
  21. _read_client: _t.Any = None
  22. _write_client: _t.Any = None
  23. serializer = RedisSerializer()
  24. def __init__(
  25. self,
  26. host: _t.Any = "localhost",
  27. port: int = 6379,
  28. password: _t.Optional[str] = None,
  29. db: int = 0,
  30. default_timeout: int = 300,
  31. key_prefix: _t.Optional[_t.Union[str, _t.Callable[[], str]]] = None,
  32. **kwargs: _t.Any,
  33. ):
  34. BaseCache.__init__(self, default_timeout)
  35. if host is None:
  36. raise ValueError("RedisCache host parameter may not be None")
  37. if isinstance(host, str):
  38. try:
  39. import redis
  40. except ImportError as err:
  41. raise RuntimeError("no redis module found") from err
  42. if kwargs.get("decode_responses", None):
  43. raise ValueError("decode_responses is not supported by RedisCache.")
  44. self._write_client = self._read_client = redis.Redis(
  45. host=host, port=port, password=password, db=db, **kwargs
  46. )
  47. else:
  48. self._read_client = self._write_client = host
  49. self.key_prefix = key_prefix or ""
  50. def _get_prefix(self) -> str:
  51. return (
  52. self.key_prefix if isinstance(self.key_prefix, str) else self.key_prefix()
  53. )
  54. def _normalize_timeout(self, timeout: _t.Optional[int]) -> int:
  55. """Normalize timeout by setting it to default of 300 if
  56. not defined (None) or -1 if explicitly set to zero.
  57. :param timeout: timeout to normalize.
  58. """
  59. timeout = BaseCache._normalize_timeout(self, timeout)
  60. if timeout == 0:
  61. timeout = -1
  62. return timeout
  63. def get(self, key: str) -> _t.Any:
  64. return self.serializer.loads(
  65. self._read_client.get(f"{self._get_prefix()}{key}")
  66. )
  67. def get_many(self, *keys: str) -> _t.List[_t.Any]:
  68. if self.key_prefix:
  69. prefixed_keys = [f"{self._get_prefix()}{key}" for key in keys]
  70. else:
  71. prefixed_keys = list(keys)
  72. return [self.serializer.loads(x) for x in self._read_client.mget(prefixed_keys)]
  73. def set(self, key: str, value: _t.Any, timeout: _t.Optional[int] = None) -> _t.Any:
  74. timeout = self._normalize_timeout(timeout)
  75. dump = self.serializer.dumps(value)
  76. if timeout == -1:
  77. result = self._write_client.set(
  78. name=f"{self._get_prefix()}{key}", value=dump
  79. )
  80. else:
  81. result = self._write_client.setex(
  82. name=f"{self._get_prefix()}{key}", value=dump, time=timeout
  83. )
  84. return result
  85. def add(self, key: str, value: _t.Any, timeout: _t.Optional[int] = None) -> _t.Any:
  86. timeout = self._normalize_timeout(timeout)
  87. dump = self.serializer.dumps(value)
  88. created = self._write_client.setnx(
  89. name=f"{self._get_prefix()}{key}", value=dump
  90. )
  91. # handle case where timeout is explicitly set to zero
  92. if created and timeout != -1:
  93. self._write_client.expire(name=f"{self._get_prefix()}{key}", time=timeout)
  94. return created
  95. def set_many(
  96. self, mapping: _t.Dict[str, _t.Any], timeout: _t.Optional[int] = None
  97. ) -> _t.List[_t.Any]:
  98. timeout = self._normalize_timeout(timeout)
  99. # Use transaction=False to batch without calling redis MULTI
  100. # which is not supported by twemproxy
  101. pipe = self._write_client.pipeline(transaction=False)
  102. for key, value in mapping.items():
  103. dump = self.serializer.dumps(value)
  104. if timeout == -1:
  105. pipe.set(name=f"{self._get_prefix()}{key}", value=dump)
  106. else:
  107. pipe.setex(name=f"{self._get_prefix()}{key}", value=dump, time=timeout)
  108. results = pipe.execute()
  109. return [k for k, was_set in zip(mapping.keys(), results) if was_set]
  110. def delete(self, key: str) -> bool:
  111. return bool(self._write_client.delete(f"{self._get_prefix()}{key}"))
  112. def delete_many(self, *keys: str) -> _t.List[_t.Any]:
  113. if not keys:
  114. return []
  115. if self.key_prefix:
  116. prefixed_keys = [f"{self._get_prefix()}{key}" for key in keys]
  117. else:
  118. prefixed_keys = [k for k in keys]
  119. self._write_client.delete(*prefixed_keys)
  120. return [k for k in prefixed_keys if not self.has(k)]
  121. def has(self, key: str) -> bool:
  122. return bool(self._read_client.exists(f"{self._get_prefix()}{key}"))
  123. def clear(self) -> bool:
  124. status = 0
  125. if self.key_prefix:
  126. keys = self._read_client.keys(self._get_prefix() + "*")
  127. if keys:
  128. status = self._write_client.delete(*keys)
  129. else:
  130. status = self._write_client.flushdb()
  131. return bool(status)
  132. def inc(self, key: str, delta: int = 1) -> _t.Any:
  133. return self._write_client.incr(name=f"{self._get_prefix()}{key}", amount=delta)
  134. def dec(self, key: str, delta: int = 1) -> _t.Any:
  135. return self._write_client.incr(name=f"{self._get_prefix()}{key}", amount=-delta)