_trio_backend.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
  2. """trio async I/O library query support"""
  3. import socket
  4. import trio
  5. import trio.socket # type: ignore
  6. import dns._asyncbackend
  7. import dns._features
  8. import dns.exception
  9. import dns.inet
  10. if not dns._features.have("trio"):
  11. raise ImportError("trio not found or too old")
  12. def _maybe_timeout(timeout):
  13. if timeout is not None:
  14. return trio.move_on_after(timeout)
  15. else:
  16. return dns._asyncbackend.NullContext()
  17. # for brevity
  18. _lltuple = dns.inet.low_level_address_tuple
  19. # pylint: disable=redefined-outer-name
  20. class DatagramSocket(dns._asyncbackend.DatagramSocket):
  21. def __init__(self, sock):
  22. super().__init__(sock.family, socket.SOCK_DGRAM)
  23. self.socket = sock
  24. async def sendto(self, what, destination, timeout):
  25. with _maybe_timeout(timeout):
  26. if destination is None:
  27. return await self.socket.send(what)
  28. else:
  29. return await self.socket.sendto(what, destination)
  30. raise dns.exception.Timeout(
  31. timeout=timeout
  32. ) # pragma: no cover lgtm[py/unreachable-statement]
  33. async def recvfrom(self, size, timeout):
  34. with _maybe_timeout(timeout):
  35. return await self.socket.recvfrom(size)
  36. raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
  37. async def close(self):
  38. self.socket.close()
  39. async def getpeername(self):
  40. return self.socket.getpeername()
  41. async def getsockname(self):
  42. return self.socket.getsockname()
  43. async def getpeercert(self, timeout):
  44. raise NotImplementedError
  45. class StreamSocket(dns._asyncbackend.StreamSocket):
  46. def __init__(self, family, stream, tls=False):
  47. super().__init__(family, socket.SOCK_STREAM)
  48. self.stream = stream
  49. self.tls = tls
  50. async def sendall(self, what, timeout):
  51. with _maybe_timeout(timeout):
  52. return await self.stream.send_all(what)
  53. raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
  54. async def recv(self, size, timeout):
  55. with _maybe_timeout(timeout):
  56. return await self.stream.receive_some(size)
  57. raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
  58. async def close(self):
  59. await self.stream.aclose()
  60. async def getpeername(self):
  61. if self.tls:
  62. return self.stream.transport_stream.socket.getpeername()
  63. else:
  64. return self.stream.socket.getpeername()
  65. async def getsockname(self):
  66. if self.tls:
  67. return self.stream.transport_stream.socket.getsockname()
  68. else:
  69. return self.stream.socket.getsockname()
  70. async def getpeercert(self, timeout):
  71. if self.tls:
  72. with _maybe_timeout(timeout):
  73. await self.stream.do_handshake()
  74. return self.stream.getpeercert()
  75. else:
  76. raise NotImplementedError
  77. if dns._features.have("doh"):
  78. import httpcore
  79. import httpcore._backends.trio
  80. import httpx
  81. _CoreAsyncNetworkBackend = httpcore.AsyncNetworkBackend
  82. _CoreTrioStream = httpcore._backends.trio.TrioStream
  83. from dns.query import _compute_times, _expiration_for_this_attempt, _remaining
  84. class _NetworkBackend(_CoreAsyncNetworkBackend):
  85. def __init__(self, resolver, local_port, bootstrap_address, family):
  86. super().__init__()
  87. self._local_port = local_port
  88. self._resolver = resolver
  89. self._bootstrap_address = bootstrap_address
  90. self._family = family
  91. async def connect_tcp(
  92. self, host, port, timeout, local_address, socket_options=None
  93. ): # pylint: disable=signature-differs
  94. addresses = []
  95. _, expiration = _compute_times(timeout)
  96. if dns.inet.is_address(host):
  97. addresses.append(host)
  98. elif self._bootstrap_address is not None:
  99. addresses.append(self._bootstrap_address)
  100. else:
  101. timeout = _remaining(expiration)
  102. family = self._family
  103. if local_address:
  104. family = dns.inet.af_for_address(local_address)
  105. answers = await self._resolver.resolve_name(
  106. host, family=family, lifetime=timeout
  107. )
  108. addresses = answers.addresses()
  109. for address in addresses:
  110. try:
  111. af = dns.inet.af_for_address(address)
  112. if local_address is not None or self._local_port != 0:
  113. source = (local_address, self._local_port)
  114. else:
  115. source = None
  116. destination = (address, port)
  117. attempt_expiration = _expiration_for_this_attempt(2.0, expiration)
  118. timeout = _remaining(attempt_expiration)
  119. sock = await Backend().make_socket(
  120. af, socket.SOCK_STREAM, 0, source, destination, timeout
  121. )
  122. return _CoreTrioStream(sock.stream)
  123. except Exception:
  124. continue
  125. raise httpcore.ConnectError
  126. async def connect_unix_socket(
  127. self, path, timeout, socket_options=None
  128. ): # pylint: disable=signature-differs
  129. raise NotImplementedError
  130. async def sleep(self, seconds): # pylint: disable=signature-differs
  131. await trio.sleep(seconds)
  132. class _HTTPTransport(httpx.AsyncHTTPTransport):
  133. def __init__(
  134. self,
  135. *args,
  136. local_port=0,
  137. bootstrap_address=None,
  138. resolver=None,
  139. family=socket.AF_UNSPEC,
  140. **kwargs,
  141. ):
  142. if resolver is None and bootstrap_address is None:
  143. # pylint: disable=import-outside-toplevel,redefined-outer-name
  144. import dns.asyncresolver
  145. resolver = dns.asyncresolver.Resolver()
  146. super().__init__(*args, **kwargs)
  147. self._pool._network_backend = _NetworkBackend(
  148. resolver, local_port, bootstrap_address, family
  149. )
  150. else:
  151. _HTTPTransport = dns._asyncbackend.NullTransport # type: ignore
  152. class Backend(dns._asyncbackend.Backend):
  153. def name(self):
  154. return "trio"
  155. async def make_socket(
  156. self,
  157. af,
  158. socktype,
  159. proto=0,
  160. source=None,
  161. destination=None,
  162. timeout=None,
  163. ssl_context=None,
  164. server_hostname=None,
  165. ):
  166. s = trio.socket.socket(af, socktype, proto)
  167. stream = None
  168. try:
  169. if source:
  170. await s.bind(_lltuple(source, af))
  171. if socktype == socket.SOCK_STREAM or destination is not None:
  172. connected = False
  173. with _maybe_timeout(timeout):
  174. await s.connect(_lltuple(destination, af))
  175. connected = True
  176. if not connected:
  177. raise dns.exception.Timeout(
  178. timeout=timeout
  179. ) # lgtm[py/unreachable-statement]
  180. except Exception: # pragma: no cover
  181. s.close()
  182. raise
  183. if socktype == socket.SOCK_DGRAM:
  184. return DatagramSocket(s)
  185. elif socktype == socket.SOCK_STREAM:
  186. stream = trio.SocketStream(s)
  187. tls = False
  188. if ssl_context:
  189. tls = True
  190. try:
  191. stream = trio.SSLStream(
  192. stream, ssl_context, server_hostname=server_hostname
  193. )
  194. except Exception: # pragma: no cover
  195. await stream.aclose()
  196. raise
  197. return StreamSocket(af, stream, tls)
  198. raise NotImplementedError(
  199. "unsupported socket " + f"type {socktype}"
  200. ) # pragma: no cover
  201. async def sleep(self, interval):
  202. await trio.sleep(interval)
  203. def get_transport_class(self):
  204. return _HTTPTransport
  205. async def wait_for(self, awaitable, timeout):
  206. with _maybe_timeout(timeout):
  207. return await awaitable
  208. raise dns.exception.Timeout(
  209. timeout=timeout
  210. ) # pragma: no cover lgtm[py/unreachable-statement]