writer.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. """WebSocket protocol versions 13 and 8."""
  2. import asyncio
  3. import random
  4. import zlib
  5. from functools import partial
  6. from typing import Any, Final, Optional, Union
  7. from ..base_protocol import BaseProtocol
  8. from ..client_exceptions import ClientConnectionResetError
  9. from ..compression_utils import ZLibCompressor
  10. from .helpers import (
  11. MASK_LEN,
  12. MSG_SIZE,
  13. PACK_CLOSE_CODE,
  14. PACK_LEN1,
  15. PACK_LEN2,
  16. PACK_LEN3,
  17. PACK_RANDBITS,
  18. websocket_mask,
  19. )
  20. from .models import WS_DEFLATE_TRAILING, WSMsgType
  21. DEFAULT_LIMIT: Final[int] = 2**16
  22. # For websockets, keeping latency low is extremely important as implementations
  23. # generally expect to be able to send and receive messages quickly. We use a
  24. # larger chunk size than the default to reduce the number of executor calls
  25. # since the executor is a significant source of latency and overhead when
  26. # the chunks are small. A size of 5KiB was chosen because it is also the
  27. # same value python-zlib-ng choose to use as the threshold to release the GIL.
  28. WEBSOCKET_MAX_SYNC_CHUNK_SIZE = 5 * 1024
  29. class WebSocketWriter:
  30. """WebSocket writer.
  31. The writer is responsible for sending messages to the client. It is
  32. created by the protocol when a connection is established. The writer
  33. should avoid implementing any application logic and should only be
  34. concerned with the low-level details of the WebSocket protocol.
  35. """
  36. def __init__(
  37. self,
  38. protocol: BaseProtocol,
  39. transport: asyncio.Transport,
  40. *,
  41. use_mask: bool = False,
  42. limit: int = DEFAULT_LIMIT,
  43. random: random.Random = random.Random(),
  44. compress: int = 0,
  45. notakeover: bool = False,
  46. ) -> None:
  47. """Initialize a WebSocket writer."""
  48. self.protocol = protocol
  49. self.transport = transport
  50. self.use_mask = use_mask
  51. self.get_random_bits = partial(random.getrandbits, 32)
  52. self.compress = compress
  53. self.notakeover = notakeover
  54. self._closing = False
  55. self._limit = limit
  56. self._output_size = 0
  57. self._compressobj: Any = None # actually compressobj
  58. async def send_frame(
  59. self, message: bytes, opcode: int, compress: Optional[int] = None
  60. ) -> None:
  61. """Send a frame over the websocket with message as its payload."""
  62. if self._closing and not (opcode & WSMsgType.CLOSE):
  63. raise ClientConnectionResetError("Cannot write to closing transport")
  64. # RSV are the reserved bits in the frame header. They are used to
  65. # indicate that the frame is using an extension.
  66. # https://datatracker.ietf.org/doc/html/rfc6455#section-5.2
  67. rsv = 0
  68. # Only compress larger packets (disabled)
  69. # Does small packet needs to be compressed?
  70. # if self.compress and opcode < 8 and len(message) > 124:
  71. if (compress or self.compress) and opcode < 8:
  72. # RSV1 (rsv = 0x40) is set for compressed frames
  73. # https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1
  74. rsv = 0x40
  75. if compress:
  76. # Do not set self._compress if compressing is for this frame
  77. compressobj = self._make_compress_obj(compress)
  78. else: # self.compress
  79. if not self._compressobj:
  80. self._compressobj = self._make_compress_obj(self.compress)
  81. compressobj = self._compressobj
  82. message = (
  83. await compressobj.compress(message)
  84. + compressobj.flush(
  85. zlib.Z_FULL_FLUSH if self.notakeover else zlib.Z_SYNC_FLUSH
  86. )
  87. ).removesuffix(WS_DEFLATE_TRAILING)
  88. # Its critical that we do not return control to the event
  89. # loop until we have finished sending all the compressed
  90. # data. Otherwise we could end up mixing compressed frames
  91. # if there are multiple coroutines compressing data.
  92. msg_length = len(message)
  93. use_mask = self.use_mask
  94. mask_bit = 0x80 if use_mask else 0
  95. # Depending on the message length, the header is assembled differently.
  96. # The first byte is reserved for the opcode and the RSV bits.
  97. first_byte = 0x80 | rsv | opcode
  98. if msg_length < 126:
  99. header = PACK_LEN1(first_byte, msg_length | mask_bit)
  100. header_len = 2
  101. elif msg_length < 65536:
  102. header = PACK_LEN2(first_byte, 126 | mask_bit, msg_length)
  103. header_len = 4
  104. else:
  105. header = PACK_LEN3(first_byte, 127 | mask_bit, msg_length)
  106. header_len = 10
  107. if self.transport.is_closing():
  108. raise ClientConnectionResetError("Cannot write to closing transport")
  109. # https://datatracker.ietf.org/doc/html/rfc6455#section-5.3
  110. # If we are using a mask, we need to generate it randomly
  111. # and apply it to the message before sending it. A mask is
  112. # a 32-bit value that is applied to the message using a
  113. # bitwise XOR operation. It is used to prevent certain types
  114. # of attacks on the websocket protocol. The mask is only used
  115. # when aiohttp is acting as a client. Servers do not use a mask.
  116. if use_mask:
  117. mask = PACK_RANDBITS(self.get_random_bits())
  118. message = bytearray(message)
  119. websocket_mask(mask, message)
  120. self.transport.write(header + mask + message)
  121. self._output_size += MASK_LEN
  122. elif msg_length > MSG_SIZE:
  123. self.transport.write(header)
  124. self.transport.write(message)
  125. else:
  126. self.transport.write(header + message)
  127. self._output_size += header_len + msg_length
  128. # It is safe to return control to the event loop when using compression
  129. # after this point as we have already sent or buffered all the data.
  130. # Once we have written output_size up to the limit, we call the
  131. # drain helper which waits for the transport to be ready to accept
  132. # more data. This is a flow control mechanism to prevent the buffer
  133. # from growing too large. The drain helper will return right away
  134. # if the writer is not paused.
  135. if self._output_size > self._limit:
  136. self._output_size = 0
  137. if self.protocol._paused:
  138. await self.protocol._drain_helper()
  139. def _make_compress_obj(self, compress: int) -> ZLibCompressor:
  140. return ZLibCompressor(
  141. level=zlib.Z_BEST_SPEED,
  142. wbits=-compress,
  143. max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
  144. )
  145. async def close(self, code: int = 1000, message: Union[bytes, str] = b"") -> None:
  146. """Close the websocket, sending the specified code and message."""
  147. if isinstance(message, str):
  148. message = message.encode("utf-8")
  149. try:
  150. await self.send_frame(
  151. PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE
  152. )
  153. finally:
  154. self._closing = True