buffered.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. from __future__ import annotations
  2. from collections.abc import Callable, Mapping
  3. from dataclasses import dataclass, field
  4. from typing import Any
  5. from .. import ClosedResourceError, DelimiterNotFound, EndOfStream, IncompleteRead
  6. from ..abc import AnyByteReceiveStream, ByteReceiveStream
  7. @dataclass(eq=False)
  8. class BufferedByteReceiveStream(ByteReceiveStream):
  9. """
  10. Wraps any bytes-based receive stream and uses a buffer to provide sophisticated
  11. receiving capabilities in the form of a byte stream.
  12. """
  13. receive_stream: AnyByteReceiveStream
  14. _buffer: bytearray = field(init=False, default_factory=bytearray)
  15. _closed: bool = field(init=False, default=False)
  16. async def aclose(self) -> None:
  17. await self.receive_stream.aclose()
  18. self._closed = True
  19. @property
  20. def buffer(self) -> bytes:
  21. """The bytes currently in the buffer."""
  22. return bytes(self._buffer)
  23. @property
  24. def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
  25. return self.receive_stream.extra_attributes
  26. async def receive(self, max_bytes: int = 65536) -> bytes:
  27. if self._closed:
  28. raise ClosedResourceError
  29. if self._buffer:
  30. chunk = bytes(self._buffer[:max_bytes])
  31. del self._buffer[:max_bytes]
  32. return chunk
  33. elif isinstance(self.receive_stream, ByteReceiveStream):
  34. return await self.receive_stream.receive(max_bytes)
  35. else:
  36. # With a bytes-oriented object stream, we need to handle any surplus bytes
  37. # we get from the receive() call
  38. chunk = await self.receive_stream.receive()
  39. if len(chunk) > max_bytes:
  40. # Save the surplus bytes in the buffer
  41. self._buffer.extend(chunk[max_bytes:])
  42. return chunk[:max_bytes]
  43. else:
  44. return chunk
  45. async def receive_exactly(self, nbytes: int) -> bytes:
  46. """
  47. Read exactly the given amount of bytes from the stream.
  48. :param nbytes: the number of bytes to read
  49. :return: the bytes read
  50. :raises ~anyio.IncompleteRead: if the stream was closed before the requested
  51. amount of bytes could be read from the stream
  52. """
  53. while True:
  54. remaining = nbytes - len(self._buffer)
  55. if remaining <= 0:
  56. retval = self._buffer[:nbytes]
  57. del self._buffer[:nbytes]
  58. return bytes(retval)
  59. try:
  60. if isinstance(self.receive_stream, ByteReceiveStream):
  61. chunk = await self.receive_stream.receive(remaining)
  62. else:
  63. chunk = await self.receive_stream.receive()
  64. except EndOfStream as exc:
  65. raise IncompleteRead from exc
  66. self._buffer.extend(chunk)
  67. async def receive_until(self, delimiter: bytes, max_bytes: int) -> bytes:
  68. """
  69. Read from the stream until the delimiter is found or max_bytes have been read.
  70. :param delimiter: the marker to look for in the stream
  71. :param max_bytes: maximum number of bytes that will be read before raising
  72. :exc:`~anyio.DelimiterNotFound`
  73. :return: the bytes read (not including the delimiter)
  74. :raises ~anyio.IncompleteRead: if the stream was closed before the delimiter
  75. was found
  76. :raises ~anyio.DelimiterNotFound: if the delimiter is not found within the
  77. bytes read up to the maximum allowed
  78. """
  79. delimiter_size = len(delimiter)
  80. offset = 0
  81. while True:
  82. # Check if the delimiter can be found in the current buffer
  83. index = self._buffer.find(delimiter, offset)
  84. if index >= 0:
  85. found = self._buffer[:index]
  86. del self._buffer[: index + len(delimiter) :]
  87. return bytes(found)
  88. # Check if the buffer is already at or over the limit
  89. if len(self._buffer) >= max_bytes:
  90. raise DelimiterNotFound(max_bytes)
  91. # Read more data into the buffer from the socket
  92. try:
  93. data = await self.receive_stream.receive()
  94. except EndOfStream as exc:
  95. raise IncompleteRead from exc
  96. # Move the offset forward and add the new data to the buffer
  97. offset = max(len(self._buffer) - delimiter_size + 1, 0)
  98. self._buffer.extend(data)