helpers.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. """Helpers for WebSocket protocol versions 13 and 8."""
  2. import functools
  3. import re
  4. from struct import Struct
  5. from typing import TYPE_CHECKING, Final, List, Optional, Pattern, Tuple
  6. from ..helpers import NO_EXTENSIONS
  7. from .models import WSHandshakeError
  8. UNPACK_LEN3 = Struct("!Q").unpack_from
  9. UNPACK_CLOSE_CODE = Struct("!H").unpack
  10. PACK_LEN1 = Struct("!BB").pack
  11. PACK_LEN2 = Struct("!BBH").pack
  12. PACK_LEN3 = Struct("!BBQ").pack
  13. PACK_CLOSE_CODE = Struct("!H").pack
  14. PACK_RANDBITS = Struct("!L").pack
  15. MSG_SIZE: Final[int] = 2**14
  16. MASK_LEN: Final[int] = 4
  17. WS_KEY: Final[bytes] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
  18. # Used by _websocket_mask_python
  19. @functools.lru_cache
  20. def _xor_table() -> List[bytes]:
  21. return [bytes(a ^ b for a in range(256)) for b in range(256)]
  22. def _websocket_mask_python(mask: bytes, data: bytearray) -> None:
  23. """Websocket masking function.
  24. `mask` is a `bytes` object of length 4; `data` is a `bytearray`
  25. object of any length. The contents of `data` are masked with `mask`,
  26. as specified in section 5.3 of RFC 6455.
  27. Note that this function mutates the `data` argument.
  28. This pure-python implementation may be replaced by an optimized
  29. version when available.
  30. """
  31. assert isinstance(data, bytearray), data
  32. assert len(mask) == 4, mask
  33. if data:
  34. _XOR_TABLE = _xor_table()
  35. a, b, c, d = (_XOR_TABLE[n] for n in mask)
  36. data[::4] = data[::4].translate(a)
  37. data[1::4] = data[1::4].translate(b)
  38. data[2::4] = data[2::4].translate(c)
  39. data[3::4] = data[3::4].translate(d)
  40. if TYPE_CHECKING or NO_EXTENSIONS: # pragma: no cover
  41. websocket_mask = _websocket_mask_python
  42. else:
  43. try:
  44. from .mask import _websocket_mask_cython # type: ignore[import-not-found]
  45. websocket_mask = _websocket_mask_cython
  46. except ImportError: # pragma: no cover
  47. websocket_mask = _websocket_mask_python
  48. _WS_EXT_RE: Final[Pattern[str]] = re.compile(
  49. r"^(?:;\s*(?:"
  50. r"(server_no_context_takeover)|"
  51. r"(client_no_context_takeover)|"
  52. r"(server_max_window_bits(?:=(\d+))?)|"
  53. r"(client_max_window_bits(?:=(\d+))?)))*$"
  54. )
  55. _WS_EXT_RE_SPLIT: Final[Pattern[str]] = re.compile(r"permessage-deflate([^,]+)?")
  56. def ws_ext_parse(extstr: Optional[str], isserver: bool = False) -> Tuple[int, bool]:
  57. if not extstr:
  58. return 0, False
  59. compress = 0
  60. notakeover = False
  61. for ext in _WS_EXT_RE_SPLIT.finditer(extstr):
  62. defext = ext.group(1)
  63. # Return compress = 15 when get `permessage-deflate`
  64. if not defext:
  65. compress = 15
  66. break
  67. match = _WS_EXT_RE.match(defext)
  68. if match:
  69. compress = 15
  70. if isserver:
  71. # Server never fail to detect compress handshake.
  72. # Server does not need to send max wbit to client
  73. if match.group(4):
  74. compress = int(match.group(4))
  75. # Group3 must match if group4 matches
  76. # Compress wbit 8 does not support in zlib
  77. # If compress level not support,
  78. # CONTINUE to next extension
  79. if compress > 15 or compress < 9:
  80. compress = 0
  81. continue
  82. if match.group(1):
  83. notakeover = True
  84. # Ignore regex group 5 & 6 for client_max_window_bits
  85. break
  86. else:
  87. if match.group(6):
  88. compress = int(match.group(6))
  89. # Group5 must match if group6 matches
  90. # Compress wbit 8 does not support in zlib
  91. # If compress level not support,
  92. # FAIL the parse progress
  93. if compress > 15 or compress < 9:
  94. raise WSHandshakeError("Invalid window size")
  95. if match.group(2):
  96. notakeover = True
  97. # Ignore regex group 5 & 6 for client_max_window_bits
  98. break
  99. # Return Fail if client side and not match
  100. elif not isserver:
  101. raise WSHandshakeError("Extension for deflate not supported" + ext.group(1))
  102. return compress, notakeover
  103. def ws_ext_gen(
  104. compress: int = 15, isserver: bool = False, server_notakeover: bool = False
  105. ) -> str:
  106. # client_notakeover=False not used for server
  107. # compress wbit 8 does not support in zlib
  108. if compress < 9 or compress > 15:
  109. raise ValueError(
  110. "Compress wbits must between 9 and 15, zlib does not support wbits=8"
  111. )
  112. enabledext = ["permessage-deflate"]
  113. if not isserver:
  114. enabledext.append("client_max_window_bits")
  115. if compress < 15:
  116. enabledext.append("server_max_window_bits=" + str(compress))
  117. if server_notakeover:
  118. enabledext.append("server_no_context_takeover")
  119. # if client_notakeover:
  120. # enabledext.append('client_no_context_takeover')
  121. return "; ".join(enabledext)