client_reqrep.py 43 KB


  1. import asyncio
  2. import codecs
  3. import contextlib
  4. import functools
  5. import io
  6. import re
  7. import sys
  8. import traceback
  9. import warnings
  10. from hashlib import md5, sha1, sha256
  11. from http.cookies import CookieError, Morsel, SimpleCookie
  12. from types import MappingProxyType, TracebackType
  13. from typing import (
  14. TYPE_CHECKING,
  15. Any,
  16. Callable,
  17. Dict,
  18. Iterable,
  19. List,
  20. Mapping,
  21. NamedTuple,
  22. Optional,
  23. Tuple,
  24. Type,
  25. Union,
  26. )
  27. import attr
  28. from multidict import CIMultiDict, CIMultiDictProxy, MultiDict, MultiDictProxy
  29. from yarl import URL
  30. from . import hdrs, helpers, http, multipart, payload
  31. from .abc import AbstractStreamWriter
  32. from .client_exceptions import (
  33. ClientConnectionError,
  34. ClientOSError,
  35. ClientResponseError,
  36. ContentTypeError,
  37. InvalidURL,
  38. ServerFingerprintMismatch,
  39. )
  40. from .compression_utils import HAS_BROTLI
  41. from .formdata import FormData
  42. from .helpers import (
  43. _SENTINEL,
  44. BaseTimerContext,
  45. BasicAuth,
  46. HeadersMixin,
  47. TimerNoop,
  48. basicauth_from_netrc,
  49. netrc_from_env,
  50. noop,
  51. reify,
  52. set_exception,
  53. set_result,
  54. )
  55. from .http import (
  56. SERVER_SOFTWARE,
  57. HttpVersion,
  58. HttpVersion10,
  59. HttpVersion11,
  60. StreamWriter,
  61. )
  62. from .log import client_logger
  63. from .streams import StreamReader
  64. from .typedefs import (
  65. DEFAULT_JSON_DECODER,
  66. JSONDecoder,
  67. LooseCookies,
  68. LooseHeaders,
  69. Query,
  70. RawHeaders,
  71. )
  72. if TYPE_CHECKING:
  73. import ssl
  74. from ssl import SSLContext
  75. else:
  76. try:
  77. import ssl
  78. from ssl import SSLContext
  79. except ImportError: # pragma: no cover
  80. ssl = None # type: ignore[assignment]
  81. SSLContext = object # type: ignore[misc,assignment]
  82. __all__ = ("ClientRequest", "ClientResponse", "RequestInfo", "Fingerprint")
  83. if TYPE_CHECKING:
  84. from .client import ClientSession
  85. from .connector import Connection
  86. from .tracing import Trace
  87. _CONTAINS_CONTROL_CHAR_RE = re.compile(r"[^-!#$%&'*+.^_`|~0-9a-zA-Z]")
  88. json_re = re.compile(r"^application/(?:[\w.+-]+?\+)?json")
  89. def _gen_default_accept_encoding() -> str:
  90. return "gzip, deflate, br" if HAS_BROTLI else "gzip, deflate"
  91. @attr.s(auto_attribs=True, frozen=True, slots=True)
  92. class ContentDisposition:
  93. type: Optional[str]
  94. parameters: "MappingProxyType[str, str]"
  95. filename: Optional[str]
  96. class _RequestInfo(NamedTuple):
  97. url: URL
  98. method: str
  99. headers: "CIMultiDictProxy[str]"
  100. real_url: URL
  101. class RequestInfo(_RequestInfo):
  102. def __new__(
  103. cls,
  104. url: URL,
  105. method: str,
  106. headers: "CIMultiDictProxy[str]",
  107. real_url: URL = _SENTINEL, # type: ignore[assignment]
  108. ) -> "RequestInfo":
  109. """Create a new RequestInfo instance.
  110. For backwards compatibility, the real_url parameter is optional.
  111. """
  112. return tuple.__new__(
  113. cls, (url, method, headers, url if real_url is _SENTINEL else real_url)
  114. )
  115. class Fingerprint:
  116. HASHFUNC_BY_DIGESTLEN = {
  117. 16: md5,
  118. 20: sha1,
  119. 32: sha256,
  120. }
  121. def __init__(self, fingerprint: bytes) -> None:
  122. digestlen = len(fingerprint)
  123. hashfunc = self.HASHFUNC_BY_DIGESTLEN.get(digestlen)
  124. if not hashfunc:
  125. raise ValueError("fingerprint has invalid length")
  126. elif hashfunc is md5 or hashfunc is sha1:
  127. raise ValueError("md5 and sha1 are insecure and not supported. Use sha256.")
  128. self._hashfunc = hashfunc
  129. self._fingerprint = fingerprint
  130. @property
  131. def fingerprint(self) -> bytes:
  132. return self._fingerprint
  133. def check(self, transport: asyncio.Transport) -> None:
  134. if not transport.get_extra_info("sslcontext"):
  135. return
  136. sslobj = transport.get_extra_info("ssl_object")
  137. cert = sslobj.getpeercert(binary_form=True)
  138. got = self._hashfunc(cert).digest()
  139. if got != self._fingerprint:
  140. host, port, *_ = transport.get_extra_info("peername")
  141. raise ServerFingerprintMismatch(self._fingerprint, got, host, port)
  142. if ssl is not None:
  143. SSL_ALLOWED_TYPES = (ssl.SSLContext, bool, Fingerprint, type(None))
  144. else: # pragma: no cover
  145. SSL_ALLOWED_TYPES = (bool, type(None))
  146. def _merge_ssl_params(
  147. ssl: Union["SSLContext", bool, Fingerprint],
  148. verify_ssl: Optional[bool],
  149. ssl_context: Optional["SSLContext"],
  150. fingerprint: Optional[bytes],
  151. ) -> Union["SSLContext", bool, Fingerprint]:
  152. if ssl is None:
  153. ssl = True # Double check for backwards compatibility
  154. if verify_ssl is not None and not verify_ssl:
  155. warnings.warn(
  156. "verify_ssl is deprecated, use ssl=False instead",
  157. DeprecationWarning,
  158. stacklevel=3,
  159. )
  160. if ssl is not True:
  161. raise ValueError(
  162. "verify_ssl, ssl_context, fingerprint and ssl "
  163. "parameters are mutually exclusive"
  164. )
  165. else:
  166. ssl = False
  167. if ssl_context is not None:
  168. warnings.warn(
  169. "ssl_context is deprecated, use ssl=context instead",
  170. DeprecationWarning,
  171. stacklevel=3,
  172. )
  173. if ssl is not True:
  174. raise ValueError(
  175. "verify_ssl, ssl_context, fingerprint and ssl "
  176. "parameters are mutually exclusive"
  177. )
  178. else:
  179. ssl = ssl_context
  180. if fingerprint is not None:
  181. warnings.warn(
  182. "fingerprint is deprecated, use ssl=Fingerprint(fingerprint) instead",
  183. DeprecationWarning,
  184. stacklevel=3,
  185. )
  186. if ssl is not True:
  187. raise ValueError(
  188. "verify_ssl, ssl_context, fingerprint and ssl "
  189. "parameters are mutually exclusive"
  190. )
  191. else:
  192. ssl = Fingerprint(fingerprint)
  193. if not isinstance(ssl, SSL_ALLOWED_TYPES):
  194. raise TypeError(
  195. "ssl should be SSLContext, bool, Fingerprint or None, "
  196. "got {!r} instead.".format(ssl)
  197. )
  198. return ssl
  199. _SSL_SCHEMES = frozenset(("https", "wss"))
  200. # ConnectionKey is a NamedTuple because it is used as a key in a dict
  201. # and a set in the connector. Since a NamedTuple is a tuple it uses
  202. # the fast native tuple __hash__ and __eq__ implementation in CPython.
  203. class ConnectionKey(NamedTuple):
  204. # the key should contain an information about used proxy / TLS
  205. # to prevent reusing wrong connections from a pool
  206. host: str
  207. port: Optional[int]
  208. is_ssl: bool
  209. ssl: Union[SSLContext, bool, Fingerprint]
  210. proxy: Optional[URL]
  211. proxy_auth: Optional[BasicAuth]
  212. proxy_headers_hash: Optional[int] # hash(CIMultiDict)
  213. def _is_expected_content_type(
  214. response_content_type: str, expected_content_type: str
  215. ) -> bool:
  216. if expected_content_type == "application/json":
  217. return json_re.match(response_content_type) is not None
  218. return expected_content_type in response_content_type
  219. class ClientRequest:
  220. GET_METHODS = {
  221. hdrs.METH_GET,
  222. hdrs.METH_HEAD,
  223. hdrs.METH_OPTIONS,
  224. hdrs.METH_TRACE,
  225. }
  226. POST_METHODS = {hdrs.METH_PATCH, hdrs.METH_POST, hdrs.METH_PUT}
  227. ALL_METHODS = GET_METHODS.union(POST_METHODS).union({hdrs.METH_DELETE})
  228. DEFAULT_HEADERS = {
  229. hdrs.ACCEPT: "*/*",
  230. hdrs.ACCEPT_ENCODING: _gen_default_accept_encoding(),
  231. }
  232. # Type of body depends on PAYLOAD_REGISTRY, which is dynamic.
  233. body: Any = b""
  234. auth = None
  235. response = None
  236. __writer = None # async task for streaming data
  237. _continue = None # waiter future for '100 Continue' response
  238. _skip_auto_headers: Optional["CIMultiDict[None]"] = None
  239. # N.B.
  240. # Adding __del__ method with self._writer closing doesn't make sense
  241. # because _writer is instance method, thus it keeps a reference to self.
  242. # Until writer has finished finalizer will not be called.
  243. def __init__(
  244. self,
  245. method: str,
  246. url: URL,
  247. *,
  248. params: Query = None,
  249. headers: Optional[LooseHeaders] = None,
  250. skip_auto_headers: Optional[Iterable[str]] = None,
  251. data: Any = None,
  252. cookies: Optional[LooseCookies] = None,
  253. auth: Optional[BasicAuth] = None,
  254. version: http.HttpVersion = http.HttpVersion11,
  255. compress: Union[str, bool, None] = None,
  256. chunked: Optional[bool] = None,
  257. expect100: bool = False,
  258. loop: Optional[asyncio.AbstractEventLoop] = None,
  259. response_class: Optional[Type["ClientResponse"]] = None,
  260. proxy: Optional[URL] = None,
  261. proxy_auth: Optional[BasicAuth] = None,
  262. timer: Optional[BaseTimerContext] = None,
  263. session: Optional["ClientSession"] = None,
  264. ssl: Union[SSLContext, bool, Fingerprint] = True,
  265. proxy_headers: Optional[LooseHeaders] = None,
  266. traces: Optional[List["Trace"]] = None,
  267. trust_env: bool = False,
  268. server_hostname: Optional[str] = None,
  269. ):
  270. if loop is None:
  271. loop = asyncio.get_event_loop()
  272. if match := _CONTAINS_CONTROL_CHAR_RE.search(method):
  273. raise ValueError(
  274. f"Method cannot contain non-token characters {method!r} "
  275. f"(found at least {match.group()!r})"
  276. )
  277. # URL forbids subclasses, so a simple type check is enough.
  278. assert type(url) is URL, url
  279. if proxy is not None:
  280. assert type(proxy) is URL, proxy
  281. # FIXME: session is None in tests only, need to fix tests
  282. # assert session is not None
  283. if TYPE_CHECKING:
  284. assert session is not None
  285. self._session = session
  286. if params:
  287. url = url.extend_query(params)
  288. self.original_url = url
  289. self.url = url.with_fragment(None) if url.raw_fragment else url
  290. self.method = method.upper()
  291. self.chunked = chunked
  292. self.compress = compress
  293. self.loop = loop
  294. self.length = None
  295. if response_class is None:
  296. real_response_class = ClientResponse
  297. else:
  298. real_response_class = response_class
  299. self.response_class: Type[ClientResponse] = real_response_class
  300. self._timer = timer if timer is not None else TimerNoop()
  301. self._ssl = ssl if ssl is not None else True
  302. self.server_hostname = server_hostname
  303. if loop.get_debug():
  304. self._source_traceback = traceback.extract_stack(sys._getframe(1))
  305. self.update_version(version)
  306. self.update_host(url)
  307. self.update_headers(headers)
  308. self.update_auto_headers(skip_auto_headers)
  309. self.update_cookies(cookies)
  310. self.update_content_encoding(data)
  311. self.update_auth(auth, trust_env)
  312. self.update_proxy(proxy, proxy_auth, proxy_headers)
  313. self.update_body_from_data(data)
  314. if data is not None or self.method not in self.GET_METHODS:
  315. self.update_transfer_encoding()
  316. self.update_expect_continue(expect100)
  317. self._traces = [] if traces is None else traces
  318. def __reset_writer(self, _: object = None) -> None:
  319. self.__writer = None
  320. @property
  321. def skip_auto_headers(self) -> CIMultiDict[None]:
  322. return self._skip_auto_headers or CIMultiDict()
  323. @property
  324. def _writer(self) -> Optional["asyncio.Task[None]"]:
  325. return self.__writer
  326. @_writer.setter
  327. def _writer(self, writer: "asyncio.Task[None]") -> None:
  328. if self.__writer is not None:
  329. self.__writer.remove_done_callback(self.__reset_writer)
  330. self.__writer = writer
  331. writer.add_done_callback(self.__reset_writer)
  332. def is_ssl(self) -> bool:
  333. return self.url.scheme in _SSL_SCHEMES
  334. @property
  335. def ssl(self) -> Union["SSLContext", bool, Fingerprint]:
  336. return self._ssl
  337. @property
  338. def connection_key(self) -> ConnectionKey:
  339. if proxy_headers := self.proxy_headers:
  340. h: Optional[int] = hash(tuple(proxy_headers.items()))
  341. else:
  342. h = None
  343. url = self.url
  344. return tuple.__new__(
  345. ConnectionKey,
  346. (
  347. url.raw_host or "",
  348. url.port,
  349. url.scheme in _SSL_SCHEMES,
  350. self._ssl,
  351. self.proxy,
  352. self.proxy_auth,
  353. h,
  354. ),
  355. )
  356. @property
  357. def host(self) -> str:
  358. ret = self.url.raw_host
  359. assert ret is not None
  360. return ret
  361. @property
  362. def port(self) -> Optional[int]:
  363. return self.url.port
  364. @property
  365. def request_info(self) -> RequestInfo:
  366. headers: CIMultiDictProxy[str] = CIMultiDictProxy(self.headers)
  367. # These are created on every request, so we use a NamedTuple
  368. # for performance reasons. We don't use the RequestInfo.__new__
  369. # method because it has a different signature which is provided
  370. # for backwards compatibility only.
  371. return tuple.__new__(
  372. RequestInfo, (self.url, self.method, headers, self.original_url)
  373. )
  374. def update_host(self, url: URL) -> None:
  375. """Update destination host, port and connection type (ssl)."""
  376. # get host/port
  377. if not url.raw_host:
  378. raise InvalidURL(url)
  379. # basic auth info
  380. if url.raw_user or url.raw_password:
  381. self.auth = helpers.BasicAuth(url.user or "", url.password or "")
  382. def update_version(self, version: Union[http.HttpVersion, str]) -> None:
  383. """Convert request version to two elements tuple.
  384. parser HTTP version '1.1' => (1, 1)
  385. """
  386. if isinstance(version, str):
  387. v = [part.strip() for part in version.split(".", 1)]
  388. try:
  389. version = http.HttpVersion(int(v[0]), int(v[1]))
  390. except ValueError:
  391. raise ValueError(
  392. f"Can not parse http version number: {version}"
  393. ) from None
  394. self.version = version
  395. def update_headers(self, headers: Optional[LooseHeaders]) -> None:
  396. """Update request headers."""
  397. self.headers: CIMultiDict[str] = CIMultiDict()
  398. # Build the host header
  399. host = self.url.host_port_subcomponent
  400. # host_port_subcomponent is None when the URL is a relative URL.
  401. # but we know we do not have a relative URL here.
  402. assert host is not None
  403. self.headers[hdrs.HOST] = host
  404. if not headers:
  405. return
  406. if isinstance(headers, (dict, MultiDictProxy, MultiDict)):
  407. headers = headers.items()
  408. for key, value in headers: # type: ignore[misc]
  409. # A special case for Host header
  410. if key in hdrs.HOST_ALL:
  411. self.headers[key] = value
  412. else:
  413. self.headers.add(key, value)
  414. def update_auto_headers(self, skip_auto_headers: Optional[Iterable[str]]) -> None:
  415. if skip_auto_headers is not None:
  416. self._skip_auto_headers = CIMultiDict(
  417. (hdr, None) for hdr in sorted(skip_auto_headers)
  418. )
  419. used_headers = self.headers.copy()
  420. used_headers.extend(self._skip_auto_headers) # type: ignore[arg-type]
  421. else:
  422. # Fast path when there are no headers to skip
  423. # which is the most common case.
  424. used_headers = self.headers
  425. for hdr, val in self.DEFAULT_HEADERS.items():
  426. if hdr not in used_headers:
  427. self.headers[hdr] = val
  428. if hdrs.USER_AGENT not in used_headers:
  429. self.headers[hdrs.USER_AGENT] = SERVER_SOFTWARE
  430. def update_cookies(self, cookies: Optional[LooseCookies]) -> None:
  431. """Update request cookies header."""
  432. if not cookies:
  433. return
  434. c = SimpleCookie()
  435. if hdrs.COOKIE in self.headers:
  436. c.load(self.headers.get(hdrs.COOKIE, ""))
  437. del self.headers[hdrs.COOKIE]
  438. if isinstance(cookies, Mapping):
  439. iter_cookies = cookies.items()
  440. else:
  441. iter_cookies = cookies # type: ignore[assignment]
  442. for name, value in iter_cookies:
  443. if isinstance(value, Morsel):
  444. # Preserve coded_value
  445. mrsl_val = value.get(value.key, Morsel())
  446. mrsl_val.set(value.key, value.value, value.coded_value)
  447. c[name] = mrsl_val
  448. else:
  449. c[name] = value # type: ignore[assignment]
  450. self.headers[hdrs.COOKIE] = c.output(header="", sep=";").strip()
  451. def update_content_encoding(self, data: Any) -> None:
  452. """Set request content encoding."""
  453. if not data:
  454. # Don't compress an empty body.
  455. self.compress = None
  456. return
  457. if self.headers.get(hdrs.CONTENT_ENCODING):
  458. if self.compress:
  459. raise ValueError(
  460. "compress can not be set if Content-Encoding header is set"
  461. )
  462. elif self.compress:
  463. if not isinstance(self.compress, str):
  464. self.compress = "deflate"
  465. self.headers[hdrs.CONTENT_ENCODING] = self.compress
  466. self.chunked = True # enable chunked, no need to deal with length
  467. def update_transfer_encoding(self) -> None:
  468. """Analyze transfer-encoding header."""
  469. te = self.headers.get(hdrs.TRANSFER_ENCODING, "").lower()
  470. if "chunked" in te:
  471. if self.chunked:
  472. raise ValueError(
  473. "chunked can not be set "
  474. 'if "Transfer-Encoding: chunked" header is set'
  475. )
  476. elif self.chunked:
  477. if hdrs.CONTENT_LENGTH in self.headers:
  478. raise ValueError(
  479. "chunked can not be set if Content-Length header is set"
  480. )
  481. self.headers[hdrs.TRANSFER_ENCODING] = "chunked"
  482. else:
  483. if hdrs.CONTENT_LENGTH not in self.headers:
  484. self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))
  485. def update_auth(self, auth: Optional[BasicAuth], trust_env: bool = False) -> None:
  486. """Set basic auth."""
  487. if auth is None:
  488. auth = self.auth
  489. if auth is None and trust_env and self.url.host is not None:
  490. netrc_obj = netrc_from_env()
  491. with contextlib.suppress(LookupError):
  492. auth = basicauth_from_netrc(netrc_obj, self.url.host)
  493. if auth is None:
  494. return
  495. if not isinstance(auth, helpers.BasicAuth):
  496. raise TypeError("BasicAuth() tuple is required instead")
  497. self.headers[hdrs.AUTHORIZATION] = auth.encode()
  498. def update_body_from_data(self, body: Any) -> None:
  499. if body is None:
  500. return
  501. # FormData
  502. if isinstance(body, FormData):
  503. body = body()
  504. try:
  505. body = payload.PAYLOAD_REGISTRY.get(body, disposition=None)
  506. except payload.LookupError:
  507. body = FormData(body)()
  508. self.body = body
  509. # enable chunked encoding if needed
  510. if not self.chunked and hdrs.CONTENT_LENGTH not in self.headers:
  511. if (size := body.size) is not None:
  512. self.headers[hdrs.CONTENT_LENGTH] = str(size)
  513. else:
  514. self.chunked = True
  515. # copy payload headers
  516. assert body.headers
  517. headers = self.headers
  518. skip_headers = self._skip_auto_headers
  519. for key, value in body.headers.items():
  520. if key in headers or (skip_headers is not None and key in skip_headers):
  521. continue
  522. headers[key] = value
  523. def update_expect_continue(self, expect: bool = False) -> None:
  524. if expect:
  525. self.headers[hdrs.EXPECT] = "100-continue"
  526. elif (
  527. hdrs.EXPECT in self.headers
  528. and self.headers[hdrs.EXPECT].lower() == "100-continue"
  529. ):
  530. expect = True
  531. if expect:
  532. self._continue = self.loop.create_future()
  533. def update_proxy(
  534. self,
  535. proxy: Optional[URL],
  536. proxy_auth: Optional[BasicAuth],
  537. proxy_headers: Optional[LooseHeaders],
  538. ) -> None:
  539. self.proxy = proxy
  540. if proxy is None:
  541. self.proxy_auth = None
  542. self.proxy_headers = None
  543. return
  544. if proxy_auth and not isinstance(proxy_auth, helpers.BasicAuth):
  545. raise ValueError("proxy_auth must be None or BasicAuth() tuple")
  546. self.proxy_auth = proxy_auth
  547. if proxy_headers is not None and not isinstance(
  548. proxy_headers, (MultiDict, MultiDictProxy)
  549. ):
  550. proxy_headers = CIMultiDict(proxy_headers)
  551. self.proxy_headers = proxy_headers
  552. async def write_bytes(
  553. self, writer: AbstractStreamWriter, conn: "Connection"
  554. ) -> None:
  555. """Support coroutines that yields bytes objects."""
  556. # 100 response
  557. if self._continue is not None:
  558. await writer.drain()
  559. await self._continue
  560. protocol = conn.protocol
  561. assert protocol is not None
  562. try:
  563. if isinstance(self.body, payload.Payload):
  564. await self.body.write(writer)
  565. else:
  566. if isinstance(self.body, (bytes, bytearray)):
  567. self.body = (self.body,)
  568. for chunk in self.body:
  569. await writer.write(chunk)
  570. except OSError as underlying_exc:
  571. reraised_exc = underlying_exc
  572. exc_is_not_timeout = underlying_exc.errno is not None or not isinstance(
  573. underlying_exc, asyncio.TimeoutError
  574. )
  575. if exc_is_not_timeout:
  576. reraised_exc = ClientOSError(
  577. underlying_exc.errno,
  578. f"Can not write request body for {self.url !s}",
  579. )
  580. set_exception(protocol, reraised_exc, underlying_exc)
  581. except asyncio.CancelledError:
  582. # Body hasn't been fully sent, so connection can't be reused.
  583. conn.close()
  584. raise
  585. except Exception as underlying_exc:
  586. set_exception(
  587. protocol,
  588. ClientConnectionError(
  589. f"Failed to send bytes into the underlying connection {conn !s}",
  590. ),
  591. underlying_exc,
  592. )
  593. else:
  594. await writer.write_eof()
  595. protocol.start_timeout()
  596. async def send(self, conn: "Connection") -> "ClientResponse":
  597. # Specify request target:
  598. # - CONNECT request must send authority form URI
  599. # - not CONNECT proxy must send absolute form URI
  600. # - most common is origin form URI
  601. if self.method == hdrs.METH_CONNECT:
  602. connect_host = self.url.host_subcomponent
  603. assert connect_host is not None
  604. path = f"{connect_host}:{self.url.port}"
  605. elif self.proxy and not self.is_ssl():
  606. path = str(self.url)
  607. else:
  608. path = self.url.raw_path_qs
  609. protocol = conn.protocol
  610. assert protocol is not None
  611. writer = StreamWriter(
  612. protocol,
  613. self.loop,
  614. on_chunk_sent=(
  615. functools.partial(self._on_chunk_request_sent, self.method, self.url)
  616. if self._traces
  617. else None
  618. ),
  619. on_headers_sent=(
  620. functools.partial(self._on_headers_request_sent, self.method, self.url)
  621. if self._traces
  622. else None
  623. ),
  624. )
  625. if self.compress:
  626. writer.enable_compression(self.compress) # type: ignore[arg-type]
  627. if self.chunked is not None:
  628. writer.enable_chunking()
  629. # set default content-type
  630. if (
  631. self.method in self.POST_METHODS
  632. and (
  633. self._skip_auto_headers is None
  634. or hdrs.CONTENT_TYPE not in self._skip_auto_headers
  635. )
  636. and hdrs.CONTENT_TYPE not in self.headers
  637. ):
  638. self.headers[hdrs.CONTENT_TYPE] = "application/octet-stream"
  639. v = self.version
  640. if hdrs.CONNECTION not in self.headers:
  641. if conn._connector.force_close:
  642. if v == HttpVersion11:
  643. self.headers[hdrs.CONNECTION] = "close"
  644. elif v == HttpVersion10:
  645. self.headers[hdrs.CONNECTION] = "keep-alive"
  646. # status + headers
  647. status_line = f"{self.method} {path} HTTP/{v.major}.{v.minor}"
  648. await writer.write_headers(status_line, self.headers)
  649. task: Optional["asyncio.Task[None]"]
  650. if self.body or self._continue is not None or protocol.writing_paused:
  651. coro = self.write_bytes(writer, conn)
  652. if sys.version_info >= (3, 12):
  653. # Optimization for Python 3.12, try to write
  654. # bytes immediately to avoid having to schedule
  655. # the task on the event loop.
  656. task = asyncio.Task(coro, loop=self.loop, eager_start=True)
  657. else:
  658. task = self.loop.create_task(coro)
  659. if task.done():
  660. task = None
  661. else:
  662. self._writer = task
  663. else:
  664. # We have nothing to write because
  665. # - there is no body
  666. # - the protocol does not have writing paused
  667. # - we are not waiting for a 100-continue response
  668. protocol.start_timeout()
  669. writer.set_eof()
  670. task = None
  671. response_class = self.response_class
  672. assert response_class is not None
  673. self.response = response_class(
  674. self.method,
  675. self.original_url,
  676. writer=task,
  677. continue100=self._continue,
  678. timer=self._timer,
  679. request_info=self.request_info,
  680. traces=self._traces,
  681. loop=self.loop,
  682. session=self._session,
  683. )
  684. return self.response
  685. async def close(self) -> None:
  686. if self.__writer is not None:
  687. try:
  688. await self.__writer
  689. except asyncio.CancelledError:
  690. if (
  691. sys.version_info >= (3, 11)
  692. and (task := asyncio.current_task())
  693. and task.cancelling()
  694. ):
  695. raise
  696. def terminate(self) -> None:
  697. if self.__writer is not None:
  698. if not self.loop.is_closed():
  699. self.__writer.cancel()
  700. self.__writer.remove_done_callback(self.__reset_writer)
  701. self.__writer = None
  702. async def _on_chunk_request_sent(self, method: str, url: URL, chunk: bytes) -> None:
  703. for trace in self._traces:
  704. await trace.send_request_chunk_sent(method, url, chunk)
  705. async def _on_headers_request_sent(
  706. self, method: str, url: URL, headers: "CIMultiDict[str]"
  707. ) -> None:
  708. for trace in self._traces:
  709. await trace.send_request_headers(method, url, headers)
  710. _CONNECTION_CLOSED_EXCEPTION = ClientConnectionError("Connection closed")
  711. class ClientResponse(HeadersMixin):
  712. # Some of these attributes are None when created,
  713. # but will be set by the start() method.
  714. # As the end user will likely never see the None values, we cheat the types below.
  715. # from the Status-Line of the response
  716. version: Optional[HttpVersion] = None # HTTP-Version
  717. status: int = None # type: ignore[assignment] # Status-Code
  718. reason: Optional[str] = None # Reason-Phrase
  719. content: StreamReader = None # type: ignore[assignment] # Payload stream
  720. _body: Optional[bytes] = None
  721. _headers: CIMultiDictProxy[str] = None # type: ignore[assignment]
  722. _history: Tuple["ClientResponse", ...] = ()
  723. _raw_headers: RawHeaders = None # type: ignore[assignment]
  724. _connection: Optional["Connection"] = None # current connection
  725. _cookies: Optional[SimpleCookie] = None
  726. _continue: Optional["asyncio.Future[bool]"] = None
  727. _source_traceback: Optional[traceback.StackSummary] = None
  728. _session: Optional["ClientSession"] = None
  729. # set up by ClientRequest after ClientResponse object creation
  730. # post-init stage allows to not change ctor signature
  731. _closed = True # to allow __del__ for non-initialized properly response
  732. _released = False
  733. _in_context = False
  734. _resolve_charset: Callable[["ClientResponse", bytes], str] = lambda *_: "utf-8"
  735. __writer: Optional["asyncio.Task[None]"] = None
  736. def __init__(
  737. self,
  738. method: str,
  739. url: URL,
  740. *,
  741. writer: "Optional[asyncio.Task[None]]",
  742. continue100: Optional["asyncio.Future[bool]"],
  743. timer: BaseTimerContext,
  744. request_info: RequestInfo,
  745. traces: List["Trace"],
  746. loop: asyncio.AbstractEventLoop,
  747. session: "ClientSession",
  748. ) -> None:
  749. # URL forbids subclasses, so a simple type check is enough.
  750. assert type(url) is URL
  751. self.method = method
  752. self._real_url = url
  753. self._url = url.with_fragment(None) if url.raw_fragment else url
  754. if writer is not None:
  755. self._writer = writer
  756. if continue100 is not None:
  757. self._continue = continue100
  758. self._request_info = request_info
  759. self._timer = timer if timer is not None else TimerNoop()
  760. self._cache: Dict[str, Any] = {}
  761. self._traces = traces
  762. self._loop = loop
  763. # Save reference to _resolve_charset, so that get_encoding() will still
  764. # work after the response has finished reading the body.
  765. # TODO: Fix session=None in tests (see ClientRequest.__init__).
  766. if session is not None:
  767. # store a reference to session #1985
  768. self._session = session
  769. self._resolve_charset = session._resolve_charset
  770. if loop.get_debug():
  771. self._source_traceback = traceback.extract_stack(sys._getframe(1))
  772. def __reset_writer(self, _: object = None) -> None:
  773. self.__writer = None
  774. @property
  775. def _writer(self) -> Optional["asyncio.Task[None]"]:
  776. """The writer task for streaming data.
  777. _writer is only provided for backwards compatibility
  778. for subclasses that may need to access it.
  779. """
  780. return self.__writer
  781. @_writer.setter
  782. def _writer(self, writer: Optional["asyncio.Task[None]"]) -> None:
  783. """Set the writer task for streaming data."""
  784. if self.__writer is not None:
  785. self.__writer.remove_done_callback(self.__reset_writer)
  786. self.__writer = writer
  787. if writer is None:
  788. return
  789. if writer.done():
  790. # The writer is already done, so we can clear it immediately.
  791. self.__writer = None
  792. else:
  793. writer.add_done_callback(self.__reset_writer)
  794. @property
  795. def cookies(self) -> SimpleCookie:
  796. if self._cookies is None:
  797. self._cookies = SimpleCookie()
  798. return self._cookies
  799. @cookies.setter
  800. def cookies(self, cookies: SimpleCookie) -> None:
  801. self._cookies = cookies
  802. @reify
  803. def url(self) -> URL:
  804. return self._url
  805. @reify
  806. def url_obj(self) -> URL:
  807. warnings.warn("Deprecated, use .url #1654", DeprecationWarning, stacklevel=2)
  808. return self._url
  809. @reify
  810. def real_url(self) -> URL:
  811. return self._real_url
  812. @reify
  813. def host(self) -> str:
  814. assert self._url.host is not None
  815. return self._url.host
  816. @reify
  817. def headers(self) -> "CIMultiDictProxy[str]":
  818. return self._headers
  819. @reify
  820. def raw_headers(self) -> RawHeaders:
  821. return self._raw_headers
  822. @reify
  823. def request_info(self) -> RequestInfo:
  824. return self._request_info
  825. @reify
  826. def content_disposition(self) -> Optional[ContentDisposition]:
  827. raw = self._headers.get(hdrs.CONTENT_DISPOSITION)
  828. if raw is None:
  829. return None
  830. disposition_type, params_dct = multipart.parse_content_disposition(raw)
  831. params = MappingProxyType(params_dct)
  832. filename = multipart.content_disposition_filename(params)
  833. return ContentDisposition(disposition_type, params, filename)
  834. def __del__(self, _warnings: Any = warnings) -> None:
  835. if self._closed:
  836. return
  837. if self._connection is not None:
  838. self._connection.release()
  839. self._cleanup_writer()
  840. if self._loop.get_debug():
  841. kwargs = {"source": self}
  842. _warnings.warn(f"Unclosed response {self!r}", ResourceWarning, **kwargs)
  843. context = {"client_response": self, "message": "Unclosed response"}
  844. if self._source_traceback:
  845. context["source_traceback"] = self._source_traceback
  846. self._loop.call_exception_handler(context)
  847. def __repr__(self) -> str:
  848. out = io.StringIO()
  849. ascii_encodable_url = str(self.url)
  850. if self.reason:
  851. ascii_encodable_reason = self.reason.encode(
  852. "ascii", "backslashreplace"
  853. ).decode("ascii")
  854. else:
  855. ascii_encodable_reason = "None"
  856. print(
  857. "<ClientResponse({}) [{} {}]>".format(
  858. ascii_encodable_url, self.status, ascii_encodable_reason
  859. ),
  860. file=out,
  861. )
  862. print(self.headers, file=out)
  863. return out.getvalue()
  864. @property
  865. def connection(self) -> Optional["Connection"]:
  866. return self._connection
  867. @reify
  868. def history(self) -> Tuple["ClientResponse", ...]:
  869. """A sequence of of responses, if redirects occurred."""
  870. return self._history
  871. @reify
  872. def links(self) -> "MultiDictProxy[MultiDictProxy[Union[str, URL]]]":
  873. links_str = ", ".join(self.headers.getall("link", []))
  874. if not links_str:
  875. return MultiDictProxy(MultiDict())
  876. links: MultiDict[MultiDictProxy[Union[str, URL]]] = MultiDict()
  877. for val in re.split(r",(?=\s*<)", links_str):
  878. match = re.match(r"\s*<(.*)>(.*)", val)
  879. if match is None: # pragma: no cover
  880. # the check exists to suppress mypy error
  881. continue
  882. url, params_str = match.groups()
  883. params = params_str.split(";")[1:]
  884. link: MultiDict[Union[str, URL]] = MultiDict()
  885. for param in params:
  886. match = re.match(r"^\s*(\S*)\s*=\s*(['\"]?)(.*?)(\2)\s*$", param, re.M)
  887. if match is None: # pragma: no cover
  888. # the check exists to suppress mypy error
  889. continue
  890. key, _, value, _ = match.groups()
  891. link.add(key, value)
  892. key = link.get("rel", url)
  893. link.add("url", self.url.join(URL(url)))
  894. links.add(str(key), MultiDictProxy(link))
  895. return MultiDictProxy(links)
  896. async def start(self, connection: "Connection") -> "ClientResponse":
  897. """Start response processing."""
  898. self._closed = False
  899. self._protocol = connection.protocol
  900. self._connection = connection
  901. with self._timer:
  902. while True:
  903. # read response
  904. try:
  905. protocol = self._protocol
  906. message, payload = await protocol.read() # type: ignore[union-attr]
  907. except http.HttpProcessingError as exc:
  908. raise ClientResponseError(
  909. self.request_info,
  910. self.history,
  911. status=exc.code,
  912. message=exc.message,
  913. headers=exc.headers,
  914. ) from exc
  915. if message.code < 100 or message.code > 199 or message.code == 101:
  916. break
  917. if self._continue is not None:
  918. set_result(self._continue, True)
  919. self._continue = None
  920. # payload eof handler
  921. payload.on_eof(self._response_eof)
  922. # response status
  923. self.version = message.version
  924. self.status = message.code
  925. self.reason = message.reason
  926. # headers
  927. self._headers = message.headers # type is CIMultiDictProxy
  928. self._raw_headers = message.raw_headers # type is Tuple[bytes, bytes]
  929. # payload
  930. self.content = payload
  931. # cookies
  932. if cookie_hdrs := self.headers.getall(hdrs.SET_COOKIE, ()):
  933. cookies = SimpleCookie()
  934. for hdr in cookie_hdrs:
  935. try:
  936. cookies.load(hdr)
  937. except CookieError as exc:
  938. client_logger.warning("Can not load response cookies: %s", exc)
  939. self._cookies = cookies
  940. return self
  941. def _response_eof(self) -> None:
  942. if self._closed:
  943. return
  944. # protocol could be None because connection could be detached
  945. protocol = self._connection and self._connection.protocol
  946. if protocol is not None and protocol.upgraded:
  947. return
  948. self._closed = True
  949. self._cleanup_writer()
  950. self._release_connection()
  951. @property
  952. def closed(self) -> bool:
  953. return self._closed
  954. def close(self) -> None:
  955. if not self._released:
  956. self._notify_content()
  957. self._closed = True
  958. if self._loop is None or self._loop.is_closed():
  959. return
  960. self._cleanup_writer()
  961. if self._connection is not None:
  962. self._connection.close()
  963. self._connection = None
  964. def release(self) -> Any:
  965. if not self._released:
  966. self._notify_content()
  967. self._closed = True
  968. self._cleanup_writer()
  969. self._release_connection()
  970. return noop()
  971. @property
  972. def ok(self) -> bool:
  973. """Returns ``True`` if ``status`` is less than ``400``, ``False`` if not.
  974. This is **not** a check for ``200 OK`` but a check that the response
  975. status is under 400.
  976. """
  977. return 400 > self.status
  978. def raise_for_status(self) -> None:
  979. if not self.ok:
  980. # reason should always be not None for a started response
  981. assert self.reason is not None
  982. # If we're in a context we can rely on __aexit__() to release as the
  983. # exception propagates.
  984. if not self._in_context:
  985. self.release()
  986. raise ClientResponseError(
  987. self.request_info,
  988. self.history,
  989. status=self.status,
  990. message=self.reason,
  991. headers=self.headers,
  992. )
  993. def _release_connection(self) -> None:
  994. if self._connection is not None:
  995. if self.__writer is None:
  996. self._connection.release()
  997. self._connection = None
  998. else:
  999. self.__writer.add_done_callback(lambda f: self._release_connection())
  1000. async def _wait_released(self) -> None:
  1001. if self.__writer is not None:
  1002. try:
  1003. await self.__writer
  1004. except asyncio.CancelledError:
  1005. if (
  1006. sys.version_info >= (3, 11)
  1007. and (task := asyncio.current_task())
  1008. and task.cancelling()
  1009. ):
  1010. raise
  1011. self._release_connection()
  1012. def _cleanup_writer(self) -> None:
  1013. if self.__writer is not None:
  1014. self.__writer.cancel()
  1015. self._session = None
  1016. def _notify_content(self) -> None:
  1017. content = self.content
  1018. if content and content.exception() is None:
  1019. set_exception(content, _CONNECTION_CLOSED_EXCEPTION)
  1020. self._released = True
  1021. async def wait_for_close(self) -> None:
  1022. if self.__writer is not None:
  1023. try:
  1024. await self.__writer
  1025. except asyncio.CancelledError:
  1026. if (
  1027. sys.version_info >= (3, 11)
  1028. and (task := asyncio.current_task())
  1029. and task.cancelling()
  1030. ):
  1031. raise
  1032. self.release()
  1033. async def read(self) -> bytes:
  1034. """Read response payload."""
  1035. if self._body is None:
  1036. try:
  1037. self._body = await self.content.read()
  1038. for trace in self._traces:
  1039. await trace.send_response_chunk_received(
  1040. self.method, self.url, self._body
  1041. )
  1042. except BaseException:
  1043. self.close()
  1044. raise
  1045. elif self._released: # Response explicitly released
  1046. raise ClientConnectionError("Connection closed")
  1047. protocol = self._connection and self._connection.protocol
  1048. if protocol is None or not protocol.upgraded:
  1049. await self._wait_released() # Underlying connection released
  1050. return self._body
  1051. def get_encoding(self) -> str:
  1052. ctype = self.headers.get(hdrs.CONTENT_TYPE, "").lower()
  1053. mimetype = helpers.parse_mimetype(ctype)
  1054. encoding = mimetype.parameters.get("charset")
  1055. if encoding:
  1056. with contextlib.suppress(LookupError, ValueError):
  1057. return codecs.lookup(encoding).name
  1058. if mimetype.type == "application" and (
  1059. mimetype.subtype == "json" or mimetype.subtype == "rdap"
  1060. ):
  1061. # RFC 7159 states that the default encoding is UTF-8.
  1062. # RFC 7483 defines application/rdap+json
  1063. return "utf-8"
  1064. if self._body is None:
  1065. raise RuntimeError(
  1066. "Cannot compute fallback encoding of a not yet read body"
  1067. )
  1068. return self._resolve_charset(self, self._body)
  1069. async def text(self, encoding: Optional[str] = None, errors: str = "strict") -> str:
  1070. """Read response payload and decode."""
  1071. if self._body is None:
  1072. await self.read()
  1073. if encoding is None:
  1074. encoding = self.get_encoding()
  1075. return self._body.decode(encoding, errors=errors) # type: ignore[union-attr]
  1076. async def json(
  1077. self,
  1078. *,
  1079. encoding: Optional[str] = None,
  1080. loads: JSONDecoder = DEFAULT_JSON_DECODER,
  1081. content_type: Optional[str] = "application/json",
  1082. ) -> Any:
  1083. """Read and decodes JSON response."""
  1084. if self._body is None:
  1085. await self.read()
  1086. if content_type:
  1087. ctype = self.headers.get(hdrs.CONTENT_TYPE, "").lower()
  1088. if not _is_expected_content_type(ctype, content_type):
  1089. raise ContentTypeError(
  1090. self.request_info,
  1091. self.history,
  1092. status=self.status,
  1093. message=(
  1094. "Attempt to decode JSON with unexpected mimetype: %s" % ctype
  1095. ),
  1096. headers=self.headers,
  1097. )
  1098. stripped = self._body.strip() # type: ignore[union-attr]
  1099. if not stripped:
  1100. return None
  1101. if encoding is None:
  1102. encoding = self.get_encoding()
  1103. return loads(stripped.decode(encoding))
  1104. async def __aenter__(self) -> "ClientResponse":
  1105. self._in_context = True
  1106. return self
  1107. async def __aexit__(
  1108. self,
  1109. exc_type: Optional[Type[BaseException]],
  1110. exc_val: Optional[BaseException],
  1111. exc_tb: Optional[TracebackType],
  1112. ) -> None:
  1113. self._in_context = False
  1114. # similar to _RequestContextManager, we do not need to check
  1115. # for exceptions, response object can close connection
  1116. # if state is broken
  1117. self.release()
  1118. await self.wait_for_close()