connection.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. from __future__ import annotations
  2. import logging
  3. import re
  4. import threading
  5. import types
  6. import typing
  7. import h2.config # type: ignore[import-untyped]
  8. import h2.connection # type: ignore[import-untyped]
  9. import h2.events # type: ignore[import-untyped]
  10. from .._base_connection import _TYPE_BODY
  11. from .._collections import HTTPHeaderDict
  12. from ..connection import HTTPSConnection, _get_default_user_agent
  13. from ..exceptions import ConnectionError
  14. from ..response import BaseHTTPResponse
  15. orig_HTTPSConnection = HTTPSConnection
  16. T = typing.TypeVar("T")
  17. log = logging.getLogger(__name__)
  18. RE_IS_LEGAL_HEADER_NAME = re.compile(rb"^[!#$%&'*+\-.^_`|~0-9a-z]+$")
  19. RE_IS_ILLEGAL_HEADER_VALUE = re.compile(rb"[\0\x00\x0a\x0d\r\n]|^[ \r\n\t]|[ \r\n\t]$")
  20. def _is_legal_header_name(name: bytes) -> bool:
  21. """
  22. "An implementation that validates fields according to the definitions in Sections
  23. 5.1 and 5.5 of [HTTP] only needs an additional check that field names do not
  24. include uppercase characters." (https://httpwg.org/specs/rfc9113.html#n-field-validity)
  25. `http.client._is_legal_header_name` does not validate the field name according to the
  26. HTTP 1.1 spec, so we do that here, in addition to checking for uppercase characters.
  27. This does not allow for the `:` character in the header name, so should not
  28. be used to validate pseudo-headers.
  29. """
  30. return bool(RE_IS_LEGAL_HEADER_NAME.match(name))
  31. def _is_illegal_header_value(value: bytes) -> bool:
  32. """
  33. "A field value MUST NOT contain the zero value (ASCII NUL, 0x00), line feed
  34. (ASCII LF, 0x0a), or carriage return (ASCII CR, 0x0d) at any position. A field
  35. value MUST NOT start or end with an ASCII whitespace character (ASCII SP or HTAB,
  36. 0x20 or 0x09)." (https://httpwg.org/specs/rfc9113.html#n-field-validity)
  37. """
  38. return bool(RE_IS_ILLEGAL_HEADER_VALUE.search(value))
  39. class _LockedObject(typing.Generic[T]):
  40. """
  41. A wrapper class that hides a specific object behind a lock.
  42. The goal here is to provide a simple way to protect access to an object
  43. that cannot safely be simultaneously accessed from multiple threads. The
  44. intended use of this class is simple: take hold of it with a context
  45. manager, which returns the protected object.
  46. """
  47. __slots__ = (
  48. "lock",
  49. "_obj",
  50. )
  51. def __init__(self, obj: T):
  52. self.lock = threading.RLock()
  53. self._obj = obj
  54. def __enter__(self) -> T:
  55. self.lock.acquire()
  56. return self._obj
  57. def __exit__(
  58. self,
  59. exc_type: type[BaseException] | None,
  60. exc_val: BaseException | None,
  61. exc_tb: types.TracebackType | None,
  62. ) -> None:
  63. self.lock.release()
  64. class HTTP2Connection(HTTPSConnection):
  65. def __init__(
  66. self, host: str, port: int | None = None, **kwargs: typing.Any
  67. ) -> None:
  68. self._h2_conn = self._new_h2_conn()
  69. self._h2_stream: int | None = None
  70. self._headers: list[tuple[bytes, bytes]] = []
  71. if "proxy" in kwargs or "proxy_config" in kwargs: # Defensive:
  72. raise NotImplementedError("Proxies aren't supported with HTTP/2")
  73. super().__init__(host, port, **kwargs)
  74. if self._tunnel_host is not None:
  75. raise NotImplementedError("Tunneling isn't supported with HTTP/2")
  76. def _new_h2_conn(self) -> _LockedObject[h2.connection.H2Connection]:
  77. config = h2.config.H2Configuration(client_side=True)
  78. return _LockedObject(h2.connection.H2Connection(config=config))
  79. def connect(self) -> None:
  80. super().connect()
  81. with self._h2_conn as conn:
  82. conn.initiate_connection()
  83. if data_to_send := conn.data_to_send():
  84. self.sock.sendall(data_to_send)
  85. def putrequest( # type: ignore[override]
  86. self,
  87. method: str,
  88. url: str,
  89. **kwargs: typing.Any,
  90. ) -> None:
  91. """putrequest
  92. This deviates from the HTTPConnection method signature since we never need to override
  93. sending accept-encoding headers or the host header.
  94. """
  95. if "skip_host" in kwargs:
  96. raise NotImplementedError("`skip_host` isn't supported")
  97. if "skip_accept_encoding" in kwargs:
  98. raise NotImplementedError("`skip_accept_encoding` isn't supported")
  99. self._request_url = url or "/"
  100. self._validate_path(url) # type: ignore[attr-defined]
  101. if ":" in self.host:
  102. authority = f"[{self.host}]:{self.port or 443}"
  103. else:
  104. authority = f"{self.host}:{self.port or 443}"
  105. self._headers.append((b":scheme", b"https"))
  106. self._headers.append((b":method", method.encode()))
  107. self._headers.append((b":authority", authority.encode()))
  108. self._headers.append((b":path", url.encode()))
  109. with self._h2_conn as conn:
  110. self._h2_stream = conn.get_next_available_stream_id()
  111. def putheader(self, header: str | bytes, *values: str | bytes) -> None:
  112. # TODO SKIPPABLE_HEADERS from urllib3 are ignored.
  113. header = header.encode() if isinstance(header, str) else header
  114. header = header.lower() # A lot of upstream code uses capitalized headers.
  115. if not _is_legal_header_name(header):
  116. raise ValueError(f"Illegal header name {str(header)}")
  117. for value in values:
  118. value = value.encode() if isinstance(value, str) else value
  119. if _is_illegal_header_value(value):
  120. raise ValueError(f"Illegal header value {str(value)}")
  121. self._headers.append((header, value))
  122. def endheaders(self, message_body: typing.Any = None) -> None: # type: ignore[override]
  123. if self._h2_stream is None:
  124. raise ConnectionError("Must call `putrequest` first.")
  125. with self._h2_conn as conn:
  126. conn.send_headers(
  127. stream_id=self._h2_stream,
  128. headers=self._headers,
  129. end_stream=(message_body is None),
  130. )
  131. if data_to_send := conn.data_to_send():
  132. self.sock.sendall(data_to_send)
  133. self._headers = [] # Reset headers for the next request.
  134. def send(self, data: typing.Any) -> None:
  135. """Send data to the server.
  136. `data` can be: `str`, `bytes`, an iterable, or file-like objects
  137. that support a .read() method.
  138. """
  139. if self._h2_stream is None:
  140. raise ConnectionError("Must call `putrequest` first.")
  141. with self._h2_conn as conn:
  142. if data_to_send := conn.data_to_send():
  143. self.sock.sendall(data_to_send)
  144. if hasattr(data, "read"): # file-like objects
  145. while True:
  146. chunk = data.read(self.blocksize)
  147. if not chunk:
  148. break
  149. if isinstance(chunk, str):
  150. chunk = chunk.encode() # pragma: no cover
  151. conn.send_data(self._h2_stream, chunk, end_stream=False)
  152. if data_to_send := conn.data_to_send():
  153. self.sock.sendall(data_to_send)
  154. conn.end_stream(self._h2_stream)
  155. return
  156. if isinstance(data, str): # str -> bytes
  157. data = data.encode()
  158. try:
  159. if isinstance(data, bytes):
  160. conn.send_data(self._h2_stream, data, end_stream=True)
  161. if data_to_send := conn.data_to_send():
  162. self.sock.sendall(data_to_send)
  163. else:
  164. for chunk in data:
  165. conn.send_data(self._h2_stream, chunk, end_stream=False)
  166. if data_to_send := conn.data_to_send():
  167. self.sock.sendall(data_to_send)
  168. conn.end_stream(self._h2_stream)
  169. except TypeError:
  170. raise TypeError(
  171. "`data` should be str, bytes, iterable, or file. got %r"
  172. % type(data)
  173. )
  174. def set_tunnel(
  175. self,
  176. host: str,
  177. port: int | None = None,
  178. headers: typing.Mapping[str, str] | None = None,
  179. scheme: str = "http",
  180. ) -> None:
  181. raise NotImplementedError(
  182. "HTTP/2 does not support setting up a tunnel through a proxy"
  183. )
  184. def getresponse( # type: ignore[override]
  185. self,
  186. ) -> HTTP2Response:
  187. status = None
  188. data = bytearray()
  189. with self._h2_conn as conn:
  190. end_stream = False
  191. while not end_stream:
  192. # TODO: Arbitrary read value.
  193. if received_data := self.sock.recv(65535):
  194. events = conn.receive_data(received_data)
  195. for event in events:
  196. if isinstance(event, h2.events.ResponseReceived):
  197. headers = HTTPHeaderDict()
  198. for header, value in event.headers:
  199. if header == b":status":
  200. status = int(value.decode())
  201. else:
  202. headers.add(
  203. header.decode("ascii"), value.decode("ascii")
  204. )
  205. elif isinstance(event, h2.events.DataReceived):
  206. data += event.data
  207. conn.acknowledge_received_data(
  208. event.flow_controlled_length, event.stream_id
  209. )
  210. elif isinstance(event, h2.events.StreamEnded):
  211. end_stream = True
  212. if data_to_send := conn.data_to_send():
  213. self.sock.sendall(data_to_send)
  214. assert status is not None
  215. return HTTP2Response(
  216. status=status,
  217. headers=headers,
  218. request_url=self._request_url,
  219. data=bytes(data),
  220. )
  221. def request( # type: ignore[override]
  222. self,
  223. method: str,
  224. url: str,
  225. body: _TYPE_BODY | None = None,
  226. headers: typing.Mapping[str, str] | None = None,
  227. *,
  228. preload_content: bool = True,
  229. decode_content: bool = True,
  230. enforce_content_length: bool = True,
  231. **kwargs: typing.Any,
  232. ) -> None:
  233. """Send an HTTP/2 request"""
  234. if "chunked" in kwargs:
  235. # TODO this is often present from upstream.
  236. # raise NotImplementedError("`chunked` isn't supported with HTTP/2")
  237. pass
  238. if self.sock is not None:
  239. self.sock.settimeout(self.timeout)
  240. self.putrequest(method, url)
  241. headers = headers or {}
  242. for k, v in headers.items():
  243. if k.lower() == "transfer-encoding" and v == "chunked":
  244. continue
  245. else:
  246. self.putheader(k, v)
  247. if b"user-agent" not in dict(self._headers):
  248. self.putheader(b"user-agent", _get_default_user_agent())
  249. if body:
  250. self.endheaders(message_body=body)
  251. self.send(body)
  252. else:
  253. self.endheaders()
  254. def close(self) -> None:
  255. with self._h2_conn as conn:
  256. try:
  257. conn.close_connection()
  258. if data := conn.data_to_send():
  259. self.sock.sendall(data)
  260. except Exception:
  261. pass
  262. # Reset all our HTTP/2 connection state.
  263. self._h2_conn = self._new_h2_conn()
  264. self._h2_stream = None
  265. self._headers = []
  266. super().close()
  267. class HTTP2Response(BaseHTTPResponse):
  268. # TODO: This is a woefully incomplete response object, but works for non-streaming.
  269. def __init__(
  270. self,
  271. status: int,
  272. headers: HTTPHeaderDict,
  273. request_url: str,
  274. data: bytes,
  275. decode_content: bool = False, # TODO: support decoding
  276. ) -> None:
  277. super().__init__(
  278. status=status,
  279. headers=headers,
  280. # Following CPython, we map HTTP versions to major * 10 + minor integers
  281. version=20,
  282. version_string="HTTP/2",
  283. # No reason phrase in HTTP/2
  284. reason=None,
  285. decode_content=decode_content,
  286. request_url=request_url,
  287. )
  288. self._data = data
  289. self.length_remaining = 0
  290. @property
  291. def data(self) -> bytes:
  292. return self._data
  293. def get_redirect_location(self) -> None:
  294. return None
  295. def close(self) -> None:
  296. pass