nameserver.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363
  1. from typing import Optional, Union
  2. from urllib.parse import urlparse
  3. import dns.asyncbackend
  4. import dns.asyncquery
  5. import dns.inet
  6. import dns.message
  7. import dns.query
  8. class Nameserver:
  9. def __init__(self):
  10. pass
  11. def __str__(self):
  12. raise NotImplementedError
  13. def kind(self) -> str:
  14. raise NotImplementedError
  15. def is_always_max_size(self) -> bool:
  16. raise NotImplementedError
  17. def answer_nameserver(self) -> str:
  18. raise NotImplementedError
  19. def answer_port(self) -> int:
  20. raise NotImplementedError
  21. def query(
  22. self,
  23. request: dns.message.QueryMessage,
  24. timeout: float,
  25. source: Optional[str],
  26. source_port: int,
  27. max_size: bool,
  28. one_rr_per_rrset: bool = False,
  29. ignore_trailing: bool = False,
  30. ) -> dns.message.Message:
  31. raise NotImplementedError
  32. async def async_query(
  33. self,
  34. request: dns.message.QueryMessage,
  35. timeout: float,
  36. source: Optional[str],
  37. source_port: int,
  38. max_size: bool,
  39. backend: dns.asyncbackend.Backend,
  40. one_rr_per_rrset: bool = False,
  41. ignore_trailing: bool = False,
  42. ) -> dns.message.Message:
  43. raise NotImplementedError
  44. class AddressAndPortNameserver(Nameserver):
  45. def __init__(self, address: str, port: int):
  46. super().__init__()
  47. self.address = address
  48. self.port = port
  49. def kind(self) -> str:
  50. raise NotImplementedError
  51. def is_always_max_size(self) -> bool:
  52. return False
  53. def __str__(self):
  54. ns_kind = self.kind()
  55. return f"{ns_kind}:{self.address}@{self.port}"
  56. def answer_nameserver(self) -> str:
  57. return self.address
  58. def answer_port(self) -> int:
  59. return self.port
  60. class Do53Nameserver(AddressAndPortNameserver):
  61. def __init__(self, address: str, port: int = 53):
  62. super().__init__(address, port)
  63. def kind(self):
  64. return "Do53"
  65. def query(
  66. self,
  67. request: dns.message.QueryMessage,
  68. timeout: float,
  69. source: Optional[str],
  70. source_port: int,
  71. max_size: bool,
  72. one_rr_per_rrset: bool = False,
  73. ignore_trailing: bool = False,
  74. ) -> dns.message.Message:
  75. if max_size:
  76. response = dns.query.tcp(
  77. request,
  78. self.address,
  79. timeout=timeout,
  80. port=self.port,
  81. source=source,
  82. source_port=source_port,
  83. one_rr_per_rrset=one_rr_per_rrset,
  84. ignore_trailing=ignore_trailing,
  85. )
  86. else:
  87. response = dns.query.udp(
  88. request,
  89. self.address,
  90. timeout=timeout,
  91. port=self.port,
  92. source=source,
  93. source_port=source_port,
  94. raise_on_truncation=True,
  95. one_rr_per_rrset=one_rr_per_rrset,
  96. ignore_trailing=ignore_trailing,
  97. ignore_errors=True,
  98. ignore_unexpected=True,
  99. )
  100. return response
  101. async def async_query(
  102. self,
  103. request: dns.message.QueryMessage,
  104. timeout: float,
  105. source: Optional[str],
  106. source_port: int,
  107. max_size: bool,
  108. backend: dns.asyncbackend.Backend,
  109. one_rr_per_rrset: bool = False,
  110. ignore_trailing: bool = False,
  111. ) -> dns.message.Message:
  112. if max_size:
  113. response = await dns.asyncquery.tcp(
  114. request,
  115. self.address,
  116. timeout=timeout,
  117. port=self.port,
  118. source=source,
  119. source_port=source_port,
  120. backend=backend,
  121. one_rr_per_rrset=one_rr_per_rrset,
  122. ignore_trailing=ignore_trailing,
  123. )
  124. else:
  125. response = await dns.asyncquery.udp(
  126. request,
  127. self.address,
  128. timeout=timeout,
  129. port=self.port,
  130. source=source,
  131. source_port=source_port,
  132. raise_on_truncation=True,
  133. backend=backend,
  134. one_rr_per_rrset=one_rr_per_rrset,
  135. ignore_trailing=ignore_trailing,
  136. ignore_errors=True,
  137. ignore_unexpected=True,
  138. )
  139. return response
  140. class DoHNameserver(Nameserver):
  141. def __init__(
  142. self,
  143. url: str,
  144. bootstrap_address: Optional[str] = None,
  145. verify: Union[bool, str] = True,
  146. want_get: bool = False,
  147. http_version: dns.query.HTTPVersion = dns.query.HTTPVersion.DEFAULT,
  148. ):
  149. super().__init__()
  150. self.url = url
  151. self.bootstrap_address = bootstrap_address
  152. self.verify = verify
  153. self.want_get = want_get
  154. self.http_version = http_version
  155. def kind(self):
  156. return "DoH"
  157. def is_always_max_size(self) -> bool:
  158. return True
  159. def __str__(self):
  160. return self.url
  161. def answer_nameserver(self) -> str:
  162. return self.url
  163. def answer_port(self) -> int:
  164. port = urlparse(self.url).port
  165. if port is None:
  166. port = 443
  167. return port
  168. def query(
  169. self,
  170. request: dns.message.QueryMessage,
  171. timeout: float,
  172. source: Optional[str],
  173. source_port: int,
  174. max_size: bool = False,
  175. one_rr_per_rrset: bool = False,
  176. ignore_trailing: bool = False,
  177. ) -> dns.message.Message:
  178. return dns.query.https(
  179. request,
  180. self.url,
  181. timeout=timeout,
  182. source=source,
  183. source_port=source_port,
  184. bootstrap_address=self.bootstrap_address,
  185. one_rr_per_rrset=one_rr_per_rrset,
  186. ignore_trailing=ignore_trailing,
  187. verify=self.verify,
  188. post=(not self.want_get),
  189. http_version=self.http_version,
  190. )
  191. async def async_query(
  192. self,
  193. request: dns.message.QueryMessage,
  194. timeout: float,
  195. source: Optional[str],
  196. source_port: int,
  197. max_size: bool,
  198. backend: dns.asyncbackend.Backend,
  199. one_rr_per_rrset: bool = False,
  200. ignore_trailing: bool = False,
  201. ) -> dns.message.Message:
  202. return await dns.asyncquery.https(
  203. request,
  204. self.url,
  205. timeout=timeout,
  206. source=source,
  207. source_port=source_port,
  208. bootstrap_address=self.bootstrap_address,
  209. one_rr_per_rrset=one_rr_per_rrset,
  210. ignore_trailing=ignore_trailing,
  211. verify=self.verify,
  212. post=(not self.want_get),
  213. http_version=self.http_version,
  214. )
  215. class DoTNameserver(AddressAndPortNameserver):
  216. def __init__(
  217. self,
  218. address: str,
  219. port: int = 853,
  220. hostname: Optional[str] = None,
  221. verify: Union[bool, str] = True,
  222. ):
  223. super().__init__(address, port)
  224. self.hostname = hostname
  225. self.verify = verify
  226. def kind(self):
  227. return "DoT"
  228. def query(
  229. self,
  230. request: dns.message.QueryMessage,
  231. timeout: float,
  232. source: Optional[str],
  233. source_port: int,
  234. max_size: bool = False,
  235. one_rr_per_rrset: bool = False,
  236. ignore_trailing: bool = False,
  237. ) -> dns.message.Message:
  238. return dns.query.tls(
  239. request,
  240. self.address,
  241. port=self.port,
  242. timeout=timeout,
  243. one_rr_per_rrset=one_rr_per_rrset,
  244. ignore_trailing=ignore_trailing,
  245. server_hostname=self.hostname,
  246. verify=self.verify,
  247. )
  248. async def async_query(
  249. self,
  250. request: dns.message.QueryMessage,
  251. timeout: float,
  252. source: Optional[str],
  253. source_port: int,
  254. max_size: bool,
  255. backend: dns.asyncbackend.Backend,
  256. one_rr_per_rrset: bool = False,
  257. ignore_trailing: bool = False,
  258. ) -> dns.message.Message:
  259. return await dns.asyncquery.tls(
  260. request,
  261. self.address,
  262. port=self.port,
  263. timeout=timeout,
  264. one_rr_per_rrset=one_rr_per_rrset,
  265. ignore_trailing=ignore_trailing,
  266. server_hostname=self.hostname,
  267. verify=self.verify,
  268. )
  269. class DoQNameserver(AddressAndPortNameserver):
  270. def __init__(
  271. self,
  272. address: str,
  273. port: int = 853,
  274. verify: Union[bool, str] = True,
  275. server_hostname: Optional[str] = None,
  276. ):
  277. super().__init__(address, port)
  278. self.verify = verify
  279. self.server_hostname = server_hostname
  280. def kind(self):
  281. return "DoQ"
  282. def query(
  283. self,
  284. request: dns.message.QueryMessage,
  285. timeout: float,
  286. source: Optional[str],
  287. source_port: int,
  288. max_size: bool = False,
  289. one_rr_per_rrset: bool = False,
  290. ignore_trailing: bool = False,
  291. ) -> dns.message.Message:
  292. return dns.query.quic(
  293. request,
  294. self.address,
  295. port=self.port,
  296. timeout=timeout,
  297. one_rr_per_rrset=one_rr_per_rrset,
  298. ignore_trailing=ignore_trailing,
  299. verify=self.verify,
  300. server_hostname=self.server_hostname,
  301. )
  302. async def async_query(
  303. self,
  304. request: dns.message.QueryMessage,
  305. timeout: float,
  306. source: Optional[str],
  307. source_port: int,
  308. max_size: bool,
  309. backend: dns.asyncbackend.Backend,
  310. one_rr_per_rrset: bool = False,
  311. ignore_trailing: bool = False,
  312. ) -> dns.message.Message:
  313. return await dns.asyncquery.quic(
  314. request,
  315. self.address,
  316. port=self.port,
  317. timeout=timeout,
  318. one_rr_per_rrset=one_rr_per_rrset,
  319. ignore_trailing=ignore_trailing,
  320. verify=self.verify,
  321. server_hostname=self.server_hostname,
  322. )