wire.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
  2. import contextlib
  3. import struct
  4. from typing import Iterator, Optional, Tuple
  5. import dns.exception
  6. import dns.name
  7. class Parser:
  8. def __init__(self, wire: bytes, current: int = 0):
  9. self.wire = wire
  10. self.current = 0
  11. self.end = len(self.wire)
  12. if current:
  13. self.seek(current)
  14. self.furthest = current
  15. def remaining(self) -> int:
  16. return self.end - self.current
  17. def get_bytes(self, size: int) -> bytes:
  18. assert size >= 0
  19. if size > self.remaining():
  20. raise dns.exception.FormError
  21. output = self.wire[self.current : self.current + size]
  22. self.current += size
  23. self.furthest = max(self.furthest, self.current)
  24. return output
  25. def get_counted_bytes(self, length_size: int = 1) -> bytes:
  26. length = int.from_bytes(self.get_bytes(length_size), "big")
  27. return self.get_bytes(length)
  28. def get_remaining(self) -> bytes:
  29. return self.get_bytes(self.remaining())
  30. def get_uint8(self) -> int:
  31. return struct.unpack("!B", self.get_bytes(1))[0]
  32. def get_uint16(self) -> int:
  33. return struct.unpack("!H", self.get_bytes(2))[0]
  34. def get_uint32(self) -> int:
  35. return struct.unpack("!I", self.get_bytes(4))[0]
  36. def get_uint48(self) -> int:
  37. return int.from_bytes(self.get_bytes(6), "big")
  38. def get_struct(self, format: str) -> Tuple:
  39. return struct.unpack(format, self.get_bytes(struct.calcsize(format)))
  40. def get_name(self, origin: Optional["dns.name.Name"] = None) -> "dns.name.Name":
  41. name = dns.name.from_wire_parser(self)
  42. if origin:
  43. name = name.relativize(origin)
  44. return name
  45. def seek(self, where: int) -> None:
  46. # Note that seeking to the end is OK! (If you try to read
  47. # after such a seek, you'll get an exception as expected.)
  48. if where < 0 or where > self.end:
  49. raise dns.exception.FormError
  50. self.current = where
  51. @contextlib.contextmanager
  52. def restrict_to(self, size: int) -> Iterator:
  53. assert size >= 0
  54. if size > self.remaining():
  55. raise dns.exception.FormError
  56. saved_end = self.end
  57. try:
  58. self.end = self.current + size
  59. yield
  60. # We make this check here and not in the finally as we
  61. # don't want to raise if we're already raising for some
  62. # other reason.
  63. if self.current != self.end:
  64. raise dns.exception.FormError
  65. finally:
  66. self.end = saved_end
  67. @contextlib.contextmanager
  68. def restore_furthest(self) -> Iterator:
  69. try:
  70. yield None
  71. finally:
  72. self.current = self.furthest