algorithms.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875
  1. from __future__ import annotations
  2. import hashlib
  3. import hmac
  4. import json
  5. from abc import ABC, abstractmethod
  6. from typing import TYPE_CHECKING, Any, ClassVar, Literal, NoReturn, cast, overload
  7. from .exceptions import InvalidKeyError
  8. from .types import HashlibHash, JWKDict
  9. from .utils import (
  10. base64url_decode,
  11. base64url_encode,
  12. der_to_raw_signature,
  13. force_bytes,
  14. from_base64url_uint,
  15. is_pem_format,
  16. is_ssh_key,
  17. raw_to_der_signature,
  18. to_base64url_uint,
  19. )
  20. try:
  21. from cryptography.exceptions import InvalidSignature, UnsupportedAlgorithm
  22. from cryptography.hazmat.backends import default_backend
  23. from cryptography.hazmat.primitives import hashes
  24. from cryptography.hazmat.primitives.asymmetric import padding
  25. from cryptography.hazmat.primitives.asymmetric.ec import (
  26. ECDSA,
  27. SECP256K1,
  28. SECP256R1,
  29. SECP384R1,
  30. SECP521R1,
  31. EllipticCurve,
  32. EllipticCurvePrivateKey,
  33. EllipticCurvePrivateNumbers,
  34. EllipticCurvePublicKey,
  35. EllipticCurvePublicNumbers,
  36. )
  37. from cryptography.hazmat.primitives.asymmetric.ed448 import (
  38. Ed448PrivateKey,
  39. Ed448PublicKey,
  40. )
  41. from cryptography.hazmat.primitives.asymmetric.ed25519 import (
  42. Ed25519PrivateKey,
  43. Ed25519PublicKey,
  44. )
  45. from cryptography.hazmat.primitives.asymmetric.rsa import (
  46. RSAPrivateKey,
  47. RSAPrivateNumbers,
  48. RSAPublicKey,
  49. RSAPublicNumbers,
  50. rsa_crt_dmp1,
  51. rsa_crt_dmq1,
  52. rsa_crt_iqmp,
  53. rsa_recover_prime_factors,
  54. )
  55. from cryptography.hazmat.primitives.serialization import (
  56. Encoding,
  57. NoEncryption,
  58. PrivateFormat,
  59. PublicFormat,
  60. load_pem_private_key,
  61. load_pem_public_key,
  62. load_ssh_public_key,
  63. )
  64. has_crypto = True
  65. except ModuleNotFoundError:
  66. has_crypto = False
  67. if TYPE_CHECKING:
  68. # Type aliases for convenience in algorithms method signatures
  69. AllowedRSAKeys = RSAPrivateKey | RSAPublicKey
  70. AllowedECKeys = EllipticCurvePrivateKey | EllipticCurvePublicKey
  71. AllowedOKPKeys = (
  72. Ed25519PrivateKey | Ed25519PublicKey | Ed448PrivateKey | Ed448PublicKey
  73. )
  74. AllowedKeys = AllowedRSAKeys | AllowedECKeys | AllowedOKPKeys
  75. AllowedPrivateKeys = (
  76. RSAPrivateKey | EllipticCurvePrivateKey | Ed25519PrivateKey | Ed448PrivateKey
  77. )
  78. AllowedPublicKeys = (
  79. RSAPublicKey | EllipticCurvePublicKey | Ed25519PublicKey | Ed448PublicKey
  80. )
  81. requires_cryptography = {
  82. "RS256",
  83. "RS384",
  84. "RS512",
  85. "ES256",
  86. "ES256K",
  87. "ES384",
  88. "ES521",
  89. "ES512",
  90. "PS256",
  91. "PS384",
  92. "PS512",
  93. "EdDSA",
  94. }
  95. def get_default_algorithms() -> dict[str, Algorithm]:
  96. """
  97. Returns the algorithms that are implemented by the library.
  98. """
  99. default_algorithms = {
  100. "none": NoneAlgorithm(),
  101. "HS256": HMACAlgorithm(HMACAlgorithm.SHA256),
  102. "HS384": HMACAlgorithm(HMACAlgorithm.SHA384),
  103. "HS512": HMACAlgorithm(HMACAlgorithm.SHA512),
  104. }
  105. if has_crypto:
  106. default_algorithms.update(
  107. {
  108. "RS256": RSAAlgorithm(RSAAlgorithm.SHA256),
  109. "RS384": RSAAlgorithm(RSAAlgorithm.SHA384),
  110. "RS512": RSAAlgorithm(RSAAlgorithm.SHA512),
  111. "ES256": ECAlgorithm(ECAlgorithm.SHA256),
  112. "ES256K": ECAlgorithm(ECAlgorithm.SHA256),
  113. "ES384": ECAlgorithm(ECAlgorithm.SHA384),
  114. "ES521": ECAlgorithm(ECAlgorithm.SHA512),
  115. "ES512": ECAlgorithm(
  116. ECAlgorithm.SHA512
  117. ), # Backward compat for #219 fix
  118. "PS256": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256),
  119. "PS384": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384),
  120. "PS512": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512),
  121. "EdDSA": OKPAlgorithm(),
  122. }
  123. )
  124. return default_algorithms
  125. class Algorithm(ABC):
  126. """
  127. The interface for an algorithm used to sign and verify tokens.
  128. """
  129. def compute_hash_digest(self, bytestr: bytes) -> bytes:
  130. """
  131. Compute a hash digest using the specified algorithm's hash algorithm.
  132. If there is no hash algorithm, raises a NotImplementedError.
  133. """
  134. # lookup self.hash_alg if defined in a way that mypy can understand
  135. hash_alg = getattr(self, "hash_alg", None)
  136. if hash_alg is None:
  137. raise NotImplementedError
  138. if (
  139. has_crypto
  140. and isinstance(hash_alg, type)
  141. and issubclass(hash_alg, hashes.HashAlgorithm)
  142. ):
  143. digest = hashes.Hash(hash_alg(), backend=default_backend())
  144. digest.update(bytestr)
  145. return bytes(digest.finalize())
  146. else:
  147. return bytes(hash_alg(bytestr).digest())
  148. @abstractmethod
  149. def prepare_key(self, key: Any) -> Any:
  150. """
  151. Performs necessary validation and conversions on the key and returns
  152. the key value in the proper format for sign() and verify().
  153. """
  154. @abstractmethod
  155. def sign(self, msg: bytes, key: Any) -> bytes:
  156. """
  157. Returns a digital signature for the specified message
  158. using the specified key value.
  159. """
  160. @abstractmethod
  161. def verify(self, msg: bytes, key: Any, sig: bytes) -> bool:
  162. """
  163. Verifies that the specified digital signature is valid
  164. for the specified message and key values.
  165. """
  166. @overload
  167. @staticmethod
  168. @abstractmethod
  169. def to_jwk(key_obj, as_dict: Literal[True]) -> JWKDict: ... # pragma: no cover
  170. @overload
  171. @staticmethod
  172. @abstractmethod
  173. def to_jwk(key_obj, as_dict: Literal[False] = False) -> str: ... # pragma: no cover
  174. @staticmethod
  175. @abstractmethod
  176. def to_jwk(key_obj, as_dict: bool = False) -> JWKDict | str:
  177. """
  178. Serializes a given key into a JWK
  179. """
  180. @staticmethod
  181. @abstractmethod
  182. def from_jwk(jwk: str | JWKDict) -> Any:
  183. """
  184. Deserializes a given key from JWK back into a key object
  185. """
  186. class NoneAlgorithm(Algorithm):
  187. """
  188. Placeholder for use when no signing or verification
  189. operations are required.
  190. """
  191. def prepare_key(self, key: str | None) -> None:
  192. if key == "":
  193. key = None
  194. if key is not None:
  195. raise InvalidKeyError('When alg = "none", key value must be None.')
  196. return key
  197. def sign(self, msg: bytes, key: None) -> bytes:
  198. return b""
  199. def verify(self, msg: bytes, key: None, sig: bytes) -> bool:
  200. return False
  201. @staticmethod
  202. def to_jwk(key_obj: Any, as_dict: bool = False) -> NoReturn:
  203. raise NotImplementedError()
  204. @staticmethod
  205. def from_jwk(jwk: str | JWKDict) -> NoReturn:
  206. raise NotImplementedError()
  207. class HMACAlgorithm(Algorithm):
  208. """
  209. Performs signing and verification operations using HMAC
  210. and the specified hash function.
  211. """
  212. SHA256: ClassVar[HashlibHash] = hashlib.sha256
  213. SHA384: ClassVar[HashlibHash] = hashlib.sha384
  214. SHA512: ClassVar[HashlibHash] = hashlib.sha512
  215. def __init__(self, hash_alg: HashlibHash) -> None:
  216. self.hash_alg = hash_alg
  217. def prepare_key(self, key: str | bytes) -> bytes:
  218. key_bytes = force_bytes(key)
  219. if is_pem_format(key_bytes) or is_ssh_key(key_bytes):
  220. raise InvalidKeyError(
  221. "The specified key is an asymmetric key or x509 certificate and"
  222. " should not be used as an HMAC secret."
  223. )
  224. return key_bytes
  225. @overload
  226. @staticmethod
  227. def to_jwk(
  228. key_obj: str | bytes, as_dict: Literal[True]
  229. ) -> JWKDict: ... # pragma: no cover
  230. @overload
  231. @staticmethod
  232. def to_jwk(
  233. key_obj: str | bytes, as_dict: Literal[False] = False
  234. ) -> str: ... # pragma: no cover
  235. @staticmethod
  236. def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> JWKDict | str:
  237. jwk = {
  238. "k": base64url_encode(force_bytes(key_obj)).decode(),
  239. "kty": "oct",
  240. }
  241. if as_dict:
  242. return jwk
  243. else:
  244. return json.dumps(jwk)
  245. @staticmethod
  246. def from_jwk(jwk: str | JWKDict) -> bytes:
  247. try:
  248. if isinstance(jwk, str):
  249. obj: JWKDict = json.loads(jwk)
  250. elif isinstance(jwk, dict):
  251. obj = jwk
  252. else:
  253. raise ValueError
  254. except ValueError:
  255. raise InvalidKeyError("Key is not valid JSON") from None
  256. if obj.get("kty") != "oct":
  257. raise InvalidKeyError("Not an HMAC key")
  258. return base64url_decode(obj["k"])
  259. def sign(self, msg: bytes, key: bytes) -> bytes:
  260. return hmac.new(key, msg, self.hash_alg).digest()
  261. def verify(self, msg: bytes, key: bytes, sig: bytes) -> bool:
  262. return hmac.compare_digest(sig, self.sign(msg, key))
  263. if has_crypto:
  264. class RSAAlgorithm(Algorithm):
  265. """
  266. Performs signing and verification operations using
  267. RSASSA-PKCS-v1_5 and the specified hash function.
  268. """
  269. SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256
  270. SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384
  271. SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512
  272. def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
  273. self.hash_alg = hash_alg
  274. def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys:
  275. if isinstance(key, (RSAPrivateKey, RSAPublicKey)):
  276. return key
  277. if not isinstance(key, (bytes, str)):
  278. raise TypeError("Expecting a PEM-formatted key.")
  279. key_bytes = force_bytes(key)
  280. try:
  281. if key_bytes.startswith(b"ssh-rsa"):
  282. return cast(RSAPublicKey, load_ssh_public_key(key_bytes))
  283. else:
  284. return cast(
  285. RSAPrivateKey, load_pem_private_key(key_bytes, password=None)
  286. )
  287. except ValueError:
  288. try:
  289. return cast(RSAPublicKey, load_pem_public_key(key_bytes))
  290. except (ValueError, UnsupportedAlgorithm):
  291. raise InvalidKeyError(
  292. "Could not parse the provided public key."
  293. ) from None
  294. @overload
  295. @staticmethod
  296. def to_jwk(
  297. key_obj: AllowedRSAKeys, as_dict: Literal[True]
  298. ) -> JWKDict: ... # pragma: no cover
  299. @overload
  300. @staticmethod
  301. def to_jwk(
  302. key_obj: AllowedRSAKeys, as_dict: Literal[False] = False
  303. ) -> str: ... # pragma: no cover
  304. @staticmethod
  305. def to_jwk(key_obj: AllowedRSAKeys, as_dict: bool = False) -> JWKDict | str:
  306. obj: dict[str, Any] | None = None
  307. if hasattr(key_obj, "private_numbers"):
  308. # Private key
  309. numbers = key_obj.private_numbers()
  310. obj = {
  311. "kty": "RSA",
  312. "key_ops": ["sign"],
  313. "n": to_base64url_uint(numbers.public_numbers.n).decode(),
  314. "e": to_base64url_uint(numbers.public_numbers.e).decode(),
  315. "d": to_base64url_uint(numbers.d).decode(),
  316. "p": to_base64url_uint(numbers.p).decode(),
  317. "q": to_base64url_uint(numbers.q).decode(),
  318. "dp": to_base64url_uint(numbers.dmp1).decode(),
  319. "dq": to_base64url_uint(numbers.dmq1).decode(),
  320. "qi": to_base64url_uint(numbers.iqmp).decode(),
  321. }
  322. elif hasattr(key_obj, "verify"):
  323. # Public key
  324. numbers = key_obj.public_numbers()
  325. obj = {
  326. "kty": "RSA",
  327. "key_ops": ["verify"],
  328. "n": to_base64url_uint(numbers.n).decode(),
  329. "e": to_base64url_uint(numbers.e).decode(),
  330. }
  331. else:
  332. raise InvalidKeyError("Not a public or private key")
  333. if as_dict:
  334. return obj
  335. else:
  336. return json.dumps(obj)
  337. @staticmethod
  338. def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys:
  339. try:
  340. if isinstance(jwk, str):
  341. obj = json.loads(jwk)
  342. elif isinstance(jwk, dict):
  343. obj = jwk
  344. else:
  345. raise ValueError
  346. except ValueError:
  347. raise InvalidKeyError("Key is not valid JSON") from None
  348. if obj.get("kty") != "RSA":
  349. raise InvalidKeyError("Not an RSA key") from None
  350. if "d" in obj and "e" in obj and "n" in obj:
  351. # Private key
  352. if "oth" in obj:
  353. raise InvalidKeyError(
  354. "Unsupported RSA private key: > 2 primes not supported"
  355. )
  356. other_props = ["p", "q", "dp", "dq", "qi"]
  357. props_found = [prop in obj for prop in other_props]
  358. any_props_found = any(props_found)
  359. if any_props_found and not all(props_found):
  360. raise InvalidKeyError(
  361. "RSA key must include all parameters if any are present besides d"
  362. ) from None
  363. public_numbers = RSAPublicNumbers(
  364. from_base64url_uint(obj["e"]),
  365. from_base64url_uint(obj["n"]),
  366. )
  367. if any_props_found:
  368. numbers = RSAPrivateNumbers(
  369. d=from_base64url_uint(obj["d"]),
  370. p=from_base64url_uint(obj["p"]),
  371. q=from_base64url_uint(obj["q"]),
  372. dmp1=from_base64url_uint(obj["dp"]),
  373. dmq1=from_base64url_uint(obj["dq"]),
  374. iqmp=from_base64url_uint(obj["qi"]),
  375. public_numbers=public_numbers,
  376. )
  377. else:
  378. d = from_base64url_uint(obj["d"])
  379. p, q = rsa_recover_prime_factors(
  380. public_numbers.n, d, public_numbers.e
  381. )
  382. numbers = RSAPrivateNumbers(
  383. d=d,
  384. p=p,
  385. q=q,
  386. dmp1=rsa_crt_dmp1(d, p),
  387. dmq1=rsa_crt_dmq1(d, q),
  388. iqmp=rsa_crt_iqmp(p, q),
  389. public_numbers=public_numbers,
  390. )
  391. return numbers.private_key()
  392. elif "n" in obj and "e" in obj:
  393. # Public key
  394. return RSAPublicNumbers(
  395. from_base64url_uint(obj["e"]),
  396. from_base64url_uint(obj["n"]),
  397. ).public_key()
  398. else:
  399. raise InvalidKeyError("Not a public or private key")
  400. def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes:
  401. return key.sign(msg, padding.PKCS1v15(), self.hash_alg())
  402. def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool:
  403. try:
  404. key.verify(sig, msg, padding.PKCS1v15(), self.hash_alg())
  405. return True
  406. except InvalidSignature:
  407. return False
  408. class ECAlgorithm(Algorithm):
  409. """
  410. Performs signing and verification operations using
  411. ECDSA and the specified hash function
  412. """
  413. SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256
  414. SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384
  415. SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512
  416. def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
  417. self.hash_alg = hash_alg
  418. def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys:
  419. if isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)):
  420. return key
  421. if not isinstance(key, (bytes, str)):
  422. raise TypeError("Expecting a PEM-formatted key.")
  423. key_bytes = force_bytes(key)
  424. # Attempt to load key. We don't know if it's
  425. # a Signing Key or a Verifying Key, so we try
  426. # the Verifying Key first.
  427. try:
  428. if key_bytes.startswith(b"ecdsa-sha2-"):
  429. crypto_key = load_ssh_public_key(key_bytes)
  430. else:
  431. crypto_key = load_pem_public_key(key_bytes) # type: ignore[assignment]
  432. except ValueError:
  433. crypto_key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment]
  434. # Explicit check the key to prevent confusing errors from cryptography
  435. if not isinstance(
  436. crypto_key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)
  437. ):
  438. raise InvalidKeyError(
  439. "Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for ECDSA algorithms"
  440. ) from None
  441. return crypto_key
  442. def sign(self, msg: bytes, key: EllipticCurvePrivateKey) -> bytes:
  443. der_sig = key.sign(msg, ECDSA(self.hash_alg()))
  444. return der_to_raw_signature(der_sig, key.curve)
  445. def verify(self, msg: bytes, key: AllowedECKeys, sig: bytes) -> bool:
  446. try:
  447. der_sig = raw_to_der_signature(sig, key.curve)
  448. except ValueError:
  449. return False
  450. try:
  451. public_key = (
  452. key.public_key()
  453. if isinstance(key, EllipticCurvePrivateKey)
  454. else key
  455. )
  456. public_key.verify(der_sig, msg, ECDSA(self.hash_alg()))
  457. return True
  458. except InvalidSignature:
  459. return False
  460. @overload
  461. @staticmethod
  462. def to_jwk(
  463. key_obj: AllowedECKeys, as_dict: Literal[True]
  464. ) -> JWKDict: ... # pragma: no cover
  465. @overload
  466. @staticmethod
  467. def to_jwk(
  468. key_obj: AllowedECKeys, as_dict: Literal[False] = False
  469. ) -> str: ... # pragma: no cover
  470. @staticmethod
  471. def to_jwk(key_obj: AllowedECKeys, as_dict: bool = False) -> JWKDict | str:
  472. if isinstance(key_obj, EllipticCurvePrivateKey):
  473. public_numbers = key_obj.public_key().public_numbers()
  474. elif isinstance(key_obj, EllipticCurvePublicKey):
  475. public_numbers = key_obj.public_numbers()
  476. else:
  477. raise InvalidKeyError("Not a public or private key")
  478. if isinstance(key_obj.curve, SECP256R1):
  479. crv = "P-256"
  480. elif isinstance(key_obj.curve, SECP384R1):
  481. crv = "P-384"
  482. elif isinstance(key_obj.curve, SECP521R1):
  483. crv = "P-521"
  484. elif isinstance(key_obj.curve, SECP256K1):
  485. crv = "secp256k1"
  486. else:
  487. raise InvalidKeyError(f"Invalid curve: {key_obj.curve}")
  488. obj: dict[str, Any] = {
  489. "kty": "EC",
  490. "crv": crv,
  491. "x": to_base64url_uint(
  492. public_numbers.x,
  493. bit_length=key_obj.curve.key_size,
  494. ).decode(),
  495. "y": to_base64url_uint(
  496. public_numbers.y,
  497. bit_length=key_obj.curve.key_size,
  498. ).decode(),
  499. }
  500. if isinstance(key_obj, EllipticCurvePrivateKey):
  501. obj["d"] = to_base64url_uint(
  502. key_obj.private_numbers().private_value,
  503. bit_length=key_obj.curve.key_size,
  504. ).decode()
  505. if as_dict:
  506. return obj
  507. else:
  508. return json.dumps(obj)
  509. @staticmethod
  510. def from_jwk(jwk: str | JWKDict) -> AllowedECKeys:
  511. try:
  512. if isinstance(jwk, str):
  513. obj = json.loads(jwk)
  514. elif isinstance(jwk, dict):
  515. obj = jwk
  516. else:
  517. raise ValueError
  518. except ValueError:
  519. raise InvalidKeyError("Key is not valid JSON") from None
  520. if obj.get("kty") != "EC":
  521. raise InvalidKeyError("Not an Elliptic curve key") from None
  522. if "x" not in obj or "y" not in obj:
  523. raise InvalidKeyError("Not an Elliptic curve key") from None
  524. x = base64url_decode(obj.get("x"))
  525. y = base64url_decode(obj.get("y"))
  526. curve = obj.get("crv")
  527. curve_obj: EllipticCurve
  528. if curve == "P-256":
  529. if len(x) == len(y) == 32:
  530. curve_obj = SECP256R1()
  531. else:
  532. raise InvalidKeyError(
  533. "Coords should be 32 bytes for curve P-256"
  534. ) from None
  535. elif curve == "P-384":
  536. if len(x) == len(y) == 48:
  537. curve_obj = SECP384R1()
  538. else:
  539. raise InvalidKeyError(
  540. "Coords should be 48 bytes for curve P-384"
  541. ) from None
  542. elif curve == "P-521":
  543. if len(x) == len(y) == 66:
  544. curve_obj = SECP521R1()
  545. else:
  546. raise InvalidKeyError(
  547. "Coords should be 66 bytes for curve P-521"
  548. ) from None
  549. elif curve == "secp256k1":
  550. if len(x) == len(y) == 32:
  551. curve_obj = SECP256K1()
  552. else:
  553. raise InvalidKeyError(
  554. "Coords should be 32 bytes for curve secp256k1"
  555. )
  556. else:
  557. raise InvalidKeyError(f"Invalid curve: {curve}")
  558. public_numbers = EllipticCurvePublicNumbers(
  559. x=int.from_bytes(x, byteorder="big"),
  560. y=int.from_bytes(y, byteorder="big"),
  561. curve=curve_obj,
  562. )
  563. if "d" not in obj:
  564. return public_numbers.public_key()
  565. d = base64url_decode(obj.get("d"))
  566. if len(d) != len(x):
  567. raise InvalidKeyError(
  568. "D should be {} bytes for curve {}", len(x), curve
  569. )
  570. return EllipticCurvePrivateNumbers(
  571. int.from_bytes(d, byteorder="big"), public_numbers
  572. ).private_key()
  573. class RSAPSSAlgorithm(RSAAlgorithm):
  574. """
  575. Performs a signature using RSASSA-PSS with MGF1
  576. """
  577. def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes:
  578. return key.sign(
  579. msg,
  580. padding.PSS(
  581. mgf=padding.MGF1(self.hash_alg()),
  582. salt_length=self.hash_alg().digest_size,
  583. ),
  584. self.hash_alg(),
  585. )
  586. def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool:
  587. try:
  588. key.verify(
  589. sig,
  590. msg,
  591. padding.PSS(
  592. mgf=padding.MGF1(self.hash_alg()),
  593. salt_length=self.hash_alg().digest_size,
  594. ),
  595. self.hash_alg(),
  596. )
  597. return True
  598. except InvalidSignature:
  599. return False
  600. class OKPAlgorithm(Algorithm):
  601. """
  602. Performs signing and verification operations using EdDSA
  603. This class requires ``cryptography>=2.6`` to be installed.
  604. """
  605. def __init__(self, **kwargs: Any) -> None:
  606. pass
  607. def prepare_key(self, key: AllowedOKPKeys | str | bytes) -> AllowedOKPKeys:
  608. if isinstance(key, (bytes, str)):
  609. key_str = key.decode("utf-8") if isinstance(key, bytes) else key
  610. key_bytes = key.encode("utf-8") if isinstance(key, str) else key
  611. if "-----BEGIN PUBLIC" in key_str:
  612. key = load_pem_public_key(key_bytes) # type: ignore[assignment]
  613. elif "-----BEGIN PRIVATE" in key_str:
  614. key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment]
  615. elif key_str[0:4] == "ssh-":
  616. key = load_ssh_public_key(key_bytes) # type: ignore[assignment]
  617. # Explicit check the key to prevent confusing errors from cryptography
  618. if not isinstance(
  619. key,
  620. (Ed25519PrivateKey, Ed25519PublicKey, Ed448PrivateKey, Ed448PublicKey),
  621. ):
  622. raise InvalidKeyError(
  623. "Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for EdDSA algorithms"
  624. )
  625. return key
  626. def sign(
  627. self, msg: str | bytes, key: Ed25519PrivateKey | Ed448PrivateKey
  628. ) -> bytes:
  629. """
  630. Sign a message ``msg`` using the EdDSA private key ``key``
  631. :param str|bytes msg: Message to sign
  632. :param Ed25519PrivateKey}Ed448PrivateKey key: A :class:`.Ed25519PrivateKey`
  633. or :class:`.Ed448PrivateKey` isinstance
  634. :return bytes signature: The signature, as bytes
  635. """
  636. msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg
  637. return key.sign(msg_bytes)
  638. def verify(
  639. self, msg: str | bytes, key: AllowedOKPKeys, sig: str | bytes
  640. ) -> bool:
  641. """
  642. Verify a given ``msg`` against a signature ``sig`` using the EdDSA key ``key``
  643. :param str|bytes sig: EdDSA signature to check ``msg`` against
  644. :param str|bytes msg: Message to sign
  645. :param Ed25519PrivateKey|Ed25519PublicKey|Ed448PrivateKey|Ed448PublicKey key:
  646. A private or public EdDSA key instance
  647. :return bool verified: True if signature is valid, False if not.
  648. """
  649. try:
  650. msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg
  651. sig_bytes = sig.encode("utf-8") if isinstance(sig, str) else sig
  652. public_key = (
  653. key.public_key()
  654. if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey))
  655. else key
  656. )
  657. public_key.verify(sig_bytes, msg_bytes)
  658. return True # If no exception was raised, the signature is valid.
  659. except InvalidSignature:
  660. return False
  661. @overload
  662. @staticmethod
  663. def to_jwk(
  664. key: AllowedOKPKeys, as_dict: Literal[True]
  665. ) -> JWKDict: ... # pragma: no cover
  666. @overload
  667. @staticmethod
  668. def to_jwk(
  669. key: AllowedOKPKeys, as_dict: Literal[False] = False
  670. ) -> str: ... # pragma: no cover
  671. @staticmethod
  672. def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> JWKDict | str:
  673. if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)):
  674. x = key.public_bytes(
  675. encoding=Encoding.Raw,
  676. format=PublicFormat.Raw,
  677. )
  678. crv = "Ed25519" if isinstance(key, Ed25519PublicKey) else "Ed448"
  679. obj = {
  680. "x": base64url_encode(force_bytes(x)).decode(),
  681. "kty": "OKP",
  682. "crv": crv,
  683. }
  684. if as_dict:
  685. return obj
  686. else:
  687. return json.dumps(obj)
  688. if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)):
  689. d = key.private_bytes(
  690. encoding=Encoding.Raw,
  691. format=PrivateFormat.Raw,
  692. encryption_algorithm=NoEncryption(),
  693. )
  694. x = key.public_key().public_bytes(
  695. encoding=Encoding.Raw,
  696. format=PublicFormat.Raw,
  697. )
  698. crv = "Ed25519" if isinstance(key, Ed25519PrivateKey) else "Ed448"
  699. obj = {
  700. "x": base64url_encode(force_bytes(x)).decode(),
  701. "d": base64url_encode(force_bytes(d)).decode(),
  702. "kty": "OKP",
  703. "crv": crv,
  704. }
  705. if as_dict:
  706. return obj
  707. else:
  708. return json.dumps(obj)
  709. raise InvalidKeyError("Not a public or private key")
  710. @staticmethod
  711. def from_jwk(jwk: str | JWKDict) -> AllowedOKPKeys:
  712. try:
  713. if isinstance(jwk, str):
  714. obj = json.loads(jwk)
  715. elif isinstance(jwk, dict):
  716. obj = jwk
  717. else:
  718. raise ValueError
  719. except ValueError:
  720. raise InvalidKeyError("Key is not valid JSON") from None
  721. if obj.get("kty") != "OKP":
  722. raise InvalidKeyError("Not an Octet Key Pair")
  723. curve = obj.get("crv")
  724. if curve != "Ed25519" and curve != "Ed448":
  725. raise InvalidKeyError(f"Invalid curve: {curve}")
  726. if "x" not in obj:
  727. raise InvalidKeyError('OKP should have "x" parameter')
  728. x = base64url_decode(obj.get("x"))
  729. try:
  730. if "d" not in obj:
  731. if curve == "Ed25519":
  732. return Ed25519PublicKey.from_public_bytes(x)
  733. return Ed448PublicKey.from_public_bytes(x)
  734. d = base64url_decode(obj.get("d"))
  735. if curve == "Ed25519":
  736. return Ed25519PrivateKey.from_private_bytes(d)
  737. return Ed448PrivateKey.from_private_bytes(d)
  738. except ValueError as err:
  739. raise InvalidKeyError("Invalid key parameter") from err