memory.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. from __future__ import annotations
  2. import warnings
  3. from collections import OrderedDict, deque
  4. from dataclasses import dataclass, field
  5. from types import TracebackType
  6. from typing import Generic, NamedTuple, TypeVar
  7. from .. import (
  8. BrokenResourceError,
  9. ClosedResourceError,
  10. EndOfStream,
  11. WouldBlock,
  12. )
  13. from .._core._testing import TaskInfo, get_current_task
  14. from ..abc import Event, ObjectReceiveStream, ObjectSendStream
  15. from ..lowlevel import checkpoint
  16. T_Item = TypeVar("T_Item")
  17. T_co = TypeVar("T_co", covariant=True)
  18. T_contra = TypeVar("T_contra", contravariant=True)
  19. class MemoryObjectStreamStatistics(NamedTuple):
  20. current_buffer_used: int #: number of items stored in the buffer
  21. #: maximum number of items that can be stored on this stream (or :data:`math.inf`)
  22. max_buffer_size: float
  23. open_send_streams: int #: number of unclosed clones of the send stream
  24. open_receive_streams: int #: number of unclosed clones of the receive stream
  25. #: number of tasks blocked on :meth:`MemoryObjectSendStream.send`
  26. tasks_waiting_send: int
  27. #: number of tasks blocked on :meth:`MemoryObjectReceiveStream.receive`
  28. tasks_waiting_receive: int
  29. @dataclass(eq=False)
  30. class MemoryObjectItemReceiver(Generic[T_Item]):
  31. task_info: TaskInfo = field(init=False, default_factory=get_current_task)
  32. item: T_Item = field(init=False)
  33. def __repr__(self) -> str:
  34. # When item is not defined, we get following error with default __repr__:
  35. # AttributeError: 'MemoryObjectItemReceiver' object has no attribute 'item'
  36. item = getattr(self, "item", None)
  37. return f"{self.__class__.__name__}(task_info={self.task_info}, item={item!r})"
  38. @dataclass(eq=False)
  39. class MemoryObjectStreamState(Generic[T_Item]):
  40. max_buffer_size: float = field()
  41. buffer: deque[T_Item] = field(init=False, default_factory=deque)
  42. open_send_channels: int = field(init=False, default=0)
  43. open_receive_channels: int = field(init=False, default=0)
  44. waiting_receivers: OrderedDict[Event, MemoryObjectItemReceiver[T_Item]] = field(
  45. init=False, default_factory=OrderedDict
  46. )
  47. waiting_senders: OrderedDict[Event, T_Item] = field(
  48. init=False, default_factory=OrderedDict
  49. )
  50. def statistics(self) -> MemoryObjectStreamStatistics:
  51. return MemoryObjectStreamStatistics(
  52. len(self.buffer),
  53. self.max_buffer_size,
  54. self.open_send_channels,
  55. self.open_receive_channels,
  56. len(self.waiting_senders),
  57. len(self.waiting_receivers),
  58. )
  59. @dataclass(eq=False)
  60. class MemoryObjectReceiveStream(Generic[T_co], ObjectReceiveStream[T_co]):
  61. _state: MemoryObjectStreamState[T_co]
  62. _closed: bool = field(init=False, default=False)
  63. def __post_init__(self) -> None:
  64. self._state.open_receive_channels += 1
  65. def receive_nowait(self) -> T_co:
  66. """
  67. Receive the next item if it can be done without waiting.
  68. :return: the received item
  69. :raises ~anyio.ClosedResourceError: if this send stream has been closed
  70. :raises ~anyio.EndOfStream: if the buffer is empty and this stream has been
  71. closed from the sending end
  72. :raises ~anyio.WouldBlock: if there are no items in the buffer and no tasks
  73. waiting to send
  74. """
  75. if self._closed:
  76. raise ClosedResourceError
  77. if self._state.waiting_senders:
  78. # Get the item from the next sender
  79. send_event, item = self._state.waiting_senders.popitem(last=False)
  80. self._state.buffer.append(item)
  81. send_event.set()
  82. if self._state.buffer:
  83. return self._state.buffer.popleft()
  84. elif not self._state.open_send_channels:
  85. raise EndOfStream
  86. raise WouldBlock
  87. async def receive(self) -> T_co:
  88. await checkpoint()
  89. try:
  90. return self.receive_nowait()
  91. except WouldBlock:
  92. # Add ourselves in the queue
  93. receive_event = Event()
  94. receiver = MemoryObjectItemReceiver[T_co]()
  95. self._state.waiting_receivers[receive_event] = receiver
  96. try:
  97. await receive_event.wait()
  98. finally:
  99. self._state.waiting_receivers.pop(receive_event, None)
  100. try:
  101. return receiver.item
  102. except AttributeError:
  103. raise EndOfStream from None
  104. def clone(self) -> MemoryObjectReceiveStream[T_co]:
  105. """
  106. Create a clone of this receive stream.
  107. Each clone can be closed separately. Only when all clones have been closed will
  108. the receiving end of the memory stream be considered closed by the sending ends.
  109. :return: the cloned stream
  110. """
  111. if self._closed:
  112. raise ClosedResourceError
  113. return MemoryObjectReceiveStream(_state=self._state)
  114. def close(self) -> None:
  115. """
  116. Close the stream.
  117. This works the exact same way as :meth:`aclose`, but is provided as a special
  118. case for the benefit of synchronous callbacks.
  119. """
  120. if not self._closed:
  121. self._closed = True
  122. self._state.open_receive_channels -= 1
  123. if self._state.open_receive_channels == 0:
  124. send_events = list(self._state.waiting_senders.keys())
  125. for event in send_events:
  126. event.set()
  127. async def aclose(self) -> None:
  128. self.close()
  129. def statistics(self) -> MemoryObjectStreamStatistics:
  130. """
  131. Return statistics about the current state of this stream.
  132. .. versionadded:: 3.0
  133. """
  134. return self._state.statistics()
  135. def __enter__(self) -> MemoryObjectReceiveStream[T_co]:
  136. return self
  137. def __exit__(
  138. self,
  139. exc_type: type[BaseException] | None,
  140. exc_val: BaseException | None,
  141. exc_tb: TracebackType | None,
  142. ) -> None:
  143. self.close()
  144. def __del__(self) -> None:
  145. if not self._closed:
  146. warnings.warn(
  147. f"Unclosed <{self.__class__.__name__} at {id(self):x}>",
  148. ResourceWarning,
  149. source=self,
  150. )
  151. @dataclass(eq=False)
  152. class MemoryObjectSendStream(Generic[T_contra], ObjectSendStream[T_contra]):
  153. _state: MemoryObjectStreamState[T_contra]
  154. _closed: bool = field(init=False, default=False)
  155. def __post_init__(self) -> None:
  156. self._state.open_send_channels += 1
  157. def send_nowait(self, item: T_contra) -> None:
  158. """
  159. Send an item immediately if it can be done without waiting.
  160. :param item: the item to send
  161. :raises ~anyio.ClosedResourceError: if this send stream has been closed
  162. :raises ~anyio.BrokenResourceError: if the stream has been closed from the
  163. receiving end
  164. :raises ~anyio.WouldBlock: if the buffer is full and there are no tasks waiting
  165. to receive
  166. """
  167. if self._closed:
  168. raise ClosedResourceError
  169. if not self._state.open_receive_channels:
  170. raise BrokenResourceError
  171. while self._state.waiting_receivers:
  172. receive_event, receiver = self._state.waiting_receivers.popitem(last=False)
  173. if not receiver.task_info.has_pending_cancellation():
  174. receiver.item = item
  175. receive_event.set()
  176. return
  177. if len(self._state.buffer) < self._state.max_buffer_size:
  178. self._state.buffer.append(item)
  179. else:
  180. raise WouldBlock
  181. async def send(self, item: T_contra) -> None:
  182. """
  183. Send an item to the stream.
  184. If the buffer is full, this method blocks until there is again room in the
  185. buffer or the item can be sent directly to a receiver.
  186. :param item: the item to send
  187. :raises ~anyio.ClosedResourceError: if this send stream has been closed
  188. :raises ~anyio.BrokenResourceError: if the stream has been closed from the
  189. receiving end
  190. """
  191. await checkpoint()
  192. try:
  193. self.send_nowait(item)
  194. except WouldBlock:
  195. # Wait until there's someone on the receiving end
  196. send_event = Event()
  197. self._state.waiting_senders[send_event] = item
  198. try:
  199. await send_event.wait()
  200. except BaseException:
  201. self._state.waiting_senders.pop(send_event, None)
  202. raise
  203. if send_event in self._state.waiting_senders:
  204. del self._state.waiting_senders[send_event]
  205. raise BrokenResourceError from None
  206. def clone(self) -> MemoryObjectSendStream[T_contra]:
  207. """
  208. Create a clone of this send stream.
  209. Each clone can be closed separately. Only when all clones have been closed will
  210. the sending end of the memory stream be considered closed by the receiving ends.
  211. :return: the cloned stream
  212. """
  213. if self._closed:
  214. raise ClosedResourceError
  215. return MemoryObjectSendStream(_state=self._state)
  216. def close(self) -> None:
  217. """
  218. Close the stream.
  219. This works the exact same way as :meth:`aclose`, but is provided as a special
  220. case for the benefit of synchronous callbacks.
  221. """
  222. if not self._closed:
  223. self._closed = True
  224. self._state.open_send_channels -= 1
  225. if self._state.open_send_channels == 0:
  226. receive_events = list(self._state.waiting_receivers.keys())
  227. self._state.waiting_receivers.clear()
  228. for event in receive_events:
  229. event.set()
  230. async def aclose(self) -> None:
  231. self.close()
  232. def statistics(self) -> MemoryObjectStreamStatistics:
  233. """
  234. Return statistics about the current state of this stream.
  235. .. versionadded:: 3.0
  236. """
  237. return self._state.statistics()
  238. def __enter__(self) -> MemoryObjectSendStream[T_contra]:
  239. return self
  240. def __exit__(
  241. self,
  242. exc_type: type[BaseException] | None,
  243. exc_val: BaseException | None,
  244. exc_tb: TracebackType | None,
  245. ) -> None:
  246. self.close()
  247. def __del__(self) -> None:
  248. if not self._closed:
  249. warnings.warn(
  250. f"Unclosed <{self.__class__.__name__} at {id(self):x}>",
  251. ResourceWarning,
  252. source=self,
  253. )