123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363 |
- from typing import Optional, Union
- from urllib.parse import urlparse
- import dns.asyncbackend
- import dns.asyncquery
- import dns.inet
- import dns.message
- import dns.query
- class Nameserver:
- def __init__(self):
- pass
- def __str__(self):
- raise NotImplementedError
- def kind(self) -> str:
- raise NotImplementedError
- def is_always_max_size(self) -> bool:
- raise NotImplementedError
- def answer_nameserver(self) -> str:
- raise NotImplementedError
- def answer_port(self) -> int:
- raise NotImplementedError
- def query(
- self,
- request: dns.message.QueryMessage,
- timeout: float,
- source: Optional[str],
- source_port: int,
- max_size: bool,
- one_rr_per_rrset: bool = False,
- ignore_trailing: bool = False,
- ) -> dns.message.Message:
- raise NotImplementedError
- async def async_query(
- self,
- request: dns.message.QueryMessage,
- timeout: float,
- source: Optional[str],
- source_port: int,
- max_size: bool,
- backend: dns.asyncbackend.Backend,
- one_rr_per_rrset: bool = False,
- ignore_trailing: bool = False,
- ) -> dns.message.Message:
- raise NotImplementedError
- class AddressAndPortNameserver(Nameserver):
- def __init__(self, address: str, port: int):
- super().__init__()
- self.address = address
- self.port = port
- def kind(self) -> str:
- raise NotImplementedError
- def is_always_max_size(self) -> bool:
- return False
- def __str__(self):
- ns_kind = self.kind()
- return f"{ns_kind}:{self.address}@{self.port}"
- def answer_nameserver(self) -> str:
- return self.address
- def answer_port(self) -> int:
- return self.port
- class Do53Nameserver(AddressAndPortNameserver):
- def __init__(self, address: str, port: int = 53):
- super().__init__(address, port)
- def kind(self):
- return "Do53"
- def query(
- self,
- request: dns.message.QueryMessage,
- timeout: float,
- source: Optional[str],
- source_port: int,
- max_size: bool,
- one_rr_per_rrset: bool = False,
- ignore_trailing: bool = False,
- ) -> dns.message.Message:
- if max_size:
- response = dns.query.tcp(
- request,
- self.address,
- timeout=timeout,
- port=self.port,
- source=source,
- source_port=source_port,
- one_rr_per_rrset=one_rr_per_rrset,
- ignore_trailing=ignore_trailing,
- )
- else:
- response = dns.query.udp(
- request,
- self.address,
- timeout=timeout,
- port=self.port,
- source=source,
- source_port=source_port,
- raise_on_truncation=True,
- one_rr_per_rrset=one_rr_per_rrset,
- ignore_trailing=ignore_trailing,
- ignore_errors=True,
- ignore_unexpected=True,
- )
- return response
- async def async_query(
- self,
- request: dns.message.QueryMessage,
- timeout: float,
- source: Optional[str],
- source_port: int,
- max_size: bool,
- backend: dns.asyncbackend.Backend,
- one_rr_per_rrset: bool = False,
- ignore_trailing: bool = False,
- ) -> dns.message.Message:
- if max_size:
- response = await dns.asyncquery.tcp(
- request,
- self.address,
- timeout=timeout,
- port=self.port,
- source=source,
- source_port=source_port,
- backend=backend,
- one_rr_per_rrset=one_rr_per_rrset,
- ignore_trailing=ignore_trailing,
- )
- else:
- response = await dns.asyncquery.udp(
- request,
- self.address,
- timeout=timeout,
- port=self.port,
- source=source,
- source_port=source_port,
- raise_on_truncation=True,
- backend=backend,
- one_rr_per_rrset=one_rr_per_rrset,
- ignore_trailing=ignore_trailing,
- ignore_errors=True,
- ignore_unexpected=True,
- )
- return response
- class DoHNameserver(Nameserver):
- def __init__(
- self,
- url: str,
- bootstrap_address: Optional[str] = None,
- verify: Union[bool, str] = True,
- want_get: bool = False,
- http_version: dns.query.HTTPVersion = dns.query.HTTPVersion.DEFAULT,
- ):
- super().__init__()
- self.url = url
- self.bootstrap_address = bootstrap_address
- self.verify = verify
- self.want_get = want_get
- self.http_version = http_version
- def kind(self):
- return "DoH"
- def is_always_max_size(self) -> bool:
- return True
- def __str__(self):
- return self.url
- def answer_nameserver(self) -> str:
- return self.url
- def answer_port(self) -> int:
- port = urlparse(self.url).port
- if port is None:
- port = 443
- return port
- def query(
- self,
- request: dns.message.QueryMessage,
- timeout: float,
- source: Optional[str],
- source_port: int,
- max_size: bool = False,
- one_rr_per_rrset: bool = False,
- ignore_trailing: bool = False,
- ) -> dns.message.Message:
- return dns.query.https(
- request,
- self.url,
- timeout=timeout,
- source=source,
- source_port=source_port,
- bootstrap_address=self.bootstrap_address,
- one_rr_per_rrset=one_rr_per_rrset,
- ignore_trailing=ignore_trailing,
- verify=self.verify,
- post=(not self.want_get),
- http_version=self.http_version,
- )
- async def async_query(
- self,
- request: dns.message.QueryMessage,
- timeout: float,
- source: Optional[str],
- source_port: int,
- max_size: bool,
- backend: dns.asyncbackend.Backend,
- one_rr_per_rrset: bool = False,
- ignore_trailing: bool = False,
- ) -> dns.message.Message:
- return await dns.asyncquery.https(
- request,
- self.url,
- timeout=timeout,
- source=source,
- source_port=source_port,
- bootstrap_address=self.bootstrap_address,
- one_rr_per_rrset=one_rr_per_rrset,
- ignore_trailing=ignore_trailing,
- verify=self.verify,
- post=(not self.want_get),
- http_version=self.http_version,
- )
- class DoTNameserver(AddressAndPortNameserver):
- def __init__(
- self,
- address: str,
- port: int = 853,
- hostname: Optional[str] = None,
- verify: Union[bool, str] = True,
- ):
- super().__init__(address, port)
- self.hostname = hostname
- self.verify = verify
- def kind(self):
- return "DoT"
- def query(
- self,
- request: dns.message.QueryMessage,
- timeout: float,
- source: Optional[str],
- source_port: int,
- max_size: bool = False,
- one_rr_per_rrset: bool = False,
- ignore_trailing: bool = False,
- ) -> dns.message.Message:
- return dns.query.tls(
- request,
- self.address,
- port=self.port,
- timeout=timeout,
- one_rr_per_rrset=one_rr_per_rrset,
- ignore_trailing=ignore_trailing,
- server_hostname=self.hostname,
- verify=self.verify,
- )
- async def async_query(
- self,
- request: dns.message.QueryMessage,
- timeout: float,
- source: Optional[str],
- source_port: int,
- max_size: bool,
- backend: dns.asyncbackend.Backend,
- one_rr_per_rrset: bool = False,
- ignore_trailing: bool = False,
- ) -> dns.message.Message:
- return await dns.asyncquery.tls(
- request,
- self.address,
- port=self.port,
- timeout=timeout,
- one_rr_per_rrset=one_rr_per_rrset,
- ignore_trailing=ignore_trailing,
- server_hostname=self.hostname,
- verify=self.verify,
- )
- class DoQNameserver(AddressAndPortNameserver):
- def __init__(
- self,
- address: str,
- port: int = 853,
- verify: Union[bool, str] = True,
- server_hostname: Optional[str] = None,
- ):
- super().__init__(address, port)
- self.verify = verify
- self.server_hostname = server_hostname
- def kind(self):
- return "DoQ"
- def query(
- self,
- request: dns.message.QueryMessage,
- timeout: float,
- source: Optional[str],
- source_port: int,
- max_size: bool = False,
- one_rr_per_rrset: bool = False,
- ignore_trailing: bool = False,
- ) -> dns.message.Message:
- return dns.query.quic(
- request,
- self.address,
- port=self.port,
- timeout=timeout,
- one_rr_per_rrset=one_rr_per_rrset,
- ignore_trailing=ignore_trailing,
- verify=self.verify,
- server_hostname=self.server_hostname,
- )
- async def async_query(
- self,
- request: dns.message.QueryMessage,
- timeout: float,
- source: Optional[str],
- source_port: int,
- max_size: bool,
- backend: dns.asyncbackend.Backend,
- one_rr_per_rrset: bool = False,
- ignore_trailing: bool = False,
- ) -> dns.message.Message:
- return await dns.asyncquery.quic(
- request,
- self.address,
- port=self.port,
- timeout=timeout,
- one_rr_per_rrset=one_rr_per_rrset,
- ignore_trailing=ignore_trailing,
- verify=self.verify,
- server_hostname=self.server_hostname,
- )
|