123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190 |
- import asyncio
- import socket
- from typing import Any, Dict, List, Optional, Tuple, Type, Union
- from .abc import AbstractResolver, ResolveResult
- __all__ = ("ThreadedResolver", "AsyncResolver", "DefaultResolver")
- try:
- import aiodns
- aiodns_default = hasattr(aiodns.DNSResolver, "getaddrinfo")
- except ImportError: # pragma: no cover
- aiodns = None # type: ignore[assignment]
- aiodns_default = False
- _NUMERIC_SOCKET_FLAGS = socket.AI_NUMERICHOST | socket.AI_NUMERICSERV
- _NAME_SOCKET_FLAGS = socket.NI_NUMERICHOST | socket.NI_NUMERICSERV
- _AI_ADDRCONFIG = socket.AI_ADDRCONFIG
- if hasattr(socket, "AI_MASK"):
- _AI_ADDRCONFIG &= socket.AI_MASK
- class ThreadedResolver(AbstractResolver):
- """Threaded resolver.
- Uses an Executor for synchronous getaddrinfo() calls.
- concurrent.futures.ThreadPoolExecutor is used by default.
- """
- def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
- self._loop = loop or asyncio.get_running_loop()
- async def resolve(
- self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET
- ) -> List[ResolveResult]:
- infos = await self._loop.getaddrinfo(
- host,
- port,
- type=socket.SOCK_STREAM,
- family=family,
- flags=_AI_ADDRCONFIG,
- )
- hosts: List[ResolveResult] = []
- for family, _, proto, _, address in infos:
- if family == socket.AF_INET6:
- if len(address) < 3:
- # IPv6 is not supported by Python build,
- # or IPv6 is not enabled in the host
- continue
- if address[3]:
- # This is essential for link-local IPv6 addresses.
- # LL IPv6 is a VERY rare case. Strictly speaking, we should use
- # getnameinfo() unconditionally, but performance makes sense.
- resolved_host, _port = await self._loop.getnameinfo(
- address, _NAME_SOCKET_FLAGS
- )
- port = int(_port)
- else:
- resolved_host, port = address[:2]
- else: # IPv4
- assert family == socket.AF_INET
- resolved_host, port = address # type: ignore[misc]
- hosts.append(
- ResolveResult(
- hostname=host,
- host=resolved_host,
- port=port,
- family=family,
- proto=proto,
- flags=_NUMERIC_SOCKET_FLAGS,
- )
- )
- return hosts
- async def close(self) -> None:
- pass
- class AsyncResolver(AbstractResolver):
- """Use the `aiodns` package to make asynchronous DNS lookups"""
- def __init__(
- self,
- loop: Optional[asyncio.AbstractEventLoop] = None,
- *args: Any,
- **kwargs: Any,
- ) -> None:
- if aiodns is None:
- raise RuntimeError("Resolver requires aiodns library")
- self._resolver = aiodns.DNSResolver(*args, **kwargs)
- if not hasattr(self._resolver, "gethostbyname"):
- # aiodns 1.1 is not available, fallback to DNSResolver.query
- self.resolve = self._resolve_with_query # type: ignore
- async def resolve(
- self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET
- ) -> List[ResolveResult]:
- try:
- resp = await self._resolver.getaddrinfo(
- host,
- port=port,
- type=socket.SOCK_STREAM,
- family=family,
- flags=_AI_ADDRCONFIG,
- )
- except aiodns.error.DNSError as exc:
- msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed"
- raise OSError(None, msg) from exc
- hosts: List[ResolveResult] = []
- for node in resp.nodes:
- address: Union[Tuple[bytes, int], Tuple[bytes, int, int, int]] = node.addr
- family = node.family
- if family == socket.AF_INET6:
- if len(address) > 3 and address[3]:
- # This is essential for link-local IPv6 addresses.
- # LL IPv6 is a VERY rare case. Strictly speaking, we should use
- # getnameinfo() unconditionally, but performance makes sense.
- result = await self._resolver.getnameinfo(
- (address[0].decode("ascii"), *address[1:]),
- _NAME_SOCKET_FLAGS,
- )
- resolved_host = result.node
- else:
- resolved_host = address[0].decode("ascii")
- port = address[1]
- else: # IPv4
- assert family == socket.AF_INET
- resolved_host = address[0].decode("ascii")
- port = address[1]
- hosts.append(
- ResolveResult(
- hostname=host,
- host=resolved_host,
- port=port,
- family=family,
- proto=0,
- flags=_NUMERIC_SOCKET_FLAGS,
- )
- )
- if not hosts:
- raise OSError(None, "DNS lookup failed")
- return hosts
- async def _resolve_with_query(
- self, host: str, port: int = 0, family: int = socket.AF_INET
- ) -> List[Dict[str, Any]]:
- if family == socket.AF_INET6:
- qtype = "AAAA"
- else:
- qtype = "A"
- try:
- resp = await self._resolver.query(host, qtype)
- except aiodns.error.DNSError as exc:
- msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed"
- raise OSError(None, msg) from exc
- hosts = []
- for rr in resp:
- hosts.append(
- {
- "hostname": host,
- "host": rr.host,
- "port": port,
- "family": family,
- "proto": 0,
- "flags": socket.AI_NUMERICHOST,
- }
- )
- if not hosts:
- raise OSError(None, "DNS lookup failed")
- return hosts
- async def close(self) -> None:
- self._resolver.cancel()
- _DefaultType = Type[Union[AsyncResolver, ThreadedResolver]]
- DefaultResolver: _DefaultType = AsyncResolver if aiodns_default else ThreadedResolver
|