resolver.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. import asyncio
  2. import socket
  3. from typing import Any, Dict, List, Optional, Tuple, Type, Union
  4. from .abc import AbstractResolver, ResolveResult
  5. __all__ = ("ThreadedResolver", "AsyncResolver", "DefaultResolver")
  6. try:
  7. import aiodns
  8. aiodns_default = hasattr(aiodns.DNSResolver, "getaddrinfo")
  9. except ImportError: # pragma: no cover
  10. aiodns = None # type: ignore[assignment]
  11. aiodns_default = False
  12. _NUMERIC_SOCKET_FLAGS = socket.AI_NUMERICHOST | socket.AI_NUMERICSERV
  13. _NAME_SOCKET_FLAGS = socket.NI_NUMERICHOST | socket.NI_NUMERICSERV
  14. _AI_ADDRCONFIG = socket.AI_ADDRCONFIG
  15. if hasattr(socket, "AI_MASK"):
  16. _AI_ADDRCONFIG &= socket.AI_MASK
  17. class ThreadedResolver(AbstractResolver):
  18. """Threaded resolver.
  19. Uses an Executor for synchronous getaddrinfo() calls.
  20. concurrent.futures.ThreadPoolExecutor is used by default.
  21. """
  22. def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
  23. self._loop = loop or asyncio.get_running_loop()
  24. async def resolve(
  25. self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET
  26. ) -> List[ResolveResult]:
  27. infos = await self._loop.getaddrinfo(
  28. host,
  29. port,
  30. type=socket.SOCK_STREAM,
  31. family=family,
  32. flags=_AI_ADDRCONFIG,
  33. )
  34. hosts: List[ResolveResult] = []
  35. for family, _, proto, _, address in infos:
  36. if family == socket.AF_INET6:
  37. if len(address) < 3:
  38. # IPv6 is not supported by Python build,
  39. # or IPv6 is not enabled in the host
  40. continue
  41. if address[3]:
  42. # This is essential for link-local IPv6 addresses.
  43. # LL IPv6 is a VERY rare case. Strictly speaking, we should use
  44. # getnameinfo() unconditionally, but performance makes sense.
  45. resolved_host, _port = await self._loop.getnameinfo(
  46. address, _NAME_SOCKET_FLAGS
  47. )
  48. port = int(_port)
  49. else:
  50. resolved_host, port = address[:2]
  51. else: # IPv4
  52. assert family == socket.AF_INET
  53. resolved_host, port = address # type: ignore[misc]
  54. hosts.append(
  55. ResolveResult(
  56. hostname=host,
  57. host=resolved_host,
  58. port=port,
  59. family=family,
  60. proto=proto,
  61. flags=_NUMERIC_SOCKET_FLAGS,
  62. )
  63. )
  64. return hosts
  65. async def close(self) -> None:
  66. pass
  67. class AsyncResolver(AbstractResolver):
  68. """Use the `aiodns` package to make asynchronous DNS lookups"""
  69. def __init__(
  70. self,
  71. loop: Optional[asyncio.AbstractEventLoop] = None,
  72. *args: Any,
  73. **kwargs: Any,
  74. ) -> None:
  75. if aiodns is None:
  76. raise RuntimeError("Resolver requires aiodns library")
  77. self._resolver = aiodns.DNSResolver(*args, **kwargs)
  78. if not hasattr(self._resolver, "gethostbyname"):
  79. # aiodns 1.1 is not available, fallback to DNSResolver.query
  80. self.resolve = self._resolve_with_query # type: ignore
  81. async def resolve(
  82. self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET
  83. ) -> List[ResolveResult]:
  84. try:
  85. resp = await self._resolver.getaddrinfo(
  86. host,
  87. port=port,
  88. type=socket.SOCK_STREAM,
  89. family=family,
  90. flags=_AI_ADDRCONFIG,
  91. )
  92. except aiodns.error.DNSError as exc:
  93. msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed"
  94. raise OSError(None, msg) from exc
  95. hosts: List[ResolveResult] = []
  96. for node in resp.nodes:
  97. address: Union[Tuple[bytes, int], Tuple[bytes, int, int, int]] = node.addr
  98. family = node.family
  99. if family == socket.AF_INET6:
  100. if len(address) > 3 and address[3]:
  101. # This is essential for link-local IPv6 addresses.
  102. # LL IPv6 is a VERY rare case. Strictly speaking, we should use
  103. # getnameinfo() unconditionally, but performance makes sense.
  104. result = await self._resolver.getnameinfo(
  105. (address[0].decode("ascii"), *address[1:]),
  106. _NAME_SOCKET_FLAGS,
  107. )
  108. resolved_host = result.node
  109. else:
  110. resolved_host = address[0].decode("ascii")
  111. port = address[1]
  112. else: # IPv4
  113. assert family == socket.AF_INET
  114. resolved_host = address[0].decode("ascii")
  115. port = address[1]
  116. hosts.append(
  117. ResolveResult(
  118. hostname=host,
  119. host=resolved_host,
  120. port=port,
  121. family=family,
  122. proto=0,
  123. flags=_NUMERIC_SOCKET_FLAGS,
  124. )
  125. )
  126. if not hosts:
  127. raise OSError(None, "DNS lookup failed")
  128. return hosts
  129. async def _resolve_with_query(
  130. self, host: str, port: int = 0, family: int = socket.AF_INET
  131. ) -> List[Dict[str, Any]]:
  132. if family == socket.AF_INET6:
  133. qtype = "AAAA"
  134. else:
  135. qtype = "A"
  136. try:
  137. resp = await self._resolver.query(host, qtype)
  138. except aiodns.error.DNSError as exc:
  139. msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed"
  140. raise OSError(None, msg) from exc
  141. hosts = []
  142. for rr in resp:
  143. hosts.append(
  144. {
  145. "hostname": host,
  146. "host": rr.host,
  147. "port": port,
  148. "family": family,
  149. "proto": 0,
  150. "flags": socket.AI_NUMERICHOST,
  151. }
  152. )
  153. if not hosts:
  154. raise OSError(None, "DNS lookup failed")
  155. return hosts
  156. async def close(self) -> None:
  157. self._resolver.cancel()
  158. _DefaultType = Type[Union[AsyncResolver, ThreadedResolver]]
  159. DefaultResolver: _DefaultType = AsyncResolver if aiodns_default else ThreadedResolver