12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652 |
- import asyncio
- import functools
- import random
- import socket
- import sys
- import traceback
- import warnings
- from collections import OrderedDict, defaultdict, deque
- from contextlib import suppress
- from http import HTTPStatus
- from itertools import chain, cycle, islice
- from time import monotonic
- from types import TracebackType
- from typing import (
- TYPE_CHECKING,
- Any,
- Awaitable,
- Callable,
- DefaultDict,
- Deque,
- Dict,
- Iterator,
- List,
- Literal,
- Optional,
- Sequence,
- Set,
- Tuple,
- Type,
- Union,
- cast,
- )
- import aiohappyeyeballs
- from . import hdrs, helpers
- from .abc import AbstractResolver, ResolveResult
- from .client_exceptions import (
- ClientConnectionError,
- ClientConnectorCertificateError,
- ClientConnectorDNSError,
- ClientConnectorError,
- ClientConnectorSSLError,
- ClientHttpProxyError,
- ClientProxyConnectionError,
- ServerFingerprintMismatch,
- UnixClientConnectorError,
- cert_errors,
- ssl_errors,
- )
- from .client_proto import ResponseHandler
- from .client_reqrep import ClientRequest, Fingerprint, _merge_ssl_params
- from .helpers import (
- ceil_timeout,
- is_ip_address,
- noop,
- sentinel,
- set_exception,
- set_result,
- )
- from .resolver import DefaultResolver
- if TYPE_CHECKING:
- import ssl
- SSLContext = ssl.SSLContext
- else:
- try:
- import ssl
- SSLContext = ssl.SSLContext
- except ImportError: # pragma: no cover
- ssl = None # type: ignore[assignment]
- SSLContext = object # type: ignore[misc,assignment]
- EMPTY_SCHEMA_SET = frozenset({""})
- HTTP_SCHEMA_SET = frozenset({"http", "https"})
- WS_SCHEMA_SET = frozenset({"ws", "wss"})
- HTTP_AND_EMPTY_SCHEMA_SET = HTTP_SCHEMA_SET | EMPTY_SCHEMA_SET
- HIGH_LEVEL_SCHEMA_SET = HTTP_AND_EMPTY_SCHEMA_SET | WS_SCHEMA_SET
- NEEDS_CLEANUP_CLOSED = (3, 13, 0) <= sys.version_info < (
- 3,
- 13,
- 1,
- ) or sys.version_info < (3, 12, 7)
- # Cleanup closed is no longer needed after https://github.com/python/cpython/pull/118960
- # which first appeared in Python 3.12.7 and 3.13.1
- __all__ = ("BaseConnector", "TCPConnector", "UnixConnector", "NamedPipeConnector")
- if TYPE_CHECKING:
- from .client import ClientTimeout
- from .client_reqrep import ConnectionKey
- from .tracing import Trace
- class _DeprecationWaiter:
- __slots__ = ("_awaitable", "_awaited")
- def __init__(self, awaitable: Awaitable[Any]) -> None:
- self._awaitable = awaitable
- self._awaited = False
- def __await__(self) -> Any:
- self._awaited = True
- return self._awaitable.__await__()
- def __del__(self) -> None:
- if not self._awaited:
- warnings.warn(
- "Connector.close() is a coroutine, "
- "please use await connector.close()",
- DeprecationWarning,
- )
- class Connection:
- _source_traceback = None
- def __init__(
- self,
- connector: "BaseConnector",
- key: "ConnectionKey",
- protocol: ResponseHandler,
- loop: asyncio.AbstractEventLoop,
- ) -> None:
- self._key = key
- self._connector = connector
- self._loop = loop
- self._protocol: Optional[ResponseHandler] = protocol
- self._callbacks: List[Callable[[], None]] = []
- if loop.get_debug():
- self._source_traceback = traceback.extract_stack(sys._getframe(1))
- def __repr__(self) -> str:
- return f"Connection<{self._key}>"
- def __del__(self, _warnings: Any = warnings) -> None:
- if self._protocol is not None:
- kwargs = {"source": self}
- _warnings.warn(f"Unclosed connection {self!r}", ResourceWarning, **kwargs)
- if self._loop.is_closed():
- return
- self._connector._release(self._key, self._protocol, should_close=True)
- context = {"client_connection": self, "message": "Unclosed connection"}
- if self._source_traceback is not None:
- context["source_traceback"] = self._source_traceback
- self._loop.call_exception_handler(context)
- def __bool__(self) -> Literal[True]:
- """Force subclasses to not be falsy, to make checks simpler."""
- return True
- @property
- def loop(self) -> asyncio.AbstractEventLoop:
- warnings.warn(
- "connector.loop property is deprecated", DeprecationWarning, stacklevel=2
- )
- return self._loop
- @property
- def transport(self) -> Optional[asyncio.Transport]:
- if self._protocol is None:
- return None
- return self._protocol.transport
- @property
- def protocol(self) -> Optional[ResponseHandler]:
- return self._protocol
- def add_callback(self, callback: Callable[[], None]) -> None:
- if callback is not None:
- self._callbacks.append(callback)
- def _notify_release(self) -> None:
- callbacks, self._callbacks = self._callbacks[:], []
- for cb in callbacks:
- with suppress(Exception):
- cb()
- def close(self) -> None:
- self._notify_release()
- if self._protocol is not None:
- self._connector._release(self._key, self._protocol, should_close=True)
- self._protocol = None
- def release(self) -> None:
- self._notify_release()
- if self._protocol is not None:
- self._connector._release(self._key, self._protocol)
- self._protocol = None
- @property
- def closed(self) -> bool:
- return self._protocol is None or not self._protocol.is_connected()
- class _TransportPlaceholder:
- """placeholder for BaseConnector.connect function"""
- __slots__ = ()
- def close(self) -> None:
- """Close the placeholder transport."""
- class BaseConnector:
- """Base connector class.
- keepalive_timeout - (optional) Keep-alive timeout.
- force_close - Set to True to force close and do reconnect
- after each request (and between redirects).
- limit - The total number of simultaneous connections.
- limit_per_host - Number of simultaneous connections to one host.
- enable_cleanup_closed - Enables clean-up closed ssl transports.
- Disabled by default.
- timeout_ceil_threshold - Trigger ceiling of timeout values when
- it's above timeout_ceil_threshold.
- loop - Optional event loop.
- """
- _closed = True # prevent AttributeError in __del__ if ctor was failed
- _source_traceback = None
- # abort transport after 2 seconds (cleanup broken connections)
- _cleanup_closed_period = 2.0
- allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET
- def __init__(
- self,
- *,
- keepalive_timeout: Union[object, None, float] = sentinel,
- force_close: bool = False,
- limit: int = 100,
- limit_per_host: int = 0,
- enable_cleanup_closed: bool = False,
- loop: Optional[asyncio.AbstractEventLoop] = None,
- timeout_ceil_threshold: float = 5,
- ) -> None:
- if force_close:
- if keepalive_timeout is not None and keepalive_timeout is not sentinel:
- raise ValueError(
- "keepalive_timeout cannot be set if force_close is True"
- )
- else:
- if keepalive_timeout is sentinel:
- keepalive_timeout = 15.0
- loop = loop or asyncio.get_running_loop()
- self._timeout_ceil_threshold = timeout_ceil_threshold
- self._closed = False
- if loop.get_debug():
- self._source_traceback = traceback.extract_stack(sys._getframe(1))
- # Connection pool of reusable connections.
- # We use a deque to store connections because it has O(1) popleft()
- # and O(1) append() operations to implement a FIFO queue.
- self._conns: DefaultDict[
- ConnectionKey, Deque[Tuple[ResponseHandler, float]]
- ] = defaultdict(deque)
- self._limit = limit
- self._limit_per_host = limit_per_host
- self._acquired: Set[ResponseHandler] = set()
- self._acquired_per_host: DefaultDict[ConnectionKey, Set[ResponseHandler]] = (
- defaultdict(set)
- )
- self._keepalive_timeout = cast(float, keepalive_timeout)
- self._force_close = force_close
- # {host_key: FIFO list of waiters}
- # The FIFO is implemented with an OrderedDict with None keys because
- # python does not have an ordered set.
- self._waiters: DefaultDict[
- ConnectionKey, OrderedDict[asyncio.Future[None], None]
- ] = defaultdict(OrderedDict)
- self._loop = loop
- self._factory = functools.partial(ResponseHandler, loop=loop)
- # start keep-alive connection cleanup task
- self._cleanup_handle: Optional[asyncio.TimerHandle] = None
- # start cleanup closed transports task
- self._cleanup_closed_handle: Optional[asyncio.TimerHandle] = None
- if enable_cleanup_closed and not NEEDS_CLEANUP_CLOSED:
- warnings.warn(
- "enable_cleanup_closed ignored because "
- "https://github.com/python/cpython/pull/118960 is fixed "
- f"in Python version {sys.version_info}",
- DeprecationWarning,
- stacklevel=2,
- )
- enable_cleanup_closed = False
- self._cleanup_closed_disabled = not enable_cleanup_closed
- self._cleanup_closed_transports: List[Optional[asyncio.Transport]] = []
- self._cleanup_closed()
- def __del__(self, _warnings: Any = warnings) -> None:
- if self._closed:
- return
- if not self._conns:
- return
- conns = [repr(c) for c in self._conns.values()]
- self._close()
- kwargs = {"source": self}
- _warnings.warn(f"Unclosed connector {self!r}", ResourceWarning, **kwargs)
- context = {
- "connector": self,
- "connections": conns,
- "message": "Unclosed connector",
- }
- if self._source_traceback is not None:
- context["source_traceback"] = self._source_traceback
- self._loop.call_exception_handler(context)
- def __enter__(self) -> "BaseConnector":
- warnings.warn(
- '"with Connector():" is deprecated, '
- 'use "async with Connector():" instead',
- DeprecationWarning,
- )
- return self
- def __exit__(self, *exc: Any) -> None:
- self._close()
- async def __aenter__(self) -> "BaseConnector":
- return self
- async def __aexit__(
- self,
- exc_type: Optional[Type[BaseException]] = None,
- exc_value: Optional[BaseException] = None,
- exc_traceback: Optional[TracebackType] = None,
- ) -> None:
- await self.close()
- @property
- def force_close(self) -> bool:
- """Ultimately close connection on releasing if True."""
- return self._force_close
- @property
- def limit(self) -> int:
- """The total number for simultaneous connections.
- If limit is 0 the connector has no limit.
- The default limit size is 100.
- """
- return self._limit
- @property
- def limit_per_host(self) -> int:
- """The limit for simultaneous connections to the same endpoint.
- Endpoints are the same if they are have equal
- (host, port, is_ssl) triple.
- """
- return self._limit_per_host
- def _cleanup(self) -> None:
- """Cleanup unused transports."""
- if self._cleanup_handle:
- self._cleanup_handle.cancel()
- # _cleanup_handle should be unset, otherwise _release() will not
- # recreate it ever!
- self._cleanup_handle = None
- now = monotonic()
- timeout = self._keepalive_timeout
- if self._conns:
- connections = defaultdict(deque)
- deadline = now - timeout
- for key, conns in self._conns.items():
- alive: Deque[Tuple[ResponseHandler, float]] = deque()
- for proto, use_time in conns:
- if proto.is_connected() and use_time - deadline >= 0:
- alive.append((proto, use_time))
- continue
- transport = proto.transport
- proto.close()
- if not self._cleanup_closed_disabled and key.is_ssl:
- self._cleanup_closed_transports.append(transport)
- if alive:
- connections[key] = alive
- self._conns = connections
- if self._conns:
- self._cleanup_handle = helpers.weakref_handle(
- self,
- "_cleanup",
- timeout,
- self._loop,
- timeout_ceil_threshold=self._timeout_ceil_threshold,
- )
- def _cleanup_closed(self) -> None:
- """Double confirmation for transport close.
- Some broken ssl servers may leave socket open without proper close.
- """
- if self._cleanup_closed_handle:
- self._cleanup_closed_handle.cancel()
- for transport in self._cleanup_closed_transports:
- if transport is not None:
- transport.abort()
- self._cleanup_closed_transports = []
- if not self._cleanup_closed_disabled:
- self._cleanup_closed_handle = helpers.weakref_handle(
- self,
- "_cleanup_closed",
- self._cleanup_closed_period,
- self._loop,
- timeout_ceil_threshold=self._timeout_ceil_threshold,
- )
- def close(self) -> Awaitable[None]:
- """Close all opened transports."""
- self._close()
- return _DeprecationWaiter(noop())
- def _close(self) -> None:
- if self._closed:
- return
- self._closed = True
- try:
- if self._loop.is_closed():
- return
- # cancel cleanup task
- if self._cleanup_handle:
- self._cleanup_handle.cancel()
- # cancel cleanup close task
- if self._cleanup_closed_handle:
- self._cleanup_closed_handle.cancel()
- for data in self._conns.values():
- for proto, t0 in data:
- proto.close()
- for proto in self._acquired:
- proto.close()
- for transport in self._cleanup_closed_transports:
- if transport is not None:
- transport.abort()
- finally:
- self._conns.clear()
- self._acquired.clear()
- for keyed_waiters in self._waiters.values():
- for keyed_waiter in keyed_waiters:
- keyed_waiter.cancel()
- self._waiters.clear()
- self._cleanup_handle = None
- self._cleanup_closed_transports.clear()
- self._cleanup_closed_handle = None
- @property
- def closed(self) -> bool:
- """Is connector closed.
- A readonly property.
- """
- return self._closed
- def _available_connections(self, key: "ConnectionKey") -> int:
- """
- Return number of available connections.
- The limit, limit_per_host and the connection key are taken into account.
- If it returns less than 1 means that there are no connections
- available.
- """
- # check total available connections
- # If there are no limits, this will always return 1
- total_remain = 1
- if self._limit and (total_remain := self._limit - len(self._acquired)) <= 0:
- return total_remain
- # check limit per host
- if host_remain := self._limit_per_host:
- if acquired := self._acquired_per_host.get(key):
- host_remain -= len(acquired)
- if total_remain > host_remain:
- return host_remain
- return total_remain
- async def connect(
- self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout"
- ) -> Connection:
- """Get from pool or create new connection."""
- key = req.connection_key
- if (conn := await self._get(key, traces)) is not None:
- # If we do not have to wait and we can get a connection from the pool
- # we can avoid the timeout ceil logic and directly return the connection
- return conn
- async with ceil_timeout(timeout.connect, timeout.ceil_threshold):
- if self._available_connections(key) <= 0:
- await self._wait_for_available_connection(key, traces)
- if (conn := await self._get(key, traces)) is not None:
- return conn
- placeholder = cast(ResponseHandler, _TransportPlaceholder())
- self._acquired.add(placeholder)
- if self._limit_per_host:
- self._acquired_per_host[key].add(placeholder)
- try:
- # Traces are done inside the try block to ensure that the
- # that the placeholder is still cleaned up if an exception
- # is raised.
- if traces:
- for trace in traces:
- await trace.send_connection_create_start()
- proto = await self._create_connection(req, traces, timeout)
- if traces:
- for trace in traces:
- await trace.send_connection_create_end()
- except BaseException:
- self._release_acquired(key, placeholder)
- raise
- else:
- if self._closed:
- proto.close()
- raise ClientConnectionError("Connector is closed.")
- # The connection was successfully created, drop the placeholder
- # and add the real connection to the acquired set. There should
- # be no awaits after the proto is added to the acquired set
- # to ensure that the connection is not left in the acquired set
- # on cancellation.
- self._acquired.remove(placeholder)
- self._acquired.add(proto)
- if self._limit_per_host:
- acquired_per_host = self._acquired_per_host[key]
- acquired_per_host.remove(placeholder)
- acquired_per_host.add(proto)
- return Connection(self, key, proto, self._loop)
- async def _wait_for_available_connection(
- self, key: "ConnectionKey", traces: List["Trace"]
- ) -> None:
- """Wait for an available connection slot."""
- # We loop here because there is a race between
- # the connection limit check and the connection
- # being acquired. If the connection is acquired
- # between the check and the await statement, we
- # need to loop again to check if the connection
- # slot is still available.
- attempts = 0
- while True:
- fut: asyncio.Future[None] = self._loop.create_future()
- keyed_waiters = self._waiters[key]
- keyed_waiters[fut] = None
- if attempts:
- # If we have waited before, we need to move the waiter
- # to the front of the queue as otherwise we might get
- # starved and hit the timeout.
- keyed_waiters.move_to_end(fut, last=False)
- try:
- # Traces happen in the try block to ensure that the
- # the waiter is still cleaned up if an exception is raised.
- if traces:
- for trace in traces:
- await trace.send_connection_queued_start()
- await fut
- if traces:
- for trace in traces:
- await trace.send_connection_queued_end()
- finally:
- # pop the waiter from the queue if its still
- # there and not already removed by _release_waiter
- keyed_waiters.pop(fut, None)
- if not self._waiters.get(key, True):
- del self._waiters[key]
- if self._available_connections(key) > 0:
- break
- attempts += 1
- async def _get(
- self, key: "ConnectionKey", traces: List["Trace"]
- ) -> Optional[Connection]:
- """Get next reusable connection for the key or None.
- The connection will be marked as acquired.
- """
- if (conns := self._conns.get(key)) is None:
- return None
- t1 = monotonic()
- while conns:
- proto, t0 = conns.popleft()
- # We will we reuse the connection if its connected and
- # the keepalive timeout has not been exceeded
- if proto.is_connected() and t1 - t0 <= self._keepalive_timeout:
- if not conns:
- # The very last connection was reclaimed: drop the key
- del self._conns[key]
- self._acquired.add(proto)
- if self._limit_per_host:
- self._acquired_per_host[key].add(proto)
- if traces:
- for trace in traces:
- try:
- await trace.send_connection_reuseconn()
- except BaseException:
- self._release_acquired(key, proto)
- raise
- return Connection(self, key, proto, self._loop)
- # Connection cannot be reused, close it
- transport = proto.transport
- proto.close()
- # only for SSL transports
- if not self._cleanup_closed_disabled and key.is_ssl:
- self._cleanup_closed_transports.append(transport)
- # No more connections: drop the key
- del self._conns[key]
- return None
- def _release_waiter(self) -> None:
- """
- Iterates over all waiters until one to be released is found.
- The one to be released is not finished and
- belongs to a host that has available connections.
- """
- if not self._waiters:
- return
- # Having the dict keys ordered this avoids to iterate
- # at the same order at each call.
- queues = list(self._waiters)
- random.shuffle(queues)
- for key in queues:
- if self._available_connections(key) < 1:
- continue
- waiters = self._waiters[key]
- while waiters:
- waiter, _ = waiters.popitem(last=False)
- if not waiter.done():
- waiter.set_result(None)
- return
- def _release_acquired(self, key: "ConnectionKey", proto: ResponseHandler) -> None:
- """Release acquired connection."""
- if self._closed:
- # acquired connection is already released on connector closing
- return
- self._acquired.discard(proto)
- if self._limit_per_host and (conns := self._acquired_per_host.get(key)):
- conns.discard(proto)
- if not conns:
- del self._acquired_per_host[key]
- self._release_waiter()
- def _release(
- self,
- key: "ConnectionKey",
- protocol: ResponseHandler,
- *,
- should_close: bool = False,
- ) -> None:
- if self._closed:
- # acquired connection is already released on connector closing
- return
- self._release_acquired(key, protocol)
- if self._force_close or should_close or protocol.should_close:
- transport = protocol.transport
- protocol.close()
- if key.is_ssl and not self._cleanup_closed_disabled:
- self._cleanup_closed_transports.append(transport)
- return
- self._conns[key].append((protocol, monotonic()))
- if self._cleanup_handle is None:
- self._cleanup_handle = helpers.weakref_handle(
- self,
- "_cleanup",
- self._keepalive_timeout,
- self._loop,
- timeout_ceil_threshold=self._timeout_ceil_threshold,
- )
- async def _create_connection(
- self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout"
- ) -> ResponseHandler:
- raise NotImplementedError()
- class _DNSCacheTable:
- def __init__(self, ttl: Optional[float] = None) -> None:
- self._addrs_rr: Dict[Tuple[str, int], Tuple[Iterator[ResolveResult], int]] = {}
- self._timestamps: Dict[Tuple[str, int], float] = {}
- self._ttl = ttl
- def __contains__(self, host: object) -> bool:
- return host in self._addrs_rr
- def add(self, key: Tuple[str, int], addrs: List[ResolveResult]) -> None:
- self._addrs_rr[key] = (cycle(addrs), len(addrs))
- if self._ttl is not None:
- self._timestamps[key] = monotonic()
- def remove(self, key: Tuple[str, int]) -> None:
- self._addrs_rr.pop(key, None)
- if self._ttl is not None:
- self._timestamps.pop(key, None)
- def clear(self) -> None:
- self._addrs_rr.clear()
- self._timestamps.clear()
- def next_addrs(self, key: Tuple[str, int]) -> List[ResolveResult]:
- loop, length = self._addrs_rr[key]
- addrs = list(islice(loop, length))
- # Consume one more element to shift internal state of `cycle`
- next(loop)
- return addrs
- def expired(self, key: Tuple[str, int]) -> bool:
- if self._ttl is None:
- return False
- return self._timestamps[key] + self._ttl < monotonic()
- def _make_ssl_context(verified: bool) -> SSLContext:
- """Create SSL context.
- This method is not async-friendly and should be called from a thread
- because it will load certificates from disk and do other blocking I/O.
- """
- if ssl is None:
- # No ssl support
- return None
- if verified:
- sslcontext = ssl.create_default_context()
- else:
- sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
- sslcontext.options |= ssl.OP_NO_SSLv2
- sslcontext.options |= ssl.OP_NO_SSLv3
- sslcontext.check_hostname = False
- sslcontext.verify_mode = ssl.CERT_NONE
- sslcontext.options |= ssl.OP_NO_COMPRESSION
- sslcontext.set_default_verify_paths()
- sslcontext.set_alpn_protocols(("http/1.1",))
- return sslcontext
- # The default SSLContext objects are created at import time
- # since they do blocking I/O to load certificates from disk,
- # and imports should always be done before the event loop starts
- # or in a thread.
- _SSL_CONTEXT_VERIFIED = _make_ssl_context(True)
- _SSL_CONTEXT_UNVERIFIED = _make_ssl_context(False)
- class TCPConnector(BaseConnector):
- """TCP connector.
- verify_ssl - Set to True to check ssl certifications.
- fingerprint - Pass the binary sha256
- digest of the expected certificate in DER format to verify
- that the certificate the server presents matches. See also
- https://en.wikipedia.org/wiki/HTTP_Public_Key_Pinning
- resolver - Enable DNS lookups and use this
- resolver
- use_dns_cache - Use memory cache for DNS lookups.
- ttl_dns_cache - Max seconds having cached a DNS entry, None forever.
- family - socket address family
- local_addr - local tuple of (host, port) to bind socket to
- keepalive_timeout - (optional) Keep-alive timeout.
- force_close - Set to True to force close and do reconnect
- after each request (and between redirects).
- limit - The total number of simultaneous connections.
- limit_per_host - Number of simultaneous connections to one host.
- enable_cleanup_closed - Enables clean-up closed ssl transports.
- Disabled by default.
- happy_eyeballs_delay - This is the “Connection Attempt Delay”
- as defined in RFC 8305. To disable
- the happy eyeballs algorithm, set to None.
- interleave - “First Address Family Count” as defined in RFC 8305
- loop - Optional event loop.
- """
- allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"tcp"})
- def __init__(
- self,
- *,
- verify_ssl: bool = True,
- fingerprint: Optional[bytes] = None,
- use_dns_cache: bool = True,
- ttl_dns_cache: Optional[int] = 10,
- family: socket.AddressFamily = socket.AddressFamily.AF_UNSPEC,
- ssl_context: Optional[SSLContext] = None,
- ssl: Union[bool, Fingerprint, SSLContext] = True,
- local_addr: Optional[Tuple[str, int]] = None,
- resolver: Optional[AbstractResolver] = None,
- keepalive_timeout: Union[None, float, object] = sentinel,
- force_close: bool = False,
- limit: int = 100,
- limit_per_host: int = 0,
- enable_cleanup_closed: bool = False,
- loop: Optional[asyncio.AbstractEventLoop] = None,
- timeout_ceil_threshold: float = 5,
- happy_eyeballs_delay: Optional[float] = 0.25,
- interleave: Optional[int] = None,
- ):
- super().__init__(
- keepalive_timeout=keepalive_timeout,
- force_close=force_close,
- limit=limit,
- limit_per_host=limit_per_host,
- enable_cleanup_closed=enable_cleanup_closed,
- loop=loop,
- timeout_ceil_threshold=timeout_ceil_threshold,
- )
- self._ssl = _merge_ssl_params(ssl, verify_ssl, ssl_context, fingerprint)
- if resolver is None:
- resolver = DefaultResolver(loop=self._loop)
- self._resolver = resolver
- self._use_dns_cache = use_dns_cache
- self._cached_hosts = _DNSCacheTable(ttl=ttl_dns_cache)
- self._throttle_dns_futures: Dict[
- Tuple[str, int], Set["asyncio.Future[None]"]
- ] = {}
- self._family = family
- self._local_addr_infos = aiohappyeyeballs.addr_to_addr_infos(local_addr)
- self._happy_eyeballs_delay = happy_eyeballs_delay
- self._interleave = interleave
- self._resolve_host_tasks: Set["asyncio.Task[List[ResolveResult]]"] = set()
- def close(self) -> Awaitable[None]:
- """Close all ongoing DNS calls."""
- for fut in chain.from_iterable(self._throttle_dns_futures.values()):
- fut.cancel()
- for t in self._resolve_host_tasks:
- t.cancel()
- return super().close()
- @property
- def family(self) -> int:
- """Socket family like AF_INET."""
- return self._family
- @property
- def use_dns_cache(self) -> bool:
- """True if local DNS caching is enabled."""
- return self._use_dns_cache
- def clear_dns_cache(
- self, host: Optional[str] = None, port: Optional[int] = None
- ) -> None:
- """Remove specified host/port or clear all dns local cache."""
- if host is not None and port is not None:
- self._cached_hosts.remove((host, port))
- elif host is not None or port is not None:
- raise ValueError("either both host and port or none of them are allowed")
- else:
- self._cached_hosts.clear()
- async def _resolve_host(
- self, host: str, port: int, traces: Optional[Sequence["Trace"]] = None
- ) -> List[ResolveResult]:
- """Resolve host and return list of addresses."""
- if is_ip_address(host):
- return [
- {
- "hostname": host,
- "host": host,
- "port": port,
- "family": self._family,
- "proto": 0,
- "flags": 0,
- }
- ]
- if not self._use_dns_cache:
- if traces:
- for trace in traces:
- await trace.send_dns_resolvehost_start(host)
- res = await self._resolver.resolve(host, port, family=self._family)
- if traces:
- for trace in traces:
- await trace.send_dns_resolvehost_end(host)
- return res
- key = (host, port)
- if key in self._cached_hosts and not self._cached_hosts.expired(key):
- # get result early, before any await (#4014)
- result = self._cached_hosts.next_addrs(key)
- if traces:
- for trace in traces:
- await trace.send_dns_cache_hit(host)
- return result
- futures: Set["asyncio.Future[None]"]
- #
- # If multiple connectors are resolving the same host, we wait
- # for the first one to resolve and then use the result for all of them.
- # We use a throttle to ensure that we only resolve the host once
- # and then use the result for all the waiters.
- #
- if key in self._throttle_dns_futures:
- # get futures early, before any await (#4014)
- futures = self._throttle_dns_futures[key]
- future: asyncio.Future[None] = self._loop.create_future()
- futures.add(future)
- if traces:
- for trace in traces:
- await trace.send_dns_cache_hit(host)
- try:
- await future
- finally:
- futures.discard(future)
- return self._cached_hosts.next_addrs(key)
- # update dict early, before any await (#4014)
- self._throttle_dns_futures[key] = futures = set()
- # In this case we need to create a task to ensure that we can shield
- # the task from cancellation as cancelling this lookup should not cancel
- # the underlying lookup or else the cancel event will get broadcast to
- # all the waiters across all connections.
- #
- coro = self._resolve_host_with_throttle(key, host, port, futures, traces)
- loop = asyncio.get_running_loop()
- if sys.version_info >= (3, 12):
- # Optimization for Python 3.12, try to send immediately
- resolved_host_task = asyncio.Task(coro, loop=loop, eager_start=True)
- else:
- resolved_host_task = loop.create_task(coro)
- if not resolved_host_task.done():
- self._resolve_host_tasks.add(resolved_host_task)
- resolved_host_task.add_done_callback(self._resolve_host_tasks.discard)
- try:
- return await asyncio.shield(resolved_host_task)
- except asyncio.CancelledError:
- def drop_exception(fut: "asyncio.Future[List[ResolveResult]]") -> None:
- with suppress(Exception, asyncio.CancelledError):
- fut.result()
- resolved_host_task.add_done_callback(drop_exception)
- raise
- async def _resolve_host_with_throttle(
- self,
- key: Tuple[str, int],
- host: str,
- port: int,
- futures: Set["asyncio.Future[None]"],
- traces: Optional[Sequence["Trace"]],
- ) -> List[ResolveResult]:
- """Resolve host and set result for all waiters.
- This method must be run in a task and shielded from cancellation
- to avoid cancelling the underlying lookup.
- """
- try:
- if traces:
- for trace in traces:
- await trace.send_dns_cache_miss(host)
- for trace in traces:
- await trace.send_dns_resolvehost_start(host)
- addrs = await self._resolver.resolve(host, port, family=self._family)
- if traces:
- for trace in traces:
- await trace.send_dns_resolvehost_end(host)
- self._cached_hosts.add(key, addrs)
- for fut in futures:
- set_result(fut, None)
- except BaseException as e:
- # any DNS exception is set for the waiters to raise the same exception.
- # This coro is always run in task that is shielded from cancellation so
- # we should never be propagating cancellation here.
- for fut in futures:
- set_exception(fut, e)
- raise
- finally:
- self._throttle_dns_futures.pop(key)
- return self._cached_hosts.next_addrs(key)
- async def _create_connection(
- self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout"
- ) -> ResponseHandler:
- """Create connection.
- Has same keyword arguments as BaseEventLoop.create_connection.
- """
- if req.proxy:
- _, proto = await self._create_proxy_connection(req, traces, timeout)
- else:
- _, proto = await self._create_direct_connection(req, traces, timeout)
- return proto
- def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]:
- """Logic to get the correct SSL context
- 0. if req.ssl is false, return None
- 1. if ssl_context is specified in req, use it
- 2. if _ssl_context is specified in self, use it
- 3. otherwise:
- 1. if verify_ssl is not specified in req, use self.ssl_context
- (will generate a default context according to self.verify_ssl)
- 2. if verify_ssl is True in req, generate a default SSL context
- 3. if verify_ssl is False in req, generate a SSL context that
- won't verify
- """
- if not req.is_ssl():
- return None
- if ssl is None: # pragma: no cover
- raise RuntimeError("SSL is not supported.")
- sslcontext = req.ssl
- if isinstance(sslcontext, ssl.SSLContext):
- return sslcontext
- if sslcontext is not True:
- # not verified or fingerprinted
- return _SSL_CONTEXT_UNVERIFIED
- sslcontext = self._ssl
- if isinstance(sslcontext, ssl.SSLContext):
- return sslcontext
- if sslcontext is not True:
- # not verified or fingerprinted
- return _SSL_CONTEXT_UNVERIFIED
- return _SSL_CONTEXT_VERIFIED
- def _get_fingerprint(self, req: ClientRequest) -> Optional["Fingerprint"]:
- ret = req.ssl
- if isinstance(ret, Fingerprint):
- return ret
- ret = self._ssl
- if isinstance(ret, Fingerprint):
- return ret
- return None
- async def _wrap_create_connection(
- self,
- *args: Any,
- addr_infos: List[aiohappyeyeballs.AddrInfoType],
- req: ClientRequest,
- timeout: "ClientTimeout",
- client_error: Type[Exception] = ClientConnectorError,
- **kwargs: Any,
- ) -> Tuple[asyncio.Transport, ResponseHandler]:
- try:
- async with ceil_timeout(
- timeout.sock_connect, ceil_threshold=timeout.ceil_threshold
- ):
- sock = await aiohappyeyeballs.start_connection(
- addr_infos=addr_infos,
- local_addr_infos=self._local_addr_infos,
- happy_eyeballs_delay=self._happy_eyeballs_delay,
- interleave=self._interleave,
- loop=self._loop,
- )
- return await self._loop.create_connection(*args, **kwargs, sock=sock)
- except cert_errors as exc:
- raise ClientConnectorCertificateError(req.connection_key, exc) from exc
- except ssl_errors as exc:
- raise ClientConnectorSSLError(req.connection_key, exc) from exc
- except OSError as exc:
- if exc.errno is None and isinstance(exc, asyncio.TimeoutError):
- raise
- raise client_error(req.connection_key, exc) from exc
- async def _wrap_existing_connection(
- self,
- *args: Any,
- req: ClientRequest,
- timeout: "ClientTimeout",
- client_error: Type[Exception] = ClientConnectorError,
- **kwargs: Any,
- ) -> Tuple[asyncio.Transport, ResponseHandler]:
- try:
- async with ceil_timeout(
- timeout.sock_connect, ceil_threshold=timeout.ceil_threshold
- ):
- return await self._loop.create_connection(*args, **kwargs)
- except cert_errors as exc:
- raise ClientConnectorCertificateError(req.connection_key, exc) from exc
- except ssl_errors as exc:
- raise ClientConnectorSSLError(req.connection_key, exc) from exc
- except OSError as exc:
- if exc.errno is None and isinstance(exc, asyncio.TimeoutError):
- raise
- raise client_error(req.connection_key, exc) from exc
- def _fail_on_no_start_tls(self, req: "ClientRequest") -> None:
- """Raise a :py:exc:`RuntimeError` on missing ``start_tls()``.
- It is necessary for TLS-in-TLS so that it is possible to
- send HTTPS queries through HTTPS proxies.
- This doesn't affect regular HTTP requests, though.
- """
- if not req.is_ssl():
- return
- proxy_url = req.proxy
- assert proxy_url is not None
- if proxy_url.scheme != "https":
- return
- self._check_loop_for_start_tls()
- def _check_loop_for_start_tls(self) -> None:
- try:
- self._loop.start_tls
- except AttributeError as attr_exc:
- raise RuntimeError(
- "An HTTPS request is being sent through an HTTPS proxy. "
- "This needs support for TLS in TLS but it is not implemented "
- "in your runtime for the stdlib asyncio.\n\n"
- "Please upgrade to Python 3.11 or higher. For more details, "
- "please see:\n"
- "* https://bugs.python.org/issue37179\n"
- "* https://github.com/python/cpython/pull/28073\n"
- "* https://docs.aiohttp.org/en/stable/"
- "client_advanced.html#proxy-support\n"
- "* https://github.com/aio-libs/aiohttp/discussions/6044\n",
- ) from attr_exc
- def _loop_supports_start_tls(self) -> bool:
- try:
- self._check_loop_for_start_tls()
- except RuntimeError:
- return False
- else:
- return True
- def _warn_about_tls_in_tls(
- self,
- underlying_transport: asyncio.Transport,
- req: ClientRequest,
- ) -> None:
- """Issue a warning if the requested URL has HTTPS scheme."""
- if req.request_info.url.scheme != "https":
- return
- asyncio_supports_tls_in_tls = getattr(
- underlying_transport,
- "_start_tls_compatible",
- False,
- )
- if asyncio_supports_tls_in_tls:
- return
- warnings.warn(
- "An HTTPS request is being sent through an HTTPS proxy. "
- "This support for TLS in TLS is known to be disabled "
- "in the stdlib asyncio (Python <3.11). This is why you'll probably see "
- "an error in the log below.\n\n"
- "It is possible to enable it via monkeypatching. "
- "For more details, see:\n"
- "* https://bugs.python.org/issue37179\n"
- "* https://github.com/python/cpython/pull/28073\n\n"
- "You can temporarily patch this as follows:\n"
- "* https://docs.aiohttp.org/en/stable/client_advanced.html#proxy-support\n"
- "* https://github.com/aio-libs/aiohttp/discussions/6044\n",
- RuntimeWarning,
- source=self,
- # Why `4`? At least 3 of the calls in the stack originate
- # from the methods in this class.
- stacklevel=3,
- )
- async def _start_tls_connection(
- self,
- underlying_transport: asyncio.Transport,
- req: ClientRequest,
- timeout: "ClientTimeout",
- client_error: Type[Exception] = ClientConnectorError,
- ) -> Tuple[asyncio.BaseTransport, ResponseHandler]:
- """Wrap the raw TCP transport with TLS."""
- tls_proto = self._factory() # Create a brand new proto for TLS
- sslcontext = self._get_ssl_context(req)
- if TYPE_CHECKING:
- # _start_tls_connection is unreachable in the current code path
- # if sslcontext is None.
- assert sslcontext is not None
- try:
- async with ceil_timeout(
- timeout.sock_connect, ceil_threshold=timeout.ceil_threshold
- ):
- try:
- tls_transport = await self._loop.start_tls(
- underlying_transport,
- tls_proto,
- sslcontext,
- server_hostname=req.server_hostname or req.host,
- ssl_handshake_timeout=timeout.total,
- )
- except BaseException:
- # We need to close the underlying transport since
- # `start_tls()` probably failed before it had a
- # chance to do this:
- underlying_transport.close()
- raise
- if isinstance(tls_transport, asyncio.Transport):
- fingerprint = self._get_fingerprint(req)
- if fingerprint:
- try:
- fingerprint.check(tls_transport)
- except ServerFingerprintMismatch:
- tls_transport.close()
- if not self._cleanup_closed_disabled:
- self._cleanup_closed_transports.append(tls_transport)
- raise
- except cert_errors as exc:
- raise ClientConnectorCertificateError(req.connection_key, exc) from exc
- except ssl_errors as exc:
- raise ClientConnectorSSLError(req.connection_key, exc) from exc
- except OSError as exc:
- if exc.errno is None and isinstance(exc, asyncio.TimeoutError):
- raise
- raise client_error(req.connection_key, exc) from exc
- except TypeError as type_err:
- # Example cause looks like this:
- # TypeError: transport <asyncio.sslproto._SSLProtocolTransport
- # object at 0x7f760615e460> is not supported by start_tls()
- raise ClientConnectionError(
- "Cannot initialize a TLS-in-TLS connection to host "
- f"{req.host!s}:{req.port:d} through an underlying connection "
- f"to an HTTPS proxy {req.proxy!s} ssl:{req.ssl or 'default'} "
- f"[{type_err!s}]"
- ) from type_err
- else:
- if tls_transport is None:
- msg = "Failed to start TLS (possibly caused by closing transport)"
- raise client_error(req.connection_key, OSError(msg))
- tls_proto.connection_made(
- tls_transport
- ) # Kick the state machine of the new TLS protocol
- return tls_transport, tls_proto
- def _convert_hosts_to_addr_infos(
- self, hosts: List[ResolveResult]
- ) -> List[aiohappyeyeballs.AddrInfoType]:
- """Converts the list of hosts to a list of addr_infos.
- The list of hosts is the result of a DNS lookup. The list of
- addr_infos is the result of a call to `socket.getaddrinfo()`.
- """
- addr_infos: List[aiohappyeyeballs.AddrInfoType] = []
- for hinfo in hosts:
- host = hinfo["host"]
- is_ipv6 = ":" in host
- family = socket.AF_INET6 if is_ipv6 else socket.AF_INET
- if self._family and self._family != family:
- continue
- addr = (host, hinfo["port"], 0, 0) if is_ipv6 else (host, hinfo["port"])
- addr_infos.append(
- (family, socket.SOCK_STREAM, socket.IPPROTO_TCP, "", addr)
- )
- return addr_infos
- async def _create_direct_connection(
- self,
- req: ClientRequest,
- traces: List["Trace"],
- timeout: "ClientTimeout",
- *,
- client_error: Type[Exception] = ClientConnectorError,
- ) -> Tuple[asyncio.Transport, ResponseHandler]:
- sslcontext = self._get_ssl_context(req)
- fingerprint = self._get_fingerprint(req)
- host = req.url.raw_host
- assert host is not None
- # Replace multiple trailing dots with a single one.
- # A trailing dot is only present for fully-qualified domain names.
- # See https://github.com/aio-libs/aiohttp/pull/7364.
- if host.endswith(".."):
- host = host.rstrip(".") + "."
- port = req.port
- assert port is not None
- try:
- # Cancelling this lookup should not cancel the underlying lookup
- # or else the cancel event will get broadcast to all the waiters
- # across all connections.
- hosts = await self._resolve_host(host, port, traces=traces)
- except OSError as exc:
- if exc.errno is None and isinstance(exc, asyncio.TimeoutError):
- raise
- # in case of proxy it is not ClientProxyConnectionError
- # it is problem of resolving proxy ip itself
- raise ClientConnectorDNSError(req.connection_key, exc) from exc
- last_exc: Optional[Exception] = None
- addr_infos = self._convert_hosts_to_addr_infos(hosts)
- while addr_infos:
- # Strip trailing dots, certificates contain FQDN without dots.
- # See https://github.com/aio-libs/aiohttp/issues/3636
- server_hostname = (
- (req.server_hostname or host).rstrip(".") if sslcontext else None
- )
- try:
- transp, proto = await self._wrap_create_connection(
- self._factory,
- timeout=timeout,
- ssl=sslcontext,
- addr_infos=addr_infos,
- server_hostname=server_hostname,
- req=req,
- client_error=client_error,
- )
- except (ClientConnectorError, asyncio.TimeoutError) as exc:
- last_exc = exc
- aiohappyeyeballs.pop_addr_infos_interleave(addr_infos, self._interleave)
- continue
- if req.is_ssl() and fingerprint:
- try:
- fingerprint.check(transp)
- except ServerFingerprintMismatch as exc:
- transp.close()
- if not self._cleanup_closed_disabled:
- self._cleanup_closed_transports.append(transp)
- last_exc = exc
- # Remove the bad peer from the list of addr_infos
- sock: socket.socket = transp.get_extra_info("socket")
- bad_peer = sock.getpeername()
- aiohappyeyeballs.remove_addr_infos(addr_infos, bad_peer)
- continue
- return transp, proto
- else:
- assert last_exc is not None
- raise last_exc
- async def _create_proxy_connection(
- self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout"
- ) -> Tuple[asyncio.BaseTransport, ResponseHandler]:
- self._fail_on_no_start_tls(req)
- runtime_has_start_tls = self._loop_supports_start_tls()
- headers: Dict[str, str] = {}
- if req.proxy_headers is not None:
- headers = req.proxy_headers # type: ignore[assignment]
- headers[hdrs.HOST] = req.headers[hdrs.HOST]
- url = req.proxy
- assert url is not None
- proxy_req = ClientRequest(
- hdrs.METH_GET,
- url,
- headers=headers,
- auth=req.proxy_auth,
- loop=self._loop,
- ssl=req.ssl,
- )
- # create connection to proxy server
- transport, proto = await self._create_direct_connection(
- proxy_req, [], timeout, client_error=ClientProxyConnectionError
- )
- auth = proxy_req.headers.pop(hdrs.AUTHORIZATION, None)
- if auth is not None:
- if not req.is_ssl():
- req.headers[hdrs.PROXY_AUTHORIZATION] = auth
- else:
- proxy_req.headers[hdrs.PROXY_AUTHORIZATION] = auth
- if req.is_ssl():
- if runtime_has_start_tls:
- self._warn_about_tls_in_tls(transport, req)
- # For HTTPS requests over HTTP proxy
- # we must notify proxy to tunnel connection
- # so we send CONNECT command:
- # CONNECT www.python.org:443 HTTP/1.1
- # Host: www.python.org
- #
- # next we must do TLS handshake and so on
- # to do this we must wrap raw socket into secure one
- # asyncio handles this perfectly
- proxy_req.method = hdrs.METH_CONNECT
- proxy_req.url = req.url
- key = req.connection_key._replace(
- proxy=None, proxy_auth=None, proxy_headers_hash=None
- )
- conn = Connection(self, key, proto, self._loop)
- proxy_resp = await proxy_req.send(conn)
- try:
- protocol = conn._protocol
- assert protocol is not None
- # read_until_eof=True will ensure the connection isn't closed
- # once the response is received and processed allowing
- # START_TLS to work on the connection below.
- protocol.set_response_params(
- read_until_eof=runtime_has_start_tls,
- timeout_ceil_threshold=self._timeout_ceil_threshold,
- )
- resp = await proxy_resp.start(conn)
- except BaseException:
- proxy_resp.close()
- conn.close()
- raise
- else:
- conn._protocol = None
- try:
- if resp.status != 200:
- message = resp.reason
- if message is None:
- message = HTTPStatus(resp.status).phrase
- raise ClientHttpProxyError(
- proxy_resp.request_info,
- resp.history,
- status=resp.status,
- message=message,
- headers=resp.headers,
- )
- if not runtime_has_start_tls:
- rawsock = transport.get_extra_info("socket", default=None)
- if rawsock is None:
- raise RuntimeError(
- "Transport does not expose socket instance"
- )
- # Duplicate the socket, so now we can close proxy transport
- rawsock = rawsock.dup()
- except BaseException:
- # It shouldn't be closed in `finally` because it's fed to
- # `loop.start_tls()` and the docs say not to touch it after
- # passing there.
- transport.close()
- raise
- finally:
- if not runtime_has_start_tls:
- transport.close()
- if not runtime_has_start_tls:
- # HTTP proxy with support for upgrade to HTTPS
- sslcontext = self._get_ssl_context(req)
- return await self._wrap_existing_connection(
- self._factory,
- timeout=timeout,
- ssl=sslcontext,
- sock=rawsock,
- server_hostname=req.host,
- req=req,
- )
- return await self._start_tls_connection(
- # Access the old transport for the last time before it's
- # closed and forgotten forever:
- transport,
- req=req,
- timeout=timeout,
- )
- finally:
- proxy_resp.close()
- return transport, proto
- class UnixConnector(BaseConnector):
- """Unix socket connector.
- path - Unix socket path.
- keepalive_timeout - (optional) Keep-alive timeout.
- force_close - Set to True to force close and do reconnect
- after each request (and between redirects).
- limit - The total number of simultaneous connections.
- limit_per_host - Number of simultaneous connections to one host.
- loop - Optional event loop.
- """
- allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"unix"})
- def __init__(
- self,
- path: str,
- force_close: bool = False,
- keepalive_timeout: Union[object, float, None] = sentinel,
- limit: int = 100,
- limit_per_host: int = 0,
- loop: Optional[asyncio.AbstractEventLoop] = None,
- ) -> None:
- super().__init__(
- force_close=force_close,
- keepalive_timeout=keepalive_timeout,
- limit=limit,
- limit_per_host=limit_per_host,
- loop=loop,
- )
- self._path = path
- @property
- def path(self) -> str:
- """Path to unix socket."""
- return self._path
- async def _create_connection(
- self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout"
- ) -> ResponseHandler:
- try:
- async with ceil_timeout(
- timeout.sock_connect, ceil_threshold=timeout.ceil_threshold
- ):
- _, proto = await self._loop.create_unix_connection(
- self._factory, self._path
- )
- except OSError as exc:
- if exc.errno is None and isinstance(exc, asyncio.TimeoutError):
- raise
- raise UnixClientConnectorError(self.path, req.connection_key, exc) from exc
- return proto
- class NamedPipeConnector(BaseConnector):
- """Named pipe connector.
- Only supported by the proactor event loop.
- See also: https://docs.python.org/3/library/asyncio-eventloop.html
- path - Windows named pipe path.
- keepalive_timeout - (optional) Keep-alive timeout.
- force_close - Set to True to force close and do reconnect
- after each request (and between redirects).
- limit - The total number of simultaneous connections.
- limit_per_host - Number of simultaneous connections to one host.
- loop - Optional event loop.
- """
- allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"npipe"})
- def __init__(
- self,
- path: str,
- force_close: bool = False,
- keepalive_timeout: Union[object, float, None] = sentinel,
- limit: int = 100,
- limit_per_host: int = 0,
- loop: Optional[asyncio.AbstractEventLoop] = None,
- ) -> None:
- super().__init__(
- force_close=force_close,
- keepalive_timeout=keepalive_timeout,
- limit=limit,
- limit_per_host=limit_per_host,
- loop=loop,
- )
- if not isinstance(
- self._loop, asyncio.ProactorEventLoop # type: ignore[attr-defined]
- ):
- raise RuntimeError(
- "Named Pipes only available in proactor loop under windows"
- )
- self._path = path
- @property
- def path(self) -> str:
- """Path to the named pipe."""
- return self._path
- async def _create_connection(
- self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout"
- ) -> ResponseHandler:
- try:
- async with ceil_timeout(
- timeout.sock_connect, ceil_threshold=timeout.ceil_threshold
- ):
- _, proto = await self._loop.create_pipe_connection( # type: ignore[attr-defined]
- self._factory, self._path
- )
- # the drain is required so that the connection_made is called
- # and transport is set otherwise it is not set before the
- # `assert conn.transport is not None`
- # in client.py's _request method
- await asyncio.sleep(0)
- # other option is to manually set transport like
- # `proto.transport = trans`
- except OSError as exc:
- if exc.errno is None and isinstance(exc, asyncio.TimeoutError):
- raise
- raise ClientConnectorError(req.connection_key, exc) from exc
- return cast(ResponseHandler, proto)
|