123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356 |
- from __future__ import annotations
- import logging
- import re
- import threading
- import types
- import typing
- import h2.config # type: ignore[import-untyped]
- import h2.connection # type: ignore[import-untyped]
- import h2.events # type: ignore[import-untyped]
- from .._base_connection import _TYPE_BODY
- from .._collections import HTTPHeaderDict
- from ..connection import HTTPSConnection, _get_default_user_agent
- from ..exceptions import ConnectionError
- from ..response import BaseHTTPResponse
- orig_HTTPSConnection = HTTPSConnection
- T = typing.TypeVar("T")
- log = logging.getLogger(__name__)
- RE_IS_LEGAL_HEADER_NAME = re.compile(rb"^[!#$%&'*+\-.^_`|~0-9a-z]+$")
- RE_IS_ILLEGAL_HEADER_VALUE = re.compile(rb"[\0\x00\x0a\x0d\r\n]|^[ \r\n\t]|[ \r\n\t]$")
- def _is_legal_header_name(name: bytes) -> bool:
- """
- "An implementation that validates fields according to the definitions in Sections
- 5.1 and 5.5 of [HTTP] only needs an additional check that field names do not
- include uppercase characters." (https://httpwg.org/specs/rfc9113.html#n-field-validity)
- `http.client._is_legal_header_name` does not validate the field name according to the
- HTTP 1.1 spec, so we do that here, in addition to checking for uppercase characters.
- This does not allow for the `:` character in the header name, so should not
- be used to validate pseudo-headers.
- """
- return bool(RE_IS_LEGAL_HEADER_NAME.match(name))
- def _is_illegal_header_value(value: bytes) -> bool:
- """
- "A field value MUST NOT contain the zero value (ASCII NUL, 0x00), line feed
- (ASCII LF, 0x0a), or carriage return (ASCII CR, 0x0d) at any position. A field
- value MUST NOT start or end with an ASCII whitespace character (ASCII SP or HTAB,
- 0x20 or 0x09)." (https://httpwg.org/specs/rfc9113.html#n-field-validity)
- """
- return bool(RE_IS_ILLEGAL_HEADER_VALUE.search(value))
- class _LockedObject(typing.Generic[T]):
- """
- A wrapper class that hides a specific object behind a lock.
- The goal here is to provide a simple way to protect access to an object
- that cannot safely be simultaneously accessed from multiple threads. The
- intended use of this class is simple: take hold of it with a context
- manager, which returns the protected object.
- """
- __slots__ = (
- "lock",
- "_obj",
- )
- def __init__(self, obj: T):
- self.lock = threading.RLock()
- self._obj = obj
- def __enter__(self) -> T:
- self.lock.acquire()
- return self._obj
- def __exit__(
- self,
- exc_type: type[BaseException] | None,
- exc_val: BaseException | None,
- exc_tb: types.TracebackType | None,
- ) -> None:
- self.lock.release()
- class HTTP2Connection(HTTPSConnection):
- def __init__(
- self, host: str, port: int | None = None, **kwargs: typing.Any
- ) -> None:
- self._h2_conn = self._new_h2_conn()
- self._h2_stream: int | None = None
- self._headers: list[tuple[bytes, bytes]] = []
- if "proxy" in kwargs or "proxy_config" in kwargs: # Defensive:
- raise NotImplementedError("Proxies aren't supported with HTTP/2")
- super().__init__(host, port, **kwargs)
- if self._tunnel_host is not None:
- raise NotImplementedError("Tunneling isn't supported with HTTP/2")
- def _new_h2_conn(self) -> _LockedObject[h2.connection.H2Connection]:
- config = h2.config.H2Configuration(client_side=True)
- return _LockedObject(h2.connection.H2Connection(config=config))
- def connect(self) -> None:
- super().connect()
- with self._h2_conn as conn:
- conn.initiate_connection()
- if data_to_send := conn.data_to_send():
- self.sock.sendall(data_to_send)
- def putrequest( # type: ignore[override]
- self,
- method: str,
- url: str,
- **kwargs: typing.Any,
- ) -> None:
- """putrequest
- This deviates from the HTTPConnection method signature since we never need to override
- sending accept-encoding headers or the host header.
- """
- if "skip_host" in kwargs:
- raise NotImplementedError("`skip_host` isn't supported")
- if "skip_accept_encoding" in kwargs:
- raise NotImplementedError("`skip_accept_encoding` isn't supported")
- self._request_url = url or "/"
- self._validate_path(url) # type: ignore[attr-defined]
- if ":" in self.host:
- authority = f"[{self.host}]:{self.port or 443}"
- else:
- authority = f"{self.host}:{self.port or 443}"
- self._headers.append((b":scheme", b"https"))
- self._headers.append((b":method", method.encode()))
- self._headers.append((b":authority", authority.encode()))
- self._headers.append((b":path", url.encode()))
- with self._h2_conn as conn:
- self._h2_stream = conn.get_next_available_stream_id()
- def putheader(self, header: str | bytes, *values: str | bytes) -> None:
- # TODO SKIPPABLE_HEADERS from urllib3 are ignored.
- header = header.encode() if isinstance(header, str) else header
- header = header.lower() # A lot of upstream code uses capitalized headers.
- if not _is_legal_header_name(header):
- raise ValueError(f"Illegal header name {str(header)}")
- for value in values:
- value = value.encode() if isinstance(value, str) else value
- if _is_illegal_header_value(value):
- raise ValueError(f"Illegal header value {str(value)}")
- self._headers.append((header, value))
- def endheaders(self, message_body: typing.Any = None) -> None: # type: ignore[override]
- if self._h2_stream is None:
- raise ConnectionError("Must call `putrequest` first.")
- with self._h2_conn as conn:
- conn.send_headers(
- stream_id=self._h2_stream,
- headers=self._headers,
- end_stream=(message_body is None),
- )
- if data_to_send := conn.data_to_send():
- self.sock.sendall(data_to_send)
- self._headers = [] # Reset headers for the next request.
- def send(self, data: typing.Any) -> None:
- """Send data to the server.
- `data` can be: `str`, `bytes`, an iterable, or file-like objects
- that support a .read() method.
- """
- if self._h2_stream is None:
- raise ConnectionError("Must call `putrequest` first.")
- with self._h2_conn as conn:
- if data_to_send := conn.data_to_send():
- self.sock.sendall(data_to_send)
- if hasattr(data, "read"): # file-like objects
- while True:
- chunk = data.read(self.blocksize)
- if not chunk:
- break
- if isinstance(chunk, str):
- chunk = chunk.encode() # pragma: no cover
- conn.send_data(self._h2_stream, chunk, end_stream=False)
- if data_to_send := conn.data_to_send():
- self.sock.sendall(data_to_send)
- conn.end_stream(self._h2_stream)
- return
- if isinstance(data, str): # str -> bytes
- data = data.encode()
- try:
- if isinstance(data, bytes):
- conn.send_data(self._h2_stream, data, end_stream=True)
- if data_to_send := conn.data_to_send():
- self.sock.sendall(data_to_send)
- else:
- for chunk in data:
- conn.send_data(self._h2_stream, chunk, end_stream=False)
- if data_to_send := conn.data_to_send():
- self.sock.sendall(data_to_send)
- conn.end_stream(self._h2_stream)
- except TypeError:
- raise TypeError(
- "`data` should be str, bytes, iterable, or file. got %r"
- % type(data)
- )
- def set_tunnel(
- self,
- host: str,
- port: int | None = None,
- headers: typing.Mapping[str, str] | None = None,
- scheme: str = "http",
- ) -> None:
- raise NotImplementedError(
- "HTTP/2 does not support setting up a tunnel through a proxy"
- )
- def getresponse( # type: ignore[override]
- self,
- ) -> HTTP2Response:
- status = None
- data = bytearray()
- with self._h2_conn as conn:
- end_stream = False
- while not end_stream:
- # TODO: Arbitrary read value.
- if received_data := self.sock.recv(65535):
- events = conn.receive_data(received_data)
- for event in events:
- if isinstance(event, h2.events.ResponseReceived):
- headers = HTTPHeaderDict()
- for header, value in event.headers:
- if header == b":status":
- status = int(value.decode())
- else:
- headers.add(
- header.decode("ascii"), value.decode("ascii")
- )
- elif isinstance(event, h2.events.DataReceived):
- data += event.data
- conn.acknowledge_received_data(
- event.flow_controlled_length, event.stream_id
- )
- elif isinstance(event, h2.events.StreamEnded):
- end_stream = True
- if data_to_send := conn.data_to_send():
- self.sock.sendall(data_to_send)
- assert status is not None
- return HTTP2Response(
- status=status,
- headers=headers,
- request_url=self._request_url,
- data=bytes(data),
- )
- def request( # type: ignore[override]
- self,
- method: str,
- url: str,
- body: _TYPE_BODY | None = None,
- headers: typing.Mapping[str, str] | None = None,
- *,
- preload_content: bool = True,
- decode_content: bool = True,
- enforce_content_length: bool = True,
- **kwargs: typing.Any,
- ) -> None:
- """Send an HTTP/2 request"""
- if "chunked" in kwargs:
- # TODO this is often present from upstream.
- # raise NotImplementedError("`chunked` isn't supported with HTTP/2")
- pass
- if self.sock is not None:
- self.sock.settimeout(self.timeout)
- self.putrequest(method, url)
- headers = headers or {}
- for k, v in headers.items():
- if k.lower() == "transfer-encoding" and v == "chunked":
- continue
- else:
- self.putheader(k, v)
- if b"user-agent" not in dict(self._headers):
- self.putheader(b"user-agent", _get_default_user_agent())
- if body:
- self.endheaders(message_body=body)
- self.send(body)
- else:
- self.endheaders()
- def close(self) -> None:
- with self._h2_conn as conn:
- try:
- conn.close_connection()
- if data := conn.data_to_send():
- self.sock.sendall(data)
- except Exception:
- pass
- # Reset all our HTTP/2 connection state.
- self._h2_conn = self._new_h2_conn()
- self._h2_stream = None
- self._headers = []
- super().close()
- class HTTP2Response(BaseHTTPResponse):
- # TODO: This is a woefully incomplete response object, but works for non-streaming.
- def __init__(
- self,
- status: int,
- headers: HTTPHeaderDict,
- request_url: str,
- data: bytes,
- decode_content: bool = False, # TODO: support decoding
- ) -> None:
- super().__init__(
- status=status,
- headers=headers,
- # Following CPython, we map HTTP versions to major * 10 + minor integers
- version=20,
- version_string="HTTP/2",
- # No reason phrase in HTTP/2
- reason=None,
- decode_content=decode_content,
- request_url=request_url,
- )
- self._data = data
- self.length_remaining = 0
- @property
- def data(self) -> bytes:
- return self._data
- def get_redirect_location(self) -> None:
- return None
- def close(self) -> None:
- pass
|