tls.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352
  1. from __future__ import annotations
  2. import logging
  3. import re
  4. import ssl
  5. import sys
  6. from collections.abc import Callable, Mapping
  7. from dataclasses import dataclass
  8. from functools import wraps
  9. from typing import Any, TypeVar
  10. from .. import (
  11. BrokenResourceError,
  12. EndOfStream,
  13. aclose_forcefully,
  14. get_cancelled_exc_class,
  15. to_thread,
  16. )
  17. from .._core._typedattr import TypedAttributeSet, typed_attribute
  18. from ..abc import AnyByteStream, ByteStream, Listener, TaskGroup
  19. if sys.version_info >= (3, 11):
  20. from typing import TypeVarTuple, Unpack
  21. else:
  22. from typing_extensions import TypeVarTuple, Unpack
  23. T_Retval = TypeVar("T_Retval")
  24. PosArgsT = TypeVarTuple("PosArgsT")
  25. _PCTRTT = tuple[tuple[str, str], ...]
  26. _PCTRTTT = tuple[_PCTRTT, ...]
  27. class TLSAttribute(TypedAttributeSet):
  28. """Contains Transport Layer Security related attributes."""
  29. #: the selected ALPN protocol
  30. alpn_protocol: str | None = typed_attribute()
  31. #: the channel binding for type ``tls-unique``
  32. channel_binding_tls_unique: bytes = typed_attribute()
  33. #: the selected cipher
  34. cipher: tuple[str, str, int] = typed_attribute()
  35. #: the peer certificate in dictionary form (see :meth:`ssl.SSLSocket.getpeercert`
  36. # for more information)
  37. peer_certificate: None | (dict[str, str | _PCTRTTT | _PCTRTT]) = typed_attribute()
  38. #: the peer certificate in binary form
  39. peer_certificate_binary: bytes | None = typed_attribute()
  40. #: ``True`` if this is the server side of the connection
  41. server_side: bool = typed_attribute()
  42. #: ciphers shared by the client during the TLS handshake (``None`` if this is the
  43. #: client side)
  44. shared_ciphers: list[tuple[str, str, int]] | None = typed_attribute()
  45. #: the :class:`~ssl.SSLObject` used for encryption
  46. ssl_object: ssl.SSLObject = typed_attribute()
  47. #: ``True`` if this stream does (and expects) a closing TLS handshake when the
  48. #: stream is being closed
  49. standard_compatible: bool = typed_attribute()
  50. #: the TLS protocol version (e.g. ``TLSv1.2``)
  51. tls_version: str = typed_attribute()
  52. @dataclass(eq=False)
  53. class TLSStream(ByteStream):
  54. """
  55. A stream wrapper that encrypts all sent data and decrypts received data.
  56. This class has no public initializer; use :meth:`wrap` instead.
  57. All extra attributes from :class:`~TLSAttribute` are supported.
  58. :var AnyByteStream transport_stream: the wrapped stream
  59. """
  60. transport_stream: AnyByteStream
  61. standard_compatible: bool
  62. _ssl_object: ssl.SSLObject
  63. _read_bio: ssl.MemoryBIO
  64. _write_bio: ssl.MemoryBIO
  65. @classmethod
  66. async def wrap(
  67. cls,
  68. transport_stream: AnyByteStream,
  69. *,
  70. server_side: bool | None = None,
  71. hostname: str | None = None,
  72. ssl_context: ssl.SSLContext | None = None,
  73. standard_compatible: bool = True,
  74. ) -> TLSStream:
  75. """
  76. Wrap an existing stream with Transport Layer Security.
  77. This performs a TLS handshake with the peer.
  78. :param transport_stream: a bytes-transporting stream to wrap
  79. :param server_side: ``True`` if this is the server side of the connection,
  80. ``False`` if this is the client side (if omitted, will be set to ``False``
  81. if ``hostname`` has been provided, ``False`` otherwise). Used only to create
  82. a default context when an explicit context has not been provided.
  83. :param hostname: host name of the peer (if host name checking is desired)
  84. :param ssl_context: the SSLContext object to use (if not provided, a secure
  85. default will be created)
  86. :param standard_compatible: if ``False``, skip the closing handshake when
  87. closing the connection, and don't raise an exception if the peer does the
  88. same
  89. :raises ~ssl.SSLError: if the TLS handshake fails
  90. """
  91. if server_side is None:
  92. server_side = not hostname
  93. if not ssl_context:
  94. purpose = (
  95. ssl.Purpose.CLIENT_AUTH if server_side else ssl.Purpose.SERVER_AUTH
  96. )
  97. ssl_context = ssl.create_default_context(purpose)
  98. # Re-enable detection of unexpected EOFs if it was disabled by Python
  99. if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"):
  100. ssl_context.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF
  101. bio_in = ssl.MemoryBIO()
  102. bio_out = ssl.MemoryBIO()
  103. # External SSLContext implementations may do blocking I/O in wrap_bio(),
  104. # but the standard library implementation won't
  105. if type(ssl_context) is ssl.SSLContext:
  106. ssl_object = ssl_context.wrap_bio(
  107. bio_in, bio_out, server_side=server_side, server_hostname=hostname
  108. )
  109. else:
  110. ssl_object = await to_thread.run_sync(
  111. ssl_context.wrap_bio,
  112. bio_in,
  113. bio_out,
  114. server_side,
  115. hostname,
  116. None,
  117. )
  118. wrapper = cls(
  119. transport_stream=transport_stream,
  120. standard_compatible=standard_compatible,
  121. _ssl_object=ssl_object,
  122. _read_bio=bio_in,
  123. _write_bio=bio_out,
  124. )
  125. await wrapper._call_sslobject_method(ssl_object.do_handshake)
  126. return wrapper
  127. async def _call_sslobject_method(
  128. self, func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT]
  129. ) -> T_Retval:
  130. while True:
  131. try:
  132. result = func(*args)
  133. except ssl.SSLWantReadError:
  134. try:
  135. # Flush any pending writes first
  136. if self._write_bio.pending:
  137. await self.transport_stream.send(self._write_bio.read())
  138. data = await self.transport_stream.receive()
  139. except EndOfStream:
  140. self._read_bio.write_eof()
  141. except OSError as exc:
  142. self._read_bio.write_eof()
  143. self._write_bio.write_eof()
  144. raise BrokenResourceError from exc
  145. else:
  146. self._read_bio.write(data)
  147. except ssl.SSLWantWriteError:
  148. await self.transport_stream.send(self._write_bio.read())
  149. except ssl.SSLSyscallError as exc:
  150. self._read_bio.write_eof()
  151. self._write_bio.write_eof()
  152. raise BrokenResourceError from exc
  153. except ssl.SSLError as exc:
  154. self._read_bio.write_eof()
  155. self._write_bio.write_eof()
  156. if isinstance(exc, ssl.SSLEOFError) or (
  157. exc.strerror and "UNEXPECTED_EOF_WHILE_READING" in exc.strerror
  158. ):
  159. if self.standard_compatible:
  160. raise BrokenResourceError from exc
  161. else:
  162. raise EndOfStream from None
  163. raise
  164. else:
  165. # Flush any pending writes first
  166. if self._write_bio.pending:
  167. await self.transport_stream.send(self._write_bio.read())
  168. return result
  169. async def unwrap(self) -> tuple[AnyByteStream, bytes]:
  170. """
  171. Does the TLS closing handshake.
  172. :return: a tuple of (wrapped byte stream, bytes left in the read buffer)
  173. """
  174. await self._call_sslobject_method(self._ssl_object.unwrap)
  175. self._read_bio.write_eof()
  176. self._write_bio.write_eof()
  177. return self.transport_stream, self._read_bio.read()
  178. async def aclose(self) -> None:
  179. if self.standard_compatible:
  180. try:
  181. await self.unwrap()
  182. except BaseException:
  183. await aclose_forcefully(self.transport_stream)
  184. raise
  185. await self.transport_stream.aclose()
  186. async def receive(self, max_bytes: int = 65536) -> bytes:
  187. data = await self._call_sslobject_method(self._ssl_object.read, max_bytes)
  188. if not data:
  189. raise EndOfStream
  190. return data
  191. async def send(self, item: bytes) -> None:
  192. await self._call_sslobject_method(self._ssl_object.write, item)
  193. async def send_eof(self) -> None:
  194. tls_version = self.extra(TLSAttribute.tls_version)
  195. match = re.match(r"TLSv(\d+)(?:\.(\d+))?", tls_version)
  196. if match:
  197. major, minor = int(match.group(1)), int(match.group(2) or 0)
  198. if (major, minor) < (1, 3):
  199. raise NotImplementedError(
  200. f"send_eof() requires at least TLSv1.3; current "
  201. f"session uses {tls_version}"
  202. )
  203. raise NotImplementedError(
  204. "send_eof() has not yet been implemented for TLS streams"
  205. )
  206. @property
  207. def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
  208. return {
  209. **self.transport_stream.extra_attributes,
  210. TLSAttribute.alpn_protocol: self._ssl_object.selected_alpn_protocol,
  211. TLSAttribute.channel_binding_tls_unique: (
  212. self._ssl_object.get_channel_binding
  213. ),
  214. TLSAttribute.cipher: self._ssl_object.cipher,
  215. TLSAttribute.peer_certificate: lambda: self._ssl_object.getpeercert(False),
  216. TLSAttribute.peer_certificate_binary: lambda: self._ssl_object.getpeercert(
  217. True
  218. ),
  219. TLSAttribute.server_side: lambda: self._ssl_object.server_side,
  220. TLSAttribute.shared_ciphers: lambda: self._ssl_object.shared_ciphers()
  221. if self._ssl_object.server_side
  222. else None,
  223. TLSAttribute.standard_compatible: lambda: self.standard_compatible,
  224. TLSAttribute.ssl_object: lambda: self._ssl_object,
  225. TLSAttribute.tls_version: self._ssl_object.version,
  226. }
  227. @dataclass(eq=False)
  228. class TLSListener(Listener[TLSStream]):
  229. """
  230. A convenience listener that wraps another listener and auto-negotiates a TLS session
  231. on every accepted connection.
  232. If the TLS handshake times out or raises an exception,
  233. :meth:`handle_handshake_error` is called to do whatever post-mortem processing is
  234. deemed necessary.
  235. Supports only the :attr:`~TLSAttribute.standard_compatible` extra attribute.
  236. :param Listener listener: the listener to wrap
  237. :param ssl_context: the SSL context object
  238. :param standard_compatible: a flag passed through to :meth:`TLSStream.wrap`
  239. :param handshake_timeout: time limit for the TLS handshake
  240. (passed to :func:`~anyio.fail_after`)
  241. """
  242. listener: Listener[Any]
  243. ssl_context: ssl.SSLContext
  244. standard_compatible: bool = True
  245. handshake_timeout: float = 30
  246. @staticmethod
  247. async def handle_handshake_error(exc: BaseException, stream: AnyByteStream) -> None:
  248. """
  249. Handle an exception raised during the TLS handshake.
  250. This method does 3 things:
  251. #. Forcefully closes the original stream
  252. #. Logs the exception (unless it was a cancellation exception) using the
  253. ``anyio.streams.tls`` logger
  254. #. Reraises the exception if it was a base exception or a cancellation exception
  255. :param exc: the exception
  256. :param stream: the original stream
  257. """
  258. await aclose_forcefully(stream)
  259. # Log all except cancellation exceptions
  260. if not isinstance(exc, get_cancelled_exc_class()):
  261. # CPython (as of 3.11.5) returns incorrect `sys.exc_info()` here when using
  262. # any asyncio implementation, so we explicitly pass the exception to log
  263. # (https://github.com/python/cpython/issues/108668). Trio does not have this
  264. # issue because it works around the CPython bug.
  265. logging.getLogger(__name__).exception(
  266. "Error during TLS handshake", exc_info=exc
  267. )
  268. # Only reraise base exceptions and cancellation exceptions
  269. if not isinstance(exc, Exception) or isinstance(exc, get_cancelled_exc_class()):
  270. raise
  271. async def serve(
  272. self,
  273. handler: Callable[[TLSStream], Any],
  274. task_group: TaskGroup | None = None,
  275. ) -> None:
  276. @wraps(handler)
  277. async def handler_wrapper(stream: AnyByteStream) -> None:
  278. from .. import fail_after
  279. try:
  280. with fail_after(self.handshake_timeout):
  281. wrapped_stream = await TLSStream.wrap(
  282. stream,
  283. ssl_context=self.ssl_context,
  284. standard_compatible=self.standard_compatible,
  285. )
  286. except BaseException as exc:
  287. await self.handle_handshake_error(exc, stream)
  288. else:
  289. await handler(wrapped_stream)
  290. await self.listener.serve(handler_wrapper, task_group)
  291. async def aclose(self) -> None:
  292. await self.listener.aclose()
  293. @property
  294. def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
  295. return {
  296. TLSAttribute.standard_compatible: lambda: self.standard_compatible,
  297. }