_common.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
  2. import base64
  3. import copy
  4. import functools
  5. import socket
  6. import struct
  7. import time
  8. import urllib
  9. from typing import Any, Optional
  10. import aioquic.h3.connection # type: ignore
  11. import aioquic.h3.events # type: ignore
  12. import aioquic.quic.configuration # type: ignore
  13. import aioquic.quic.connection # type: ignore
  14. import dns.inet
  15. QUIC_MAX_DATAGRAM = 2048
  16. MAX_SESSION_TICKETS = 8
  17. # If we hit the max sessions limit we will delete this many of the oldest connections.
  18. # The value must be a integer > 0 and <= MAX_SESSION_TICKETS.
  19. SESSIONS_TO_DELETE = MAX_SESSION_TICKETS // 4
  20. class UnexpectedEOF(Exception):
  21. pass
  22. class Buffer:
  23. def __init__(self):
  24. self._buffer = b""
  25. self._seen_end = False
  26. def put(self, data, is_end):
  27. if self._seen_end:
  28. return
  29. self._buffer += data
  30. if is_end:
  31. self._seen_end = True
  32. def have(self, amount):
  33. if len(self._buffer) >= amount:
  34. return True
  35. if self._seen_end:
  36. raise UnexpectedEOF
  37. return False
  38. def seen_end(self):
  39. return self._seen_end
  40. def get(self, amount):
  41. assert self.have(amount)
  42. data = self._buffer[:amount]
  43. self._buffer = self._buffer[amount:]
  44. return data
  45. def get_all(self):
  46. assert self.seen_end()
  47. data = self._buffer
  48. self._buffer = b""
  49. return data
  50. class BaseQuicStream:
  51. def __init__(self, connection, stream_id):
  52. self._connection = connection
  53. self._stream_id = stream_id
  54. self._buffer = Buffer()
  55. self._expecting = 0
  56. self._headers = None
  57. self._trailers = None
  58. def id(self):
  59. return self._stream_id
  60. def headers(self):
  61. return self._headers
  62. def trailers(self):
  63. return self._trailers
  64. def _expiration_from_timeout(self, timeout):
  65. if timeout is not None:
  66. expiration = time.time() + timeout
  67. else:
  68. expiration = None
  69. return expiration
  70. def _timeout_from_expiration(self, expiration):
  71. if expiration is not None:
  72. timeout = max(expiration - time.time(), 0.0)
  73. else:
  74. timeout = None
  75. return timeout
  76. # Subclass must implement receive() as sync / async and which returns a message
  77. # or raises.
  78. # Subclass must implement send() as sync / async and which takes a message and
  79. # an EOF indicator.
  80. def send_h3(self, url, datagram, post=True):
  81. if not self._connection.is_h3():
  82. raise SyntaxError("cannot send H3 to a non-H3 connection")
  83. url_parts = urllib.parse.urlparse(url)
  84. path = url_parts.path.encode()
  85. if post:
  86. method = b"POST"
  87. else:
  88. method = b"GET"
  89. path += b"?dns=" + base64.urlsafe_b64encode(datagram).rstrip(b"=")
  90. headers = [
  91. (b":method", method),
  92. (b":scheme", url_parts.scheme.encode()),
  93. (b":authority", url_parts.netloc.encode()),
  94. (b":path", path),
  95. (b"accept", b"application/dns-message"),
  96. ]
  97. if post:
  98. headers.extend(
  99. [
  100. (b"content-type", b"application/dns-message"),
  101. (b"content-length", str(len(datagram)).encode()),
  102. ]
  103. )
  104. self._connection.send_headers(self._stream_id, headers, not post)
  105. if post:
  106. self._connection.send_data(self._stream_id, datagram, True)
  107. def _encapsulate(self, datagram):
  108. if self._connection.is_h3():
  109. return datagram
  110. l = len(datagram)
  111. return struct.pack("!H", l) + datagram
  112. def _common_add_input(self, data, is_end):
  113. self._buffer.put(data, is_end)
  114. try:
  115. return (
  116. self._expecting > 0 and self._buffer.have(self._expecting)
  117. ) or self._buffer.seen_end
  118. except UnexpectedEOF:
  119. return True
  120. def _close(self):
  121. self._connection.close_stream(self._stream_id)
  122. self._buffer.put(b"", True) # send EOF in case we haven't seen it.
  123. class BaseQuicConnection:
  124. def __init__(
  125. self,
  126. connection,
  127. address,
  128. port,
  129. source=None,
  130. source_port=0,
  131. manager=None,
  132. ):
  133. self._done = False
  134. self._connection = connection
  135. self._address = address
  136. self._port = port
  137. self._closed = False
  138. self._manager = manager
  139. self._streams = {}
  140. if manager.is_h3():
  141. self._h3_conn = aioquic.h3.connection.H3Connection(connection, False)
  142. else:
  143. self._h3_conn = None
  144. self._af = dns.inet.af_for_address(address)
  145. self._peer = dns.inet.low_level_address_tuple((address, port))
  146. if source is None and source_port != 0:
  147. if self._af == socket.AF_INET:
  148. source = "0.0.0.0"
  149. elif self._af == socket.AF_INET6:
  150. source = "::"
  151. else:
  152. raise NotImplementedError
  153. if source:
  154. self._source = (source, source_port)
  155. else:
  156. self._source = None
  157. def is_h3(self):
  158. return self._h3_conn is not None
  159. def close_stream(self, stream_id):
  160. del self._streams[stream_id]
  161. def send_headers(self, stream_id, headers, is_end=False):
  162. self._h3_conn.send_headers(stream_id, headers, is_end)
  163. def send_data(self, stream_id, data, is_end=False):
  164. self._h3_conn.send_data(stream_id, data, is_end)
  165. def _get_timer_values(self, closed_is_special=True):
  166. now = time.time()
  167. expiration = self._connection.get_timer()
  168. if expiration is None:
  169. expiration = now + 3600 # arbitrary "big" value
  170. interval = max(expiration - now, 0)
  171. if self._closed and closed_is_special:
  172. # lower sleep interval to avoid a race in the closing process
  173. # which can lead to higher latency closing due to sleeping when
  174. # we have events.
  175. interval = min(interval, 0.05)
  176. return (expiration, interval)
  177. def _handle_timer(self, expiration):
  178. now = time.time()
  179. if expiration <= now:
  180. self._connection.handle_timer(now)
  181. class AsyncQuicConnection(BaseQuicConnection):
  182. async def make_stream(self, timeout: Optional[float] = None) -> Any:
  183. pass
  184. class BaseQuicManager:
  185. def __init__(
  186. self, conf, verify_mode, connection_factory, server_name=None, h3=False
  187. ):
  188. self._connections = {}
  189. self._connection_factory = connection_factory
  190. self._session_tickets = {}
  191. self._tokens = {}
  192. self._h3 = h3
  193. if conf is None:
  194. verify_path = None
  195. if isinstance(verify_mode, str):
  196. verify_path = verify_mode
  197. verify_mode = True
  198. if h3:
  199. alpn_protocols = ["h3"]
  200. else:
  201. alpn_protocols = ["doq", "doq-i03"]
  202. conf = aioquic.quic.configuration.QuicConfiguration(
  203. alpn_protocols=alpn_protocols,
  204. verify_mode=verify_mode,
  205. server_name=server_name,
  206. )
  207. if verify_path is not None:
  208. conf.load_verify_locations(verify_path)
  209. self._conf = conf
  210. def _connect(
  211. self,
  212. address,
  213. port=853,
  214. source=None,
  215. source_port=0,
  216. want_session_ticket=True,
  217. want_token=True,
  218. ):
  219. connection = self._connections.get((address, port))
  220. if connection is not None:
  221. return (connection, False)
  222. conf = self._conf
  223. if want_session_ticket:
  224. try:
  225. session_ticket = self._session_tickets.pop((address, port))
  226. # We found a session ticket, so make a configuration that uses it.
  227. conf = copy.copy(conf)
  228. conf.session_ticket = session_ticket
  229. except KeyError:
  230. # No session ticket.
  231. pass
  232. # Whether or not we found a session ticket, we want a handler to save
  233. # one.
  234. session_ticket_handler = functools.partial(
  235. self.save_session_ticket, address, port
  236. )
  237. else:
  238. session_ticket_handler = None
  239. if want_token:
  240. try:
  241. token = self._tokens.pop((address, port))
  242. # We found a token, so make a configuration that uses it.
  243. conf = copy.copy(conf)
  244. conf.token = token
  245. except KeyError:
  246. # No token
  247. pass
  248. # Whether or not we found a token, we want a handler to save # one.
  249. token_handler = functools.partial(self.save_token, address, port)
  250. else:
  251. token_handler = None
  252. qconn = aioquic.quic.connection.QuicConnection(
  253. configuration=conf,
  254. session_ticket_handler=session_ticket_handler,
  255. token_handler=token_handler,
  256. )
  257. lladdress = dns.inet.low_level_address_tuple((address, port))
  258. qconn.connect(lladdress, time.time())
  259. connection = self._connection_factory(
  260. qconn, address, port, source, source_port, self
  261. )
  262. self._connections[(address, port)] = connection
  263. return (connection, True)
  264. def closed(self, address, port):
  265. try:
  266. del self._connections[(address, port)]
  267. except KeyError:
  268. pass
  269. def is_h3(self):
  270. return self._h3
  271. def save_session_ticket(self, address, port, ticket):
  272. # We rely on dictionaries keys() being in insertion order here. We
  273. # can't just popitem() as that would be LIFO which is the opposite of
  274. # what we want.
  275. l = len(self._session_tickets)
  276. if l >= MAX_SESSION_TICKETS:
  277. keys_to_delete = list(self._session_tickets.keys())[0:SESSIONS_TO_DELETE]
  278. for key in keys_to_delete:
  279. del self._session_tickets[key]
  280. self._session_tickets[(address, port)] = ticket
  281. def save_token(self, address, port, token):
  282. # We rely on dictionaries keys() being in insertion order here. We
  283. # can't just popitem() as that would be LIFO which is the opposite of
  284. # what we want.
  285. l = len(self._tokens)
  286. if l >= MAX_SESSION_TICKETS:
  287. keys_to_delete = list(self._tokens.keys())[0:SESSIONS_TO_DELETE]
  288. for key in keys_to_delete:
  289. del self._tokens[key]
  290. self._tokens[(address, port)] = token
  291. class AsyncQuicManager(BaseQuicManager):
  292. def connect(self, address, port=853, source=None, source_port=0):
  293. raise NotImplementedError