http_writer.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. """Http related parsers and protocol."""
  2. import asyncio
  3. import sys
  4. import zlib
  5. from typing import ( # noqa
  6. Any,
  7. Awaitable,
  8. Callable,
  9. Iterable,
  10. List,
  11. NamedTuple,
  12. Optional,
  13. Union,
  14. )
  15. from multidict import CIMultiDict
  16. from .abc import AbstractStreamWriter
  17. from .base_protocol import BaseProtocol
  18. from .client_exceptions import ClientConnectionResetError
  19. from .compression_utils import ZLibCompressor
  20. from .helpers import NO_EXTENSIONS
  21. __all__ = ("StreamWriter", "HttpVersion", "HttpVersion10", "HttpVersion11")
  22. MIN_PAYLOAD_FOR_WRITELINES = 2048
  23. IS_PY313_BEFORE_313_2 = (3, 13, 0) <= sys.version_info < (3, 13, 2)
  24. IS_PY_BEFORE_312_9 = sys.version_info < (3, 12, 9)
  25. SKIP_WRITELINES = IS_PY313_BEFORE_313_2 or IS_PY_BEFORE_312_9
  26. # writelines is not safe for use
  27. # on Python 3.12+ until 3.12.9
  28. # on Python 3.13+ until 3.13.2
  29. # and on older versions it not any faster than write
  30. # CVE-2024-12254: https://github.com/python/cpython/pull/127656
  31. class HttpVersion(NamedTuple):
  32. major: int
  33. minor: int
  34. HttpVersion10 = HttpVersion(1, 0)
  35. HttpVersion11 = HttpVersion(1, 1)
  36. _T_OnChunkSent = Optional[Callable[[bytes], Awaitable[None]]]
  37. _T_OnHeadersSent = Optional[Callable[["CIMultiDict[str]"], Awaitable[None]]]
  38. class StreamWriter(AbstractStreamWriter):
  39. length: Optional[int] = None
  40. chunked: bool = False
  41. _eof: bool = False
  42. _compress: Optional[ZLibCompressor] = None
  43. def __init__(
  44. self,
  45. protocol: BaseProtocol,
  46. loop: asyncio.AbstractEventLoop,
  47. on_chunk_sent: _T_OnChunkSent = None,
  48. on_headers_sent: _T_OnHeadersSent = None,
  49. ) -> None:
  50. self._protocol = protocol
  51. self.loop = loop
  52. self._on_chunk_sent: _T_OnChunkSent = on_chunk_sent
  53. self._on_headers_sent: _T_OnHeadersSent = on_headers_sent
  54. @property
  55. def transport(self) -> Optional[asyncio.Transport]:
  56. return self._protocol.transport
  57. @property
  58. def protocol(self) -> BaseProtocol:
  59. return self._protocol
  60. def enable_chunking(self) -> None:
  61. self.chunked = True
  62. def enable_compression(
  63. self, encoding: str = "deflate", strategy: int = zlib.Z_DEFAULT_STRATEGY
  64. ) -> None:
  65. self._compress = ZLibCompressor(encoding=encoding, strategy=strategy)
  66. def _write(self, chunk: Union[bytes, bytearray, memoryview]) -> None:
  67. size = len(chunk)
  68. self.buffer_size += size
  69. self.output_size += size
  70. transport = self._protocol.transport
  71. if transport is None or transport.is_closing():
  72. raise ClientConnectionResetError("Cannot write to closing transport")
  73. transport.write(chunk)
  74. def _writelines(self, chunks: Iterable[bytes]) -> None:
  75. size = 0
  76. for chunk in chunks:
  77. size += len(chunk)
  78. self.buffer_size += size
  79. self.output_size += size
  80. transport = self._protocol.transport
  81. if transport is None or transport.is_closing():
  82. raise ClientConnectionResetError("Cannot write to closing transport")
  83. if SKIP_WRITELINES or size < MIN_PAYLOAD_FOR_WRITELINES:
  84. transport.write(b"".join(chunks))
  85. else:
  86. transport.writelines(chunks)
  87. async def write(
  88. self,
  89. chunk: Union[bytes, bytearray, memoryview],
  90. *,
  91. drain: bool = True,
  92. LIMIT: int = 0x10000,
  93. ) -> None:
  94. """Writes chunk of data to a stream.
  95. write_eof() indicates end of stream.
  96. writer can't be used after write_eof() method being called.
  97. write() return drain future.
  98. """
  99. if self._on_chunk_sent is not None:
  100. await self._on_chunk_sent(chunk)
  101. if isinstance(chunk, memoryview):
  102. if chunk.nbytes != len(chunk):
  103. # just reshape it
  104. chunk = chunk.cast("c")
  105. if self._compress is not None:
  106. chunk = await self._compress.compress(chunk)
  107. if not chunk:
  108. return
  109. if self.length is not None:
  110. chunk_len = len(chunk)
  111. if self.length >= chunk_len:
  112. self.length = self.length - chunk_len
  113. else:
  114. chunk = chunk[: self.length]
  115. self.length = 0
  116. if not chunk:
  117. return
  118. if chunk:
  119. if self.chunked:
  120. self._writelines(
  121. (f"{len(chunk):x}\r\n".encode("ascii"), chunk, b"\r\n")
  122. )
  123. else:
  124. self._write(chunk)
  125. if self.buffer_size > LIMIT and drain:
  126. self.buffer_size = 0
  127. await self.drain()
  128. async def write_headers(
  129. self, status_line: str, headers: "CIMultiDict[str]"
  130. ) -> None:
  131. """Write request/response status and headers."""
  132. if self._on_headers_sent is not None:
  133. await self._on_headers_sent(headers)
  134. # status + headers
  135. buf = _serialize_headers(status_line, headers)
  136. self._write(buf)
  137. def set_eof(self) -> None:
  138. """Indicate that the message is complete."""
  139. self._eof = True
  140. async def write_eof(self, chunk: bytes = b"") -> None:
  141. if self._eof:
  142. return
  143. if chunk and self._on_chunk_sent is not None:
  144. await self._on_chunk_sent(chunk)
  145. if self._compress:
  146. chunks: List[bytes] = []
  147. chunks_len = 0
  148. if chunk and (compressed_chunk := await self._compress.compress(chunk)):
  149. chunks_len = len(compressed_chunk)
  150. chunks.append(compressed_chunk)
  151. flush_chunk = self._compress.flush()
  152. chunks_len += len(flush_chunk)
  153. chunks.append(flush_chunk)
  154. assert chunks_len
  155. if self.chunked:
  156. chunk_len_pre = f"{chunks_len:x}\r\n".encode("ascii")
  157. self._writelines((chunk_len_pre, *chunks, b"\r\n0\r\n\r\n"))
  158. elif len(chunks) > 1:
  159. self._writelines(chunks)
  160. else:
  161. self._write(chunks[0])
  162. elif self.chunked:
  163. if chunk:
  164. chunk_len_pre = f"{len(chunk):x}\r\n".encode("ascii")
  165. self._writelines((chunk_len_pre, chunk, b"\r\n0\r\n\r\n"))
  166. else:
  167. self._write(b"0\r\n\r\n")
  168. elif chunk:
  169. self._write(chunk)
  170. await self.drain()
  171. self._eof = True
  172. async def drain(self) -> None:
  173. """Flush the write buffer.
  174. The intended use is to write
  175. await w.write(data)
  176. await w.drain()
  177. """
  178. protocol = self._protocol
  179. if protocol.transport is not None and protocol._paused:
  180. await protocol._drain_helper()
  181. def _safe_header(string: str) -> str:
  182. if "\r" in string or "\n" in string:
  183. raise ValueError(
  184. "Newline or carriage return detected in headers. "
  185. "Potential header injection attack."
  186. )
  187. return string
  188. def _py_serialize_headers(status_line: str, headers: "CIMultiDict[str]") -> bytes:
  189. headers_gen = (_safe_header(k) + ": " + _safe_header(v) for k, v in headers.items())
  190. line = status_line + "\r\n" + "\r\n".join(headers_gen) + "\r\n\r\n"
  191. return line.encode("utf-8")
  192. _serialize_headers = _py_serialize_headers
  193. try:
  194. import aiohttp._http_writer as _http_writer # type: ignore[import-not-found]
  195. _c_serialize_headers = _http_writer._serialize_headers
  196. if not NO_EXTENSIONS:
  197. _serialize_headers = _c_serialize_headers
  198. except ImportError:
  199. pass