123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468 |
- """Reader for WebSocket protocol versions 13 and 8."""
- import asyncio
- import builtins
- from collections import deque
- from typing import Deque, Final, List, Optional, Set, Tuple, Union
- from ..base_protocol import BaseProtocol
- from ..compression_utils import ZLibDecompressor
- from ..helpers import _EXC_SENTINEL, set_exception
- from ..streams import EofStream
- from .helpers import UNPACK_CLOSE_CODE, UNPACK_LEN3, websocket_mask
- from .models import (
- WS_DEFLATE_TRAILING,
- WebSocketError,
- WSCloseCode,
- WSMessage,
- WSMsgType,
- )
- ALLOWED_CLOSE_CODES: Final[Set[int]] = {int(i) for i in WSCloseCode}
- # States for the reader, used to parse the WebSocket frame
- # integer values are used so they can be cythonized
- READ_HEADER = 1
- READ_PAYLOAD_LENGTH = 2
- READ_PAYLOAD_MASK = 3
- READ_PAYLOAD = 4
- WS_MSG_TYPE_BINARY = WSMsgType.BINARY
- WS_MSG_TYPE_TEXT = WSMsgType.TEXT
- # WSMsgType values unpacked so they can by cythonized to ints
- OP_CODE_CONTINUATION = WSMsgType.CONTINUATION.value
- OP_CODE_TEXT = WSMsgType.TEXT.value
- OP_CODE_BINARY = WSMsgType.BINARY.value
- OP_CODE_CLOSE = WSMsgType.CLOSE.value
- OP_CODE_PING = WSMsgType.PING.value
- OP_CODE_PONG = WSMsgType.PONG.value
- EMPTY_FRAME_ERROR = (True, b"")
- EMPTY_FRAME = (False, b"")
- TUPLE_NEW = tuple.__new__
- int_ = int # Prevent Cython from converting to PyInt
- class WebSocketDataQueue:
- """WebSocketDataQueue resumes and pauses an underlying stream.
- It is a destination for WebSocket data.
- """
- def __init__(
- self, protocol: BaseProtocol, limit: int, *, loop: asyncio.AbstractEventLoop
- ) -> None:
- self._size = 0
- self._protocol = protocol
- self._limit = limit * 2
- self._loop = loop
- self._eof = False
- self._waiter: Optional[asyncio.Future[None]] = None
- self._exception: Union[BaseException, None] = None
- self._buffer: Deque[Tuple[WSMessage, int]] = deque()
- self._get_buffer = self._buffer.popleft
- self._put_buffer = self._buffer.append
- def is_eof(self) -> bool:
- return self._eof
- def exception(self) -> Optional[BaseException]:
- return self._exception
- def set_exception(
- self,
- exc: "BaseException",
- exc_cause: builtins.BaseException = _EXC_SENTINEL,
- ) -> None:
- self._eof = True
- self._exception = exc
- if (waiter := self._waiter) is not None:
- self._waiter = None
- set_exception(waiter, exc, exc_cause)
- def _release_waiter(self) -> None:
- if (waiter := self._waiter) is None:
- return
- self._waiter = None
- if not waiter.done():
- waiter.set_result(None)
- def feed_eof(self) -> None:
- self._eof = True
- self._release_waiter()
- self._exception = None # Break cyclic references
- def feed_data(self, data: "WSMessage", size: "int_") -> None:
- self._size += size
- self._put_buffer((data, size))
- self._release_waiter()
- if self._size > self._limit and not self._protocol._reading_paused:
- self._protocol.pause_reading()
- async def read(self) -> WSMessage:
- if not self._buffer and not self._eof:
- assert not self._waiter
- self._waiter = self._loop.create_future()
- try:
- await self._waiter
- except (asyncio.CancelledError, asyncio.TimeoutError):
- self._waiter = None
- raise
- return self._read_from_buffer()
- def _read_from_buffer(self) -> WSMessage:
- if self._buffer:
- data, size = self._get_buffer()
- self._size -= size
- if self._size < self._limit and self._protocol._reading_paused:
- self._protocol.resume_reading()
- return data
- if self._exception is not None:
- raise self._exception
- raise EofStream
- class WebSocketReader:
- def __init__(
- self, queue: WebSocketDataQueue, max_msg_size: int, compress: bool = True
- ) -> None:
- self.queue = queue
- self._max_msg_size = max_msg_size
- self._exc: Optional[Exception] = None
- self._partial = bytearray()
- self._state = READ_HEADER
- self._opcode: Optional[int] = None
- self._frame_fin = False
- self._frame_opcode: Optional[int] = None
- self._frame_payload: Union[bytes, bytearray] = b""
- self._frame_payload_len = 0
- self._tail: bytes = b""
- self._has_mask = False
- self._frame_mask: Optional[bytes] = None
- self._payload_length = 0
- self._payload_length_flag = 0
- self._compressed: Optional[bool] = None
- self._decompressobj: Optional[ZLibDecompressor] = None
- self._compress = compress
- def feed_eof(self) -> None:
- self.queue.feed_eof()
- # data can be bytearray on Windows because proactor event loop uses bytearray
- # and asyncio types this to Union[bytes, bytearray, memoryview] so we need
- # coerce data to bytes if it is not
- def feed_data(
- self, data: Union[bytes, bytearray, memoryview]
- ) -> Tuple[bool, bytes]:
- if type(data) is not bytes:
- data = bytes(data)
- if self._exc is not None:
- return True, data
- try:
- self._feed_data(data)
- except Exception as exc:
- self._exc = exc
- set_exception(self.queue, exc)
- return EMPTY_FRAME_ERROR
- return EMPTY_FRAME
- def _feed_data(self, data: bytes) -> None:
- msg: WSMessage
- for frame in self.parse_frame(data):
- fin = frame[0]
- opcode = frame[1]
- payload = frame[2]
- compressed = frame[3]
- is_continuation = opcode == OP_CODE_CONTINUATION
- if opcode == OP_CODE_TEXT or opcode == OP_CODE_BINARY or is_continuation:
- # load text/binary
- if not fin:
- # got partial frame payload
- if not is_continuation:
- self._opcode = opcode
- self._partial += payload
- if self._max_msg_size and len(self._partial) >= self._max_msg_size:
- raise WebSocketError(
- WSCloseCode.MESSAGE_TOO_BIG,
- f"Message size {len(self._partial)} "
- f"exceeds limit {self._max_msg_size}",
- )
- continue
- has_partial = bool(self._partial)
- if is_continuation:
- if self._opcode is None:
- raise WebSocketError(
- WSCloseCode.PROTOCOL_ERROR,
- "Continuation frame for non started message",
- )
- opcode = self._opcode
- self._opcode = None
- # previous frame was non finished
- # we should get continuation opcode
- elif has_partial:
- raise WebSocketError(
- WSCloseCode.PROTOCOL_ERROR,
- "The opcode in non-fin frame is expected "
- f"to be zero, got {opcode!r}",
- )
- assembled_payload: Union[bytes, bytearray]
- if has_partial:
- assembled_payload = self._partial + payload
- self._partial.clear()
- else:
- assembled_payload = payload
- if self._max_msg_size and len(assembled_payload) >= self._max_msg_size:
- raise WebSocketError(
- WSCloseCode.MESSAGE_TOO_BIG,
- f"Message size {len(assembled_payload)} "
- f"exceeds limit {self._max_msg_size}",
- )
- # Decompress process must to be done after all packets
- # received.
- if compressed:
- if not self._decompressobj:
- self._decompressobj = ZLibDecompressor(
- suppress_deflate_header=True
- )
- payload_merged = self._decompressobj.decompress_sync(
- assembled_payload + WS_DEFLATE_TRAILING, self._max_msg_size
- )
- if self._decompressobj.unconsumed_tail:
- left = len(self._decompressobj.unconsumed_tail)
- raise WebSocketError(
- WSCloseCode.MESSAGE_TOO_BIG,
- f"Decompressed message size {self._max_msg_size + left}"
- f" exceeds limit {self._max_msg_size}",
- )
- elif type(assembled_payload) is bytes:
- payload_merged = assembled_payload
- else:
- payload_merged = bytes(assembled_payload)
- if opcode == OP_CODE_TEXT:
- try:
- text = payload_merged.decode("utf-8")
- except UnicodeDecodeError as exc:
- raise WebSocketError(
- WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
- ) from exc
- # XXX: The Text and Binary messages here can be a performance
- # bottleneck, so we use tuple.__new__ to improve performance.
- # This is not type safe, but many tests should fail in
- # test_client_ws_functional.py if this is wrong.
- self.queue.feed_data(
- TUPLE_NEW(WSMessage, (WS_MSG_TYPE_TEXT, text, "")),
- len(payload_merged),
- )
- else:
- self.queue.feed_data(
- TUPLE_NEW(WSMessage, (WS_MSG_TYPE_BINARY, payload_merged, "")),
- len(payload_merged),
- )
- elif opcode == OP_CODE_CLOSE:
- if len(payload) >= 2:
- close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
- if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES:
- raise WebSocketError(
- WSCloseCode.PROTOCOL_ERROR,
- f"Invalid close code: {close_code}",
- )
- try:
- close_message = payload[2:].decode("utf-8")
- except UnicodeDecodeError as exc:
- raise WebSocketError(
- WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
- ) from exc
- msg = TUPLE_NEW(
- WSMessage, (WSMsgType.CLOSE, close_code, close_message)
- )
- elif payload:
- raise WebSocketError(
- WSCloseCode.PROTOCOL_ERROR,
- f"Invalid close frame: {fin} {opcode} {payload!r}",
- )
- else:
- msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, 0, ""))
- self.queue.feed_data(msg, 0)
- elif opcode == OP_CODE_PING:
- msg = TUPLE_NEW(WSMessage, (WSMsgType.PING, payload, ""))
- self.queue.feed_data(msg, len(payload))
- elif opcode == OP_CODE_PONG:
- msg = TUPLE_NEW(WSMessage, (WSMsgType.PONG, payload, ""))
- self.queue.feed_data(msg, len(payload))
- else:
- raise WebSocketError(
- WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}"
- )
- def parse_frame(
- self, buf: bytes
- ) -> List[Tuple[bool, Optional[int], Union[bytes, bytearray], Optional[bool]]]:
- """Return the next frame from the socket."""
- frames: List[
- Tuple[bool, Optional[int], Union[bytes, bytearray], Optional[bool]]
- ] = []
- if self._tail:
- buf, self._tail = self._tail + buf, b""
- start_pos: int = 0
- buf_length = len(buf)
- buf_cstr = buf
- while True:
- # read header
- if self._state == READ_HEADER:
- if buf_length - start_pos < 2:
- break
- first_byte = buf_cstr[start_pos]
- second_byte = buf_cstr[start_pos + 1]
- start_pos += 2
- fin = (first_byte >> 7) & 1
- rsv1 = (first_byte >> 6) & 1
- rsv2 = (first_byte >> 5) & 1
- rsv3 = (first_byte >> 4) & 1
- opcode = first_byte & 0xF
- # frame-fin = %x0 ; more frames of this message follow
- # / %x1 ; final frame of this message
- # frame-rsv1 = %x0 ;
- # 1 bit, MUST be 0 unless negotiated otherwise
- # frame-rsv2 = %x0 ;
- # 1 bit, MUST be 0 unless negotiated otherwise
- # frame-rsv3 = %x0 ;
- # 1 bit, MUST be 0 unless negotiated otherwise
- #
- # Remove rsv1 from this test for deflate development
- if rsv2 or rsv3 or (rsv1 and not self._compress):
- raise WebSocketError(
- WSCloseCode.PROTOCOL_ERROR,
- "Received frame with non-zero reserved bits",
- )
- if opcode > 0x7 and fin == 0:
- raise WebSocketError(
- WSCloseCode.PROTOCOL_ERROR,
- "Received fragmented control frame",
- )
- has_mask = (second_byte >> 7) & 1
- length = second_byte & 0x7F
- # Control frames MUST have a payload
- # length of 125 bytes or less
- if opcode > 0x7 and length > 125:
- raise WebSocketError(
- WSCloseCode.PROTOCOL_ERROR,
- "Control frame payload cannot be larger than 125 bytes",
- )
- # Set compress status if last package is FIN
- # OR set compress status if this is first fragment
- # Raise error if not first fragment with rsv1 = 0x1
- if self._frame_fin or self._compressed is None:
- self._compressed = True if rsv1 else False
- elif rsv1:
- raise WebSocketError(
- WSCloseCode.PROTOCOL_ERROR,
- "Received frame with non-zero reserved bits",
- )
- self._frame_fin = bool(fin)
- self._frame_opcode = opcode
- self._has_mask = bool(has_mask)
- self._payload_length_flag = length
- self._state = READ_PAYLOAD_LENGTH
- # read payload length
- if self._state == READ_PAYLOAD_LENGTH:
- length_flag = self._payload_length_flag
- if length_flag == 126:
- if buf_length - start_pos < 2:
- break
- first_byte = buf_cstr[start_pos]
- second_byte = buf_cstr[start_pos + 1]
- start_pos += 2
- self._payload_length = first_byte << 8 | second_byte
- elif length_flag > 126:
- if buf_length - start_pos < 8:
- break
- data = buf_cstr[start_pos : start_pos + 8]
- start_pos += 8
- self._payload_length = UNPACK_LEN3(data)[0]
- else:
- self._payload_length = length_flag
- self._state = READ_PAYLOAD_MASK if self._has_mask else READ_PAYLOAD
- # read payload mask
- if self._state == READ_PAYLOAD_MASK:
- if buf_length - start_pos < 4:
- break
- self._frame_mask = buf_cstr[start_pos : start_pos + 4]
- start_pos += 4
- self._state = READ_PAYLOAD
- if self._state == READ_PAYLOAD:
- chunk_len = buf_length - start_pos
- if self._payload_length >= chunk_len:
- end_pos = buf_length
- self._payload_length -= chunk_len
- else:
- end_pos = start_pos + self._payload_length
- self._payload_length = 0
- if self._frame_payload_len:
- if type(self._frame_payload) is not bytearray:
- self._frame_payload = bytearray(self._frame_payload)
- self._frame_payload += buf_cstr[start_pos:end_pos]
- else:
- # Fast path for the first frame
- self._frame_payload = buf_cstr[start_pos:end_pos]
- self._frame_payload_len += end_pos - start_pos
- start_pos = end_pos
- if self._payload_length != 0:
- break
- if self._has_mask:
- assert self._frame_mask is not None
- if type(self._frame_payload) is not bytearray:
- self._frame_payload = bytearray(self._frame_payload)
- websocket_mask(self._frame_mask, self._frame_payload)
- frames.append(
- (
- self._frame_fin,
- self._frame_opcode,
- self._frame_payload,
- self._compressed,
- )
- )
- self._frame_payload = b""
- self._frame_payload_len = 0
- self._state = READ_HEADER
- # XXX: Cython needs slices to be bounded, so we can't omit the slice end here.
- self._tail = buf_cstr[start_pos:buf_length] if start_pos < buf_length else b""
- return frames
|