_trio.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
  2. import socket
  3. import ssl
  4. import struct
  5. import time
  6. import aioquic.quic.configuration # type: ignore
  7. import aioquic.quic.connection # type: ignore
  8. import aioquic.quic.events # type: ignore
  9. import trio
  10. import dns.exception
  11. import dns.inet
  12. from dns._asyncbackend import NullContext
  13. from dns.quic._common import (
  14. QUIC_MAX_DATAGRAM,
  15. AsyncQuicConnection,
  16. AsyncQuicManager,
  17. BaseQuicStream,
  18. UnexpectedEOF,
  19. )
  20. class TrioQuicStream(BaseQuicStream):
  21. def __init__(self, connection, stream_id):
  22. super().__init__(connection, stream_id)
  23. self._wake_up = trio.Condition()
  24. async def wait_for(self, amount):
  25. while True:
  26. if self._buffer.have(amount):
  27. return
  28. self._expecting = amount
  29. async with self._wake_up:
  30. await self._wake_up.wait()
  31. self._expecting = 0
  32. async def wait_for_end(self):
  33. while True:
  34. if self._buffer.seen_end():
  35. return
  36. async with self._wake_up:
  37. await self._wake_up.wait()
  38. async def receive(self, timeout=None):
  39. if timeout is None:
  40. context = NullContext(None)
  41. else:
  42. context = trio.move_on_after(timeout)
  43. with context:
  44. if self._connection.is_h3():
  45. await self.wait_for_end()
  46. return self._buffer.get_all()
  47. else:
  48. await self.wait_for(2)
  49. (size,) = struct.unpack("!H", self._buffer.get(2))
  50. await self.wait_for(size)
  51. return self._buffer.get(size)
  52. raise dns.exception.Timeout
  53. async def send(self, datagram, is_end=False):
  54. data = self._encapsulate(datagram)
  55. await self._connection.write(self._stream_id, data, is_end)
  56. async def _add_input(self, data, is_end):
  57. if self._common_add_input(data, is_end):
  58. async with self._wake_up:
  59. self._wake_up.notify()
  60. async def close(self):
  61. self._close()
  62. # Streams are async context managers
  63. async def __aenter__(self):
  64. return self
  65. async def __aexit__(self, exc_type, exc_val, exc_tb):
  66. await self.close()
  67. async with self._wake_up:
  68. self._wake_up.notify()
  69. return False
  70. class TrioQuicConnection(AsyncQuicConnection):
  71. def __init__(self, connection, address, port, source, source_port, manager=None):
  72. super().__init__(connection, address, port, source, source_port, manager)
  73. self._socket = trio.socket.socket(self._af, socket.SOCK_DGRAM, 0)
  74. self._handshake_complete = trio.Event()
  75. self._run_done = trio.Event()
  76. self._worker_scope = None
  77. self._send_pending = False
  78. async def _worker(self):
  79. try:
  80. if self._source:
  81. await self._socket.bind(
  82. dns.inet.low_level_address_tuple(self._source, self._af)
  83. )
  84. await self._socket.connect(self._peer)
  85. while not self._done:
  86. (expiration, interval) = self._get_timer_values(False)
  87. if self._send_pending:
  88. # Do not block forever if sends are pending. Even though we
  89. # have a wake-up mechanism if we've already started the blocking
  90. # read, the possibility of context switching in send means that
  91. # more writes can happen while we have no wake up context, so
  92. # we need self._send_pending to avoid (effectively) a "lost wakeup"
  93. # race.
  94. interval = 0.0
  95. with trio.CancelScope(
  96. deadline=trio.current_time() + interval
  97. ) as self._worker_scope:
  98. datagram = await self._socket.recv(QUIC_MAX_DATAGRAM)
  99. self._connection.receive_datagram(datagram, self._peer, time.time())
  100. self._worker_scope = None
  101. self._handle_timer(expiration)
  102. await self._handle_events()
  103. # We clear this now, before sending anything, as sending can cause
  104. # context switches that do more sends. We want to know if that
  105. # happens so we don't block a long time on the recv() above.
  106. self._send_pending = False
  107. datagrams = self._connection.datagrams_to_send(time.time())
  108. for datagram, _ in datagrams:
  109. await self._socket.send(datagram)
  110. finally:
  111. self._done = True
  112. self._socket.close()
  113. self._handshake_complete.set()
  114. async def _handle_events(self):
  115. count = 0
  116. while True:
  117. event = self._connection.next_event()
  118. if event is None:
  119. return
  120. if isinstance(event, aioquic.quic.events.StreamDataReceived):
  121. if self.is_h3():
  122. h3_events = self._h3_conn.handle_event(event)
  123. for h3_event in h3_events:
  124. if isinstance(h3_event, aioquic.h3.events.HeadersReceived):
  125. stream = self._streams.get(event.stream_id)
  126. if stream:
  127. if stream._headers is None:
  128. stream._headers = h3_event.headers
  129. elif stream._trailers is None:
  130. stream._trailers = h3_event.headers
  131. if h3_event.stream_ended:
  132. await stream._add_input(b"", True)
  133. elif isinstance(h3_event, aioquic.h3.events.DataReceived):
  134. stream = self._streams.get(event.stream_id)
  135. if stream:
  136. await stream._add_input(
  137. h3_event.data, h3_event.stream_ended
  138. )
  139. else:
  140. stream = self._streams.get(event.stream_id)
  141. if stream:
  142. await stream._add_input(event.data, event.end_stream)
  143. elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
  144. self._handshake_complete.set()
  145. elif isinstance(event, aioquic.quic.events.ConnectionTerminated):
  146. self._done = True
  147. self._socket.close()
  148. elif isinstance(event, aioquic.quic.events.StreamReset):
  149. stream = self._streams.get(event.stream_id)
  150. if stream:
  151. await stream._add_input(b"", True)
  152. count += 1
  153. if count > 10:
  154. # yield
  155. count = 0
  156. await trio.sleep(0)
  157. async def write(self, stream, data, is_end=False):
  158. self._connection.send_stream_data(stream, data, is_end)
  159. self._send_pending = True
  160. if self._worker_scope is not None:
  161. self._worker_scope.cancel()
  162. async def run(self):
  163. if self._closed:
  164. return
  165. async with trio.open_nursery() as nursery:
  166. nursery.start_soon(self._worker)
  167. self._run_done.set()
  168. async def make_stream(self, timeout=None):
  169. if timeout is None:
  170. context = NullContext(None)
  171. else:
  172. context = trio.move_on_after(timeout)
  173. with context:
  174. await self._handshake_complete.wait()
  175. if self._done:
  176. raise UnexpectedEOF
  177. stream_id = self._connection.get_next_available_stream_id(False)
  178. stream = TrioQuicStream(self, stream_id)
  179. self._streams[stream_id] = stream
  180. return stream
  181. raise dns.exception.Timeout
  182. async def close(self):
  183. if not self._closed:
  184. self._manager.closed(self._peer[0], self._peer[1])
  185. self._closed = True
  186. self._connection.close()
  187. self._send_pending = True
  188. if self._worker_scope is not None:
  189. self._worker_scope.cancel()
  190. await self._run_done.wait()
  191. class TrioQuicManager(AsyncQuicManager):
  192. def __init__(
  193. self,
  194. nursery,
  195. conf=None,
  196. verify_mode=ssl.CERT_REQUIRED,
  197. server_name=None,
  198. h3=False,
  199. ):
  200. super().__init__(conf, verify_mode, TrioQuicConnection, server_name, h3)
  201. self._nursery = nursery
  202. def connect(
  203. self, address, port=853, source=None, source_port=0, want_session_ticket=True
  204. ):
  205. (connection, start) = self._connect(
  206. address, port, source, source_port, want_session_ticket
  207. )
  208. if start:
  209. self._nursery.start_soon(connection.run)
  210. return connection
  211. async def __aenter__(self):
  212. return self
  213. async def __aexit__(self, exc_type, exc_val, exc_tb):
  214. # Copy the iterator into a list as exiting things will mutate the connections
  215. # table.
  216. connections = list(self._connections.values())
  217. for connection in connections:
  218. await connection.close()
  219. return False