_sync.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
  2. import selectors
  3. import socket
  4. import ssl
  5. import struct
  6. import threading
  7. import time
  8. import aioquic.quic.configuration # type: ignore
  9. import aioquic.quic.connection # type: ignore
  10. import aioquic.quic.events # type: ignore
  11. import dns.exception
  12. import dns.inet
  13. from dns.quic._common import (
  14. QUIC_MAX_DATAGRAM,
  15. BaseQuicConnection,
  16. BaseQuicManager,
  17. BaseQuicStream,
  18. UnexpectedEOF,
  19. )
  20. # Function used to create a socket. Can be overridden if needed in special
  21. # situations.
  22. socket_factory = socket.socket
  23. class SyncQuicStream(BaseQuicStream):
  24. def __init__(self, connection, stream_id):
  25. super().__init__(connection, stream_id)
  26. self._wake_up = threading.Condition()
  27. self._lock = threading.Lock()
  28. def wait_for(self, amount, expiration):
  29. while True:
  30. timeout = self._timeout_from_expiration(expiration)
  31. with self._lock:
  32. if self._buffer.have(amount):
  33. return
  34. self._expecting = amount
  35. with self._wake_up:
  36. if not self._wake_up.wait(timeout):
  37. raise dns.exception.Timeout
  38. self._expecting = 0
  39. def wait_for_end(self, expiration):
  40. while True:
  41. timeout = self._timeout_from_expiration(expiration)
  42. with self._lock:
  43. if self._buffer.seen_end():
  44. return
  45. with self._wake_up:
  46. if not self._wake_up.wait(timeout):
  47. raise dns.exception.Timeout
  48. def receive(self, timeout=None):
  49. expiration = self._expiration_from_timeout(timeout)
  50. if self._connection.is_h3():
  51. self.wait_for_end(expiration)
  52. with self._lock:
  53. return self._buffer.get_all()
  54. else:
  55. self.wait_for(2, expiration)
  56. with self._lock:
  57. (size,) = struct.unpack("!H", self._buffer.get(2))
  58. self.wait_for(size, expiration)
  59. with self._lock:
  60. return self._buffer.get(size)
  61. def send(self, datagram, is_end=False):
  62. data = self._encapsulate(datagram)
  63. self._connection.write(self._stream_id, data, is_end)
  64. def _add_input(self, data, is_end):
  65. if self._common_add_input(data, is_end):
  66. with self._wake_up:
  67. self._wake_up.notify()
  68. def close(self):
  69. with self._lock:
  70. self._close()
  71. def __enter__(self):
  72. return self
  73. def __exit__(self, exc_type, exc_val, exc_tb):
  74. self.close()
  75. with self._wake_up:
  76. self._wake_up.notify()
  77. return False
  78. class SyncQuicConnection(BaseQuicConnection):
  79. def __init__(self, connection, address, port, source, source_port, manager):
  80. super().__init__(connection, address, port, source, source_port, manager)
  81. self._socket = socket_factory(self._af, socket.SOCK_DGRAM, 0)
  82. if self._source is not None:
  83. try:
  84. self._socket.bind(
  85. dns.inet.low_level_address_tuple(self._source, self._af)
  86. )
  87. except Exception:
  88. self._socket.close()
  89. raise
  90. self._socket.connect(self._peer)
  91. (self._send_wakeup, self._receive_wakeup) = socket.socketpair()
  92. self._receive_wakeup.setblocking(False)
  93. self._socket.setblocking(False)
  94. self._handshake_complete = threading.Event()
  95. self._worker_thread = None
  96. self._lock = threading.Lock()
  97. def _read(self):
  98. count = 0
  99. while count < 10:
  100. count += 1
  101. try:
  102. datagram = self._socket.recv(QUIC_MAX_DATAGRAM)
  103. except BlockingIOError:
  104. return
  105. with self._lock:
  106. self._connection.receive_datagram(datagram, self._peer, time.time())
  107. def _drain_wakeup(self):
  108. while True:
  109. try:
  110. self._receive_wakeup.recv(32)
  111. except BlockingIOError:
  112. return
  113. def _worker(self):
  114. try:
  115. sel = selectors.DefaultSelector()
  116. sel.register(self._socket, selectors.EVENT_READ, self._read)
  117. sel.register(self._receive_wakeup, selectors.EVENT_READ, self._drain_wakeup)
  118. while not self._done:
  119. (expiration, interval) = self._get_timer_values(False)
  120. items = sel.select(interval)
  121. for key, _ in items:
  122. key.data()
  123. with self._lock:
  124. self._handle_timer(expiration)
  125. self._handle_events()
  126. with self._lock:
  127. datagrams = self._connection.datagrams_to_send(time.time())
  128. for datagram, _ in datagrams:
  129. try:
  130. self._socket.send(datagram)
  131. except BlockingIOError:
  132. # we let QUIC handle any lossage
  133. pass
  134. finally:
  135. with self._lock:
  136. self._done = True
  137. self._socket.close()
  138. # Ensure anyone waiting for this gets woken up.
  139. self._handshake_complete.set()
  140. def _handle_events(self):
  141. while True:
  142. with self._lock:
  143. event = self._connection.next_event()
  144. if event is None:
  145. return
  146. if isinstance(event, aioquic.quic.events.StreamDataReceived):
  147. if self.is_h3():
  148. h3_events = self._h3_conn.handle_event(event)
  149. for h3_event in h3_events:
  150. if isinstance(h3_event, aioquic.h3.events.HeadersReceived):
  151. with self._lock:
  152. stream = self._streams.get(event.stream_id)
  153. if stream:
  154. if stream._headers is None:
  155. stream._headers = h3_event.headers
  156. elif stream._trailers is None:
  157. stream._trailers = h3_event.headers
  158. if h3_event.stream_ended:
  159. stream._add_input(b"", True)
  160. elif isinstance(h3_event, aioquic.h3.events.DataReceived):
  161. with self._lock:
  162. stream = self._streams.get(event.stream_id)
  163. if stream:
  164. stream._add_input(h3_event.data, h3_event.stream_ended)
  165. else:
  166. with self._lock:
  167. stream = self._streams.get(event.stream_id)
  168. if stream:
  169. stream._add_input(event.data, event.end_stream)
  170. elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
  171. self._handshake_complete.set()
  172. elif isinstance(event, aioquic.quic.events.ConnectionTerminated):
  173. with self._lock:
  174. self._done = True
  175. elif isinstance(event, aioquic.quic.events.StreamReset):
  176. with self._lock:
  177. stream = self._streams.get(event.stream_id)
  178. if stream:
  179. stream._add_input(b"", True)
  180. def write(self, stream, data, is_end=False):
  181. with self._lock:
  182. self._connection.send_stream_data(stream, data, is_end)
  183. self._send_wakeup.send(b"\x01")
  184. def send_headers(self, stream_id, headers, is_end=False):
  185. with self._lock:
  186. super().send_headers(stream_id, headers, is_end)
  187. if is_end:
  188. self._send_wakeup.send(b"\x01")
  189. def send_data(self, stream_id, data, is_end=False):
  190. with self._lock:
  191. super().send_data(stream_id, data, is_end)
  192. if is_end:
  193. self._send_wakeup.send(b"\x01")
  194. def run(self):
  195. if self._closed:
  196. return
  197. self._worker_thread = threading.Thread(target=self._worker)
  198. self._worker_thread.start()
  199. def make_stream(self, timeout=None):
  200. if not self._handshake_complete.wait(timeout):
  201. raise dns.exception.Timeout
  202. with self._lock:
  203. if self._done:
  204. raise UnexpectedEOF
  205. stream_id = self._connection.get_next_available_stream_id(False)
  206. stream = SyncQuicStream(self, stream_id)
  207. self._streams[stream_id] = stream
  208. return stream
  209. def close_stream(self, stream_id):
  210. with self._lock:
  211. super().close_stream(stream_id)
  212. def close(self):
  213. with self._lock:
  214. if self._closed:
  215. return
  216. self._manager.closed(self._peer[0], self._peer[1])
  217. self._closed = True
  218. self._connection.close()
  219. self._send_wakeup.send(b"\x01")
  220. self._worker_thread.join()
  221. class SyncQuicManager(BaseQuicManager):
  222. def __init__(
  223. self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None, h3=False
  224. ):
  225. super().__init__(conf, verify_mode, SyncQuicConnection, server_name, h3)
  226. self._lock = threading.Lock()
  227. def connect(
  228. self,
  229. address,
  230. port=853,
  231. source=None,
  232. source_port=0,
  233. want_session_ticket=True,
  234. want_token=True,
  235. ):
  236. with self._lock:
  237. (connection, start) = self._connect(
  238. address, port, source, source_port, want_session_ticket, want_token
  239. )
  240. if start:
  241. connection.run()
  242. return connection
  243. def closed(self, address, port):
  244. with self._lock:
  245. super().closed(address, port)
  246. def save_session_ticket(self, address, port, ticket):
  247. with self._lock:
  248. super().save_session_ticket(address, port, ticket)
  249. def save_token(self, address, port, token):
  250. with self._lock:
  251. super().save_token(address, port, token)
  252. def __enter__(self):
  253. return self
  254. def __exit__(self, exc_type, exc_val, exc_tb):
  255. # Copy the iterator into a list as exiting things will mutate the connections
  256. # table.
  257. connections = list(self._connections.values())
  258. for connection in connections:
  259. connection.close()
  260. return False