rsa.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. # This file is dual licensed under the terms of the Apache License, Version
  2. # 2.0, and the BSD License. See the LICENSE file in the root of this repository
  3. # for complete details.
  4. from __future__ import annotations
  5. import abc
  6. import random
  7. import typing
  8. from math import gcd
  9. from cryptography.hazmat.bindings._rust import openssl as rust_openssl
  10. from cryptography.hazmat.primitives import _serialization, hashes
  11. from cryptography.hazmat.primitives._asymmetric import AsymmetricPadding
  12. from cryptography.hazmat.primitives.asymmetric import utils as asym_utils
  13. class RSAPrivateKey(metaclass=abc.ABCMeta):
  14. @abc.abstractmethod
  15. def decrypt(self, ciphertext: bytes, padding: AsymmetricPadding) -> bytes:
  16. """
  17. Decrypts the provided ciphertext.
  18. """
  19. @property
  20. @abc.abstractmethod
  21. def key_size(self) -> int:
  22. """
  23. The bit length of the public modulus.
  24. """
  25. @abc.abstractmethod
  26. def public_key(self) -> RSAPublicKey:
  27. """
  28. The RSAPublicKey associated with this private key.
  29. """
  30. @abc.abstractmethod
  31. def sign(
  32. self,
  33. data: bytes,
  34. padding: AsymmetricPadding,
  35. algorithm: asym_utils.Prehashed | hashes.HashAlgorithm,
  36. ) -> bytes:
  37. """
  38. Signs the data.
  39. """
  40. @abc.abstractmethod
  41. def private_numbers(self) -> RSAPrivateNumbers:
  42. """
  43. Returns an RSAPrivateNumbers.
  44. """
  45. @abc.abstractmethod
  46. def private_bytes(
  47. self,
  48. encoding: _serialization.Encoding,
  49. format: _serialization.PrivateFormat,
  50. encryption_algorithm: _serialization.KeySerializationEncryption,
  51. ) -> bytes:
  52. """
  53. Returns the key serialized as bytes.
  54. """
  55. RSAPrivateKeyWithSerialization = RSAPrivateKey
  56. RSAPrivateKey.register(rust_openssl.rsa.RSAPrivateKey)
  57. class RSAPublicKey(metaclass=abc.ABCMeta):
  58. @abc.abstractmethod
  59. def encrypt(self, plaintext: bytes, padding: AsymmetricPadding) -> bytes:
  60. """
  61. Encrypts the given plaintext.
  62. """
  63. @property
  64. @abc.abstractmethod
  65. def key_size(self) -> int:
  66. """
  67. The bit length of the public modulus.
  68. """
  69. @abc.abstractmethod
  70. def public_numbers(self) -> RSAPublicNumbers:
  71. """
  72. Returns an RSAPublicNumbers
  73. """
  74. @abc.abstractmethod
  75. def public_bytes(
  76. self,
  77. encoding: _serialization.Encoding,
  78. format: _serialization.PublicFormat,
  79. ) -> bytes:
  80. """
  81. Returns the key serialized as bytes.
  82. """
  83. @abc.abstractmethod
  84. def verify(
  85. self,
  86. signature: bytes,
  87. data: bytes,
  88. padding: AsymmetricPadding,
  89. algorithm: asym_utils.Prehashed | hashes.HashAlgorithm,
  90. ) -> None:
  91. """
  92. Verifies the signature of the data.
  93. """
  94. @abc.abstractmethod
  95. def recover_data_from_signature(
  96. self,
  97. signature: bytes,
  98. padding: AsymmetricPadding,
  99. algorithm: hashes.HashAlgorithm | None,
  100. ) -> bytes:
  101. """
  102. Recovers the original data from the signature.
  103. """
  104. @abc.abstractmethod
  105. def __eq__(self, other: object) -> bool:
  106. """
  107. Checks equality.
  108. """
  109. RSAPublicKeyWithSerialization = RSAPublicKey
  110. RSAPublicKey.register(rust_openssl.rsa.RSAPublicKey)
  111. RSAPrivateNumbers = rust_openssl.rsa.RSAPrivateNumbers
  112. RSAPublicNumbers = rust_openssl.rsa.RSAPublicNumbers
  113. def generate_private_key(
  114. public_exponent: int,
  115. key_size: int,
  116. backend: typing.Any = None,
  117. ) -> RSAPrivateKey:
  118. _verify_rsa_parameters(public_exponent, key_size)
  119. return rust_openssl.rsa.generate_private_key(public_exponent, key_size)
  120. def _verify_rsa_parameters(public_exponent: int, key_size: int) -> None:
  121. if public_exponent not in (3, 65537):
  122. raise ValueError(
  123. "public_exponent must be either 3 (for legacy compatibility) or "
  124. "65537. Almost everyone should choose 65537 here!"
  125. )
  126. if key_size < 1024:
  127. raise ValueError("key_size must be at least 1024-bits.")
  128. def _modinv(e: int, m: int) -> int:
  129. """
  130. Modular Multiplicative Inverse. Returns x such that: (x*e) mod m == 1
  131. """
  132. x1, x2 = 1, 0
  133. a, b = e, m
  134. while b > 0:
  135. q, r = divmod(a, b)
  136. xn = x1 - q * x2
  137. a, b, x1, x2 = b, r, x2, xn
  138. return x1 % m
  139. def rsa_crt_iqmp(p: int, q: int) -> int:
  140. """
  141. Compute the CRT (q ** -1) % p value from RSA primes p and q.
  142. """
  143. return _modinv(q, p)
  144. def rsa_crt_dmp1(private_exponent: int, p: int) -> int:
  145. """
  146. Compute the CRT private_exponent % (p - 1) value from the RSA
  147. private_exponent (d) and p.
  148. """
  149. return private_exponent % (p - 1)
  150. def rsa_crt_dmq1(private_exponent: int, q: int) -> int:
  151. """
  152. Compute the CRT private_exponent % (q - 1) value from the RSA
  153. private_exponent (d) and q.
  154. """
  155. return private_exponent % (q - 1)
  156. def rsa_recover_private_exponent(e: int, p: int, q: int) -> int:
  157. """
  158. Compute the RSA private_exponent (d) given the public exponent (e)
  159. and the RSA primes p and q.
  160. This uses the Carmichael totient function to generate the
  161. smallest possible working value of the private exponent.
  162. """
  163. # This lambda_n is the Carmichael totient function.
  164. # The original RSA paper uses the Euler totient function
  165. # here: phi_n = (p - 1) * (q - 1)
  166. # Either version of the private exponent will work, but the
  167. # one generated by the older formulation may be larger
  168. # than necessary. (lambda_n always divides phi_n)
  169. #
  170. # TODO: Replace with lcm(p - 1, q - 1) once the minimum
  171. # supported Python version is >= 3.9.
  172. lambda_n = (p - 1) * (q - 1) // gcd(p - 1, q - 1)
  173. return _modinv(e, lambda_n)
  174. # Controls the number of iterations rsa_recover_prime_factors will perform
  175. # to obtain the prime factors.
  176. _MAX_RECOVERY_ATTEMPTS = 500
  177. def rsa_recover_prime_factors(n: int, e: int, d: int) -> tuple[int, int]:
  178. """
  179. Compute factors p and q from the private exponent d. We assume that n has
  180. no more than two factors. This function is adapted from code in PyCrypto.
  181. """
  182. # reject invalid values early
  183. if 17 != pow(17, e * d, n):
  184. raise ValueError("n, d, e don't match")
  185. # See 8.2.2(i) in Handbook of Applied Cryptography.
  186. ktot = d * e - 1
  187. # The quantity d*e-1 is a multiple of phi(n), even,
  188. # and can be represented as t*2^s.
  189. t = ktot
  190. while t % 2 == 0:
  191. t = t // 2
  192. # Cycle through all multiplicative inverses in Zn.
  193. # The algorithm is non-deterministic, but there is a 50% chance
  194. # any candidate a leads to successful factoring.
  195. # See "Digitalized Signatures and Public Key Functions as Intractable
  196. # as Factorization", M. Rabin, 1979
  197. spotted = False
  198. tries = 0
  199. while not spotted and tries < _MAX_RECOVERY_ATTEMPTS:
  200. a = random.randint(2, n - 1)
  201. tries += 1
  202. k = t
  203. # Cycle through all values a^{t*2^i}=a^k
  204. while k < ktot:
  205. cand = pow(a, k, n)
  206. # Check if a^k is a non-trivial root of unity (mod n)
  207. if cand != 1 and cand != (n - 1) and pow(cand, 2, n) == 1:
  208. # We have found a number such that (cand-1)(cand+1)=0 (mod n).
  209. # Either of the terms divides n.
  210. p = gcd(cand + 1, n)
  211. spotted = True
  212. break
  213. k *= 2
  214. if not spotted:
  215. raise ValueError("Unable to compute factors p and q from exponent d.")
  216. # Found !
  217. q, r = divmod(n, p)
  218. assert r == 0
  219. p, q = sorted((p, q), reverse=True)
  220. return (p, q)