transaction.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649
  1. # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
  2. import collections
  3. from typing import Any, Callable, Iterator, List, Optional, Tuple, Union
  4. import dns.exception
  5. import dns.name
  6. import dns.node
  7. import dns.rdataclass
  8. import dns.rdataset
  9. import dns.rdatatype
  10. import dns.rrset
  11. import dns.serial
  12. import dns.ttl
  13. class TransactionManager:
  14. def reader(self) -> "Transaction":
  15. """Begin a read-only transaction."""
  16. raise NotImplementedError # pragma: no cover
  17. def writer(self, replacement: bool = False) -> "Transaction":
  18. """Begin a writable transaction.
  19. *replacement*, a ``bool``. If `True`, the content of the
  20. transaction completely replaces any prior content. If False,
  21. the default, then the content of the transaction updates the
  22. existing content.
  23. """
  24. raise NotImplementedError # pragma: no cover
  25. def origin_information(
  26. self,
  27. ) -> Tuple[Optional[dns.name.Name], bool, Optional[dns.name.Name]]:
  28. """Returns a tuple
  29. (absolute_origin, relativize, effective_origin)
  30. giving the absolute name of the default origin for any
  31. relative domain names, the "effective origin", and whether
  32. names should be relativized. The "effective origin" is the
  33. absolute origin if relativize is False, and the empty name if
  34. relativize is true. (The effective origin is provided even
  35. though it can be computed from the absolute_origin and
  36. relativize setting because it avoids a lot of code
  37. duplication.)
  38. If the returned names are `None`, then no origin information is
  39. available.
  40. This information is used by code working with transactions to
  41. allow it to coordinate relativization. The transaction code
  42. itself takes what it gets (i.e. does not change name
  43. relativity).
  44. """
  45. raise NotImplementedError # pragma: no cover
  46. def get_class(self) -> dns.rdataclass.RdataClass:
  47. """The class of the transaction manager."""
  48. raise NotImplementedError # pragma: no cover
  49. def from_wire_origin(self) -> Optional[dns.name.Name]:
  50. """Origin to use in from_wire() calls."""
  51. (absolute_origin, relativize, _) = self.origin_information()
  52. if relativize:
  53. return absolute_origin
  54. else:
  55. return None
  56. class DeleteNotExact(dns.exception.DNSException):
  57. """Existing data did not match data specified by an exact delete."""
  58. class ReadOnly(dns.exception.DNSException):
  59. """Tried to write to a read-only transaction."""
  60. class AlreadyEnded(dns.exception.DNSException):
  61. """Tried to use an already-ended transaction."""
  62. def _ensure_immutable_rdataset(rdataset):
  63. if rdataset is None or isinstance(rdataset, dns.rdataset.ImmutableRdataset):
  64. return rdataset
  65. return dns.rdataset.ImmutableRdataset(rdataset)
  66. def _ensure_immutable_node(node):
  67. if node is None or node.is_immutable():
  68. return node
  69. return dns.node.ImmutableNode(node)
  70. CheckPutRdatasetType = Callable[
  71. ["Transaction", dns.name.Name, dns.rdataset.Rdataset], None
  72. ]
  73. CheckDeleteRdatasetType = Callable[
  74. ["Transaction", dns.name.Name, dns.rdatatype.RdataType, dns.rdatatype.RdataType],
  75. None,
  76. ]
  77. CheckDeleteNameType = Callable[["Transaction", dns.name.Name], None]
  78. class Transaction:
  79. def __init__(
  80. self,
  81. manager: TransactionManager,
  82. replacement: bool = False,
  83. read_only: bool = False,
  84. ):
  85. self.manager = manager
  86. self.replacement = replacement
  87. self.read_only = read_only
  88. self._ended = False
  89. self._check_put_rdataset: List[CheckPutRdatasetType] = []
  90. self._check_delete_rdataset: List[CheckDeleteRdatasetType] = []
  91. self._check_delete_name: List[CheckDeleteNameType] = []
  92. #
  93. # This is the high level API
  94. #
  95. # Note that we currently use non-immutable types in the return type signature to
  96. # avoid covariance problems, e.g. if the caller has a List[Rdataset], mypy will be
  97. # unhappy if we return an ImmutableRdataset.
  98. def get(
  99. self,
  100. name: Optional[Union[dns.name.Name, str]],
  101. rdtype: Union[dns.rdatatype.RdataType, str],
  102. covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE,
  103. ) -> dns.rdataset.Rdataset:
  104. """Return the rdataset associated with *name*, *rdtype*, and *covers*,
  105. or `None` if not found.
  106. Note that the returned rdataset is immutable.
  107. """
  108. self._check_ended()
  109. if isinstance(name, str):
  110. name = dns.name.from_text(name, None)
  111. rdtype = dns.rdatatype.RdataType.make(rdtype)
  112. covers = dns.rdatatype.RdataType.make(covers)
  113. rdataset = self._get_rdataset(name, rdtype, covers)
  114. return _ensure_immutable_rdataset(rdataset)
  115. def get_node(self, name: dns.name.Name) -> Optional[dns.node.Node]:
  116. """Return the node at *name*, if any.
  117. Returns an immutable node or ``None``.
  118. """
  119. return _ensure_immutable_node(self._get_node(name))
  120. def _check_read_only(self) -> None:
  121. if self.read_only:
  122. raise ReadOnly
  123. def add(self, *args: Any) -> None:
  124. """Add records.
  125. The arguments may be:
  126. - rrset
  127. - name, rdataset...
  128. - name, ttl, rdata...
  129. """
  130. self._check_ended()
  131. self._check_read_only()
  132. self._add(False, args)
  133. def replace(self, *args: Any) -> None:
  134. """Replace the existing rdataset at the name with the specified
  135. rdataset, or add the specified rdataset if there was no existing
  136. rdataset.
  137. The arguments may be:
  138. - rrset
  139. - name, rdataset...
  140. - name, ttl, rdata...
  141. Note that if you want to replace the entire node, you should do
  142. a delete of the name followed by one or more calls to add() or
  143. replace().
  144. """
  145. self._check_ended()
  146. self._check_read_only()
  147. self._add(True, args)
  148. def delete(self, *args: Any) -> None:
  149. """Delete records.
  150. It is not an error if some of the records are not in the existing
  151. set.
  152. The arguments may be:
  153. - rrset
  154. - name
  155. - name, rdatatype, [covers]
  156. - name, rdataset...
  157. - name, rdata...
  158. """
  159. self._check_ended()
  160. self._check_read_only()
  161. self._delete(False, args)
  162. def delete_exact(self, *args: Any) -> None:
  163. """Delete records.
  164. The arguments may be:
  165. - rrset
  166. - name
  167. - name, rdatatype, [covers]
  168. - name, rdataset...
  169. - name, rdata...
  170. Raises dns.transaction.DeleteNotExact if some of the records
  171. are not in the existing set.
  172. """
  173. self._check_ended()
  174. self._check_read_only()
  175. self._delete(True, args)
  176. def name_exists(self, name: Union[dns.name.Name, str]) -> bool:
  177. """Does the specified name exist?"""
  178. self._check_ended()
  179. if isinstance(name, str):
  180. name = dns.name.from_text(name, None)
  181. return self._name_exists(name)
  182. def update_serial(
  183. self,
  184. value: int = 1,
  185. relative: bool = True,
  186. name: dns.name.Name = dns.name.empty,
  187. ) -> None:
  188. """Update the serial number.
  189. *value*, an `int`, is an increment if *relative* is `True`, or the
  190. actual value to set if *relative* is `False`.
  191. Raises `KeyError` if there is no SOA rdataset at *name*.
  192. Raises `ValueError` if *value* is negative or if the increment is
  193. so large that it would cause the new serial to be less than the
  194. prior value.
  195. """
  196. self._check_ended()
  197. if value < 0:
  198. raise ValueError("negative update_serial() value")
  199. if isinstance(name, str):
  200. name = dns.name.from_text(name, None)
  201. rdataset = self._get_rdataset(name, dns.rdatatype.SOA, dns.rdatatype.NONE)
  202. if rdataset is None or len(rdataset) == 0:
  203. raise KeyError
  204. if relative:
  205. serial = dns.serial.Serial(rdataset[0].serial) + value
  206. else:
  207. serial = dns.serial.Serial(value)
  208. serial = serial.value # convert back to int
  209. if serial == 0:
  210. serial = 1
  211. rdata = rdataset[0].replace(serial=serial)
  212. new_rdataset = dns.rdataset.from_rdata(rdataset.ttl, rdata)
  213. self.replace(name, new_rdataset)
  214. def __iter__(self):
  215. self._check_ended()
  216. return self._iterate_rdatasets()
  217. def changed(self) -> bool:
  218. """Has this transaction changed anything?
  219. For read-only transactions, the result is always `False`.
  220. For writable transactions, the result is `True` if at some time
  221. during the life of the transaction, the content was changed.
  222. """
  223. self._check_ended()
  224. return self._changed()
  225. def commit(self) -> None:
  226. """Commit the transaction.
  227. Normally transactions are used as context managers and commit
  228. or rollback automatically, but it may be done explicitly if needed.
  229. A ``dns.transaction.Ended`` exception will be raised if you try
  230. to use a transaction after it has been committed or rolled back.
  231. Raises an exception if the commit fails (in which case the transaction
  232. is also rolled back.
  233. """
  234. self._end(True)
  235. def rollback(self) -> None:
  236. """Rollback the transaction.
  237. Normally transactions are used as context managers and commit
  238. or rollback automatically, but it may be done explicitly if needed.
  239. A ``dns.transaction.AlreadyEnded`` exception will be raised if you try
  240. to use a transaction after it has been committed or rolled back.
  241. Rollback cannot otherwise fail.
  242. """
  243. self._end(False)
  244. def check_put_rdataset(self, check: CheckPutRdatasetType) -> None:
  245. """Call *check* before putting (storing) an rdataset.
  246. The function is called with the transaction, the name, and the rdataset.
  247. The check function may safely make non-mutating transaction method
  248. calls, but behavior is undefined if mutating transaction methods are
  249. called. The check function should raise an exception if it objects to
  250. the put, and otherwise should return ``None``.
  251. """
  252. self._check_put_rdataset.append(check)
  253. def check_delete_rdataset(self, check: CheckDeleteRdatasetType) -> None:
  254. """Call *check* before deleting an rdataset.
  255. The function is called with the transaction, the name, the rdatatype,
  256. and the covered rdatatype.
  257. The check function may safely make non-mutating transaction method
  258. calls, but behavior is undefined if mutating transaction methods are
  259. called. The check function should raise an exception if it objects to
  260. the put, and otherwise should return ``None``.
  261. """
  262. self._check_delete_rdataset.append(check)
  263. def check_delete_name(self, check: CheckDeleteNameType) -> None:
  264. """Call *check* before putting (storing) an rdataset.
  265. The function is called with the transaction and the name.
  266. The check function may safely make non-mutating transaction method
  267. calls, but behavior is undefined if mutating transaction methods are
  268. called. The check function should raise an exception if it objects to
  269. the put, and otherwise should return ``None``.
  270. """
  271. self._check_delete_name.append(check)
  272. def iterate_rdatasets(
  273. self,
  274. ) -> Iterator[Tuple[dns.name.Name, dns.rdataset.Rdataset]]:
  275. """Iterate all the rdatasets in the transaction, returning
  276. (`dns.name.Name`, `dns.rdataset.Rdataset`) tuples.
  277. Note that as is usual with python iterators, adding or removing items
  278. while iterating will invalidate the iterator and may raise `RuntimeError`
  279. or fail to iterate over all entries."""
  280. self._check_ended()
  281. return self._iterate_rdatasets()
  282. def iterate_names(self) -> Iterator[dns.name.Name]:
  283. """Iterate all the names in the transaction.
  284. Note that as is usual with python iterators, adding or removing names
  285. while iterating will invalidate the iterator and may raise `RuntimeError`
  286. or fail to iterate over all entries."""
  287. self._check_ended()
  288. return self._iterate_names()
  289. #
  290. # Helper methods
  291. #
  292. def _raise_if_not_empty(self, method, args):
  293. if len(args) != 0:
  294. raise TypeError(f"extra parameters to {method}")
  295. def _rdataset_from_args(self, method, deleting, args):
  296. try:
  297. arg = args.popleft()
  298. if isinstance(arg, dns.rrset.RRset):
  299. rdataset = arg.to_rdataset()
  300. elif isinstance(arg, dns.rdataset.Rdataset):
  301. rdataset = arg
  302. else:
  303. if deleting:
  304. ttl = 0
  305. else:
  306. if isinstance(arg, int):
  307. ttl = arg
  308. if ttl > dns.ttl.MAX_TTL:
  309. raise ValueError(f"{method}: TTL value too big")
  310. else:
  311. raise TypeError(f"{method}: expected a TTL")
  312. arg = args.popleft()
  313. if isinstance(arg, dns.rdata.Rdata):
  314. rdataset = dns.rdataset.from_rdata(ttl, arg)
  315. else:
  316. raise TypeError(f"{method}: expected an Rdata")
  317. return rdataset
  318. except IndexError:
  319. if deleting:
  320. return None
  321. else:
  322. # reraise
  323. raise TypeError(f"{method}: expected more arguments")
  324. def _add(self, replace, args):
  325. try:
  326. args = collections.deque(args)
  327. if replace:
  328. method = "replace()"
  329. else:
  330. method = "add()"
  331. arg = args.popleft()
  332. if isinstance(arg, str):
  333. arg = dns.name.from_text(arg, None)
  334. if isinstance(arg, dns.name.Name):
  335. name = arg
  336. rdataset = self._rdataset_from_args(method, False, args)
  337. elif isinstance(arg, dns.rrset.RRset):
  338. rrset = arg
  339. name = rrset.name
  340. # rrsets are also rdatasets, but they don't print the
  341. # same and can't be stored in nodes, so convert.
  342. rdataset = rrset.to_rdataset()
  343. else:
  344. raise TypeError(
  345. f"{method} requires a name or RRset as the first argument"
  346. )
  347. if rdataset.rdclass != self.manager.get_class():
  348. raise ValueError(f"{method} has objects of wrong RdataClass")
  349. if rdataset.rdtype == dns.rdatatype.SOA:
  350. (_, _, origin) = self._origin_information()
  351. if name != origin:
  352. raise ValueError(f"{method} has non-origin SOA")
  353. self._raise_if_not_empty(method, args)
  354. if not replace:
  355. existing = self._get_rdataset(name, rdataset.rdtype, rdataset.covers)
  356. if existing is not None:
  357. if isinstance(existing, dns.rdataset.ImmutableRdataset):
  358. trds = dns.rdataset.Rdataset(
  359. existing.rdclass, existing.rdtype, existing.covers
  360. )
  361. trds.update(existing)
  362. existing = trds
  363. rdataset = existing.union(rdataset)
  364. self._checked_put_rdataset(name, rdataset)
  365. except IndexError:
  366. raise TypeError(f"not enough parameters to {method}")
  367. def _delete(self, exact, args):
  368. try:
  369. args = collections.deque(args)
  370. if exact:
  371. method = "delete_exact()"
  372. else:
  373. method = "delete()"
  374. arg = args.popleft()
  375. if isinstance(arg, str):
  376. arg = dns.name.from_text(arg, None)
  377. if isinstance(arg, dns.name.Name):
  378. name = arg
  379. if len(args) > 0 and (
  380. isinstance(args[0], int) or isinstance(args[0], str)
  381. ):
  382. # deleting by type and (optionally) covers
  383. rdtype = dns.rdatatype.RdataType.make(args.popleft())
  384. if len(args) > 0:
  385. covers = dns.rdatatype.RdataType.make(args.popleft())
  386. else:
  387. covers = dns.rdatatype.NONE
  388. self._raise_if_not_empty(method, args)
  389. existing = self._get_rdataset(name, rdtype, covers)
  390. if existing is None:
  391. if exact:
  392. raise DeleteNotExact(f"{method}: missing rdataset")
  393. else:
  394. self._checked_delete_rdataset(name, rdtype, covers)
  395. return
  396. else:
  397. rdataset = self._rdataset_from_args(method, True, args)
  398. elif isinstance(arg, dns.rrset.RRset):
  399. rdataset = arg # rrsets are also rdatasets
  400. name = rdataset.name
  401. else:
  402. raise TypeError(
  403. f"{method} requires a name or RRset as the first argument"
  404. )
  405. self._raise_if_not_empty(method, args)
  406. if rdataset:
  407. if rdataset.rdclass != self.manager.get_class():
  408. raise ValueError(f"{method} has objects of wrong RdataClass")
  409. existing = self._get_rdataset(name, rdataset.rdtype, rdataset.covers)
  410. if existing is not None:
  411. if exact:
  412. intersection = existing.intersection(rdataset)
  413. if intersection != rdataset:
  414. raise DeleteNotExact(f"{method}: missing rdatas")
  415. rdataset = existing.difference(rdataset)
  416. if len(rdataset) == 0:
  417. self._checked_delete_rdataset(
  418. name, rdataset.rdtype, rdataset.covers
  419. )
  420. else:
  421. self._checked_put_rdataset(name, rdataset)
  422. elif exact:
  423. raise DeleteNotExact(f"{method}: missing rdataset")
  424. else:
  425. if exact and not self._name_exists(name):
  426. raise DeleteNotExact(f"{method}: name not known")
  427. self._checked_delete_name(name)
  428. except IndexError:
  429. raise TypeError(f"not enough parameters to {method}")
  430. def _check_ended(self):
  431. if self._ended:
  432. raise AlreadyEnded
  433. def _end(self, commit):
  434. self._check_ended()
  435. try:
  436. self._end_transaction(commit)
  437. finally:
  438. self._ended = True
  439. def _checked_put_rdataset(self, name, rdataset):
  440. for check in self._check_put_rdataset:
  441. check(self, name, rdataset)
  442. self._put_rdataset(name, rdataset)
  443. def _checked_delete_rdataset(self, name, rdtype, covers):
  444. for check in self._check_delete_rdataset:
  445. check(self, name, rdtype, covers)
  446. self._delete_rdataset(name, rdtype, covers)
  447. def _checked_delete_name(self, name):
  448. for check in self._check_delete_name:
  449. check(self, name)
  450. self._delete_name(name)
  451. #
  452. # Transactions are context managers.
  453. #
  454. def __enter__(self):
  455. return self
  456. def __exit__(self, exc_type, exc_val, exc_tb):
  457. if not self._ended:
  458. if exc_type is None:
  459. self.commit()
  460. else:
  461. self.rollback()
  462. return False
  463. #
  464. # This is the low level API, which must be implemented by subclasses
  465. # of Transaction.
  466. #
  467. def _get_rdataset(self, name, rdtype, covers):
  468. """Return the rdataset associated with *name*, *rdtype*, and *covers*,
  469. or `None` if not found.
  470. """
  471. raise NotImplementedError # pragma: no cover
  472. def _put_rdataset(self, name, rdataset):
  473. """Store the rdataset."""
  474. raise NotImplementedError # pragma: no cover
  475. def _delete_name(self, name):
  476. """Delete all data associated with *name*.
  477. It is not an error if the name does not exist.
  478. """
  479. raise NotImplementedError # pragma: no cover
  480. def _delete_rdataset(self, name, rdtype, covers):
  481. """Delete all data associated with *name*, *rdtype*, and *covers*.
  482. It is not an error if the rdataset does not exist.
  483. """
  484. raise NotImplementedError # pragma: no cover
  485. def _name_exists(self, name):
  486. """Does name exist?
  487. Returns a bool.
  488. """
  489. raise NotImplementedError # pragma: no cover
  490. def _changed(self):
  491. """Has this transaction changed anything?"""
  492. raise NotImplementedError # pragma: no cover
  493. def _end_transaction(self, commit):
  494. """End the transaction.
  495. *commit*, a bool. If ``True``, commit the transaction, otherwise
  496. roll it back.
  497. If committing and the commit fails, then roll back and raise an
  498. exception.
  499. """
  500. raise NotImplementedError # pragma: no cover
  501. def _set_origin(self, origin):
  502. """Set the origin.
  503. This method is called when reading a possibly relativized
  504. source, and an origin setting operation occurs (e.g. $ORIGIN
  505. in a zone file).
  506. """
  507. raise NotImplementedError # pragma: no cover
  508. def _iterate_rdatasets(self):
  509. """Return an iterator that yields (name, rdataset) tuples."""
  510. raise NotImplementedError # pragma: no cover
  511. def _iterate_names(self):
  512. """Return an iterator that yields a name."""
  513. raise NotImplementedError # pragma: no cover
  514. def _get_node(self, name):
  515. """Return the node at *name*, if any.
  516. Returns a node or ``None``.
  517. """
  518. raise NotImplementedError # pragma: no cover
  519. #
  520. # Low-level API with a default implementation, in case a subclass needs
  521. # to override.
  522. #
  523. def _origin_information(self):
  524. # This is only used by _add()
  525. return self.manager.origin_information()