rsa.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import math
  2. import struct
  3. from cryptography.hazmat.backends import default_backend
  4. from cryptography.hazmat.primitives import hashes
  5. from cryptography.hazmat.primitives.asymmetric import padding, rsa
  6. from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey
  7. from dns.dnssectypes import Algorithm
  8. from dns.rdtypes.ANY.DNSKEY import DNSKEY
  9. class PublicRSA(CryptographyPublicKey):
  10. key: rsa.RSAPublicKey
  11. key_cls = rsa.RSAPublicKey
  12. algorithm: Algorithm
  13. chosen_hash: hashes.HashAlgorithm
  14. def verify(self, signature: bytes, data: bytes) -> None:
  15. self.key.verify(signature, data, padding.PKCS1v15(), self.chosen_hash)
  16. def encode_key_bytes(self) -> bytes:
  17. """Encode a public key per RFC 3110, section 2."""
  18. pn = self.key.public_numbers()
  19. _exp_len = math.ceil(int.bit_length(pn.e) / 8)
  20. exp = int.to_bytes(pn.e, length=_exp_len, byteorder="big")
  21. if _exp_len > 255:
  22. exp_header = b"\0" + struct.pack("!H", _exp_len)
  23. else:
  24. exp_header = struct.pack("!B", _exp_len)
  25. if pn.n.bit_length() < 512 or pn.n.bit_length() > 4096:
  26. raise ValueError("unsupported RSA key length")
  27. return exp_header + exp + pn.n.to_bytes((pn.n.bit_length() + 7) // 8, "big")
  28. @classmethod
  29. def from_dnskey(cls, key: DNSKEY) -> "PublicRSA":
  30. cls._ensure_algorithm_key_combination(key)
  31. keyptr = key.key
  32. (bytes_,) = struct.unpack("!B", keyptr[0:1])
  33. keyptr = keyptr[1:]
  34. if bytes_ == 0:
  35. (bytes_,) = struct.unpack("!H", keyptr[0:2])
  36. keyptr = keyptr[2:]
  37. rsa_e = keyptr[0:bytes_]
  38. rsa_n = keyptr[bytes_:]
  39. return cls(
  40. key=rsa.RSAPublicNumbers(
  41. int.from_bytes(rsa_e, "big"), int.from_bytes(rsa_n, "big")
  42. ).public_key(default_backend())
  43. )
  44. class PrivateRSA(CryptographyPrivateKey):
  45. key: rsa.RSAPrivateKey
  46. key_cls = rsa.RSAPrivateKey
  47. public_cls = PublicRSA
  48. default_public_exponent = 65537
  49. def sign(
  50. self,
  51. data: bytes,
  52. verify: bool = False,
  53. deterministic: bool = True,
  54. ) -> bytes:
  55. """Sign using a private key per RFC 3110, section 3."""
  56. signature = self.key.sign(data, padding.PKCS1v15(), self.public_cls.chosen_hash)
  57. if verify:
  58. self.public_key().verify(signature, data)
  59. return signature
  60. @classmethod
  61. def generate(cls, key_size: int) -> "PrivateRSA":
  62. return cls(
  63. key=rsa.generate_private_key(
  64. public_exponent=cls.default_public_exponent,
  65. key_size=key_size,
  66. backend=default_backend(),
  67. )
  68. )
  69. class PublicRSAMD5(PublicRSA):
  70. algorithm = Algorithm.RSAMD5
  71. chosen_hash = hashes.MD5()
  72. class PrivateRSAMD5(PrivateRSA):
  73. public_cls = PublicRSAMD5
  74. class PublicRSASHA1(PublicRSA):
  75. algorithm = Algorithm.RSASHA1
  76. chosen_hash = hashes.SHA1()
  77. class PrivateRSASHA1(PrivateRSA):
  78. public_cls = PublicRSASHA1
  79. class PublicRSASHA1NSEC3SHA1(PublicRSA):
  80. algorithm = Algorithm.RSASHA1NSEC3SHA1
  81. chosen_hash = hashes.SHA1()
  82. class PrivateRSASHA1NSEC3SHA1(PrivateRSA):
  83. public_cls = PublicRSASHA1NSEC3SHA1
  84. class PublicRSASHA256(PublicRSA):
  85. algorithm = Algorithm.RSASHA256
  86. chosen_hash = hashes.SHA256()
  87. class PrivateRSASHA256(PrivateRSA):
  88. public_cls = PublicRSASHA256
  89. class PublicRSASHA512(PublicRSA):
  90. algorithm = Algorithm.RSASHA512
  91. chosen_hash = hashes.SHA512()
  92. class PrivateRSASHA512(PrivateRSA):
  93. public_cls = PublicRSASHA512