query.py 55 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665
  1. # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
  2. # Copyright (C) 2003-2017 Nominum, Inc.
  3. #
  4. # Permission to use, copy, modify, and distribute this software and its
  5. # documentation for any purpose with or without fee is hereby granted,
  6. # provided that the above copyright notice and this permission notice
  7. # appear in all copies.
  8. #
  9. # THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
  10. # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
  11. # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
  12. # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
  13. # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
  14. # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
  15. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
  16. """Talk to a DNS server."""
  17. import base64
  18. import contextlib
  19. import enum
  20. import errno
  21. import os
  22. import os.path
  23. import random
  24. import selectors
  25. import socket
  26. import struct
  27. import time
  28. import urllib.parse
  29. from typing import Any, Dict, Optional, Tuple, Union, cast
  30. import dns._features
  31. import dns.exception
  32. import dns.inet
  33. import dns.message
  34. import dns.name
  35. import dns.quic
  36. import dns.rcode
  37. import dns.rdataclass
  38. import dns.rdatatype
  39. import dns.serial
  40. import dns.transaction
  41. import dns.tsig
  42. import dns.xfr
  43. def _remaining(expiration):
  44. if expiration is None:
  45. return None
  46. timeout = expiration - time.time()
  47. if timeout <= 0.0:
  48. raise dns.exception.Timeout
  49. return timeout
  50. def _expiration_for_this_attempt(timeout, expiration):
  51. if expiration is None:
  52. return None
  53. return min(time.time() + timeout, expiration)
  54. _have_httpx = dns._features.have("doh")
  55. if _have_httpx:
  56. import httpcore._backends.sync
  57. import httpx
  58. _CoreNetworkBackend = httpcore.NetworkBackend
  59. _CoreSyncStream = httpcore._backends.sync.SyncStream
  60. class _NetworkBackend(_CoreNetworkBackend):
  61. def __init__(self, resolver, local_port, bootstrap_address, family):
  62. super().__init__()
  63. self._local_port = local_port
  64. self._resolver = resolver
  65. self._bootstrap_address = bootstrap_address
  66. self._family = family
  67. def connect_tcp(
  68. self, host, port, timeout, local_address, socket_options=None
  69. ): # pylint: disable=signature-differs
  70. addresses = []
  71. _, expiration = _compute_times(timeout)
  72. if dns.inet.is_address(host):
  73. addresses.append(host)
  74. elif self._bootstrap_address is not None:
  75. addresses.append(self._bootstrap_address)
  76. else:
  77. timeout = _remaining(expiration)
  78. family = self._family
  79. if local_address:
  80. family = dns.inet.af_for_address(local_address)
  81. answers = self._resolver.resolve_name(
  82. host, family=family, lifetime=timeout
  83. )
  84. addresses = answers.addresses()
  85. for address in addresses:
  86. af = dns.inet.af_for_address(address)
  87. if local_address is not None or self._local_port != 0:
  88. source = dns.inet.low_level_address_tuple(
  89. (local_address, self._local_port), af
  90. )
  91. else:
  92. source = None
  93. sock = _make_socket(af, socket.SOCK_STREAM, source)
  94. attempt_expiration = _expiration_for_this_attempt(2.0, expiration)
  95. try:
  96. _connect(
  97. sock,
  98. dns.inet.low_level_address_tuple((address, port), af),
  99. attempt_expiration,
  100. )
  101. return _CoreSyncStream(sock)
  102. except Exception:
  103. pass
  104. raise httpcore.ConnectError
  105. def connect_unix_socket(
  106. self, path, timeout, socket_options=None
  107. ): # pylint: disable=signature-differs
  108. raise NotImplementedError
  109. class _HTTPTransport(httpx.HTTPTransport):
  110. def __init__(
  111. self,
  112. *args,
  113. local_port=0,
  114. bootstrap_address=None,
  115. resolver=None,
  116. family=socket.AF_UNSPEC,
  117. **kwargs,
  118. ):
  119. if resolver is None and bootstrap_address is None:
  120. # pylint: disable=import-outside-toplevel,redefined-outer-name
  121. import dns.resolver
  122. resolver = dns.resolver.Resolver()
  123. super().__init__(*args, **kwargs)
  124. self._pool._network_backend = _NetworkBackend(
  125. resolver, local_port, bootstrap_address, family
  126. )
  127. else:
  128. class _HTTPTransport: # type: ignore
  129. def connect_tcp(self, host, port, timeout, local_address):
  130. raise NotImplementedError
  131. have_doh = _have_httpx
  132. try:
  133. import ssl
  134. except ImportError: # pragma: no cover
  135. class ssl: # type: ignore
  136. CERT_NONE = 0
  137. class WantReadException(Exception):
  138. pass
  139. class WantWriteException(Exception):
  140. pass
  141. class SSLContext:
  142. pass
  143. class SSLSocket:
  144. pass
  145. @classmethod
  146. def create_default_context(cls, *args, **kwargs):
  147. raise Exception("no ssl support") # pylint: disable=broad-exception-raised
  148. # Function used to create a socket. Can be overridden if needed in special
  149. # situations.
  150. socket_factory = socket.socket
  151. class UnexpectedSource(dns.exception.DNSException):
  152. """A DNS query response came from an unexpected address or port."""
  153. class BadResponse(dns.exception.FormError):
  154. """A DNS query response does not respond to the question asked."""
  155. class NoDOH(dns.exception.DNSException):
  156. """DNS over HTTPS (DOH) was requested but the httpx module is not
  157. available."""
  158. class NoDOQ(dns.exception.DNSException):
  159. """DNS over QUIC (DOQ) was requested but the aioquic module is not
  160. available."""
  161. # for backwards compatibility
  162. TransferError = dns.xfr.TransferError
  163. def _compute_times(timeout):
  164. now = time.time()
  165. if timeout is None:
  166. return (now, None)
  167. else:
  168. return (now, now + timeout)
  169. def _wait_for(fd, readable, writable, _, expiration):
  170. # Use the selected selector class to wait for any of the specified
  171. # events. An "expiration" absolute time is converted into a relative
  172. # timeout.
  173. #
  174. # The unused parameter is 'error', which is always set when
  175. # selecting for read or write, and we have no error-only selects.
  176. if readable and isinstance(fd, ssl.SSLSocket) and fd.pending() > 0:
  177. return True
  178. sel = selectors.DefaultSelector()
  179. events = 0
  180. if readable:
  181. events |= selectors.EVENT_READ
  182. if writable:
  183. events |= selectors.EVENT_WRITE
  184. if events:
  185. sel.register(fd, events)
  186. if expiration is None:
  187. timeout = None
  188. else:
  189. timeout = expiration - time.time()
  190. if timeout <= 0.0:
  191. raise dns.exception.Timeout
  192. if not sel.select(timeout):
  193. raise dns.exception.Timeout
  194. def _wait_for_readable(s, expiration):
  195. _wait_for(s, True, False, True, expiration)
  196. def _wait_for_writable(s, expiration):
  197. _wait_for(s, False, True, True, expiration)
  198. def _addresses_equal(af, a1, a2):
  199. # Convert the first value of the tuple, which is a textual format
  200. # address into binary form, so that we are not confused by different
  201. # textual representations of the same address
  202. try:
  203. n1 = dns.inet.inet_pton(af, a1[0])
  204. n2 = dns.inet.inet_pton(af, a2[0])
  205. except dns.exception.SyntaxError:
  206. return False
  207. return n1 == n2 and a1[1:] == a2[1:]
  208. def _matches_destination(af, from_address, destination, ignore_unexpected):
  209. # Check that from_address is appropriate for a response to a query
  210. # sent to destination.
  211. if not destination:
  212. return True
  213. if _addresses_equal(af, from_address, destination) or (
  214. dns.inet.is_multicast(destination[0]) and from_address[1:] == destination[1:]
  215. ):
  216. return True
  217. elif ignore_unexpected:
  218. return False
  219. raise UnexpectedSource(
  220. f"got a response from {from_address} instead of " f"{destination}"
  221. )
  222. def _destination_and_source(
  223. where, port, source, source_port, where_must_be_address=True
  224. ):
  225. # Apply defaults and compute destination and source tuples
  226. # suitable for use in connect(), sendto(), or bind().
  227. af = None
  228. destination = None
  229. try:
  230. af = dns.inet.af_for_address(where)
  231. destination = where
  232. except Exception:
  233. if where_must_be_address:
  234. raise
  235. # URLs are ok so eat the exception
  236. if source:
  237. saf = dns.inet.af_for_address(source)
  238. if af:
  239. # We know the destination af, so source had better agree!
  240. if saf != af:
  241. raise ValueError(
  242. "different address families for source and destination"
  243. )
  244. else:
  245. # We didn't know the destination af, but we know the source,
  246. # so that's our af.
  247. af = saf
  248. if source_port and not source:
  249. # Caller has specified a source_port but not an address, so we
  250. # need to return a source, and we need to use the appropriate
  251. # wildcard address as the address.
  252. try:
  253. source = dns.inet.any_for_af(af)
  254. except Exception:
  255. # we catch this and raise ValueError for backwards compatibility
  256. raise ValueError("source_port specified but address family is unknown")
  257. # Convert high-level (address, port) tuples into low-level address
  258. # tuples.
  259. if destination:
  260. destination = dns.inet.low_level_address_tuple((destination, port), af)
  261. if source:
  262. source = dns.inet.low_level_address_tuple((source, source_port), af)
  263. return (af, destination, source)
  264. def _make_socket(af, type, source, ssl_context=None, server_hostname=None):
  265. s = socket_factory(af, type)
  266. try:
  267. s.setblocking(False)
  268. if source is not None:
  269. s.bind(source)
  270. if ssl_context:
  271. # LGTM gets a false positive here, as our default context is OK
  272. return ssl_context.wrap_socket(
  273. s,
  274. do_handshake_on_connect=False, # lgtm[py/insecure-protocol]
  275. server_hostname=server_hostname,
  276. )
  277. else:
  278. return s
  279. except Exception:
  280. s.close()
  281. raise
  282. def _maybe_get_resolver(
  283. resolver: Optional["dns.resolver.Resolver"],
  284. ) -> "dns.resolver.Resolver":
  285. # We need a separate method for this to avoid overriding the global
  286. # variable "dns" with the as-yet undefined local variable "dns"
  287. # in https().
  288. if resolver is None:
  289. # pylint: disable=import-outside-toplevel,redefined-outer-name
  290. import dns.resolver
  291. resolver = dns.resolver.Resolver()
  292. return resolver
  293. class HTTPVersion(enum.IntEnum):
  294. """Which version of HTTP should be used?
  295. DEFAULT will select the first version from the list [2, 1.1, 3] that
  296. is available.
  297. """
  298. DEFAULT = 0
  299. HTTP_1 = 1
  300. H1 = 1
  301. HTTP_2 = 2
  302. H2 = 2
  303. HTTP_3 = 3
  304. H3 = 3
  305. def https(
  306. q: dns.message.Message,
  307. where: str,
  308. timeout: Optional[float] = None,
  309. port: int = 443,
  310. source: Optional[str] = None,
  311. source_port: int = 0,
  312. one_rr_per_rrset: bool = False,
  313. ignore_trailing: bool = False,
  314. session: Optional[Any] = None,
  315. path: str = "/dns-query",
  316. post: bool = True,
  317. bootstrap_address: Optional[str] = None,
  318. verify: Union[bool, str] = True,
  319. resolver: Optional["dns.resolver.Resolver"] = None,
  320. family: int = socket.AF_UNSPEC,
  321. http_version: HTTPVersion = HTTPVersion.DEFAULT,
  322. ) -> dns.message.Message:
  323. """Return the response obtained after sending a query via DNS-over-HTTPS.
  324. *q*, a ``dns.message.Message``, the query to send.
  325. *where*, a ``str``, the nameserver IP address or the full URL. If an IP address is
  326. given, the URL will be constructed using the following schema:
  327. https://<IP-address>:<port>/<path>.
  328. *timeout*, a ``float`` or ``None``, the number of seconds to wait before the query
  329. times out. If ``None``, the default, wait forever.
  330. *port*, a ``int``, the port to send the query to. The default is 443.
  331. *source*, a ``str`` containing an IPv4 or IPv6 address, specifying the source
  332. address. The default is the wildcard address.
  333. *source_port*, an ``int``, the port from which to send the message. The default is
  334. 0.
  335. *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own RRset.
  336. *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the
  337. received message.
  338. *session*, an ``httpx.Client``. If provided, the client session to use to send the
  339. queries.
  340. *path*, a ``str``. If *where* is an IP address, then *path* will be used to
  341. construct the URL to send the DNS query to.
  342. *post*, a ``bool``. If ``True``, the default, POST method will be used.
  343. *bootstrap_address*, a ``str``, the IP address to use to bypass resolution.
  344. *verify*, a ``bool`` or ``str``. If a ``True``, then TLS certificate verification
  345. of the server is done using the default CA bundle; if ``False``, then no
  346. verification is done; if a `str` then it specifies the path to a certificate file or
  347. directory which will be used for verification.
  348. *resolver*, a ``dns.resolver.Resolver`` or ``None``, the resolver to use for
  349. resolution of hostnames in URLs. If not specified, a new resolver with a default
  350. configuration will be used; note this is *not* the default resolver as that resolver
  351. might have been configured to use DoH causing a chicken-and-egg problem. This
  352. parameter only has an effect if the HTTP library is httpx.
  353. *family*, an ``int``, the address family. If socket.AF_UNSPEC (the default), both A
  354. and AAAA records will be retrieved.
  355. *http_version*, a ``dns.query.HTTPVersion``, indicating which HTTP version to use.
  356. Returns a ``dns.message.Message``.
  357. """
  358. (af, _, the_source) = _destination_and_source(
  359. where, port, source, source_port, False
  360. )
  361. if af is not None and dns.inet.is_address(where):
  362. if af == socket.AF_INET:
  363. url = f"https://{where}:{port}{path}"
  364. elif af == socket.AF_INET6:
  365. url = f"https://[{where}]:{port}{path}"
  366. else:
  367. url = where
  368. extensions = {}
  369. if bootstrap_address is None:
  370. # pylint: disable=possibly-used-before-assignment
  371. parsed = urllib.parse.urlparse(url)
  372. if parsed.hostname is None:
  373. raise ValueError("no hostname in URL")
  374. if dns.inet.is_address(parsed.hostname):
  375. bootstrap_address = parsed.hostname
  376. extensions["sni_hostname"] = parsed.hostname
  377. if parsed.port is not None:
  378. port = parsed.port
  379. if http_version == HTTPVersion.H3 or (
  380. http_version == HTTPVersion.DEFAULT and not have_doh
  381. ):
  382. if bootstrap_address is None:
  383. resolver = _maybe_get_resolver(resolver)
  384. assert parsed.hostname is not None # for mypy
  385. answers = resolver.resolve_name(parsed.hostname, family)
  386. bootstrap_address = random.choice(list(answers.addresses()))
  387. return _http3(
  388. q,
  389. bootstrap_address,
  390. url,
  391. timeout,
  392. port,
  393. source,
  394. source_port,
  395. one_rr_per_rrset,
  396. ignore_trailing,
  397. verify=verify,
  398. post=post,
  399. )
  400. if not have_doh:
  401. raise NoDOH # pragma: no cover
  402. if session and not isinstance(session, httpx.Client):
  403. raise ValueError("session parameter must be an httpx.Client")
  404. wire = q.to_wire()
  405. headers = {"accept": "application/dns-message"}
  406. h1 = http_version in (HTTPVersion.H1, HTTPVersion.DEFAULT)
  407. h2 = http_version in (HTTPVersion.H2, HTTPVersion.DEFAULT)
  408. # set source port and source address
  409. if the_source is None:
  410. local_address = None
  411. local_port = 0
  412. else:
  413. local_address = the_source[0]
  414. local_port = the_source[1]
  415. if session:
  416. cm: contextlib.AbstractContextManager = contextlib.nullcontext(session)
  417. else:
  418. transport = _HTTPTransport(
  419. local_address=local_address,
  420. http1=h1,
  421. http2=h2,
  422. verify=verify,
  423. local_port=local_port,
  424. bootstrap_address=bootstrap_address,
  425. resolver=resolver,
  426. family=family,
  427. )
  428. cm = httpx.Client(http1=h1, http2=h2, verify=verify, transport=transport)
  429. with cm as session:
  430. # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
  431. # GET and POST examples
  432. if post:
  433. headers.update(
  434. {
  435. "content-type": "application/dns-message",
  436. "content-length": str(len(wire)),
  437. }
  438. )
  439. response = session.post(
  440. url,
  441. headers=headers,
  442. content=wire,
  443. timeout=timeout,
  444. extensions=extensions,
  445. )
  446. else:
  447. wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
  448. twire = wire.decode() # httpx does a repr() if we give it bytes
  449. response = session.get(
  450. url,
  451. headers=headers,
  452. timeout=timeout,
  453. params={"dns": twire},
  454. extensions=extensions,
  455. )
  456. # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
  457. # status codes
  458. if response.status_code < 200 or response.status_code > 299:
  459. raise ValueError(
  460. f"{where} responded with status code {response.status_code}"
  461. f"\nResponse body: {response.content}"
  462. )
  463. r = dns.message.from_wire(
  464. response.content,
  465. keyring=q.keyring,
  466. request_mac=q.request_mac,
  467. one_rr_per_rrset=one_rr_per_rrset,
  468. ignore_trailing=ignore_trailing,
  469. )
  470. r.time = response.elapsed.total_seconds()
  471. if not q.is_response(r):
  472. raise BadResponse
  473. return r
  474. def _find_header(headers: dns.quic.Headers, name: bytes) -> bytes:
  475. if headers is None:
  476. raise KeyError
  477. for header, value in headers:
  478. if header == name:
  479. return value
  480. raise KeyError
  481. def _check_status(headers: dns.quic.Headers, peer: str, wire: bytes) -> None:
  482. value = _find_header(headers, b":status")
  483. if value is None:
  484. raise SyntaxError("no :status header in response")
  485. status = int(value)
  486. if status < 0:
  487. raise SyntaxError("status is negative")
  488. if status < 200 or status > 299:
  489. error = ""
  490. if len(wire) > 0:
  491. try:
  492. error = ": " + wire.decode()
  493. except Exception:
  494. pass
  495. raise ValueError(f"{peer} responded with status code {status}{error}")
  496. def _http3(
  497. q: dns.message.Message,
  498. where: str,
  499. url: str,
  500. timeout: Optional[float] = None,
  501. port: int = 853,
  502. source: Optional[str] = None,
  503. source_port: int = 0,
  504. one_rr_per_rrset: bool = False,
  505. ignore_trailing: bool = False,
  506. verify: Union[bool, str] = True,
  507. hostname: Optional[str] = None,
  508. post: bool = True,
  509. ) -> dns.message.Message:
  510. if not dns.quic.have_quic:
  511. raise NoDOH("DNS-over-HTTP3 is not available.") # pragma: no cover
  512. url_parts = urllib.parse.urlparse(url)
  513. hostname = url_parts.hostname
  514. if url_parts.port is not None:
  515. port = url_parts.port
  516. q.id = 0
  517. wire = q.to_wire()
  518. manager = dns.quic.SyncQuicManager(
  519. verify_mode=verify, server_name=hostname, h3=True
  520. )
  521. with manager:
  522. connection = manager.connect(where, port, source, source_port)
  523. (start, expiration) = _compute_times(timeout)
  524. with connection.make_stream(timeout) as stream:
  525. stream.send_h3(url, wire, post)
  526. wire = stream.receive(_remaining(expiration))
  527. _check_status(stream.headers(), where, wire)
  528. finish = time.time()
  529. r = dns.message.from_wire(
  530. wire,
  531. keyring=q.keyring,
  532. request_mac=q.request_mac,
  533. one_rr_per_rrset=one_rr_per_rrset,
  534. ignore_trailing=ignore_trailing,
  535. )
  536. r.time = max(finish - start, 0.0)
  537. if not q.is_response(r):
  538. raise BadResponse
  539. return r
  540. def _udp_recv(sock, max_size, expiration):
  541. """Reads a datagram from the socket.
  542. A Timeout exception will be raised if the operation is not completed
  543. by the expiration time.
  544. """
  545. while True:
  546. try:
  547. return sock.recvfrom(max_size)
  548. except BlockingIOError:
  549. _wait_for_readable(sock, expiration)
  550. def _udp_send(sock, data, destination, expiration):
  551. """Sends the specified datagram to destination over the socket.
  552. A Timeout exception will be raised if the operation is not completed
  553. by the expiration time.
  554. """
  555. while True:
  556. try:
  557. if destination:
  558. return sock.sendto(data, destination)
  559. else:
  560. return sock.send(data)
  561. except BlockingIOError: # pragma: no cover
  562. _wait_for_writable(sock, expiration)
  563. def send_udp(
  564. sock: Any,
  565. what: Union[dns.message.Message, bytes],
  566. destination: Any,
  567. expiration: Optional[float] = None,
  568. ) -> Tuple[int, float]:
  569. """Send a DNS message to the specified UDP socket.
  570. *sock*, a ``socket``.
  571. *what*, a ``bytes`` or ``dns.message.Message``, the message to send.
  572. *destination*, a destination tuple appropriate for the address family
  573. of the socket, specifying where to send the query.
  574. *expiration*, a ``float`` or ``None``, the absolute time at which
  575. a timeout exception should be raised. If ``None``, no timeout will
  576. occur.
  577. Returns an ``(int, float)`` tuple of bytes sent and the sent time.
  578. """
  579. if isinstance(what, dns.message.Message):
  580. what = what.to_wire()
  581. sent_time = time.time()
  582. n = _udp_send(sock, what, destination, expiration)
  583. return (n, sent_time)
  584. def receive_udp(
  585. sock: Any,
  586. destination: Optional[Any] = None,
  587. expiration: Optional[float] = None,
  588. ignore_unexpected: bool = False,
  589. one_rr_per_rrset: bool = False,
  590. keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]] = None,
  591. request_mac: Optional[bytes] = b"",
  592. ignore_trailing: bool = False,
  593. raise_on_truncation: bool = False,
  594. ignore_errors: bool = False,
  595. query: Optional[dns.message.Message] = None,
  596. ) -> Any:
  597. """Read a DNS message from a UDP socket.
  598. *sock*, a ``socket``.
  599. *destination*, a destination tuple appropriate for the address family
  600. of the socket, specifying where the message is expected to arrive from.
  601. When receiving a response, this would be where the associated query was
  602. sent.
  603. *expiration*, a ``float`` or ``None``, the absolute time at which
  604. a timeout exception should be raised. If ``None``, no timeout will
  605. occur.
  606. *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from
  607. unexpected sources.
  608. *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
  609. RRset.
  610. *keyring*, a ``dict``, the keyring to use for TSIG.
  611. *request_mac*, a ``bytes`` or ``None``, the MAC of the request (for TSIG).
  612. *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
  613. junk at end of the received message.
  614. *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if
  615. the TC bit is set.
  616. Raises if the message is malformed, if network errors occur, of if
  617. there is a timeout.
  618. If *destination* is not ``None``, returns a ``(dns.message.Message, float)``
  619. tuple of the received message and the received time.
  620. If *destination* is ``None``, returns a
  621. ``(dns.message.Message, float, tuple)``
  622. tuple of the received message, the received time, and the address where
  623. the message arrived from.
  624. *ignore_errors*, a ``bool``. If various format errors or response
  625. mismatches occur, ignore them and keep listening for a valid response.
  626. The default is ``False``.
  627. *query*, a ``dns.message.Message`` or ``None``. If not ``None`` and
  628. *ignore_errors* is ``True``, check that the received message is a response
  629. to this query, and if not keep listening for a valid response.
  630. """
  631. wire = b""
  632. while True:
  633. (wire, from_address) = _udp_recv(sock, 65535, expiration)
  634. if not _matches_destination(
  635. sock.family, from_address, destination, ignore_unexpected
  636. ):
  637. continue
  638. received_time = time.time()
  639. try:
  640. r = dns.message.from_wire(
  641. wire,
  642. keyring=keyring,
  643. request_mac=request_mac,
  644. one_rr_per_rrset=one_rr_per_rrset,
  645. ignore_trailing=ignore_trailing,
  646. raise_on_truncation=raise_on_truncation,
  647. )
  648. except dns.message.Truncated as e:
  649. # If we got Truncated and not FORMERR, we at least got the header with TC
  650. # set, and very likely the question section, so we'll re-raise if the
  651. # message seems to be a response as we need to know when truncation happens.
  652. # We need to check that it seems to be a response as we don't want a random
  653. # injected message with TC set to cause us to bail out.
  654. if (
  655. ignore_errors
  656. and query is not None
  657. and not query.is_response(e.message())
  658. ):
  659. continue
  660. else:
  661. raise
  662. except Exception:
  663. if ignore_errors:
  664. continue
  665. else:
  666. raise
  667. if ignore_errors and query is not None and not query.is_response(r):
  668. continue
  669. if destination:
  670. return (r, received_time)
  671. else:
  672. return (r, received_time, from_address)
  673. def udp(
  674. q: dns.message.Message,
  675. where: str,
  676. timeout: Optional[float] = None,
  677. port: int = 53,
  678. source: Optional[str] = None,
  679. source_port: int = 0,
  680. ignore_unexpected: bool = False,
  681. one_rr_per_rrset: bool = False,
  682. ignore_trailing: bool = False,
  683. raise_on_truncation: bool = False,
  684. sock: Optional[Any] = None,
  685. ignore_errors: bool = False,
  686. ) -> dns.message.Message:
  687. """Return the response obtained after sending a query via UDP.
  688. *q*, a ``dns.message.Message``, the query to send
  689. *where*, a ``str`` containing an IPv4 or IPv6 address, where
  690. to send the message.
  691. *timeout*, a ``float`` or ``None``, the number of seconds to wait before the
  692. query times out. If ``None``, the default, wait forever.
  693. *port*, an ``int``, the port send the message to. The default is 53.
  694. *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
  695. the source address. The default is the wildcard address.
  696. *source_port*, an ``int``, the port from which to send the message.
  697. The default is 0.
  698. *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from
  699. unexpected sources.
  700. *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
  701. RRset.
  702. *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
  703. junk at end of the received message.
  704. *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if
  705. the TC bit is set.
  706. *sock*, a ``socket.socket``, or ``None``, the socket to use for the
  707. query. If ``None``, the default, a socket is created. Note that
  708. if a socket is provided, it must be a nonblocking datagram socket,
  709. and the *source* and *source_port* are ignored.
  710. *ignore_errors*, a ``bool``. If various format errors or response
  711. mismatches occur, ignore them and keep listening for a valid response.
  712. The default is ``False``.
  713. Returns a ``dns.message.Message``.
  714. """
  715. wire = q.to_wire()
  716. (af, destination, source) = _destination_and_source(
  717. where, port, source, source_port
  718. )
  719. (begin_time, expiration) = _compute_times(timeout)
  720. if sock:
  721. cm: contextlib.AbstractContextManager = contextlib.nullcontext(sock)
  722. else:
  723. cm = _make_socket(af, socket.SOCK_DGRAM, source)
  724. with cm as s:
  725. send_udp(s, wire, destination, expiration)
  726. (r, received_time) = receive_udp(
  727. s,
  728. destination,
  729. expiration,
  730. ignore_unexpected,
  731. one_rr_per_rrset,
  732. q.keyring,
  733. q.mac,
  734. ignore_trailing,
  735. raise_on_truncation,
  736. ignore_errors,
  737. q,
  738. )
  739. r.time = received_time - begin_time
  740. # We don't need to check q.is_response() if we are in ignore_errors mode
  741. # as receive_udp() will have checked it.
  742. if not (ignore_errors or q.is_response(r)):
  743. raise BadResponse
  744. return r
  745. assert (
  746. False # help mypy figure out we can't get here lgtm[py/unreachable-statement]
  747. )
  748. def udp_with_fallback(
  749. q: dns.message.Message,
  750. where: str,
  751. timeout: Optional[float] = None,
  752. port: int = 53,
  753. source: Optional[str] = None,
  754. source_port: int = 0,
  755. ignore_unexpected: bool = False,
  756. one_rr_per_rrset: bool = False,
  757. ignore_trailing: bool = False,
  758. udp_sock: Optional[Any] = None,
  759. tcp_sock: Optional[Any] = None,
  760. ignore_errors: bool = False,
  761. ) -> Tuple[dns.message.Message, bool]:
  762. """Return the response to the query, trying UDP first and falling back
  763. to TCP if UDP results in a truncated response.
  764. *q*, a ``dns.message.Message``, the query to send
  765. *where*, a ``str`` containing an IPv4 or IPv6 address, where to send the message.
  766. *timeout*, a ``float`` or ``None``, the number of seconds to wait before the query
  767. times out. If ``None``, the default, wait forever.
  768. *port*, an ``int``, the port send the message to. The default is 53.
  769. *source*, a ``str`` containing an IPv4 or IPv6 address, specifying the source
  770. address. The default is the wildcard address.
  771. *source_port*, an ``int``, the port from which to send the message. The default is
  772. 0.
  773. *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from unexpected
  774. sources.
  775. *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own RRset.
  776. *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the
  777. received message.
  778. *udp_sock*, a ``socket.socket``, or ``None``, the socket to use for the UDP query.
  779. If ``None``, the default, a socket is created. Note that if a socket is provided,
  780. it must be a nonblocking datagram socket, and the *source* and *source_port* are
  781. ignored for the UDP query.
  782. *tcp_sock*, a ``socket.socket``, or ``None``, the connected socket to use for the
  783. TCP query. If ``None``, the default, a socket is created. Note that if a socket is
  784. provided, it must be a nonblocking connected stream socket, and *where*, *source*
  785. and *source_port* are ignored for the TCP query.
  786. *ignore_errors*, a ``bool``. If various format errors or response mismatches occur
  787. while listening for UDP, ignore them and keep listening for a valid response. The
  788. default is ``False``.
  789. Returns a (``dns.message.Message``, tcp) tuple where tcp is ``True`` if and only if
  790. TCP was used.
  791. """
  792. try:
  793. response = udp(
  794. q,
  795. where,
  796. timeout,
  797. port,
  798. source,
  799. source_port,
  800. ignore_unexpected,
  801. one_rr_per_rrset,
  802. ignore_trailing,
  803. True,
  804. udp_sock,
  805. ignore_errors,
  806. )
  807. return (response, False)
  808. except dns.message.Truncated:
  809. response = tcp(
  810. q,
  811. where,
  812. timeout,
  813. port,
  814. source,
  815. source_port,
  816. one_rr_per_rrset,
  817. ignore_trailing,
  818. tcp_sock,
  819. )
  820. return (response, True)
  821. def _net_read(sock, count, expiration):
  822. """Read the specified number of bytes from sock. Keep trying until we
  823. either get the desired amount, or we hit EOF.
  824. A Timeout exception will be raised if the operation is not completed
  825. by the expiration time.
  826. """
  827. s = b""
  828. while count > 0:
  829. try:
  830. n = sock.recv(count)
  831. if n == b"":
  832. raise EOFError("EOF")
  833. count -= len(n)
  834. s += n
  835. except (BlockingIOError, ssl.SSLWantReadError):
  836. _wait_for_readable(sock, expiration)
  837. except ssl.SSLWantWriteError: # pragma: no cover
  838. _wait_for_writable(sock, expiration)
  839. return s
  840. def _net_write(sock, data, expiration):
  841. """Write the specified data to the socket.
  842. A Timeout exception will be raised if the operation is not completed
  843. by the expiration time.
  844. """
  845. current = 0
  846. l = len(data)
  847. while current < l:
  848. try:
  849. current += sock.send(data[current:])
  850. except (BlockingIOError, ssl.SSLWantWriteError):
  851. _wait_for_writable(sock, expiration)
  852. except ssl.SSLWantReadError: # pragma: no cover
  853. _wait_for_readable(sock, expiration)
  854. def send_tcp(
  855. sock: Any,
  856. what: Union[dns.message.Message, bytes],
  857. expiration: Optional[float] = None,
  858. ) -> Tuple[int, float]:
  859. """Send a DNS message to the specified TCP socket.
  860. *sock*, a ``socket``.
  861. *what*, a ``bytes`` or ``dns.message.Message``, the message to send.
  862. *expiration*, a ``float`` or ``None``, the absolute time at which
  863. a timeout exception should be raised. If ``None``, no timeout will
  864. occur.
  865. Returns an ``(int, float)`` tuple of bytes sent and the sent time.
  866. """
  867. if isinstance(what, dns.message.Message):
  868. tcpmsg = what.to_wire(prepend_length=True)
  869. else:
  870. # copying the wire into tcpmsg is inefficient, but lets us
  871. # avoid writev() or doing a short write that would get pushed
  872. # onto the net
  873. tcpmsg = len(what).to_bytes(2, "big") + what
  874. sent_time = time.time()
  875. _net_write(sock, tcpmsg, expiration)
  876. return (len(tcpmsg), sent_time)
  877. def receive_tcp(
  878. sock: Any,
  879. expiration: Optional[float] = None,
  880. one_rr_per_rrset: bool = False,
  881. keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]] = None,
  882. request_mac: Optional[bytes] = b"",
  883. ignore_trailing: bool = False,
  884. ) -> Tuple[dns.message.Message, float]:
  885. """Read a DNS message from a TCP socket.
  886. *sock*, a ``socket``.
  887. *expiration*, a ``float`` or ``None``, the absolute time at which
  888. a timeout exception should be raised. If ``None``, no timeout will
  889. occur.
  890. *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
  891. RRset.
  892. *keyring*, a ``dict``, the keyring to use for TSIG.
  893. *request_mac*, a ``bytes`` or ``None``, the MAC of the request (for TSIG).
  894. *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
  895. junk at end of the received message.
  896. Raises if the message is malformed, if network errors occur, of if
  897. there is a timeout.
  898. Returns a ``(dns.message.Message, float)`` tuple of the received message
  899. and the received time.
  900. """
  901. ldata = _net_read(sock, 2, expiration)
  902. (l,) = struct.unpack("!H", ldata)
  903. wire = _net_read(sock, l, expiration)
  904. received_time = time.time()
  905. r = dns.message.from_wire(
  906. wire,
  907. keyring=keyring,
  908. request_mac=request_mac,
  909. one_rr_per_rrset=one_rr_per_rrset,
  910. ignore_trailing=ignore_trailing,
  911. )
  912. return (r, received_time)
  913. def _connect(s, address, expiration):
  914. err = s.connect_ex(address)
  915. if err == 0:
  916. return
  917. if err in (errno.EINPROGRESS, errno.EWOULDBLOCK, errno.EALREADY):
  918. _wait_for_writable(s, expiration)
  919. err = s.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
  920. if err != 0:
  921. raise OSError(err, os.strerror(err))
  922. def tcp(
  923. q: dns.message.Message,
  924. where: str,
  925. timeout: Optional[float] = None,
  926. port: int = 53,
  927. source: Optional[str] = None,
  928. source_port: int = 0,
  929. one_rr_per_rrset: bool = False,
  930. ignore_trailing: bool = False,
  931. sock: Optional[Any] = None,
  932. ) -> dns.message.Message:
  933. """Return the response obtained after sending a query via TCP.
  934. *q*, a ``dns.message.Message``, the query to send
  935. *where*, a ``str`` containing an IPv4 or IPv6 address, where
  936. to send the message.
  937. *timeout*, a ``float`` or ``None``, the number of seconds to wait before the
  938. query times out. If ``None``, the default, wait forever.
  939. *port*, an ``int``, the port send the message to. The default is 53.
  940. *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
  941. the source address. The default is the wildcard address.
  942. *source_port*, an ``int``, the port from which to send the message.
  943. The default is 0.
  944. *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
  945. RRset.
  946. *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
  947. junk at end of the received message.
  948. *sock*, a ``socket.socket``, or ``None``, the connected socket to use for the
  949. query. If ``None``, the default, a socket is created. Note that
  950. if a socket is provided, it must be a nonblocking connected stream
  951. socket, and *where*, *port*, *source* and *source_port* are ignored.
  952. Returns a ``dns.message.Message``.
  953. """
  954. wire = q.to_wire()
  955. (begin_time, expiration) = _compute_times(timeout)
  956. if sock:
  957. cm: contextlib.AbstractContextManager = contextlib.nullcontext(sock)
  958. else:
  959. (af, destination, source) = _destination_and_source(
  960. where, port, source, source_port
  961. )
  962. cm = _make_socket(af, socket.SOCK_STREAM, source)
  963. with cm as s:
  964. if not sock:
  965. # pylint: disable=possibly-used-before-assignment
  966. _connect(s, destination, expiration)
  967. send_tcp(s, wire, expiration)
  968. (r, received_time) = receive_tcp(
  969. s, expiration, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing
  970. )
  971. r.time = received_time - begin_time
  972. if not q.is_response(r):
  973. raise BadResponse
  974. return r
  975. assert (
  976. False # help mypy figure out we can't get here lgtm[py/unreachable-statement]
  977. )
  978. def _tls_handshake(s, expiration):
  979. while True:
  980. try:
  981. s.do_handshake()
  982. return
  983. except ssl.SSLWantReadError:
  984. _wait_for_readable(s, expiration)
  985. except ssl.SSLWantWriteError: # pragma: no cover
  986. _wait_for_writable(s, expiration)
  987. def _make_dot_ssl_context(
  988. server_hostname: Optional[str], verify: Union[bool, str]
  989. ) -> ssl.SSLContext:
  990. cafile: Optional[str] = None
  991. capath: Optional[str] = None
  992. if isinstance(verify, str):
  993. if os.path.isfile(verify):
  994. cafile = verify
  995. elif os.path.isdir(verify):
  996. capath = verify
  997. else:
  998. raise ValueError("invalid verify string")
  999. ssl_context = ssl.create_default_context(cafile=cafile, capath=capath)
  1000. ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2
  1001. if server_hostname is None:
  1002. ssl_context.check_hostname = False
  1003. ssl_context.set_alpn_protocols(["dot"])
  1004. if verify is False:
  1005. ssl_context.verify_mode = ssl.CERT_NONE
  1006. return ssl_context
  1007. def tls(
  1008. q: dns.message.Message,
  1009. where: str,
  1010. timeout: Optional[float] = None,
  1011. port: int = 853,
  1012. source: Optional[str] = None,
  1013. source_port: int = 0,
  1014. one_rr_per_rrset: bool = False,
  1015. ignore_trailing: bool = False,
  1016. sock: Optional[ssl.SSLSocket] = None,
  1017. ssl_context: Optional[ssl.SSLContext] = None,
  1018. server_hostname: Optional[str] = None,
  1019. verify: Union[bool, str] = True,
  1020. ) -> dns.message.Message:
  1021. """Return the response obtained after sending a query via TLS.
  1022. *q*, a ``dns.message.Message``, the query to send
  1023. *where*, a ``str`` containing an IPv4 or IPv6 address, where
  1024. to send the message.
  1025. *timeout*, a ``float`` or ``None``, the number of seconds to wait before the
  1026. query times out. If ``None``, the default, wait forever.
  1027. *port*, an ``int``, the port send the message to. The default is 853.
  1028. *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
  1029. the source address. The default is the wildcard address.
  1030. *source_port*, an ``int``, the port from which to send the message.
  1031. The default is 0.
  1032. *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
  1033. RRset.
  1034. *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
  1035. junk at end of the received message.
  1036. *sock*, an ``ssl.SSLSocket``, or ``None``, the socket to use for
  1037. the query. If ``None``, the default, a socket is created. Note
  1038. that if a socket is provided, it must be a nonblocking connected
  1039. SSL stream socket, and *where*, *port*, *source*, *source_port*,
  1040. and *ssl_context* are ignored.
  1041. *ssl_context*, an ``ssl.SSLContext``, the context to use when establishing
  1042. a TLS connection. If ``None``, the default, creates one with the default
  1043. configuration.
  1044. *server_hostname*, a ``str`` containing the server's hostname. The
  1045. default is ``None``, which means that no hostname is known, and if an
  1046. SSL context is created, hostname checking will be disabled.
  1047. *verify*, a ``bool`` or ``str``. If a ``True``, then TLS certificate verification
  1048. of the server is done using the default CA bundle; if ``False``, then no
  1049. verification is done; if a `str` then it specifies the path to a certificate file or
  1050. directory which will be used for verification.
  1051. Returns a ``dns.message.Message``.
  1052. """
  1053. if sock:
  1054. #
  1055. # If a socket was provided, there's no special TLS handling needed.
  1056. #
  1057. return tcp(
  1058. q,
  1059. where,
  1060. timeout,
  1061. port,
  1062. source,
  1063. source_port,
  1064. one_rr_per_rrset,
  1065. ignore_trailing,
  1066. sock,
  1067. )
  1068. wire = q.to_wire()
  1069. (begin_time, expiration) = _compute_times(timeout)
  1070. (af, destination, source) = _destination_and_source(
  1071. where, port, source, source_port
  1072. )
  1073. if ssl_context is None and not sock:
  1074. ssl_context = _make_dot_ssl_context(server_hostname, verify)
  1075. with _make_socket(
  1076. af,
  1077. socket.SOCK_STREAM,
  1078. source,
  1079. ssl_context=ssl_context,
  1080. server_hostname=server_hostname,
  1081. ) as s:
  1082. _connect(s, destination, expiration)
  1083. _tls_handshake(s, expiration)
  1084. send_tcp(s, wire, expiration)
  1085. (r, received_time) = receive_tcp(
  1086. s, expiration, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing
  1087. )
  1088. r.time = received_time - begin_time
  1089. if not q.is_response(r):
  1090. raise BadResponse
  1091. return r
  1092. assert (
  1093. False # help mypy figure out we can't get here lgtm[py/unreachable-statement]
  1094. )
  1095. def quic(
  1096. q: dns.message.Message,
  1097. where: str,
  1098. timeout: Optional[float] = None,
  1099. port: int = 853,
  1100. source: Optional[str] = None,
  1101. source_port: int = 0,
  1102. one_rr_per_rrset: bool = False,
  1103. ignore_trailing: bool = False,
  1104. connection: Optional[dns.quic.SyncQuicConnection] = None,
  1105. verify: Union[bool, str] = True,
  1106. hostname: Optional[str] = None,
  1107. server_hostname: Optional[str] = None,
  1108. ) -> dns.message.Message:
  1109. """Return the response obtained after sending a query via DNS-over-QUIC.
  1110. *q*, a ``dns.message.Message``, the query to send.
  1111. *where*, a ``str``, the nameserver IP address.
  1112. *timeout*, a ``float`` or ``None``, the number of seconds to wait before the query
  1113. times out. If ``None``, the default, wait forever.
  1114. *port*, a ``int``, the port to send the query to. The default is 853.
  1115. *source*, a ``str`` containing an IPv4 or IPv6 address, specifying the source
  1116. address. The default is the wildcard address.
  1117. *source_port*, an ``int``, the port from which to send the message. The default is
  1118. 0.
  1119. *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own RRset.
  1120. *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the
  1121. received message.
  1122. *connection*, a ``dns.quic.SyncQuicConnection``. If provided, the connection to use
  1123. to send the query.
  1124. *verify*, a ``bool`` or ``str``. If a ``True``, then TLS certificate verification
  1125. of the server is done using the default CA bundle; if ``False``, then no
  1126. verification is done; if a `str` then it specifies the path to a certificate file or
  1127. directory which will be used for verification.
  1128. *hostname*, a ``str`` containing the server's hostname or ``None``. The default is
  1129. ``None``, which means that no hostname is known, and if an SSL context is created,
  1130. hostname checking will be disabled. This value is ignored if *url* is not
  1131. ``None``.
  1132. *server_hostname*, a ``str`` or ``None``. This item is for backwards compatibility
  1133. only, and has the same meaning as *hostname*.
  1134. Returns a ``dns.message.Message``.
  1135. """
  1136. if not dns.quic.have_quic:
  1137. raise NoDOQ("DNS-over-QUIC is not available.") # pragma: no cover
  1138. if server_hostname is not None and hostname is None:
  1139. hostname = server_hostname
  1140. q.id = 0
  1141. wire = q.to_wire()
  1142. the_connection: dns.quic.SyncQuicConnection
  1143. the_manager: dns.quic.SyncQuicManager
  1144. if connection:
  1145. manager: contextlib.AbstractContextManager = contextlib.nullcontext(None)
  1146. the_connection = connection
  1147. else:
  1148. manager = dns.quic.SyncQuicManager(verify_mode=verify, server_name=hostname)
  1149. the_manager = manager # for type checking happiness
  1150. with manager:
  1151. if not connection:
  1152. the_connection = the_manager.connect(where, port, source, source_port)
  1153. (start, expiration) = _compute_times(timeout)
  1154. with the_connection.make_stream(timeout) as stream:
  1155. stream.send(wire, True)
  1156. wire = stream.receive(_remaining(expiration))
  1157. finish = time.time()
  1158. r = dns.message.from_wire(
  1159. wire,
  1160. keyring=q.keyring,
  1161. request_mac=q.request_mac,
  1162. one_rr_per_rrset=one_rr_per_rrset,
  1163. ignore_trailing=ignore_trailing,
  1164. )
  1165. r.time = max(finish - start, 0.0)
  1166. if not q.is_response(r):
  1167. raise BadResponse
  1168. return r
  1169. class UDPMode(enum.IntEnum):
  1170. """How should UDP be used in an IXFR from :py:func:`inbound_xfr()`?
  1171. NEVER means "never use UDP; always use TCP"
  1172. TRY_FIRST means "try to use UDP but fall back to TCP if needed"
  1173. ONLY means "raise ``dns.xfr.UseTCP`` if trying UDP does not succeed"
  1174. """
  1175. NEVER = 0
  1176. TRY_FIRST = 1
  1177. ONLY = 2
  1178. def _inbound_xfr(
  1179. txn_manager: dns.transaction.TransactionManager,
  1180. s: socket.socket,
  1181. query: dns.message.Message,
  1182. serial: Optional[int],
  1183. timeout: Optional[float],
  1184. expiration: float,
  1185. ) -> Any:
  1186. """Given a socket, does the zone transfer."""
  1187. rdtype = query.question[0].rdtype
  1188. is_ixfr = rdtype == dns.rdatatype.IXFR
  1189. origin = txn_manager.from_wire_origin()
  1190. wire = query.to_wire()
  1191. is_udp = s.type == socket.SOCK_DGRAM
  1192. if is_udp:
  1193. _udp_send(s, wire, None, expiration)
  1194. else:
  1195. tcpmsg = struct.pack("!H", len(wire)) + wire
  1196. _net_write(s, tcpmsg, expiration)
  1197. with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
  1198. done = False
  1199. tsig_ctx = None
  1200. while not done:
  1201. (_, mexpiration) = _compute_times(timeout)
  1202. if mexpiration is None or (
  1203. expiration is not None and mexpiration > expiration
  1204. ):
  1205. mexpiration = expiration
  1206. if is_udp:
  1207. (rwire, _) = _udp_recv(s, 65535, mexpiration)
  1208. else:
  1209. ldata = _net_read(s, 2, mexpiration)
  1210. (l,) = struct.unpack("!H", ldata)
  1211. rwire = _net_read(s, l, mexpiration)
  1212. r = dns.message.from_wire(
  1213. rwire,
  1214. keyring=query.keyring,
  1215. request_mac=query.mac,
  1216. xfr=True,
  1217. origin=origin,
  1218. tsig_ctx=tsig_ctx,
  1219. multi=(not is_udp),
  1220. one_rr_per_rrset=is_ixfr,
  1221. )
  1222. done = inbound.process_message(r)
  1223. yield r
  1224. tsig_ctx = r.tsig_ctx
  1225. if query.keyring and not r.had_tsig:
  1226. raise dns.exception.FormError("missing TSIG")
  1227. def xfr(
  1228. where: str,
  1229. zone: Union[dns.name.Name, str],
  1230. rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.AXFR,
  1231. rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN,
  1232. timeout: Optional[float] = None,
  1233. port: int = 53,
  1234. keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]] = None,
  1235. keyname: Optional[Union[dns.name.Name, str]] = None,
  1236. relativize: bool = True,
  1237. lifetime: Optional[float] = None,
  1238. source: Optional[str] = None,
  1239. source_port: int = 0,
  1240. serial: int = 0,
  1241. use_udp: bool = False,
  1242. keyalgorithm: Union[dns.name.Name, str] = dns.tsig.default_algorithm,
  1243. ) -> Any:
  1244. """Return a generator for the responses to a zone transfer.
  1245. *where*, a ``str`` containing an IPv4 or IPv6 address, where
  1246. to send the message.
  1247. *zone*, a ``dns.name.Name`` or ``str``, the name of the zone to transfer.
  1248. *rdtype*, an ``int`` or ``str``, the type of zone transfer. The
  1249. default is ``dns.rdatatype.AXFR``. ``dns.rdatatype.IXFR`` can be
  1250. used to do an incremental transfer instead.
  1251. *rdclass*, an ``int`` or ``str``, the class of the zone transfer.
  1252. The default is ``dns.rdataclass.IN``.
  1253. *timeout*, a ``float``, the number of seconds to wait for each
  1254. response message. If None, the default, wait forever.
  1255. *port*, an ``int``, the port send the message to. The default is 53.
  1256. *keyring*, a ``dict``, the keyring to use for TSIG.
  1257. *keyname*, a ``dns.name.Name`` or ``str``, the name of the TSIG
  1258. key to use.
  1259. *relativize*, a ``bool``. If ``True``, all names in the zone will be
  1260. relativized to the zone origin. It is essential that the
  1261. relativize setting matches the one specified to
  1262. ``dns.zone.from_xfr()`` if using this generator to make a zone.
  1263. *lifetime*, a ``float``, the total number of seconds to spend
  1264. doing the transfer. If ``None``, the default, then there is no
  1265. limit on the time the transfer may take.
  1266. *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
  1267. the source address. The default is the wildcard address.
  1268. *source_port*, an ``int``, the port from which to send the message.
  1269. The default is 0.
  1270. *serial*, an ``int``, the SOA serial number to use as the base for
  1271. an IXFR diff sequence (only meaningful if *rdtype* is
  1272. ``dns.rdatatype.IXFR``).
  1273. *use_udp*, a ``bool``. If ``True``, use UDP (only meaningful for IXFR).
  1274. *keyalgorithm*, a ``dns.name.Name`` or ``str``, the TSIG algorithm to use.
  1275. Raises on errors, and so does the generator.
  1276. Returns a generator of ``dns.message.Message`` objects.
  1277. """
  1278. class DummyTransactionManager(dns.transaction.TransactionManager):
  1279. def __init__(self, origin, relativize):
  1280. self.info = (origin, relativize, dns.name.empty if relativize else origin)
  1281. def origin_information(self):
  1282. return self.info
  1283. def get_class(self) -> dns.rdataclass.RdataClass:
  1284. raise NotImplementedError # pragma: no cover
  1285. def reader(self):
  1286. raise NotImplementedError # pragma: no cover
  1287. def writer(self, replacement: bool = False) -> dns.transaction.Transaction:
  1288. class DummyTransaction:
  1289. def nop(self, *args, **kw):
  1290. pass
  1291. def __getattr__(self, _):
  1292. return self.nop
  1293. return cast(dns.transaction.Transaction, DummyTransaction())
  1294. if isinstance(zone, str):
  1295. zone = dns.name.from_text(zone)
  1296. rdtype = dns.rdatatype.RdataType.make(rdtype)
  1297. q = dns.message.make_query(zone, rdtype, rdclass)
  1298. if rdtype == dns.rdatatype.IXFR:
  1299. rrset = q.find_rrset(
  1300. q.authority, zone, dns.rdataclass.IN, dns.rdatatype.SOA, create=True
  1301. )
  1302. soa = dns.rdata.from_text("IN", "SOA", ". . %u 0 0 0 0" % serial)
  1303. rrset.add(soa, 0)
  1304. if keyring is not None:
  1305. q.use_tsig(keyring, keyname, algorithm=keyalgorithm)
  1306. (af, destination, source) = _destination_and_source(
  1307. where, port, source, source_port
  1308. )
  1309. (_, expiration) = _compute_times(lifetime)
  1310. tm = DummyTransactionManager(zone, relativize)
  1311. if use_udp and rdtype != dns.rdatatype.IXFR:
  1312. raise ValueError("cannot do a UDP AXFR")
  1313. sock_type = socket.SOCK_DGRAM if use_udp else socket.SOCK_STREAM
  1314. with _make_socket(af, sock_type, source) as s:
  1315. _connect(s, destination, expiration)
  1316. yield from _inbound_xfr(tm, s, q, serial, timeout, expiration)
  1317. def inbound_xfr(
  1318. where: str,
  1319. txn_manager: dns.transaction.TransactionManager,
  1320. query: Optional[dns.message.Message] = None,
  1321. port: int = 53,
  1322. timeout: Optional[float] = None,
  1323. lifetime: Optional[float] = None,
  1324. source: Optional[str] = None,
  1325. source_port: int = 0,
  1326. udp_mode: UDPMode = UDPMode.NEVER,
  1327. ) -> None:
  1328. """Conduct an inbound transfer and apply it via a transaction from the
  1329. txn_manager.
  1330. *where*, a ``str`` containing an IPv4 or IPv6 address, where
  1331. to send the message.
  1332. *txn_manager*, a ``dns.transaction.TransactionManager``, the txn_manager
  1333. for this transfer (typically a ``dns.zone.Zone``).
  1334. *query*, the query to send. If not supplied, a default query is
  1335. constructed using information from the *txn_manager*.
  1336. *port*, an ``int``, the port send the message to. The default is 53.
  1337. *timeout*, a ``float``, the number of seconds to wait for each
  1338. response message. If None, the default, wait forever.
  1339. *lifetime*, a ``float``, the total number of seconds to spend
  1340. doing the transfer. If ``None``, the default, then there is no
  1341. limit on the time the transfer may take.
  1342. *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
  1343. the source address. The default is the wildcard address.
  1344. *source_port*, an ``int``, the port from which to send the message.
  1345. The default is 0.
  1346. *udp_mode*, a ``dns.query.UDPMode``, determines how UDP is used
  1347. for IXFRs. The default is ``dns.UDPMode.NEVER``, i.e. only use
  1348. TCP. Other possibilities are ``dns.UDPMode.TRY_FIRST``, which
  1349. means "try UDP but fallback to TCP if needed", and
  1350. ``dns.UDPMode.ONLY``, which means "try UDP and raise
  1351. ``dns.xfr.UseTCP`` if it does not succeed.
  1352. Raises on errors.
  1353. """
  1354. if query is None:
  1355. (query, serial) = dns.xfr.make_query(txn_manager)
  1356. else:
  1357. serial = dns.xfr.extract_serial_from_query(query)
  1358. (af, destination, source) = _destination_and_source(
  1359. where, port, source, source_port
  1360. )
  1361. (_, expiration) = _compute_times(lifetime)
  1362. if query.question[0].rdtype == dns.rdatatype.IXFR and udp_mode != UDPMode.NEVER:
  1363. with _make_socket(af, socket.SOCK_DGRAM, source) as s:
  1364. _connect(s, destination, expiration)
  1365. try:
  1366. for _ in _inbound_xfr(
  1367. txn_manager, s, query, serial, timeout, expiration
  1368. ):
  1369. pass
  1370. return
  1371. except dns.xfr.UseTCP:
  1372. if udp_mode == UDPMode.ONLY:
  1373. raise
  1374. with _make_socket(af, socket.SOCK_STREAM, source) as s:
  1375. _connect(s, destination, expiration)
  1376. for _ in _inbound_xfr(txn_manager, s, query, serial, timeout, expiration):
  1377. pass