1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315 |
- import asyncio
- import codecs
- import contextlib
- import functools
- import io
- import re
- import sys
- import traceback
- import warnings
- from hashlib import md5, sha1, sha256
- from http.cookies import CookieError, Morsel, SimpleCookie
- from types import MappingProxyType, TracebackType
- from typing import (
- TYPE_CHECKING,
- Any,
- Callable,
- Dict,
- Iterable,
- List,
- Mapping,
- NamedTuple,
- Optional,
- Tuple,
- Type,
- Union,
- )
- import attr
- from multidict import CIMultiDict, CIMultiDictProxy, MultiDict, MultiDictProxy
- from yarl import URL
- from . import hdrs, helpers, http, multipart, payload
- from .abc import AbstractStreamWriter
- from .client_exceptions import (
- ClientConnectionError,
- ClientOSError,
- ClientResponseError,
- ContentTypeError,
- InvalidURL,
- ServerFingerprintMismatch,
- )
- from .compression_utils import HAS_BROTLI
- from .formdata import FormData
- from .helpers import (
- _SENTINEL,
- BaseTimerContext,
- BasicAuth,
- HeadersMixin,
- TimerNoop,
- basicauth_from_netrc,
- netrc_from_env,
- noop,
- reify,
- set_exception,
- set_result,
- )
- from .http import (
- SERVER_SOFTWARE,
- HttpVersion,
- HttpVersion10,
- HttpVersion11,
- StreamWriter,
- )
- from .log import client_logger
- from .streams import StreamReader
- from .typedefs import (
- DEFAULT_JSON_DECODER,
- JSONDecoder,
- LooseCookies,
- LooseHeaders,
- Query,
- RawHeaders,
- )
- if TYPE_CHECKING:
- import ssl
- from ssl import SSLContext
- else:
- try:
- import ssl
- from ssl import SSLContext
- except ImportError: # pragma: no cover
- ssl = None # type: ignore[assignment]
- SSLContext = object # type: ignore[misc,assignment]
- __all__ = ("ClientRequest", "ClientResponse", "RequestInfo", "Fingerprint")
- if TYPE_CHECKING:
- from .client import ClientSession
- from .connector import Connection
- from .tracing import Trace
- _CONTAINS_CONTROL_CHAR_RE = re.compile(r"[^-!#$%&'*+.^_`|~0-9a-zA-Z]")
- json_re = re.compile(r"^application/(?:[\w.+-]+?\+)?json")
- def _gen_default_accept_encoding() -> str:
- return "gzip, deflate, br" if HAS_BROTLI else "gzip, deflate"
- @attr.s(auto_attribs=True, frozen=True, slots=True)
- class ContentDisposition:
- type: Optional[str]
- parameters: "MappingProxyType[str, str]"
- filename: Optional[str]
- class _RequestInfo(NamedTuple):
- url: URL
- method: str
- headers: "CIMultiDictProxy[str]"
- real_url: URL
- class RequestInfo(_RequestInfo):
- def __new__(
- cls,
- url: URL,
- method: str,
- headers: "CIMultiDictProxy[str]",
- real_url: URL = _SENTINEL, # type: ignore[assignment]
- ) -> "RequestInfo":
- """Create a new RequestInfo instance.
- For backwards compatibility, the real_url parameter is optional.
- """
- return tuple.__new__(
- cls, (url, method, headers, url if real_url is _SENTINEL else real_url)
- )
- class Fingerprint:
- HASHFUNC_BY_DIGESTLEN = {
- 16: md5,
- 20: sha1,
- 32: sha256,
- }
- def __init__(self, fingerprint: bytes) -> None:
- digestlen = len(fingerprint)
- hashfunc = self.HASHFUNC_BY_DIGESTLEN.get(digestlen)
- if not hashfunc:
- raise ValueError("fingerprint has invalid length")
- elif hashfunc is md5 or hashfunc is sha1:
- raise ValueError("md5 and sha1 are insecure and not supported. Use sha256.")
- self._hashfunc = hashfunc
- self._fingerprint = fingerprint
- @property
- def fingerprint(self) -> bytes:
- return self._fingerprint
- def check(self, transport: asyncio.Transport) -> None:
- if not transport.get_extra_info("sslcontext"):
- return
- sslobj = transport.get_extra_info("ssl_object")
- cert = sslobj.getpeercert(binary_form=True)
- got = self._hashfunc(cert).digest()
- if got != self._fingerprint:
- host, port, *_ = transport.get_extra_info("peername")
- raise ServerFingerprintMismatch(self._fingerprint, got, host, port)
- if ssl is not None:
- SSL_ALLOWED_TYPES = (ssl.SSLContext, bool, Fingerprint, type(None))
- else: # pragma: no cover
- SSL_ALLOWED_TYPES = (bool, type(None))
- def _merge_ssl_params(
- ssl: Union["SSLContext", bool, Fingerprint],
- verify_ssl: Optional[bool],
- ssl_context: Optional["SSLContext"],
- fingerprint: Optional[bytes],
- ) -> Union["SSLContext", bool, Fingerprint]:
- if ssl is None:
- ssl = True # Double check for backwards compatibility
- if verify_ssl is not None and not verify_ssl:
- warnings.warn(
- "verify_ssl is deprecated, use ssl=False instead",
- DeprecationWarning,
- stacklevel=3,
- )
- if ssl is not True:
- raise ValueError(
- "verify_ssl, ssl_context, fingerprint and ssl "
- "parameters are mutually exclusive"
- )
- else:
- ssl = False
- if ssl_context is not None:
- warnings.warn(
- "ssl_context is deprecated, use ssl=context instead",
- DeprecationWarning,
- stacklevel=3,
- )
- if ssl is not True:
- raise ValueError(
- "verify_ssl, ssl_context, fingerprint and ssl "
- "parameters are mutually exclusive"
- )
- else:
- ssl = ssl_context
- if fingerprint is not None:
- warnings.warn(
- "fingerprint is deprecated, use ssl=Fingerprint(fingerprint) instead",
- DeprecationWarning,
- stacklevel=3,
- )
- if ssl is not True:
- raise ValueError(
- "verify_ssl, ssl_context, fingerprint and ssl "
- "parameters are mutually exclusive"
- )
- else:
- ssl = Fingerprint(fingerprint)
- if not isinstance(ssl, SSL_ALLOWED_TYPES):
- raise TypeError(
- "ssl should be SSLContext, bool, Fingerprint or None, "
- "got {!r} instead.".format(ssl)
- )
- return ssl
- _SSL_SCHEMES = frozenset(("https", "wss"))
- # ConnectionKey is a NamedTuple because it is used as a key in a dict
- # and a set in the connector. Since a NamedTuple is a tuple it uses
- # the fast native tuple __hash__ and __eq__ implementation in CPython.
- class ConnectionKey(NamedTuple):
- # the key should contain an information about used proxy / TLS
- # to prevent reusing wrong connections from a pool
- host: str
- port: Optional[int]
- is_ssl: bool
- ssl: Union[SSLContext, bool, Fingerprint]
- proxy: Optional[URL]
- proxy_auth: Optional[BasicAuth]
- proxy_headers_hash: Optional[int] # hash(CIMultiDict)
- def _is_expected_content_type(
- response_content_type: str, expected_content_type: str
- ) -> bool:
- if expected_content_type == "application/json":
- return json_re.match(response_content_type) is not None
- return expected_content_type in response_content_type
- class ClientRequest:
- GET_METHODS = {
- hdrs.METH_GET,
- hdrs.METH_HEAD,
- hdrs.METH_OPTIONS,
- hdrs.METH_TRACE,
- }
- POST_METHODS = {hdrs.METH_PATCH, hdrs.METH_POST, hdrs.METH_PUT}
- ALL_METHODS = GET_METHODS.union(POST_METHODS).union({hdrs.METH_DELETE})
- DEFAULT_HEADERS = {
- hdrs.ACCEPT: "*/*",
- hdrs.ACCEPT_ENCODING: _gen_default_accept_encoding(),
- }
- # Type of body depends on PAYLOAD_REGISTRY, which is dynamic.
- body: Any = b""
- auth = None
- response = None
- __writer = None # async task for streaming data
- _continue = None # waiter future for '100 Continue' response
- _skip_auto_headers: Optional["CIMultiDict[None]"] = None
- # N.B.
- # Adding __del__ method with self._writer closing doesn't make sense
- # because _writer is instance method, thus it keeps a reference to self.
- # Until writer has finished finalizer will not be called.
- def __init__(
- self,
- method: str,
- url: URL,
- *,
- params: Query = None,
- headers: Optional[LooseHeaders] = None,
- skip_auto_headers: Optional[Iterable[str]] = None,
- data: Any = None,
- cookies: Optional[LooseCookies] = None,
- auth: Optional[BasicAuth] = None,
- version: http.HttpVersion = http.HttpVersion11,
- compress: Union[str, bool, None] = None,
- chunked: Optional[bool] = None,
- expect100: bool = False,
- loop: Optional[asyncio.AbstractEventLoop] = None,
- response_class: Optional[Type["ClientResponse"]] = None,
- proxy: Optional[URL] = None,
- proxy_auth: Optional[BasicAuth] = None,
- timer: Optional[BaseTimerContext] = None,
- session: Optional["ClientSession"] = None,
- ssl: Union[SSLContext, bool, Fingerprint] = True,
- proxy_headers: Optional[LooseHeaders] = None,
- traces: Optional[List["Trace"]] = None,
- trust_env: bool = False,
- server_hostname: Optional[str] = None,
- ):
- if loop is None:
- loop = asyncio.get_event_loop()
- if match := _CONTAINS_CONTROL_CHAR_RE.search(method):
- raise ValueError(
- f"Method cannot contain non-token characters {method!r} "
- f"(found at least {match.group()!r})"
- )
- # URL forbids subclasses, so a simple type check is enough.
- assert type(url) is URL, url
- if proxy is not None:
- assert type(proxy) is URL, proxy
- # FIXME: session is None in tests only, need to fix tests
- # assert session is not None
- if TYPE_CHECKING:
- assert session is not None
- self._session = session
- if params:
- url = url.extend_query(params)
- self.original_url = url
- self.url = url.with_fragment(None) if url.raw_fragment else url
- self.method = method.upper()
- self.chunked = chunked
- self.compress = compress
- self.loop = loop
- self.length = None
- if response_class is None:
- real_response_class = ClientResponse
- else:
- real_response_class = response_class
- self.response_class: Type[ClientResponse] = real_response_class
- self._timer = timer if timer is not None else TimerNoop()
- self._ssl = ssl if ssl is not None else True
- self.server_hostname = server_hostname
- if loop.get_debug():
- self._source_traceback = traceback.extract_stack(sys._getframe(1))
- self.update_version(version)
- self.update_host(url)
- self.update_headers(headers)
- self.update_auto_headers(skip_auto_headers)
- self.update_cookies(cookies)
- self.update_content_encoding(data)
- self.update_auth(auth, trust_env)
- self.update_proxy(proxy, proxy_auth, proxy_headers)
- self.update_body_from_data(data)
- if data is not None or self.method not in self.GET_METHODS:
- self.update_transfer_encoding()
- self.update_expect_continue(expect100)
- self._traces = [] if traces is None else traces
- def __reset_writer(self, _: object = None) -> None:
- self.__writer = None
- @property
- def skip_auto_headers(self) -> CIMultiDict[None]:
- return self._skip_auto_headers or CIMultiDict()
- @property
- def _writer(self) -> Optional["asyncio.Task[None]"]:
- return self.__writer
- @_writer.setter
- def _writer(self, writer: "asyncio.Task[None]") -> None:
- if self.__writer is not None:
- self.__writer.remove_done_callback(self.__reset_writer)
- self.__writer = writer
- writer.add_done_callback(self.__reset_writer)
- def is_ssl(self) -> bool:
- return self.url.scheme in _SSL_SCHEMES
- @property
- def ssl(self) -> Union["SSLContext", bool, Fingerprint]:
- return self._ssl
- @property
- def connection_key(self) -> ConnectionKey:
- if proxy_headers := self.proxy_headers:
- h: Optional[int] = hash(tuple(proxy_headers.items()))
- else:
- h = None
- url = self.url
- return tuple.__new__(
- ConnectionKey,
- (
- url.raw_host or "",
- url.port,
- url.scheme in _SSL_SCHEMES,
- self._ssl,
- self.proxy,
- self.proxy_auth,
- h,
- ),
- )
- @property
- def host(self) -> str:
- ret = self.url.raw_host
- assert ret is not None
- return ret
- @property
- def port(self) -> Optional[int]:
- return self.url.port
- @property
- def request_info(self) -> RequestInfo:
- headers: CIMultiDictProxy[str] = CIMultiDictProxy(self.headers)
- # These are created on every request, so we use a NamedTuple
- # for performance reasons. We don't use the RequestInfo.__new__
- # method because it has a different signature which is provided
- # for backwards compatibility only.
- return tuple.__new__(
- RequestInfo, (self.url, self.method, headers, self.original_url)
- )
- def update_host(self, url: URL) -> None:
- """Update destination host, port and connection type (ssl)."""
- # get host/port
- if not url.raw_host:
- raise InvalidURL(url)
- # basic auth info
- if url.raw_user or url.raw_password:
- self.auth = helpers.BasicAuth(url.user or "", url.password or "")
- def update_version(self, version: Union[http.HttpVersion, str]) -> None:
- """Convert request version to two elements tuple.
- parser HTTP version '1.1' => (1, 1)
- """
- if isinstance(version, str):
- v = [part.strip() for part in version.split(".", 1)]
- try:
- version = http.HttpVersion(int(v[0]), int(v[1]))
- except ValueError:
- raise ValueError(
- f"Can not parse http version number: {version}"
- ) from None
- self.version = version
- def update_headers(self, headers: Optional[LooseHeaders]) -> None:
- """Update request headers."""
- self.headers: CIMultiDict[str] = CIMultiDict()
- # Build the host header
- host = self.url.host_port_subcomponent
- # host_port_subcomponent is None when the URL is a relative URL.
- # but we know we do not have a relative URL here.
- assert host is not None
- self.headers[hdrs.HOST] = host
- if not headers:
- return
- if isinstance(headers, (dict, MultiDictProxy, MultiDict)):
- headers = headers.items()
- for key, value in headers: # type: ignore[misc]
- # A special case for Host header
- if key in hdrs.HOST_ALL:
- self.headers[key] = value
- else:
- self.headers.add(key, value)
- def update_auto_headers(self, skip_auto_headers: Optional[Iterable[str]]) -> None:
- if skip_auto_headers is not None:
- self._skip_auto_headers = CIMultiDict(
- (hdr, None) for hdr in sorted(skip_auto_headers)
- )
- used_headers = self.headers.copy()
- used_headers.extend(self._skip_auto_headers) # type: ignore[arg-type]
- else:
- # Fast path when there are no headers to skip
- # which is the most common case.
- used_headers = self.headers
- for hdr, val in self.DEFAULT_HEADERS.items():
- if hdr not in used_headers:
- self.headers[hdr] = val
- if hdrs.USER_AGENT not in used_headers:
- self.headers[hdrs.USER_AGENT] = SERVER_SOFTWARE
- def update_cookies(self, cookies: Optional[LooseCookies]) -> None:
- """Update request cookies header."""
- if not cookies:
- return
- c = SimpleCookie()
- if hdrs.COOKIE in self.headers:
- c.load(self.headers.get(hdrs.COOKIE, ""))
- del self.headers[hdrs.COOKIE]
- if isinstance(cookies, Mapping):
- iter_cookies = cookies.items()
- else:
- iter_cookies = cookies # type: ignore[assignment]
- for name, value in iter_cookies:
- if isinstance(value, Morsel):
- # Preserve coded_value
- mrsl_val = value.get(value.key, Morsel())
- mrsl_val.set(value.key, value.value, value.coded_value)
- c[name] = mrsl_val
- else:
- c[name] = value # type: ignore[assignment]
- self.headers[hdrs.COOKIE] = c.output(header="", sep=";").strip()
- def update_content_encoding(self, data: Any) -> None:
- """Set request content encoding."""
- if not data:
- # Don't compress an empty body.
- self.compress = None
- return
- if self.headers.get(hdrs.CONTENT_ENCODING):
- if self.compress:
- raise ValueError(
- "compress can not be set if Content-Encoding header is set"
- )
- elif self.compress:
- if not isinstance(self.compress, str):
- self.compress = "deflate"
- self.headers[hdrs.CONTENT_ENCODING] = self.compress
- self.chunked = True # enable chunked, no need to deal with length
- def update_transfer_encoding(self) -> None:
- """Analyze transfer-encoding header."""
- te = self.headers.get(hdrs.TRANSFER_ENCODING, "").lower()
- if "chunked" in te:
- if self.chunked:
- raise ValueError(
- "chunked can not be set "
- 'if "Transfer-Encoding: chunked" header is set'
- )
- elif self.chunked:
- if hdrs.CONTENT_LENGTH in self.headers:
- raise ValueError(
- "chunked can not be set if Content-Length header is set"
- )
- self.headers[hdrs.TRANSFER_ENCODING] = "chunked"
- else:
- if hdrs.CONTENT_LENGTH not in self.headers:
- self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))
- def update_auth(self, auth: Optional[BasicAuth], trust_env: bool = False) -> None:
- """Set basic auth."""
- if auth is None:
- auth = self.auth
- if auth is None and trust_env and self.url.host is not None:
- netrc_obj = netrc_from_env()
- with contextlib.suppress(LookupError):
- auth = basicauth_from_netrc(netrc_obj, self.url.host)
- if auth is None:
- return
- if not isinstance(auth, helpers.BasicAuth):
- raise TypeError("BasicAuth() tuple is required instead")
- self.headers[hdrs.AUTHORIZATION] = auth.encode()
- def update_body_from_data(self, body: Any) -> None:
- if body is None:
- return
- # FormData
- if isinstance(body, FormData):
- body = body()
- try:
- body = payload.PAYLOAD_REGISTRY.get(body, disposition=None)
- except payload.LookupError:
- body = FormData(body)()
- self.body = body
- # enable chunked encoding if needed
- if not self.chunked and hdrs.CONTENT_LENGTH not in self.headers:
- if (size := body.size) is not None:
- self.headers[hdrs.CONTENT_LENGTH] = str(size)
- else:
- self.chunked = True
- # copy payload headers
- assert body.headers
- headers = self.headers
- skip_headers = self._skip_auto_headers
- for key, value in body.headers.items():
- if key in headers or (skip_headers is not None and key in skip_headers):
- continue
- headers[key] = value
- def update_expect_continue(self, expect: bool = False) -> None:
- if expect:
- self.headers[hdrs.EXPECT] = "100-continue"
- elif (
- hdrs.EXPECT in self.headers
- and self.headers[hdrs.EXPECT].lower() == "100-continue"
- ):
- expect = True
- if expect:
- self._continue = self.loop.create_future()
- def update_proxy(
- self,
- proxy: Optional[URL],
- proxy_auth: Optional[BasicAuth],
- proxy_headers: Optional[LooseHeaders],
- ) -> None:
- self.proxy = proxy
- if proxy is None:
- self.proxy_auth = None
- self.proxy_headers = None
- return
- if proxy_auth and not isinstance(proxy_auth, helpers.BasicAuth):
- raise ValueError("proxy_auth must be None or BasicAuth() tuple")
- self.proxy_auth = proxy_auth
- if proxy_headers is not None and not isinstance(
- proxy_headers, (MultiDict, MultiDictProxy)
- ):
- proxy_headers = CIMultiDict(proxy_headers)
- self.proxy_headers = proxy_headers
- async def write_bytes(
- self, writer: AbstractStreamWriter, conn: "Connection"
- ) -> None:
- """Support coroutines that yields bytes objects."""
- # 100 response
- if self._continue is not None:
- await writer.drain()
- await self._continue
- protocol = conn.protocol
- assert protocol is not None
- try:
- if isinstance(self.body, payload.Payload):
- await self.body.write(writer)
- else:
- if isinstance(self.body, (bytes, bytearray)):
- self.body = (self.body,)
- for chunk in self.body:
- await writer.write(chunk)
- except OSError as underlying_exc:
- reraised_exc = underlying_exc
- exc_is_not_timeout = underlying_exc.errno is not None or not isinstance(
- underlying_exc, asyncio.TimeoutError
- )
- if exc_is_not_timeout:
- reraised_exc = ClientOSError(
- underlying_exc.errno,
- f"Can not write request body for {self.url !s}",
- )
- set_exception(protocol, reraised_exc, underlying_exc)
- except asyncio.CancelledError:
- # Body hasn't been fully sent, so connection can't be reused.
- conn.close()
- raise
- except Exception as underlying_exc:
- set_exception(
- protocol,
- ClientConnectionError(
- f"Failed to send bytes into the underlying connection {conn !s}",
- ),
- underlying_exc,
- )
- else:
- await writer.write_eof()
- protocol.start_timeout()
- async def send(self, conn: "Connection") -> "ClientResponse":
- # Specify request target:
- # - CONNECT request must send authority form URI
- # - not CONNECT proxy must send absolute form URI
- # - most common is origin form URI
- if self.method == hdrs.METH_CONNECT:
- connect_host = self.url.host_subcomponent
- assert connect_host is not None
- path = f"{connect_host}:{self.url.port}"
- elif self.proxy and not self.is_ssl():
- path = str(self.url)
- else:
- path = self.url.raw_path_qs
- protocol = conn.protocol
- assert protocol is not None
- writer = StreamWriter(
- protocol,
- self.loop,
- on_chunk_sent=(
- functools.partial(self._on_chunk_request_sent, self.method, self.url)
- if self._traces
- else None
- ),
- on_headers_sent=(
- functools.partial(self._on_headers_request_sent, self.method, self.url)
- if self._traces
- else None
- ),
- )
- if self.compress:
- writer.enable_compression(self.compress) # type: ignore[arg-type]
- if self.chunked is not None:
- writer.enable_chunking()
- # set default content-type
- if (
- self.method in self.POST_METHODS
- and (
- self._skip_auto_headers is None
- or hdrs.CONTENT_TYPE not in self._skip_auto_headers
- )
- and hdrs.CONTENT_TYPE not in self.headers
- ):
- self.headers[hdrs.CONTENT_TYPE] = "application/octet-stream"
- v = self.version
- if hdrs.CONNECTION not in self.headers:
- if conn._connector.force_close:
- if v == HttpVersion11:
- self.headers[hdrs.CONNECTION] = "close"
- elif v == HttpVersion10:
- self.headers[hdrs.CONNECTION] = "keep-alive"
- # status + headers
- status_line = f"{self.method} {path} HTTP/{v.major}.{v.minor}"
- await writer.write_headers(status_line, self.headers)
- task: Optional["asyncio.Task[None]"]
- if self.body or self._continue is not None or protocol.writing_paused:
- coro = self.write_bytes(writer, conn)
- if sys.version_info >= (3, 12):
- # Optimization for Python 3.12, try to write
- # bytes immediately to avoid having to schedule
- # the task on the event loop.
- task = asyncio.Task(coro, loop=self.loop, eager_start=True)
- else:
- task = self.loop.create_task(coro)
- if task.done():
- task = None
- else:
- self._writer = task
- else:
- # We have nothing to write because
- # - there is no body
- # - the protocol does not have writing paused
- # - we are not waiting for a 100-continue response
- protocol.start_timeout()
- writer.set_eof()
- task = None
- response_class = self.response_class
- assert response_class is not None
- self.response = response_class(
- self.method,
- self.original_url,
- writer=task,
- continue100=self._continue,
- timer=self._timer,
- request_info=self.request_info,
- traces=self._traces,
- loop=self.loop,
- session=self._session,
- )
- return self.response
- async def close(self) -> None:
- if self.__writer is not None:
- try:
- await self.__writer
- except asyncio.CancelledError:
- if (
- sys.version_info >= (3, 11)
- and (task := asyncio.current_task())
- and task.cancelling()
- ):
- raise
- def terminate(self) -> None:
- if self.__writer is not None:
- if not self.loop.is_closed():
- self.__writer.cancel()
- self.__writer.remove_done_callback(self.__reset_writer)
- self.__writer = None
- async def _on_chunk_request_sent(self, method: str, url: URL, chunk: bytes) -> None:
- for trace in self._traces:
- await trace.send_request_chunk_sent(method, url, chunk)
- async def _on_headers_request_sent(
- self, method: str, url: URL, headers: "CIMultiDict[str]"
- ) -> None:
- for trace in self._traces:
- await trace.send_request_headers(method, url, headers)
- _CONNECTION_CLOSED_EXCEPTION = ClientConnectionError("Connection closed")
- class ClientResponse(HeadersMixin):
- # Some of these attributes are None when created,
- # but will be set by the start() method.
- # As the end user will likely never see the None values, we cheat the types below.
- # from the Status-Line of the response
- version: Optional[HttpVersion] = None # HTTP-Version
- status: int = None # type: ignore[assignment] # Status-Code
- reason: Optional[str] = None # Reason-Phrase
- content: StreamReader = None # type: ignore[assignment] # Payload stream
- _body: Optional[bytes] = None
- _headers: CIMultiDictProxy[str] = None # type: ignore[assignment]
- _history: Tuple["ClientResponse", ...] = ()
- _raw_headers: RawHeaders = None # type: ignore[assignment]
- _connection: Optional["Connection"] = None # current connection
- _cookies: Optional[SimpleCookie] = None
- _continue: Optional["asyncio.Future[bool]"] = None
- _source_traceback: Optional[traceback.StackSummary] = None
- _session: Optional["ClientSession"] = None
- # set up by ClientRequest after ClientResponse object creation
- # post-init stage allows to not change ctor signature
- _closed = True # to allow __del__ for non-initialized properly response
- _released = False
- _in_context = False
- _resolve_charset: Callable[["ClientResponse", bytes], str] = lambda *_: "utf-8"
- __writer: Optional["asyncio.Task[None]"] = None
- def __init__(
- self,
- method: str,
- url: URL,
- *,
- writer: "Optional[asyncio.Task[None]]",
- continue100: Optional["asyncio.Future[bool]"],
- timer: BaseTimerContext,
- request_info: RequestInfo,
- traces: List["Trace"],
- loop: asyncio.AbstractEventLoop,
- session: "ClientSession",
- ) -> None:
- # URL forbids subclasses, so a simple type check is enough.
- assert type(url) is URL
- self.method = method
- self._real_url = url
- self._url = url.with_fragment(None) if url.raw_fragment else url
- if writer is not None:
- self._writer = writer
- if continue100 is not None:
- self._continue = continue100
- self._request_info = request_info
- self._timer = timer if timer is not None else TimerNoop()
- self._cache: Dict[str, Any] = {}
- self._traces = traces
- self._loop = loop
- # Save reference to _resolve_charset, so that get_encoding() will still
- # work after the response has finished reading the body.
- # TODO: Fix session=None in tests (see ClientRequest.__init__).
- if session is not None:
- # store a reference to session #1985
- self._session = session
- self._resolve_charset = session._resolve_charset
- if loop.get_debug():
- self._source_traceback = traceback.extract_stack(sys._getframe(1))
- def __reset_writer(self, _: object = None) -> None:
- self.__writer = None
- @property
- def _writer(self) -> Optional["asyncio.Task[None]"]:
- """The writer task for streaming data.
- _writer is only provided for backwards compatibility
- for subclasses that may need to access it.
- """
- return self.__writer
- @_writer.setter
- def _writer(self, writer: Optional["asyncio.Task[None]"]) -> None:
- """Set the writer task for streaming data."""
- if self.__writer is not None:
- self.__writer.remove_done_callback(self.__reset_writer)
- self.__writer = writer
- if writer is None:
- return
- if writer.done():
- # The writer is already done, so we can clear it immediately.
- self.__writer = None
- else:
- writer.add_done_callback(self.__reset_writer)
- @property
- def cookies(self) -> SimpleCookie:
- if self._cookies is None:
- self._cookies = SimpleCookie()
- return self._cookies
- @cookies.setter
- def cookies(self, cookies: SimpleCookie) -> None:
- self._cookies = cookies
- @reify
- def url(self) -> URL:
- return self._url
- @reify
- def url_obj(self) -> URL:
- warnings.warn("Deprecated, use .url #1654", DeprecationWarning, stacklevel=2)
- return self._url
- @reify
- def real_url(self) -> URL:
- return self._real_url
- @reify
- def host(self) -> str:
- assert self._url.host is not None
- return self._url.host
- @reify
- def headers(self) -> "CIMultiDictProxy[str]":
- return self._headers
- @reify
- def raw_headers(self) -> RawHeaders:
- return self._raw_headers
- @reify
- def request_info(self) -> RequestInfo:
- return self._request_info
- @reify
- def content_disposition(self) -> Optional[ContentDisposition]:
- raw = self._headers.get(hdrs.CONTENT_DISPOSITION)
- if raw is None:
- return None
- disposition_type, params_dct = multipart.parse_content_disposition(raw)
- params = MappingProxyType(params_dct)
- filename = multipart.content_disposition_filename(params)
- return ContentDisposition(disposition_type, params, filename)
- def __del__(self, _warnings: Any = warnings) -> None:
- if self._closed:
- return
- if self._connection is not None:
- self._connection.release()
- self._cleanup_writer()
- if self._loop.get_debug():
- kwargs = {"source": self}
- _warnings.warn(f"Unclosed response {self!r}", ResourceWarning, **kwargs)
- context = {"client_response": self, "message": "Unclosed response"}
- if self._source_traceback:
- context["source_traceback"] = self._source_traceback
- self._loop.call_exception_handler(context)
- def __repr__(self) -> str:
- out = io.StringIO()
- ascii_encodable_url = str(self.url)
- if self.reason:
- ascii_encodable_reason = self.reason.encode(
- "ascii", "backslashreplace"
- ).decode("ascii")
- else:
- ascii_encodable_reason = "None"
- print(
- "<ClientResponse({}) [{} {}]>".format(
- ascii_encodable_url, self.status, ascii_encodable_reason
- ),
- file=out,
- )
- print(self.headers, file=out)
- return out.getvalue()
- @property
- def connection(self) -> Optional["Connection"]:
- return self._connection
- @reify
- def history(self) -> Tuple["ClientResponse", ...]:
- """A sequence of of responses, if redirects occurred."""
- return self._history
- @reify
- def links(self) -> "MultiDictProxy[MultiDictProxy[Union[str, URL]]]":
- links_str = ", ".join(self.headers.getall("link", []))
- if not links_str:
- return MultiDictProxy(MultiDict())
- links: MultiDict[MultiDictProxy[Union[str, URL]]] = MultiDict()
- for val in re.split(r",(?=\s*<)", links_str):
- match = re.match(r"\s*<(.*)>(.*)", val)
- if match is None: # pragma: no cover
- # the check exists to suppress mypy error
- continue
- url, params_str = match.groups()
- params = params_str.split(";")[1:]
- link: MultiDict[Union[str, URL]] = MultiDict()
- for param in params:
- match = re.match(r"^\s*(\S*)\s*=\s*(['\"]?)(.*?)(\2)\s*$", param, re.M)
- if match is None: # pragma: no cover
- # the check exists to suppress mypy error
- continue
- key, _, value, _ = match.groups()
- link.add(key, value)
- key = link.get("rel", url)
- link.add("url", self.url.join(URL(url)))
- links.add(str(key), MultiDictProxy(link))
- return MultiDictProxy(links)
- async def start(self, connection: "Connection") -> "ClientResponse":
- """Start response processing."""
- self._closed = False
- self._protocol = connection.protocol
- self._connection = connection
- with self._timer:
- while True:
- # read response
- try:
- protocol = self._protocol
- message, payload = await protocol.read() # type: ignore[union-attr]
- except http.HttpProcessingError as exc:
- raise ClientResponseError(
- self.request_info,
- self.history,
- status=exc.code,
- message=exc.message,
- headers=exc.headers,
- ) from exc
- if message.code < 100 or message.code > 199 or message.code == 101:
- break
- if self._continue is not None:
- set_result(self._continue, True)
- self._continue = None
- # payload eof handler
- payload.on_eof(self._response_eof)
- # response status
- self.version = message.version
- self.status = message.code
- self.reason = message.reason
- # headers
- self._headers = message.headers # type is CIMultiDictProxy
- self._raw_headers = message.raw_headers # type is Tuple[bytes, bytes]
- # payload
- self.content = payload
- # cookies
- if cookie_hdrs := self.headers.getall(hdrs.SET_COOKIE, ()):
- cookies = SimpleCookie()
- for hdr in cookie_hdrs:
- try:
- cookies.load(hdr)
- except CookieError as exc:
- client_logger.warning("Can not load response cookies: %s", exc)
- self._cookies = cookies
- return self
- def _response_eof(self) -> None:
- if self._closed:
- return
- # protocol could be None because connection could be detached
- protocol = self._connection and self._connection.protocol
- if protocol is not None and protocol.upgraded:
- return
- self._closed = True
- self._cleanup_writer()
- self._release_connection()
- @property
- def closed(self) -> bool:
- return self._closed
- def close(self) -> None:
- if not self._released:
- self._notify_content()
- self._closed = True
- if self._loop is None or self._loop.is_closed():
- return
- self._cleanup_writer()
- if self._connection is not None:
- self._connection.close()
- self._connection = None
- def release(self) -> Any:
- if not self._released:
- self._notify_content()
- self._closed = True
- self._cleanup_writer()
- self._release_connection()
- return noop()
- @property
- def ok(self) -> bool:
- """Returns ``True`` if ``status`` is less than ``400``, ``False`` if not.
- This is **not** a check for ``200 OK`` but a check that the response
- status is under 400.
- """
- return 400 > self.status
- def raise_for_status(self) -> None:
- if not self.ok:
- # reason should always be not None for a started response
- assert self.reason is not None
- # If we're in a context we can rely on __aexit__() to release as the
- # exception propagates.
- if not self._in_context:
- self.release()
- raise ClientResponseError(
- self.request_info,
- self.history,
- status=self.status,
- message=self.reason,
- headers=self.headers,
- )
- def _release_connection(self) -> None:
- if self._connection is not None:
- if self.__writer is None:
- self._connection.release()
- self._connection = None
- else:
- self.__writer.add_done_callback(lambda f: self._release_connection())
- async def _wait_released(self) -> None:
- if self.__writer is not None:
- try:
- await self.__writer
- except asyncio.CancelledError:
- if (
- sys.version_info >= (3, 11)
- and (task := asyncio.current_task())
- and task.cancelling()
- ):
- raise
- self._release_connection()
- def _cleanup_writer(self) -> None:
- if self.__writer is not None:
- self.__writer.cancel()
- self._session = None
- def _notify_content(self) -> None:
- content = self.content
- if content and content.exception() is None:
- set_exception(content, _CONNECTION_CLOSED_EXCEPTION)
- self._released = True
- async def wait_for_close(self) -> None:
- if self.__writer is not None:
- try:
- await self.__writer
- except asyncio.CancelledError:
- if (
- sys.version_info >= (3, 11)
- and (task := asyncio.current_task())
- and task.cancelling()
- ):
- raise
- self.release()
- async def read(self) -> bytes:
- """Read response payload."""
- if self._body is None:
- try:
- self._body = await self.content.read()
- for trace in self._traces:
- await trace.send_response_chunk_received(
- self.method, self.url, self._body
- )
- except BaseException:
- self.close()
- raise
- elif self._released: # Response explicitly released
- raise ClientConnectionError("Connection closed")
- protocol = self._connection and self._connection.protocol
- if protocol is None or not protocol.upgraded:
- await self._wait_released() # Underlying connection released
- return self._body
- def get_encoding(self) -> str:
- ctype = self.headers.get(hdrs.CONTENT_TYPE, "").lower()
- mimetype = helpers.parse_mimetype(ctype)
- encoding = mimetype.parameters.get("charset")
- if encoding:
- with contextlib.suppress(LookupError, ValueError):
- return codecs.lookup(encoding).name
- if mimetype.type == "application" and (
- mimetype.subtype == "json" or mimetype.subtype == "rdap"
- ):
- # RFC 7159 states that the default encoding is UTF-8.
- # RFC 7483 defines application/rdap+json
- return "utf-8"
- if self._body is None:
- raise RuntimeError(
- "Cannot compute fallback encoding of a not yet read body"
- )
- return self._resolve_charset(self, self._body)
- async def text(self, encoding: Optional[str] = None, errors: str = "strict") -> str:
- """Read response payload and decode."""
- if self._body is None:
- await self.read()
- if encoding is None:
- encoding = self.get_encoding()
- return self._body.decode(encoding, errors=errors) # type: ignore[union-attr]
- async def json(
- self,
- *,
- encoding: Optional[str] = None,
- loads: JSONDecoder = DEFAULT_JSON_DECODER,
- content_type: Optional[str] = "application/json",
- ) -> Any:
- """Read and decodes JSON response."""
- if self._body is None:
- await self.read()
- if content_type:
- ctype = self.headers.get(hdrs.CONTENT_TYPE, "").lower()
- if not _is_expected_content_type(ctype, content_type):
- raise ContentTypeError(
- self.request_info,
- self.history,
- status=self.status,
- message=(
- "Attempt to decode JSON with unexpected mimetype: %s" % ctype
- ),
- headers=self.headers,
- )
- stripped = self._body.strip() # type: ignore[union-attr]
- if not stripped:
- return None
- if encoding is None:
- encoding = self.get_encoding()
- return loads(stripped.decode(encoding))
- async def __aenter__(self) -> "ClientResponse":
- self._in_context = True
- return self
- async def __aexit__(
- self,
- exc_type: Optional[Type[BaseException]],
- exc_val: Optional[BaseException],
- exc_tb: Optional[TracebackType],
- ) -> None:
- self._in_context = False
- # similar to _RequestContextManager, we do not need to check
- # for exceptions, response object can close connection
- # if state is broken
- self.release()
- await self.wait_for_close()
|