streams.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727
  1. import asyncio
  2. import collections
  3. import warnings
  4. from typing import (
  5. Awaitable,
  6. Callable,
  7. Deque,
  8. Final,
  9. Generic,
  10. List,
  11. Optional,
  12. Tuple,
  13. TypeVar,
  14. )
  15. from .base_protocol import BaseProtocol
  16. from .helpers import (
  17. _EXC_SENTINEL,
  18. BaseTimerContext,
  19. TimerNoop,
  20. set_exception,
  21. set_result,
  22. )
  23. from .log import internal_logger
  24. __all__ = (
  25. "EMPTY_PAYLOAD",
  26. "EofStream",
  27. "StreamReader",
  28. "DataQueue",
  29. )
  30. _T = TypeVar("_T")
  31. class EofStream(Exception):
  32. """eof stream indication."""
  33. class AsyncStreamIterator(Generic[_T]):
  34. __slots__ = ("read_func",)
  35. def __init__(self, read_func: Callable[[], Awaitable[_T]]) -> None:
  36. self.read_func = read_func
  37. def __aiter__(self) -> "AsyncStreamIterator[_T]":
  38. return self
  39. async def __anext__(self) -> _T:
  40. try:
  41. rv = await self.read_func()
  42. except EofStream:
  43. raise StopAsyncIteration
  44. if rv == b"":
  45. raise StopAsyncIteration
  46. return rv
  47. class ChunkTupleAsyncStreamIterator:
  48. __slots__ = ("_stream",)
  49. def __init__(self, stream: "StreamReader") -> None:
  50. self._stream = stream
  51. def __aiter__(self) -> "ChunkTupleAsyncStreamIterator":
  52. return self
  53. async def __anext__(self) -> Tuple[bytes, bool]:
  54. rv = await self._stream.readchunk()
  55. if rv == (b"", False):
  56. raise StopAsyncIteration
  57. return rv
  58. class AsyncStreamReaderMixin:
  59. __slots__ = ()
  60. def __aiter__(self) -> AsyncStreamIterator[bytes]:
  61. return AsyncStreamIterator(self.readline) # type: ignore[attr-defined]
  62. def iter_chunked(self, n: int) -> AsyncStreamIterator[bytes]:
  63. """Returns an asynchronous iterator that yields chunks of size n."""
  64. return AsyncStreamIterator(lambda: self.read(n)) # type: ignore[attr-defined]
  65. def iter_any(self) -> AsyncStreamIterator[bytes]:
  66. """Yield all available data as soon as it is received."""
  67. return AsyncStreamIterator(self.readany) # type: ignore[attr-defined]
  68. def iter_chunks(self) -> ChunkTupleAsyncStreamIterator:
  69. """Yield chunks of data as they are received by the server.
  70. The yielded objects are tuples
  71. of (bytes, bool) as returned by the StreamReader.readchunk method.
  72. """
  73. return ChunkTupleAsyncStreamIterator(self) # type: ignore[arg-type]
  74. class StreamReader(AsyncStreamReaderMixin):
  75. """An enhancement of asyncio.StreamReader.
  76. Supports asynchronous iteration by line, chunk or as available::
  77. async for line in reader:
  78. ...
  79. async for chunk in reader.iter_chunked(1024):
  80. ...
  81. async for slice in reader.iter_any():
  82. ...
  83. """
  84. __slots__ = (
  85. "_protocol",
  86. "_low_water",
  87. "_high_water",
  88. "_loop",
  89. "_size",
  90. "_cursor",
  91. "_http_chunk_splits",
  92. "_buffer",
  93. "_buffer_offset",
  94. "_eof",
  95. "_waiter",
  96. "_eof_waiter",
  97. "_exception",
  98. "_timer",
  99. "_eof_callbacks",
  100. "_eof_counter",
  101. "total_bytes",
  102. )
  103. def __init__(
  104. self,
  105. protocol: BaseProtocol,
  106. limit: int,
  107. *,
  108. timer: Optional[BaseTimerContext] = None,
  109. loop: Optional[asyncio.AbstractEventLoop] = None,
  110. ) -> None:
  111. self._protocol = protocol
  112. self._low_water = limit
  113. self._high_water = limit * 2
  114. if loop is None:
  115. loop = asyncio.get_event_loop()
  116. self._loop = loop
  117. self._size = 0
  118. self._cursor = 0
  119. self._http_chunk_splits: Optional[List[int]] = None
  120. self._buffer: Deque[bytes] = collections.deque()
  121. self._buffer_offset = 0
  122. self._eof = False
  123. self._waiter: Optional[asyncio.Future[None]] = None
  124. self._eof_waiter: Optional[asyncio.Future[None]] = None
  125. self._exception: Optional[BaseException] = None
  126. self._timer = TimerNoop() if timer is None else timer
  127. self._eof_callbacks: List[Callable[[], None]] = []
  128. self._eof_counter = 0
  129. self.total_bytes = 0
  130. def __repr__(self) -> str:
  131. info = [self.__class__.__name__]
  132. if self._size:
  133. info.append("%d bytes" % self._size)
  134. if self._eof:
  135. info.append("eof")
  136. if self._low_water != 2**16: # default limit
  137. info.append("low=%d high=%d" % (self._low_water, self._high_water))
  138. if self._waiter:
  139. info.append("w=%r" % self._waiter)
  140. if self._exception:
  141. info.append("e=%r" % self._exception)
  142. return "<%s>" % " ".join(info)
  143. def get_read_buffer_limits(self) -> Tuple[int, int]:
  144. return (self._low_water, self._high_water)
  145. def exception(self) -> Optional[BaseException]:
  146. return self._exception
  147. def set_exception(
  148. self,
  149. exc: BaseException,
  150. exc_cause: BaseException = _EXC_SENTINEL,
  151. ) -> None:
  152. self._exception = exc
  153. self._eof_callbacks.clear()
  154. waiter = self._waiter
  155. if waiter is not None:
  156. self._waiter = None
  157. set_exception(waiter, exc, exc_cause)
  158. waiter = self._eof_waiter
  159. if waiter is not None:
  160. self._eof_waiter = None
  161. set_exception(waiter, exc, exc_cause)
  162. def on_eof(self, callback: Callable[[], None]) -> None:
  163. if self._eof:
  164. try:
  165. callback()
  166. except Exception:
  167. internal_logger.exception("Exception in eof callback")
  168. else:
  169. self._eof_callbacks.append(callback)
  170. def feed_eof(self) -> None:
  171. self._eof = True
  172. waiter = self._waiter
  173. if waiter is not None:
  174. self._waiter = None
  175. set_result(waiter, None)
  176. waiter = self._eof_waiter
  177. if waiter is not None:
  178. self._eof_waiter = None
  179. set_result(waiter, None)
  180. if self._protocol._reading_paused:
  181. self._protocol.resume_reading()
  182. for cb in self._eof_callbacks:
  183. try:
  184. cb()
  185. except Exception:
  186. internal_logger.exception("Exception in eof callback")
  187. self._eof_callbacks.clear()
  188. def is_eof(self) -> bool:
  189. """Return True if 'feed_eof' was called."""
  190. return self._eof
  191. def at_eof(self) -> bool:
  192. """Return True if the buffer is empty and 'feed_eof' was called."""
  193. return self._eof and not self._buffer
  194. async def wait_eof(self) -> None:
  195. if self._eof:
  196. return
  197. assert self._eof_waiter is None
  198. self._eof_waiter = self._loop.create_future()
  199. try:
  200. await self._eof_waiter
  201. finally:
  202. self._eof_waiter = None
  203. def unread_data(self, data: bytes) -> None:
  204. """rollback reading some data from stream, inserting it to buffer head."""
  205. warnings.warn(
  206. "unread_data() is deprecated "
  207. "and will be removed in future releases (#3260)",
  208. DeprecationWarning,
  209. stacklevel=2,
  210. )
  211. if not data:
  212. return
  213. if self._buffer_offset:
  214. self._buffer[0] = self._buffer[0][self._buffer_offset :]
  215. self._buffer_offset = 0
  216. self._size += len(data)
  217. self._cursor -= len(data)
  218. self._buffer.appendleft(data)
  219. self._eof_counter = 0
  220. # TODO: size is ignored, remove the param later
  221. def feed_data(self, data: bytes, size: int = 0) -> None:
  222. assert not self._eof, "feed_data after feed_eof"
  223. if not data:
  224. return
  225. data_len = len(data)
  226. self._size += data_len
  227. self._buffer.append(data)
  228. self.total_bytes += data_len
  229. waiter = self._waiter
  230. if waiter is not None:
  231. self._waiter = None
  232. set_result(waiter, None)
  233. if self._size > self._high_water and not self._protocol._reading_paused:
  234. self._protocol.pause_reading()
  235. def begin_http_chunk_receiving(self) -> None:
  236. if self._http_chunk_splits is None:
  237. if self.total_bytes:
  238. raise RuntimeError(
  239. "Called begin_http_chunk_receiving when some data was already fed"
  240. )
  241. self._http_chunk_splits = []
  242. def end_http_chunk_receiving(self) -> None:
  243. if self._http_chunk_splits is None:
  244. raise RuntimeError(
  245. "Called end_chunk_receiving without calling "
  246. "begin_chunk_receiving first"
  247. )
  248. # self._http_chunk_splits contains logical byte offsets from start of
  249. # the body transfer. Each offset is the offset of the end of a chunk.
  250. # "Logical" means bytes, accessible for a user.
  251. # If no chunks containing logical data were received, current position
  252. # is difinitely zero.
  253. pos = self._http_chunk_splits[-1] if self._http_chunk_splits else 0
  254. if self.total_bytes == pos:
  255. # We should not add empty chunks here. So we check for that.
  256. # Note, when chunked + gzip is used, we can receive a chunk
  257. # of compressed data, but that data may not be enough for gzip FSM
  258. # to yield any uncompressed data. That's why current position may
  259. # not change after receiving a chunk.
  260. return
  261. self._http_chunk_splits.append(self.total_bytes)
  262. # wake up readchunk when end of http chunk received
  263. waiter = self._waiter
  264. if waiter is not None:
  265. self._waiter = None
  266. set_result(waiter, None)
  267. async def _wait(self, func_name: str) -> None:
  268. if not self._protocol.connected:
  269. raise RuntimeError("Connection closed.")
  270. # StreamReader uses a future to link the protocol feed_data() method
  271. # to a read coroutine. Running two read coroutines at the same time
  272. # would have an unexpected behaviour. It would not possible to know
  273. # which coroutine would get the next data.
  274. if self._waiter is not None:
  275. raise RuntimeError(
  276. "%s() called while another coroutine is "
  277. "already waiting for incoming data" % func_name
  278. )
  279. waiter = self._waiter = self._loop.create_future()
  280. try:
  281. with self._timer:
  282. await waiter
  283. finally:
  284. self._waiter = None
  285. async def readline(self) -> bytes:
  286. return await self.readuntil()
  287. async def readuntil(self, separator: bytes = b"\n") -> bytes:
  288. seplen = len(separator)
  289. if seplen == 0:
  290. raise ValueError("Separator should be at least one-byte string")
  291. if self._exception is not None:
  292. raise self._exception
  293. chunk = b""
  294. chunk_size = 0
  295. not_enough = True
  296. while not_enough:
  297. while self._buffer and not_enough:
  298. offset = self._buffer_offset
  299. ichar = self._buffer[0].find(separator, offset) + 1
  300. # Read from current offset to found separator or to the end.
  301. data = self._read_nowait_chunk(
  302. ichar - offset + seplen - 1 if ichar else -1
  303. )
  304. chunk += data
  305. chunk_size += len(data)
  306. if ichar:
  307. not_enough = False
  308. if chunk_size > self._high_water:
  309. raise ValueError("Chunk too big")
  310. if self._eof:
  311. break
  312. if not_enough:
  313. await self._wait("readuntil")
  314. return chunk
  315. async def read(self, n: int = -1) -> bytes:
  316. if self._exception is not None:
  317. raise self._exception
  318. # migration problem; with DataQueue you have to catch
  319. # EofStream exception, so common way is to run payload.read() inside
  320. # infinite loop. what can cause real infinite loop with StreamReader
  321. # lets keep this code one major release.
  322. if __debug__:
  323. if self._eof and not self._buffer:
  324. self._eof_counter = getattr(self, "_eof_counter", 0) + 1
  325. if self._eof_counter > 5:
  326. internal_logger.warning(
  327. "Multiple access to StreamReader in eof state, "
  328. "might be infinite loop.",
  329. stack_info=True,
  330. )
  331. if not n:
  332. return b""
  333. if n < 0:
  334. # This used to just loop creating a new waiter hoping to
  335. # collect everything in self._buffer, but that would
  336. # deadlock if the subprocess sends more than self.limit
  337. # bytes. So just call self.readany() until EOF.
  338. blocks = []
  339. while True:
  340. block = await self.readany()
  341. if not block:
  342. break
  343. blocks.append(block)
  344. return b"".join(blocks)
  345. # TODO: should be `if` instead of `while`
  346. # because waiter maybe triggered on chunk end,
  347. # without feeding any data
  348. while not self._buffer and not self._eof:
  349. await self._wait("read")
  350. return self._read_nowait(n)
  351. async def readany(self) -> bytes:
  352. if self._exception is not None:
  353. raise self._exception
  354. # TODO: should be `if` instead of `while`
  355. # because waiter maybe triggered on chunk end,
  356. # without feeding any data
  357. while not self._buffer and not self._eof:
  358. await self._wait("readany")
  359. return self._read_nowait(-1)
  360. async def readchunk(self) -> Tuple[bytes, bool]:
  361. """Returns a tuple of (data, end_of_http_chunk).
  362. When chunked transfer
  363. encoding is used, end_of_http_chunk is a boolean indicating if the end
  364. of the data corresponds to the end of a HTTP chunk , otherwise it is
  365. always False.
  366. """
  367. while True:
  368. if self._exception is not None:
  369. raise self._exception
  370. while self._http_chunk_splits:
  371. pos = self._http_chunk_splits.pop(0)
  372. if pos == self._cursor:
  373. return (b"", True)
  374. if pos > self._cursor:
  375. return (self._read_nowait(pos - self._cursor), True)
  376. internal_logger.warning(
  377. "Skipping HTTP chunk end due to data "
  378. "consumption beyond chunk boundary"
  379. )
  380. if self._buffer:
  381. return (self._read_nowait_chunk(-1), False)
  382. # return (self._read_nowait(-1), False)
  383. if self._eof:
  384. # Special case for signifying EOF.
  385. # (b'', True) is not a final return value actually.
  386. return (b"", False)
  387. await self._wait("readchunk")
  388. async def readexactly(self, n: int) -> bytes:
  389. if self._exception is not None:
  390. raise self._exception
  391. blocks: List[bytes] = []
  392. while n > 0:
  393. block = await self.read(n)
  394. if not block:
  395. partial = b"".join(blocks)
  396. raise asyncio.IncompleteReadError(partial, len(partial) + n)
  397. blocks.append(block)
  398. n -= len(block)
  399. return b"".join(blocks)
  400. def read_nowait(self, n: int = -1) -> bytes:
  401. # default was changed to be consistent with .read(-1)
  402. #
  403. # I believe the most users don't know about the method and
  404. # they are not affected.
  405. if self._exception is not None:
  406. raise self._exception
  407. if self._waiter and not self._waiter.done():
  408. raise RuntimeError(
  409. "Called while some coroutine is waiting for incoming data."
  410. )
  411. return self._read_nowait(n)
  412. def _read_nowait_chunk(self, n: int) -> bytes:
  413. first_buffer = self._buffer[0]
  414. offset = self._buffer_offset
  415. if n != -1 and len(first_buffer) - offset > n:
  416. data = first_buffer[offset : offset + n]
  417. self._buffer_offset += n
  418. elif offset:
  419. self._buffer.popleft()
  420. data = first_buffer[offset:]
  421. self._buffer_offset = 0
  422. else:
  423. data = self._buffer.popleft()
  424. data_len = len(data)
  425. self._size -= data_len
  426. self._cursor += data_len
  427. chunk_splits = self._http_chunk_splits
  428. # Prevent memory leak: drop useless chunk splits
  429. while chunk_splits and chunk_splits[0] < self._cursor:
  430. chunk_splits.pop(0)
  431. if self._size < self._low_water and self._protocol._reading_paused:
  432. self._protocol.resume_reading()
  433. return data
  434. def _read_nowait(self, n: int) -> bytes:
  435. """Read not more than n bytes, or whole buffer if n == -1"""
  436. self._timer.assert_timeout()
  437. chunks = []
  438. while self._buffer:
  439. chunk = self._read_nowait_chunk(n)
  440. chunks.append(chunk)
  441. if n != -1:
  442. n -= len(chunk)
  443. if n == 0:
  444. break
  445. return b"".join(chunks) if chunks else b""
  446. class EmptyStreamReader(StreamReader): # lgtm [py/missing-call-to-init]
  447. __slots__ = ("_read_eof_chunk",)
  448. def __init__(self) -> None:
  449. self._read_eof_chunk = False
  450. self.total_bytes = 0
  451. def __repr__(self) -> str:
  452. return "<%s>" % self.__class__.__name__
  453. def exception(self) -> Optional[BaseException]:
  454. return None
  455. def set_exception(
  456. self,
  457. exc: BaseException,
  458. exc_cause: BaseException = _EXC_SENTINEL,
  459. ) -> None:
  460. pass
  461. def on_eof(self, callback: Callable[[], None]) -> None:
  462. try:
  463. callback()
  464. except Exception:
  465. internal_logger.exception("Exception in eof callback")
  466. def feed_eof(self) -> None:
  467. pass
  468. def is_eof(self) -> bool:
  469. return True
  470. def at_eof(self) -> bool:
  471. return True
  472. async def wait_eof(self) -> None:
  473. return
  474. def feed_data(self, data: bytes, n: int = 0) -> None:
  475. pass
  476. async def readline(self) -> bytes:
  477. return b""
  478. async def read(self, n: int = -1) -> bytes:
  479. return b""
  480. # TODO add async def readuntil
  481. async def readany(self) -> bytes:
  482. return b""
  483. async def readchunk(self) -> Tuple[bytes, bool]:
  484. if not self._read_eof_chunk:
  485. self._read_eof_chunk = True
  486. return (b"", False)
  487. return (b"", True)
  488. async def readexactly(self, n: int) -> bytes:
  489. raise asyncio.IncompleteReadError(b"", n)
  490. def read_nowait(self, n: int = -1) -> bytes:
  491. return b""
  492. EMPTY_PAYLOAD: Final[StreamReader] = EmptyStreamReader()
  493. class DataQueue(Generic[_T]):
  494. """DataQueue is a general-purpose blocking queue with one reader."""
  495. def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
  496. self._loop = loop
  497. self._eof = False
  498. self._waiter: Optional[asyncio.Future[None]] = None
  499. self._exception: Optional[BaseException] = None
  500. self._buffer: Deque[Tuple[_T, int]] = collections.deque()
  501. def __len__(self) -> int:
  502. return len(self._buffer)
  503. def is_eof(self) -> bool:
  504. return self._eof
  505. def at_eof(self) -> bool:
  506. return self._eof and not self._buffer
  507. def exception(self) -> Optional[BaseException]:
  508. return self._exception
  509. def set_exception(
  510. self,
  511. exc: BaseException,
  512. exc_cause: BaseException = _EXC_SENTINEL,
  513. ) -> None:
  514. self._eof = True
  515. self._exception = exc
  516. if (waiter := self._waiter) is not None:
  517. self._waiter = None
  518. set_exception(waiter, exc, exc_cause)
  519. def feed_data(self, data: _T, size: int = 0) -> None:
  520. self._buffer.append((data, size))
  521. if (waiter := self._waiter) is not None:
  522. self._waiter = None
  523. set_result(waiter, None)
  524. def feed_eof(self) -> None:
  525. self._eof = True
  526. if (waiter := self._waiter) is not None:
  527. self._waiter = None
  528. set_result(waiter, None)
  529. async def read(self) -> _T:
  530. if not self._buffer and not self._eof:
  531. assert not self._waiter
  532. self._waiter = self._loop.create_future()
  533. try:
  534. await self._waiter
  535. except (asyncio.CancelledError, asyncio.TimeoutError):
  536. self._waiter = None
  537. raise
  538. if self._buffer:
  539. data, _ = self._buffer.popleft()
  540. return data
  541. if self._exception is not None:
  542. raise self._exception
  543. raise EofStream
  544. def __aiter__(self) -> AsyncStreamIterator[_T]:
  545. return AsyncStreamIterator(self.read)
  546. class FlowControlDataQueue(DataQueue[_T]):
  547. """FlowControlDataQueue resumes and pauses an underlying stream.
  548. It is a destination for parsed data.
  549. This class is deprecated and will be removed in version 4.0.
  550. """
  551. def __init__(
  552. self, protocol: BaseProtocol, limit: int, *, loop: asyncio.AbstractEventLoop
  553. ) -> None:
  554. super().__init__(loop=loop)
  555. self._size = 0
  556. self._protocol = protocol
  557. self._limit = limit * 2
  558. def feed_data(self, data: _T, size: int = 0) -> None:
  559. super().feed_data(data, size)
  560. self._size += size
  561. if self._size > self._limit and not self._protocol._reading_paused:
  562. self._protocol.pause_reading()
  563. async def read(self) -> _T:
  564. if not self._buffer and not self._eof:
  565. assert not self._waiter
  566. self._waiter = self._loop.create_future()
  567. try:
  568. await self._waiter
  569. except (asyncio.CancelledError, asyncio.TimeoutError):
  570. self._waiter = None
  571. raise
  572. if self._buffer:
  573. data, size = self._buffer.popleft()
  574. self._size -= size
  575. if self._size < self._limit and self._protocol._reading_paused:
  576. self._protocol.resume_reading()
  577. return data
  578. if self._exception is not None:
  579. raise self._exception
  580. raise EofStream