_ddr.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
  2. #
  3. # Support for Discovery of Designated Resolvers
  4. import socket
  5. import time
  6. from urllib.parse import urlparse
  7. import dns.asyncbackend
  8. import dns.inet
  9. import dns.name
  10. import dns.nameserver
  11. import dns.query
  12. import dns.rdtypes.svcbbase
  13. # The special name of the local resolver when using DDR
  14. _local_resolver_name = dns.name.from_text("_dns.resolver.arpa")
  15. #
  16. # Processing is split up into I/O independent and I/O dependent parts to
  17. # make supporting sync and async versions easy.
  18. #
  19. class _SVCBInfo:
  20. def __init__(self, bootstrap_address, port, hostname, nameservers):
  21. self.bootstrap_address = bootstrap_address
  22. self.port = port
  23. self.hostname = hostname
  24. self.nameservers = nameservers
  25. def ddr_check_certificate(self, cert):
  26. """Verify that the _SVCBInfo's address is in the cert's subjectAltName (SAN)"""
  27. for name, value in cert["subjectAltName"]:
  28. if name == "IP Address" and value == self.bootstrap_address:
  29. return True
  30. return False
  31. def make_tls_context(self):
  32. ssl = dns.query.ssl
  33. ctx = ssl.create_default_context()
  34. ctx.minimum_version = ssl.TLSVersion.TLSv1_2
  35. return ctx
  36. def ddr_tls_check_sync(self, lifetime):
  37. ctx = self.make_tls_context()
  38. expiration = time.time() + lifetime
  39. with socket.create_connection(
  40. (self.bootstrap_address, self.port), lifetime
  41. ) as s:
  42. with ctx.wrap_socket(s, server_hostname=self.hostname) as ts:
  43. ts.settimeout(dns.query._remaining(expiration))
  44. ts.do_handshake()
  45. cert = ts.getpeercert()
  46. return self.ddr_check_certificate(cert)
  47. async def ddr_tls_check_async(self, lifetime, backend=None):
  48. if backend is None:
  49. backend = dns.asyncbackend.get_default_backend()
  50. ctx = self.make_tls_context()
  51. expiration = time.time() + lifetime
  52. async with await backend.make_socket(
  53. dns.inet.af_for_address(self.bootstrap_address),
  54. socket.SOCK_STREAM,
  55. 0,
  56. None,
  57. (self.bootstrap_address, self.port),
  58. lifetime,
  59. ctx,
  60. self.hostname,
  61. ) as ts:
  62. cert = await ts.getpeercert(dns.query._remaining(expiration))
  63. return self.ddr_check_certificate(cert)
  64. def _extract_nameservers_from_svcb(answer):
  65. bootstrap_address = answer.nameserver
  66. if not dns.inet.is_address(bootstrap_address):
  67. return []
  68. infos = []
  69. for rr in answer.rrset.processing_order():
  70. nameservers = []
  71. param = rr.params.get(dns.rdtypes.svcbbase.ParamKey.ALPN)
  72. if param is None:
  73. continue
  74. alpns = set(param.ids)
  75. host = rr.target.to_text(omit_final_dot=True)
  76. port = None
  77. param = rr.params.get(dns.rdtypes.svcbbase.ParamKey.PORT)
  78. if param is not None:
  79. port = param.port
  80. # For now we ignore address hints and address resolution and always use the
  81. # bootstrap address
  82. if b"h2" in alpns:
  83. param = rr.params.get(dns.rdtypes.svcbbase.ParamKey.DOHPATH)
  84. if param is None or not param.value.endswith(b"{?dns}"):
  85. continue
  86. path = param.value[:-6].decode()
  87. if not path.startswith("/"):
  88. path = "/" + path
  89. if port is None:
  90. port = 443
  91. url = f"https://{host}:{port}{path}"
  92. # check the URL
  93. try:
  94. urlparse(url)
  95. nameservers.append(dns.nameserver.DoHNameserver(url, bootstrap_address))
  96. except Exception:
  97. # continue processing other ALPN types
  98. pass
  99. if b"dot" in alpns:
  100. if port is None:
  101. port = 853
  102. nameservers.append(
  103. dns.nameserver.DoTNameserver(bootstrap_address, port, host)
  104. )
  105. if b"doq" in alpns:
  106. if port is None:
  107. port = 853
  108. nameservers.append(
  109. dns.nameserver.DoQNameserver(bootstrap_address, port, True, host)
  110. )
  111. if len(nameservers) > 0:
  112. infos.append(_SVCBInfo(bootstrap_address, port, host, nameservers))
  113. return infos
  114. def _get_nameservers_sync(answer, lifetime):
  115. """Return a list of TLS-validated resolver nameservers extracted from an SVCB
  116. answer."""
  117. nameservers = []
  118. infos = _extract_nameservers_from_svcb(answer)
  119. for info in infos:
  120. try:
  121. if info.ddr_tls_check_sync(lifetime):
  122. nameservers.extend(info.nameservers)
  123. except Exception:
  124. pass
  125. return nameservers
  126. async def _get_nameservers_async(answer, lifetime):
  127. """Return a list of TLS-validated resolver nameservers extracted from an SVCB
  128. answer."""
  129. nameservers = []
  130. infos = _extract_nameservers_from_svcb(answer)
  131. for info in infos:
  132. try:
  133. if await info.ddr_tls_check_async(lifetime):
  134. nameservers.extend(info.nameservers)
  135. except Exception:
  136. pass
  137. return nameservers