impl.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. """Base implementation."""
  2. import asyncio
  3. import collections
  4. import contextlib
  5. import functools
  6. import itertools
  7. import socket
  8. from typing import List, Optional, Sequence, Set, Union
  9. from . import _staggered
  10. from .types import AddrInfoType, SocketFactoryType
  11. async def start_connection(
  12. addr_infos: Sequence[AddrInfoType],
  13. *,
  14. local_addr_infos: Optional[Sequence[AddrInfoType]] = None,
  15. happy_eyeballs_delay: Optional[float] = None,
  16. interleave: Optional[int] = None,
  17. loop: Optional[asyncio.AbstractEventLoop] = None,
  18. socket_factory: Optional[SocketFactoryType] = None,
  19. ) -> socket.socket:
  20. """
  21. Connect to a TCP server.
  22. Create a socket connection to a specified destination. The
  23. destination is specified as a list of AddrInfoType tuples as
  24. returned from getaddrinfo().
  25. The arguments are, in order:
  26. * ``family``: the address family, e.g. ``socket.AF_INET`` or
  27. ``socket.AF_INET6``.
  28. * ``type``: the socket type, e.g. ``socket.SOCK_STREAM`` or
  29. ``socket.SOCK_DGRAM``.
  30. * ``proto``: the protocol, e.g. ``socket.IPPROTO_TCP`` or
  31. ``socket.IPPROTO_UDP``.
  32. * ``canonname``: the canonical name of the address, e.g.
  33. ``"www.python.org"``.
  34. * ``sockaddr``: the socket address
  35. This method is a coroutine which will try to establish the connection
  36. in the background. When successful, the coroutine returns a
  37. socket.
  38. The expected use case is to use this method in conjunction with
  39. loop.create_connection() to establish a connection to a server::
  40. socket = await start_connection(addr_infos)
  41. transport, protocol = await loop.create_connection(
  42. MyProtocol, sock=socket, ...)
  43. """
  44. if not (current_loop := loop):
  45. current_loop = asyncio.get_running_loop()
  46. single_addr_info = len(addr_infos) == 1
  47. if happy_eyeballs_delay is not None and interleave is None:
  48. # If using happy eyeballs, default to interleave addresses by family
  49. interleave = 1
  50. if interleave and not single_addr_info:
  51. addr_infos = _interleave_addrinfos(addr_infos, interleave)
  52. sock: Optional[socket.socket] = None
  53. # uvloop can raise RuntimeError instead of OSError
  54. exceptions: List[List[Union[OSError, RuntimeError]]] = []
  55. if happy_eyeballs_delay is None or single_addr_info:
  56. # not using happy eyeballs
  57. for addrinfo in addr_infos:
  58. try:
  59. sock = await _connect_sock(
  60. current_loop,
  61. exceptions,
  62. addrinfo,
  63. local_addr_infos,
  64. None,
  65. socket_factory,
  66. )
  67. break
  68. except (RuntimeError, OSError):
  69. continue
  70. else: # using happy eyeballs
  71. open_sockets: Set[socket.socket] = set()
  72. try:
  73. sock, _, _ = await _staggered.staggered_race(
  74. (
  75. functools.partial(
  76. _connect_sock,
  77. current_loop,
  78. exceptions,
  79. addrinfo,
  80. local_addr_infos,
  81. open_sockets,
  82. socket_factory,
  83. )
  84. for addrinfo in addr_infos
  85. ),
  86. happy_eyeballs_delay,
  87. )
  88. finally:
  89. # If we have a winner, staggered_race will
  90. # cancel the other tasks, however there is a
  91. # small race window where any of the other tasks
  92. # can be done before they are cancelled which
  93. # will leave the socket open. To avoid this problem
  94. # we pass a set to _connect_sock to keep track of
  95. # the open sockets and close them here if there
  96. # are any "runner up" sockets.
  97. for s in open_sockets:
  98. if s is not sock:
  99. with contextlib.suppress(OSError):
  100. s.close()
  101. open_sockets = None # type: ignore[assignment]
  102. if sock is None:
  103. all_exceptions = [exc for sub in exceptions for exc in sub]
  104. try:
  105. first_exception = all_exceptions[0]
  106. if len(all_exceptions) == 1:
  107. raise first_exception
  108. else:
  109. # If they all have the same str(), raise one.
  110. model = str(first_exception)
  111. if all(str(exc) == model for exc in all_exceptions):
  112. raise first_exception
  113. # Raise a combined exception so the user can see all
  114. # the various error messages.
  115. msg = "Multiple exceptions: {}".format(
  116. ", ".join(str(exc) for exc in all_exceptions)
  117. )
  118. # If the errno is the same for all exceptions, raise
  119. # an OSError with that errno.
  120. if isinstance(first_exception, OSError):
  121. first_errno = first_exception.errno
  122. if all(
  123. isinstance(exc, OSError) and exc.errno == first_errno
  124. for exc in all_exceptions
  125. ):
  126. raise OSError(first_errno, msg)
  127. elif isinstance(first_exception, RuntimeError) and all(
  128. isinstance(exc, RuntimeError) for exc in all_exceptions
  129. ):
  130. raise RuntimeError(msg)
  131. # We have a mix of OSError and RuntimeError
  132. # so we have to pick which one to raise.
  133. # and we raise OSError for compatibility
  134. raise OSError(msg)
  135. finally:
  136. all_exceptions = None # type: ignore[assignment]
  137. exceptions = None # type: ignore[assignment]
  138. return sock
  139. async def _connect_sock(
  140. loop: asyncio.AbstractEventLoop,
  141. exceptions: List[List[Union[OSError, RuntimeError]]],
  142. addr_info: AddrInfoType,
  143. local_addr_infos: Optional[Sequence[AddrInfoType]] = None,
  144. open_sockets: Optional[Set[socket.socket]] = None,
  145. socket_factory: Optional[SocketFactoryType] = None,
  146. ) -> socket.socket:
  147. """
  148. Create, bind and connect one socket.
  149. If open_sockets is passed, add the socket to the set of open sockets.
  150. Any failure caught here will remove the socket from the set and close it.
  151. Callers can use this set to close any sockets that are not the winner
  152. of all staggered tasks in the result there are runner up sockets aka
  153. multiple winners.
  154. """
  155. my_exceptions: List[Union[OSError, RuntimeError]] = []
  156. exceptions.append(my_exceptions)
  157. family, type_, proto, _, address = addr_info
  158. sock = None
  159. try:
  160. if socket_factory is not None:
  161. sock = socket_factory(addr_info)
  162. else:
  163. sock = socket.socket(family=family, type=type_, proto=proto)
  164. if open_sockets is not None:
  165. open_sockets.add(sock)
  166. sock.setblocking(False)
  167. if local_addr_infos is not None:
  168. for lfamily, _, _, _, laddr in local_addr_infos:
  169. # skip local addresses of different family
  170. if lfamily != family:
  171. continue
  172. try:
  173. sock.bind(laddr)
  174. break
  175. except OSError as exc:
  176. msg = (
  177. f"error while attempting to bind on "
  178. f"address {laddr!r}: "
  179. f"{(exc.strerror or '').lower()}"
  180. )
  181. exc = OSError(exc.errno, msg)
  182. my_exceptions.append(exc)
  183. else: # all bind attempts failed
  184. if my_exceptions:
  185. raise my_exceptions.pop()
  186. else:
  187. raise OSError(f"no matching local address with {family=} found")
  188. await loop.sock_connect(sock, address)
  189. return sock
  190. except (RuntimeError, OSError) as exc:
  191. my_exceptions.append(exc)
  192. if sock is not None:
  193. if open_sockets is not None:
  194. open_sockets.remove(sock)
  195. try:
  196. sock.close()
  197. except OSError as e:
  198. my_exceptions.append(e)
  199. raise
  200. raise
  201. except:
  202. if sock is not None:
  203. if open_sockets is not None:
  204. open_sockets.remove(sock)
  205. try:
  206. sock.close()
  207. except OSError as e:
  208. my_exceptions.append(e)
  209. raise
  210. raise
  211. finally:
  212. exceptions = my_exceptions = None # type: ignore[assignment]
  213. def _interleave_addrinfos(
  214. addrinfos: Sequence[AddrInfoType], first_address_family_count: int = 1
  215. ) -> List[AddrInfoType]:
  216. """Interleave list of addrinfo tuples by family."""
  217. # Group addresses by family
  218. addrinfos_by_family: collections.OrderedDict[int, List[AddrInfoType]] = (
  219. collections.OrderedDict()
  220. )
  221. for addr in addrinfos:
  222. family = addr[0]
  223. if family not in addrinfos_by_family:
  224. addrinfos_by_family[family] = []
  225. addrinfos_by_family[family].append(addr)
  226. addrinfos_lists = list(addrinfos_by_family.values())
  227. reordered: List[AddrInfoType] = []
  228. if first_address_family_count > 1:
  229. reordered.extend(addrinfos_lists[0][: first_address_family_count - 1])
  230. del addrinfos_lists[0][: first_address_family_count - 1]
  231. reordered.extend(
  232. a
  233. for a in itertools.chain.from_iterable(itertools.zip_longest(*addrinfos_lists))
  234. if a is not None
  235. )
  236. return reordered