etcd.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. from __future__ import annotations
  2. import asyncio
  3. import time
  4. import urllib.parse
  5. from typing import TYPE_CHECKING
  6. from deprecated.sphinx import deprecated
  7. from limits.aio.storage.base import Storage
  8. from limits.errors import ConcurrentUpdateError
  9. if TYPE_CHECKING:
  10. import aetcd
  11. @deprecated(version="4.4")
  12. class EtcdStorage(Storage):
  13. """
  14. Rate limit storage with etcd as backend.
  15. Depends on :pypi:`aetcd`.
  16. """
  17. STORAGE_SCHEME = ["async+etcd"]
  18. """The async storage scheme for etcd"""
  19. DEPENDENCIES = ["aetcd"]
  20. PREFIX = "limits"
  21. MAX_RETRIES = 5
  22. def __init__(
  23. self,
  24. uri: str,
  25. max_retries: int = MAX_RETRIES,
  26. wrap_exceptions: bool = False,
  27. **options: str,
  28. ) -> None:
  29. """
  30. :param uri: etcd location of the form
  31. ``async+etcd://host:port``,
  32. :param max_retries: Maximum number of attempts to retry
  33. in the case of concurrent updates to a rate limit key
  34. :param wrap_exceptions: Whether to wrap storage exceptions in
  35. :exc:`limits.errors.StorageError` before raising it.
  36. :param options: all remaining keyword arguments are passed
  37. directly to the constructor of :class:`aetcd.client.Client`
  38. :raise ConfigurationError: when :pypi:`aetcd` is not available
  39. """
  40. parsed = urllib.parse.urlparse(uri)
  41. self.lib = self.dependencies["aetcd"].module
  42. self.storage: aetcd.Client = self.lib.Client(
  43. host=parsed.hostname, port=parsed.port, **options
  44. )
  45. self.max_retries = max_retries
  46. super().__init__(uri, wrap_exceptions=wrap_exceptions)
  47. @property
  48. def base_exceptions(
  49. self,
  50. ) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
  51. return self.lib.ClientError # type: ignore[no-any-return]
  52. def prefixed_key(self, key: str) -> bytes:
  53. return f"{self.PREFIX}/{key}".encode()
  54. async def incr(
  55. self, key: str, expiry: int, elastic_expiry: bool = False, amount: int = 1
  56. ) -> int:
  57. retries = 0
  58. etcd_key = self.prefixed_key(key)
  59. while retries < self.max_retries:
  60. now = time.time()
  61. lease = await self.storage.lease(expiry)
  62. window_end = now + expiry
  63. create_attempt = await self.storage.transaction(
  64. compare=[self.storage.transactions.create(etcd_key) == b"0"],
  65. success=[
  66. self.storage.transactions.put(
  67. etcd_key, f"{amount}:{window_end}".encode(), lease=lease.id
  68. )
  69. ],
  70. failure=[self.storage.transactions.get(etcd_key)],
  71. )
  72. if create_attempt[0]:
  73. return amount
  74. else:
  75. cur = create_attempt[1][0][0][1]
  76. cur_value, window_end = cur.value.split(b":")
  77. window_end = float(window_end)
  78. if window_end <= now:
  79. await asyncio.gather(
  80. self.storage.revoke_lease(cur.lease),
  81. self.storage.delete(etcd_key),
  82. )
  83. else:
  84. if elastic_expiry:
  85. await self.storage.refresh_lease(cur.lease)
  86. window_end = now + expiry
  87. new = int(cur_value) + amount
  88. if (
  89. await self.storage.transaction(
  90. compare=[
  91. self.storage.transactions.value(etcd_key) == cur.value
  92. ],
  93. success=[
  94. self.storage.transactions.put(
  95. etcd_key,
  96. f"{new}:{window_end}".encode(),
  97. lease=cur.lease,
  98. )
  99. ],
  100. failure=[],
  101. )
  102. )[0]:
  103. return new
  104. retries += 1
  105. raise ConcurrentUpdateError(key, retries)
  106. async def get(self, key: str) -> int:
  107. cur = await self.storage.get(self.prefixed_key(key))
  108. if cur:
  109. amount, expiry = cur.value.split(b":")
  110. if float(expiry) > time.time():
  111. return int(amount)
  112. return 0
  113. async def get_expiry(self, key: str) -> float:
  114. cur = await self.storage.get(self.prefixed_key(key))
  115. if cur:
  116. window_end = float(cur.value.split(b":")[1])
  117. return window_end
  118. return time.time()
  119. async def check(self) -> bool:
  120. try:
  121. await self.storage.status()
  122. return True
  123. except: # noqa
  124. return False
  125. async def reset(self) -> int | None:
  126. return (await self.storage.delete_prefix(f"{self.PREFIX}/".encode())).deleted
  127. async def clear(self, key: str) -> None:
  128. await self.storage.delete(self.prefixed_key(key))