reader_c.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468
  1. """Reader for WebSocket protocol versions 13 and 8."""
  2. import asyncio
  3. import builtins
  4. from collections import deque
  5. from typing import Deque, Final, List, Optional, Set, Tuple, Union
  6. from ..base_protocol import BaseProtocol
  7. from ..compression_utils import ZLibDecompressor
  8. from ..helpers import _EXC_SENTINEL, set_exception
  9. from ..streams import EofStream
  10. from .helpers import UNPACK_CLOSE_CODE, UNPACK_LEN3, websocket_mask
  11. from .models import (
  12. WS_DEFLATE_TRAILING,
  13. WebSocketError,
  14. WSCloseCode,
  15. WSMessage,
  16. WSMsgType,
  17. )
  18. ALLOWED_CLOSE_CODES: Final[Set[int]] = {int(i) for i in WSCloseCode}
  19. # States for the reader, used to parse the WebSocket frame
  20. # integer values are used so they can be cythonized
  21. READ_HEADER = 1
  22. READ_PAYLOAD_LENGTH = 2
  23. READ_PAYLOAD_MASK = 3
  24. READ_PAYLOAD = 4
  25. WS_MSG_TYPE_BINARY = WSMsgType.BINARY
  26. WS_MSG_TYPE_TEXT = WSMsgType.TEXT
  27. # WSMsgType values unpacked so they can by cythonized to ints
  28. OP_CODE_CONTINUATION = WSMsgType.CONTINUATION.value
  29. OP_CODE_TEXT = WSMsgType.TEXT.value
  30. OP_CODE_BINARY = WSMsgType.BINARY.value
  31. OP_CODE_CLOSE = WSMsgType.CLOSE.value
  32. OP_CODE_PING = WSMsgType.PING.value
  33. OP_CODE_PONG = WSMsgType.PONG.value
  34. EMPTY_FRAME_ERROR = (True, b"")
  35. EMPTY_FRAME = (False, b"")
  36. TUPLE_NEW = tuple.__new__
  37. int_ = int # Prevent Cython from converting to PyInt
  38. class WebSocketDataQueue:
  39. """WebSocketDataQueue resumes and pauses an underlying stream.
  40. It is a destination for WebSocket data.
  41. """
  42. def __init__(
  43. self, protocol: BaseProtocol, limit: int, *, loop: asyncio.AbstractEventLoop
  44. ) -> None:
  45. self._size = 0
  46. self._protocol = protocol
  47. self._limit = limit * 2
  48. self._loop = loop
  49. self._eof = False
  50. self._waiter: Optional[asyncio.Future[None]] = None
  51. self._exception: Union[BaseException, None] = None
  52. self._buffer: Deque[Tuple[WSMessage, int]] = deque()
  53. self._get_buffer = self._buffer.popleft
  54. self._put_buffer = self._buffer.append
  55. def is_eof(self) -> bool:
  56. return self._eof
  57. def exception(self) -> Optional[BaseException]:
  58. return self._exception
  59. def set_exception(
  60. self,
  61. exc: "BaseException",
  62. exc_cause: builtins.BaseException = _EXC_SENTINEL,
  63. ) -> None:
  64. self._eof = True
  65. self._exception = exc
  66. if (waiter := self._waiter) is not None:
  67. self._waiter = None
  68. set_exception(waiter, exc, exc_cause)
  69. def _release_waiter(self) -> None:
  70. if (waiter := self._waiter) is None:
  71. return
  72. self._waiter = None
  73. if not waiter.done():
  74. waiter.set_result(None)
  75. def feed_eof(self) -> None:
  76. self._eof = True
  77. self._release_waiter()
  78. self._exception = None # Break cyclic references
  79. def feed_data(self, data: "WSMessage", size: "int_") -> None:
  80. self._size += size
  81. self._put_buffer((data, size))
  82. self._release_waiter()
  83. if self._size > self._limit and not self._protocol._reading_paused:
  84. self._protocol.pause_reading()
  85. async def read(self) -> WSMessage:
  86. if not self._buffer and not self._eof:
  87. assert not self._waiter
  88. self._waiter = self._loop.create_future()
  89. try:
  90. await self._waiter
  91. except (asyncio.CancelledError, asyncio.TimeoutError):
  92. self._waiter = None
  93. raise
  94. return self._read_from_buffer()
  95. def _read_from_buffer(self) -> WSMessage:
  96. if self._buffer:
  97. data, size = self._get_buffer()
  98. self._size -= size
  99. if self._size < self._limit and self._protocol._reading_paused:
  100. self._protocol.resume_reading()
  101. return data
  102. if self._exception is not None:
  103. raise self._exception
  104. raise EofStream
  105. class WebSocketReader:
  106. def __init__(
  107. self, queue: WebSocketDataQueue, max_msg_size: int, compress: bool = True
  108. ) -> None:
  109. self.queue = queue
  110. self._max_msg_size = max_msg_size
  111. self._exc: Optional[Exception] = None
  112. self._partial = bytearray()
  113. self._state = READ_HEADER
  114. self._opcode: Optional[int] = None
  115. self._frame_fin = False
  116. self._frame_opcode: Optional[int] = None
  117. self._frame_payload: Union[bytes, bytearray] = b""
  118. self._frame_payload_len = 0
  119. self._tail: bytes = b""
  120. self._has_mask = False
  121. self._frame_mask: Optional[bytes] = None
  122. self._payload_length = 0
  123. self._payload_length_flag = 0
  124. self._compressed: Optional[bool] = None
  125. self._decompressobj: Optional[ZLibDecompressor] = None
  126. self._compress = compress
  127. def feed_eof(self) -> None:
  128. self.queue.feed_eof()
  129. # data can be bytearray on Windows because proactor event loop uses bytearray
  130. # and asyncio types this to Union[bytes, bytearray, memoryview] so we need
  131. # coerce data to bytes if it is not
  132. def feed_data(
  133. self, data: Union[bytes, bytearray, memoryview]
  134. ) -> Tuple[bool, bytes]:
  135. if type(data) is not bytes:
  136. data = bytes(data)
  137. if self._exc is not None:
  138. return True, data
  139. try:
  140. self._feed_data(data)
  141. except Exception as exc:
  142. self._exc = exc
  143. set_exception(self.queue, exc)
  144. return EMPTY_FRAME_ERROR
  145. return EMPTY_FRAME
  146. def _feed_data(self, data: bytes) -> None:
  147. msg: WSMessage
  148. for frame in self.parse_frame(data):
  149. fin = frame[0]
  150. opcode = frame[1]
  151. payload = frame[2]
  152. compressed = frame[3]
  153. is_continuation = opcode == OP_CODE_CONTINUATION
  154. if opcode == OP_CODE_TEXT or opcode == OP_CODE_BINARY or is_continuation:
  155. # load text/binary
  156. if not fin:
  157. # got partial frame payload
  158. if not is_continuation:
  159. self._opcode = opcode
  160. self._partial += payload
  161. if self._max_msg_size and len(self._partial) >= self._max_msg_size:
  162. raise WebSocketError(
  163. WSCloseCode.MESSAGE_TOO_BIG,
  164. f"Message size {len(self._partial)} "
  165. f"exceeds limit {self._max_msg_size}",
  166. )
  167. continue
  168. has_partial = bool(self._partial)
  169. if is_continuation:
  170. if self._opcode is None:
  171. raise WebSocketError(
  172. WSCloseCode.PROTOCOL_ERROR,
  173. "Continuation frame for non started message",
  174. )
  175. opcode = self._opcode
  176. self._opcode = None
  177. # previous frame was non finished
  178. # we should get continuation opcode
  179. elif has_partial:
  180. raise WebSocketError(
  181. WSCloseCode.PROTOCOL_ERROR,
  182. "The opcode in non-fin frame is expected "
  183. f"to be zero, got {opcode!r}",
  184. )
  185. assembled_payload: Union[bytes, bytearray]
  186. if has_partial:
  187. assembled_payload = self._partial + payload
  188. self._partial.clear()
  189. else:
  190. assembled_payload = payload
  191. if self._max_msg_size and len(assembled_payload) >= self._max_msg_size:
  192. raise WebSocketError(
  193. WSCloseCode.MESSAGE_TOO_BIG,
  194. f"Message size {len(assembled_payload)} "
  195. f"exceeds limit {self._max_msg_size}",
  196. )
  197. # Decompress process must to be done after all packets
  198. # received.
  199. if compressed:
  200. if not self._decompressobj:
  201. self._decompressobj = ZLibDecompressor(
  202. suppress_deflate_header=True
  203. )
  204. payload_merged = self._decompressobj.decompress_sync(
  205. assembled_payload + WS_DEFLATE_TRAILING, self._max_msg_size
  206. )
  207. if self._decompressobj.unconsumed_tail:
  208. left = len(self._decompressobj.unconsumed_tail)
  209. raise WebSocketError(
  210. WSCloseCode.MESSAGE_TOO_BIG,
  211. f"Decompressed message size {self._max_msg_size + left}"
  212. f" exceeds limit {self._max_msg_size}",
  213. )
  214. elif type(assembled_payload) is bytes:
  215. payload_merged = assembled_payload
  216. else:
  217. payload_merged = bytes(assembled_payload)
  218. if opcode == OP_CODE_TEXT:
  219. try:
  220. text = payload_merged.decode("utf-8")
  221. except UnicodeDecodeError as exc:
  222. raise WebSocketError(
  223. WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
  224. ) from exc
  225. # XXX: The Text and Binary messages here can be a performance
  226. # bottleneck, so we use tuple.__new__ to improve performance.
  227. # This is not type safe, but many tests should fail in
  228. # test_client_ws_functional.py if this is wrong.
  229. self.queue.feed_data(
  230. TUPLE_NEW(WSMessage, (WS_MSG_TYPE_TEXT, text, "")),
  231. len(payload_merged),
  232. )
  233. else:
  234. self.queue.feed_data(
  235. TUPLE_NEW(WSMessage, (WS_MSG_TYPE_BINARY, payload_merged, "")),
  236. len(payload_merged),
  237. )
  238. elif opcode == OP_CODE_CLOSE:
  239. if len(payload) >= 2:
  240. close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
  241. if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES:
  242. raise WebSocketError(
  243. WSCloseCode.PROTOCOL_ERROR,
  244. f"Invalid close code: {close_code}",
  245. )
  246. try:
  247. close_message = payload[2:].decode("utf-8")
  248. except UnicodeDecodeError as exc:
  249. raise WebSocketError(
  250. WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
  251. ) from exc
  252. msg = TUPLE_NEW(
  253. WSMessage, (WSMsgType.CLOSE, close_code, close_message)
  254. )
  255. elif payload:
  256. raise WebSocketError(
  257. WSCloseCode.PROTOCOL_ERROR,
  258. f"Invalid close frame: {fin} {opcode} {payload!r}",
  259. )
  260. else:
  261. msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, 0, ""))
  262. self.queue.feed_data(msg, 0)
  263. elif opcode == OP_CODE_PING:
  264. msg = TUPLE_NEW(WSMessage, (WSMsgType.PING, payload, ""))
  265. self.queue.feed_data(msg, len(payload))
  266. elif opcode == OP_CODE_PONG:
  267. msg = TUPLE_NEW(WSMessage, (WSMsgType.PONG, payload, ""))
  268. self.queue.feed_data(msg, len(payload))
  269. else:
  270. raise WebSocketError(
  271. WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}"
  272. )
  273. def parse_frame(
  274. self, buf: bytes
  275. ) -> List[Tuple[bool, Optional[int], Union[bytes, bytearray], Optional[bool]]]:
  276. """Return the next frame from the socket."""
  277. frames: List[
  278. Tuple[bool, Optional[int], Union[bytes, bytearray], Optional[bool]]
  279. ] = []
  280. if self._tail:
  281. buf, self._tail = self._tail + buf, b""
  282. start_pos: int = 0
  283. buf_length = len(buf)
  284. buf_cstr = buf
  285. while True:
  286. # read header
  287. if self._state == READ_HEADER:
  288. if buf_length - start_pos < 2:
  289. break
  290. first_byte = buf_cstr[start_pos]
  291. second_byte = buf_cstr[start_pos + 1]
  292. start_pos += 2
  293. fin = (first_byte >> 7) & 1
  294. rsv1 = (first_byte >> 6) & 1
  295. rsv2 = (first_byte >> 5) & 1
  296. rsv3 = (first_byte >> 4) & 1
  297. opcode = first_byte & 0xF
  298. # frame-fin = %x0 ; more frames of this message follow
  299. # / %x1 ; final frame of this message
  300. # frame-rsv1 = %x0 ;
  301. # 1 bit, MUST be 0 unless negotiated otherwise
  302. # frame-rsv2 = %x0 ;
  303. # 1 bit, MUST be 0 unless negotiated otherwise
  304. # frame-rsv3 = %x0 ;
  305. # 1 bit, MUST be 0 unless negotiated otherwise
  306. #
  307. # Remove rsv1 from this test for deflate development
  308. if rsv2 or rsv3 or (rsv1 and not self._compress):
  309. raise WebSocketError(
  310. WSCloseCode.PROTOCOL_ERROR,
  311. "Received frame with non-zero reserved bits",
  312. )
  313. if opcode > 0x7 and fin == 0:
  314. raise WebSocketError(
  315. WSCloseCode.PROTOCOL_ERROR,
  316. "Received fragmented control frame",
  317. )
  318. has_mask = (second_byte >> 7) & 1
  319. length = second_byte & 0x7F
  320. # Control frames MUST have a payload
  321. # length of 125 bytes or less
  322. if opcode > 0x7 and length > 125:
  323. raise WebSocketError(
  324. WSCloseCode.PROTOCOL_ERROR,
  325. "Control frame payload cannot be larger than 125 bytes",
  326. )
  327. # Set compress status if last package is FIN
  328. # OR set compress status if this is first fragment
  329. # Raise error if not first fragment with rsv1 = 0x1
  330. if self._frame_fin or self._compressed is None:
  331. self._compressed = True if rsv1 else False
  332. elif rsv1:
  333. raise WebSocketError(
  334. WSCloseCode.PROTOCOL_ERROR,
  335. "Received frame with non-zero reserved bits",
  336. )
  337. self._frame_fin = bool(fin)
  338. self._frame_opcode = opcode
  339. self._has_mask = bool(has_mask)
  340. self._payload_length_flag = length
  341. self._state = READ_PAYLOAD_LENGTH
  342. # read payload length
  343. if self._state == READ_PAYLOAD_LENGTH:
  344. length_flag = self._payload_length_flag
  345. if length_flag == 126:
  346. if buf_length - start_pos < 2:
  347. break
  348. first_byte = buf_cstr[start_pos]
  349. second_byte = buf_cstr[start_pos + 1]
  350. start_pos += 2
  351. self._payload_length = first_byte << 8 | second_byte
  352. elif length_flag > 126:
  353. if buf_length - start_pos < 8:
  354. break
  355. data = buf_cstr[start_pos : start_pos + 8]
  356. start_pos += 8
  357. self._payload_length = UNPACK_LEN3(data)[0]
  358. else:
  359. self._payload_length = length_flag
  360. self._state = READ_PAYLOAD_MASK if self._has_mask else READ_PAYLOAD
  361. # read payload mask
  362. if self._state == READ_PAYLOAD_MASK:
  363. if buf_length - start_pos < 4:
  364. break
  365. self._frame_mask = buf_cstr[start_pos : start_pos + 4]
  366. start_pos += 4
  367. self._state = READ_PAYLOAD
  368. if self._state == READ_PAYLOAD:
  369. chunk_len = buf_length - start_pos
  370. if self._payload_length >= chunk_len:
  371. end_pos = buf_length
  372. self._payload_length -= chunk_len
  373. else:
  374. end_pos = start_pos + self._payload_length
  375. self._payload_length = 0
  376. if self._frame_payload_len:
  377. if type(self._frame_payload) is not bytearray:
  378. self._frame_payload = bytearray(self._frame_payload)
  379. self._frame_payload += buf_cstr[start_pos:end_pos]
  380. else:
  381. # Fast path for the first frame
  382. self._frame_payload = buf_cstr[start_pos:end_pos]
  383. self._frame_payload_len += end_pos - start_pos
  384. start_pos = end_pos
  385. if self._payload_length != 0:
  386. break
  387. if self._has_mask:
  388. assert self._frame_mask is not None
  389. if type(self._frame_payload) is not bytearray:
  390. self._frame_payload = bytearray(self._frame_payload)
  391. websocket_mask(self._frame_mask, self._frame_payload)
  392. frames.append(
  393. (
  394. self._frame_fin,
  395. self._frame_opcode,
  396. self._frame_payload,
  397. self._compressed,
  398. )
  399. )
  400. self._frame_payload = b""
  401. self._frame_payload_len = 0
  402. self._state = READ_HEADER
  403. # XXX: Cython needs slices to be bounded, so we can't omit the slice end here.
  404. self._tail = buf_cstr[start_pos:buf_length] if start_pos < buf_length else b""
  405. return frames