123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875 |
- from __future__ import annotations
- import hashlib
- import hmac
- import json
- from abc import ABC, abstractmethod
- from typing import TYPE_CHECKING, Any, ClassVar, Literal, NoReturn, cast, overload
- from .exceptions import InvalidKeyError
- from .types import HashlibHash, JWKDict
- from .utils import (
- base64url_decode,
- base64url_encode,
- der_to_raw_signature,
- force_bytes,
- from_base64url_uint,
- is_pem_format,
- is_ssh_key,
- raw_to_der_signature,
- to_base64url_uint,
- )
- try:
- from cryptography.exceptions import InvalidSignature, UnsupportedAlgorithm
- from cryptography.hazmat.backends import default_backend
- from cryptography.hazmat.primitives import hashes
- from cryptography.hazmat.primitives.asymmetric import padding
- from cryptography.hazmat.primitives.asymmetric.ec import (
- ECDSA,
- SECP256K1,
- SECP256R1,
- SECP384R1,
- SECP521R1,
- EllipticCurve,
- EllipticCurvePrivateKey,
- EllipticCurvePrivateNumbers,
- EllipticCurvePublicKey,
- EllipticCurvePublicNumbers,
- )
- from cryptography.hazmat.primitives.asymmetric.ed448 import (
- Ed448PrivateKey,
- Ed448PublicKey,
- )
- from cryptography.hazmat.primitives.asymmetric.ed25519 import (
- Ed25519PrivateKey,
- Ed25519PublicKey,
- )
- from cryptography.hazmat.primitives.asymmetric.rsa import (
- RSAPrivateKey,
- RSAPrivateNumbers,
- RSAPublicKey,
- RSAPublicNumbers,
- rsa_crt_dmp1,
- rsa_crt_dmq1,
- rsa_crt_iqmp,
- rsa_recover_prime_factors,
- )
- from cryptography.hazmat.primitives.serialization import (
- Encoding,
- NoEncryption,
- PrivateFormat,
- PublicFormat,
- load_pem_private_key,
- load_pem_public_key,
- load_ssh_public_key,
- )
- has_crypto = True
- except ModuleNotFoundError:
- has_crypto = False
- if TYPE_CHECKING:
- # Type aliases for convenience in algorithms method signatures
- AllowedRSAKeys = RSAPrivateKey | RSAPublicKey
- AllowedECKeys = EllipticCurvePrivateKey | EllipticCurvePublicKey
- AllowedOKPKeys = (
- Ed25519PrivateKey | Ed25519PublicKey | Ed448PrivateKey | Ed448PublicKey
- )
- AllowedKeys = AllowedRSAKeys | AllowedECKeys | AllowedOKPKeys
- AllowedPrivateKeys = (
- RSAPrivateKey | EllipticCurvePrivateKey | Ed25519PrivateKey | Ed448PrivateKey
- )
- AllowedPublicKeys = (
- RSAPublicKey | EllipticCurvePublicKey | Ed25519PublicKey | Ed448PublicKey
- )
- requires_cryptography = {
- "RS256",
- "RS384",
- "RS512",
- "ES256",
- "ES256K",
- "ES384",
- "ES521",
- "ES512",
- "PS256",
- "PS384",
- "PS512",
- "EdDSA",
- }
- def get_default_algorithms() -> dict[str, Algorithm]:
- """
- Returns the algorithms that are implemented by the library.
- """
- default_algorithms = {
- "none": NoneAlgorithm(),
- "HS256": HMACAlgorithm(HMACAlgorithm.SHA256),
- "HS384": HMACAlgorithm(HMACAlgorithm.SHA384),
- "HS512": HMACAlgorithm(HMACAlgorithm.SHA512),
- }
- if has_crypto:
- default_algorithms.update(
- {
- "RS256": RSAAlgorithm(RSAAlgorithm.SHA256),
- "RS384": RSAAlgorithm(RSAAlgorithm.SHA384),
- "RS512": RSAAlgorithm(RSAAlgorithm.SHA512),
- "ES256": ECAlgorithm(ECAlgorithm.SHA256),
- "ES256K": ECAlgorithm(ECAlgorithm.SHA256),
- "ES384": ECAlgorithm(ECAlgorithm.SHA384),
- "ES521": ECAlgorithm(ECAlgorithm.SHA512),
- "ES512": ECAlgorithm(
- ECAlgorithm.SHA512
- ), # Backward compat for #219 fix
- "PS256": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256),
- "PS384": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384),
- "PS512": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512),
- "EdDSA": OKPAlgorithm(),
- }
- )
- return default_algorithms
- class Algorithm(ABC):
- """
- The interface for an algorithm used to sign and verify tokens.
- """
- def compute_hash_digest(self, bytestr: bytes) -> bytes:
- """
- Compute a hash digest using the specified algorithm's hash algorithm.
- If there is no hash algorithm, raises a NotImplementedError.
- """
- # lookup self.hash_alg if defined in a way that mypy can understand
- hash_alg = getattr(self, "hash_alg", None)
- if hash_alg is None:
- raise NotImplementedError
- if (
- has_crypto
- and isinstance(hash_alg, type)
- and issubclass(hash_alg, hashes.HashAlgorithm)
- ):
- digest = hashes.Hash(hash_alg(), backend=default_backend())
- digest.update(bytestr)
- return bytes(digest.finalize())
- else:
- return bytes(hash_alg(bytestr).digest())
- @abstractmethod
- def prepare_key(self, key: Any) -> Any:
- """
- Performs necessary validation and conversions on the key and returns
- the key value in the proper format for sign() and verify().
- """
- @abstractmethod
- def sign(self, msg: bytes, key: Any) -> bytes:
- """
- Returns a digital signature for the specified message
- using the specified key value.
- """
- @abstractmethod
- def verify(self, msg: bytes, key: Any, sig: bytes) -> bool:
- """
- Verifies that the specified digital signature is valid
- for the specified message and key values.
- """
- @overload
- @staticmethod
- @abstractmethod
- def to_jwk(key_obj, as_dict: Literal[True]) -> JWKDict: ... # pragma: no cover
- @overload
- @staticmethod
- @abstractmethod
- def to_jwk(key_obj, as_dict: Literal[False] = False) -> str: ... # pragma: no cover
- @staticmethod
- @abstractmethod
- def to_jwk(key_obj, as_dict: bool = False) -> JWKDict | str:
- """
- Serializes a given key into a JWK
- """
- @staticmethod
- @abstractmethod
- def from_jwk(jwk: str | JWKDict) -> Any:
- """
- Deserializes a given key from JWK back into a key object
- """
- class NoneAlgorithm(Algorithm):
- """
- Placeholder for use when no signing or verification
- operations are required.
- """
- def prepare_key(self, key: str | None) -> None:
- if key == "":
- key = None
- if key is not None:
- raise InvalidKeyError('When alg = "none", key value must be None.')
- return key
- def sign(self, msg: bytes, key: None) -> bytes:
- return b""
- def verify(self, msg: bytes, key: None, sig: bytes) -> bool:
- return False
- @staticmethod
- def to_jwk(key_obj: Any, as_dict: bool = False) -> NoReturn:
- raise NotImplementedError()
- @staticmethod
- def from_jwk(jwk: str | JWKDict) -> NoReturn:
- raise NotImplementedError()
- class HMACAlgorithm(Algorithm):
- """
- Performs signing and verification operations using HMAC
- and the specified hash function.
- """
- SHA256: ClassVar[HashlibHash] = hashlib.sha256
- SHA384: ClassVar[HashlibHash] = hashlib.sha384
- SHA512: ClassVar[HashlibHash] = hashlib.sha512
- def __init__(self, hash_alg: HashlibHash) -> None:
- self.hash_alg = hash_alg
- def prepare_key(self, key: str | bytes) -> bytes:
- key_bytes = force_bytes(key)
- if is_pem_format(key_bytes) or is_ssh_key(key_bytes):
- raise InvalidKeyError(
- "The specified key is an asymmetric key or x509 certificate and"
- " should not be used as an HMAC secret."
- )
- return key_bytes
- @overload
- @staticmethod
- def to_jwk(
- key_obj: str | bytes, as_dict: Literal[True]
- ) -> JWKDict: ... # pragma: no cover
- @overload
- @staticmethod
- def to_jwk(
- key_obj: str | bytes, as_dict: Literal[False] = False
- ) -> str: ... # pragma: no cover
- @staticmethod
- def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> JWKDict | str:
- jwk = {
- "k": base64url_encode(force_bytes(key_obj)).decode(),
- "kty": "oct",
- }
- if as_dict:
- return jwk
- else:
- return json.dumps(jwk)
- @staticmethod
- def from_jwk(jwk: str | JWKDict) -> bytes:
- try:
- if isinstance(jwk, str):
- obj: JWKDict = json.loads(jwk)
- elif isinstance(jwk, dict):
- obj = jwk
- else:
- raise ValueError
- except ValueError:
- raise InvalidKeyError("Key is not valid JSON") from None
- if obj.get("kty") != "oct":
- raise InvalidKeyError("Not an HMAC key")
- return base64url_decode(obj["k"])
- def sign(self, msg: bytes, key: bytes) -> bytes:
- return hmac.new(key, msg, self.hash_alg).digest()
- def verify(self, msg: bytes, key: bytes, sig: bytes) -> bool:
- return hmac.compare_digest(sig, self.sign(msg, key))
- if has_crypto:
- class RSAAlgorithm(Algorithm):
- """
- Performs signing and verification operations using
- RSASSA-PKCS-v1_5 and the specified hash function.
- """
- SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256
- SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384
- SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512
- def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
- self.hash_alg = hash_alg
- def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys:
- if isinstance(key, (RSAPrivateKey, RSAPublicKey)):
- return key
- if not isinstance(key, (bytes, str)):
- raise TypeError("Expecting a PEM-formatted key.")
- key_bytes = force_bytes(key)
- try:
- if key_bytes.startswith(b"ssh-rsa"):
- return cast(RSAPublicKey, load_ssh_public_key(key_bytes))
- else:
- return cast(
- RSAPrivateKey, load_pem_private_key(key_bytes, password=None)
- )
- except ValueError:
- try:
- return cast(RSAPublicKey, load_pem_public_key(key_bytes))
- except (ValueError, UnsupportedAlgorithm):
- raise InvalidKeyError(
- "Could not parse the provided public key."
- ) from None
- @overload
- @staticmethod
- def to_jwk(
- key_obj: AllowedRSAKeys, as_dict: Literal[True]
- ) -> JWKDict: ... # pragma: no cover
- @overload
- @staticmethod
- def to_jwk(
- key_obj: AllowedRSAKeys, as_dict: Literal[False] = False
- ) -> str: ... # pragma: no cover
- @staticmethod
- def to_jwk(key_obj: AllowedRSAKeys, as_dict: bool = False) -> JWKDict | str:
- obj: dict[str, Any] | None = None
- if hasattr(key_obj, "private_numbers"):
- # Private key
- numbers = key_obj.private_numbers()
- obj = {
- "kty": "RSA",
- "key_ops": ["sign"],
- "n": to_base64url_uint(numbers.public_numbers.n).decode(),
- "e": to_base64url_uint(numbers.public_numbers.e).decode(),
- "d": to_base64url_uint(numbers.d).decode(),
- "p": to_base64url_uint(numbers.p).decode(),
- "q": to_base64url_uint(numbers.q).decode(),
- "dp": to_base64url_uint(numbers.dmp1).decode(),
- "dq": to_base64url_uint(numbers.dmq1).decode(),
- "qi": to_base64url_uint(numbers.iqmp).decode(),
- }
- elif hasattr(key_obj, "verify"):
- # Public key
- numbers = key_obj.public_numbers()
- obj = {
- "kty": "RSA",
- "key_ops": ["verify"],
- "n": to_base64url_uint(numbers.n).decode(),
- "e": to_base64url_uint(numbers.e).decode(),
- }
- else:
- raise InvalidKeyError("Not a public or private key")
- if as_dict:
- return obj
- else:
- return json.dumps(obj)
- @staticmethod
- def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys:
- try:
- if isinstance(jwk, str):
- obj = json.loads(jwk)
- elif isinstance(jwk, dict):
- obj = jwk
- else:
- raise ValueError
- except ValueError:
- raise InvalidKeyError("Key is not valid JSON") from None
- if obj.get("kty") != "RSA":
- raise InvalidKeyError("Not an RSA key") from None
- if "d" in obj and "e" in obj and "n" in obj:
- # Private key
- if "oth" in obj:
- raise InvalidKeyError(
- "Unsupported RSA private key: > 2 primes not supported"
- )
- other_props = ["p", "q", "dp", "dq", "qi"]
- props_found = [prop in obj for prop in other_props]
- any_props_found = any(props_found)
- if any_props_found and not all(props_found):
- raise InvalidKeyError(
- "RSA key must include all parameters if any are present besides d"
- ) from None
- public_numbers = RSAPublicNumbers(
- from_base64url_uint(obj["e"]),
- from_base64url_uint(obj["n"]),
- )
- if any_props_found:
- numbers = RSAPrivateNumbers(
- d=from_base64url_uint(obj["d"]),
- p=from_base64url_uint(obj["p"]),
- q=from_base64url_uint(obj["q"]),
- dmp1=from_base64url_uint(obj["dp"]),
- dmq1=from_base64url_uint(obj["dq"]),
- iqmp=from_base64url_uint(obj["qi"]),
- public_numbers=public_numbers,
- )
- else:
- d = from_base64url_uint(obj["d"])
- p, q = rsa_recover_prime_factors(
- public_numbers.n, d, public_numbers.e
- )
- numbers = RSAPrivateNumbers(
- d=d,
- p=p,
- q=q,
- dmp1=rsa_crt_dmp1(d, p),
- dmq1=rsa_crt_dmq1(d, q),
- iqmp=rsa_crt_iqmp(p, q),
- public_numbers=public_numbers,
- )
- return numbers.private_key()
- elif "n" in obj and "e" in obj:
- # Public key
- return RSAPublicNumbers(
- from_base64url_uint(obj["e"]),
- from_base64url_uint(obj["n"]),
- ).public_key()
- else:
- raise InvalidKeyError("Not a public or private key")
- def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes:
- return key.sign(msg, padding.PKCS1v15(), self.hash_alg())
- def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool:
- try:
- key.verify(sig, msg, padding.PKCS1v15(), self.hash_alg())
- return True
- except InvalidSignature:
- return False
- class ECAlgorithm(Algorithm):
- """
- Performs signing and verification operations using
- ECDSA and the specified hash function
- """
- SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256
- SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384
- SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512
- def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
- self.hash_alg = hash_alg
- def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys:
- if isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)):
- return key
- if not isinstance(key, (bytes, str)):
- raise TypeError("Expecting a PEM-formatted key.")
- key_bytes = force_bytes(key)
- # Attempt to load key. We don't know if it's
- # a Signing Key or a Verifying Key, so we try
- # the Verifying Key first.
- try:
- if key_bytes.startswith(b"ecdsa-sha2-"):
- crypto_key = load_ssh_public_key(key_bytes)
- else:
- crypto_key = load_pem_public_key(key_bytes) # type: ignore[assignment]
- except ValueError:
- crypto_key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment]
- # Explicit check the key to prevent confusing errors from cryptography
- if not isinstance(
- crypto_key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)
- ):
- raise InvalidKeyError(
- "Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for ECDSA algorithms"
- ) from None
- return crypto_key
- def sign(self, msg: bytes, key: EllipticCurvePrivateKey) -> bytes:
- der_sig = key.sign(msg, ECDSA(self.hash_alg()))
- return der_to_raw_signature(der_sig, key.curve)
- def verify(self, msg: bytes, key: AllowedECKeys, sig: bytes) -> bool:
- try:
- der_sig = raw_to_der_signature(sig, key.curve)
- except ValueError:
- return False
- try:
- public_key = (
- key.public_key()
- if isinstance(key, EllipticCurvePrivateKey)
- else key
- )
- public_key.verify(der_sig, msg, ECDSA(self.hash_alg()))
- return True
- except InvalidSignature:
- return False
- @overload
- @staticmethod
- def to_jwk(
- key_obj: AllowedECKeys, as_dict: Literal[True]
- ) -> JWKDict: ... # pragma: no cover
- @overload
- @staticmethod
- def to_jwk(
- key_obj: AllowedECKeys, as_dict: Literal[False] = False
- ) -> str: ... # pragma: no cover
- @staticmethod
- def to_jwk(key_obj: AllowedECKeys, as_dict: bool = False) -> JWKDict | str:
- if isinstance(key_obj, EllipticCurvePrivateKey):
- public_numbers = key_obj.public_key().public_numbers()
- elif isinstance(key_obj, EllipticCurvePublicKey):
- public_numbers = key_obj.public_numbers()
- else:
- raise InvalidKeyError("Not a public or private key")
- if isinstance(key_obj.curve, SECP256R1):
- crv = "P-256"
- elif isinstance(key_obj.curve, SECP384R1):
- crv = "P-384"
- elif isinstance(key_obj.curve, SECP521R1):
- crv = "P-521"
- elif isinstance(key_obj.curve, SECP256K1):
- crv = "secp256k1"
- else:
- raise InvalidKeyError(f"Invalid curve: {key_obj.curve}")
- obj: dict[str, Any] = {
- "kty": "EC",
- "crv": crv,
- "x": to_base64url_uint(
- public_numbers.x,
- bit_length=key_obj.curve.key_size,
- ).decode(),
- "y": to_base64url_uint(
- public_numbers.y,
- bit_length=key_obj.curve.key_size,
- ).decode(),
- }
- if isinstance(key_obj, EllipticCurvePrivateKey):
- obj["d"] = to_base64url_uint(
- key_obj.private_numbers().private_value,
- bit_length=key_obj.curve.key_size,
- ).decode()
- if as_dict:
- return obj
- else:
- return json.dumps(obj)
- @staticmethod
- def from_jwk(jwk: str | JWKDict) -> AllowedECKeys:
- try:
- if isinstance(jwk, str):
- obj = json.loads(jwk)
- elif isinstance(jwk, dict):
- obj = jwk
- else:
- raise ValueError
- except ValueError:
- raise InvalidKeyError("Key is not valid JSON") from None
- if obj.get("kty") != "EC":
- raise InvalidKeyError("Not an Elliptic curve key") from None
- if "x" not in obj or "y" not in obj:
- raise InvalidKeyError("Not an Elliptic curve key") from None
- x = base64url_decode(obj.get("x"))
- y = base64url_decode(obj.get("y"))
- curve = obj.get("crv")
- curve_obj: EllipticCurve
- if curve == "P-256":
- if len(x) == len(y) == 32:
- curve_obj = SECP256R1()
- else:
- raise InvalidKeyError(
- "Coords should be 32 bytes for curve P-256"
- ) from None
- elif curve == "P-384":
- if len(x) == len(y) == 48:
- curve_obj = SECP384R1()
- else:
- raise InvalidKeyError(
- "Coords should be 48 bytes for curve P-384"
- ) from None
- elif curve == "P-521":
- if len(x) == len(y) == 66:
- curve_obj = SECP521R1()
- else:
- raise InvalidKeyError(
- "Coords should be 66 bytes for curve P-521"
- ) from None
- elif curve == "secp256k1":
- if len(x) == len(y) == 32:
- curve_obj = SECP256K1()
- else:
- raise InvalidKeyError(
- "Coords should be 32 bytes for curve secp256k1"
- )
- else:
- raise InvalidKeyError(f"Invalid curve: {curve}")
- public_numbers = EllipticCurvePublicNumbers(
- x=int.from_bytes(x, byteorder="big"),
- y=int.from_bytes(y, byteorder="big"),
- curve=curve_obj,
- )
- if "d" not in obj:
- return public_numbers.public_key()
- d = base64url_decode(obj.get("d"))
- if len(d) != len(x):
- raise InvalidKeyError(
- "D should be {} bytes for curve {}", len(x), curve
- )
- return EllipticCurvePrivateNumbers(
- int.from_bytes(d, byteorder="big"), public_numbers
- ).private_key()
- class RSAPSSAlgorithm(RSAAlgorithm):
- """
- Performs a signature using RSASSA-PSS with MGF1
- """
- def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes:
- return key.sign(
- msg,
- padding.PSS(
- mgf=padding.MGF1(self.hash_alg()),
- salt_length=self.hash_alg().digest_size,
- ),
- self.hash_alg(),
- )
- def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool:
- try:
- key.verify(
- sig,
- msg,
- padding.PSS(
- mgf=padding.MGF1(self.hash_alg()),
- salt_length=self.hash_alg().digest_size,
- ),
- self.hash_alg(),
- )
- return True
- except InvalidSignature:
- return False
- class OKPAlgorithm(Algorithm):
- """
- Performs signing and verification operations using EdDSA
- This class requires ``cryptography>=2.6`` to be installed.
- """
- def __init__(self, **kwargs: Any) -> None:
- pass
- def prepare_key(self, key: AllowedOKPKeys | str | bytes) -> AllowedOKPKeys:
- if isinstance(key, (bytes, str)):
- key_str = key.decode("utf-8") if isinstance(key, bytes) else key
- key_bytes = key.encode("utf-8") if isinstance(key, str) else key
- if "-----BEGIN PUBLIC" in key_str:
- key = load_pem_public_key(key_bytes) # type: ignore[assignment]
- elif "-----BEGIN PRIVATE" in key_str:
- key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment]
- elif key_str[0:4] == "ssh-":
- key = load_ssh_public_key(key_bytes) # type: ignore[assignment]
- # Explicit check the key to prevent confusing errors from cryptography
- if not isinstance(
- key,
- (Ed25519PrivateKey, Ed25519PublicKey, Ed448PrivateKey, Ed448PublicKey),
- ):
- raise InvalidKeyError(
- "Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for EdDSA algorithms"
- )
- return key
- def sign(
- self, msg: str | bytes, key: Ed25519PrivateKey | Ed448PrivateKey
- ) -> bytes:
- """
- Sign a message ``msg`` using the EdDSA private key ``key``
- :param str|bytes msg: Message to sign
- :param Ed25519PrivateKey}Ed448PrivateKey key: A :class:`.Ed25519PrivateKey`
- or :class:`.Ed448PrivateKey` isinstance
- :return bytes signature: The signature, as bytes
- """
- msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg
- return key.sign(msg_bytes)
- def verify(
- self, msg: str | bytes, key: AllowedOKPKeys, sig: str | bytes
- ) -> bool:
- """
- Verify a given ``msg`` against a signature ``sig`` using the EdDSA key ``key``
- :param str|bytes sig: EdDSA signature to check ``msg`` against
- :param str|bytes msg: Message to sign
- :param Ed25519PrivateKey|Ed25519PublicKey|Ed448PrivateKey|Ed448PublicKey key:
- A private or public EdDSA key instance
- :return bool verified: True if signature is valid, False if not.
- """
- try:
- msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg
- sig_bytes = sig.encode("utf-8") if isinstance(sig, str) else sig
- public_key = (
- key.public_key()
- if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey))
- else key
- )
- public_key.verify(sig_bytes, msg_bytes)
- return True # If no exception was raised, the signature is valid.
- except InvalidSignature:
- return False
- @overload
- @staticmethod
- def to_jwk(
- key: AllowedOKPKeys, as_dict: Literal[True]
- ) -> JWKDict: ... # pragma: no cover
- @overload
- @staticmethod
- def to_jwk(
- key: AllowedOKPKeys, as_dict: Literal[False] = False
- ) -> str: ... # pragma: no cover
- @staticmethod
- def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> JWKDict | str:
- if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)):
- x = key.public_bytes(
- encoding=Encoding.Raw,
- format=PublicFormat.Raw,
- )
- crv = "Ed25519" if isinstance(key, Ed25519PublicKey) else "Ed448"
- obj = {
- "x": base64url_encode(force_bytes(x)).decode(),
- "kty": "OKP",
- "crv": crv,
- }
- if as_dict:
- return obj
- else:
- return json.dumps(obj)
- if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)):
- d = key.private_bytes(
- encoding=Encoding.Raw,
- format=PrivateFormat.Raw,
- encryption_algorithm=NoEncryption(),
- )
- x = key.public_key().public_bytes(
- encoding=Encoding.Raw,
- format=PublicFormat.Raw,
- )
- crv = "Ed25519" if isinstance(key, Ed25519PrivateKey) else "Ed448"
- obj = {
- "x": base64url_encode(force_bytes(x)).decode(),
- "d": base64url_encode(force_bytes(d)).decode(),
- "kty": "OKP",
- "crv": crv,
- }
- if as_dict:
- return obj
- else:
- return json.dumps(obj)
- raise InvalidKeyError("Not a public or private key")
- @staticmethod
- def from_jwk(jwk: str | JWKDict) -> AllowedOKPKeys:
- try:
- if isinstance(jwk, str):
- obj = json.loads(jwk)
- elif isinstance(jwk, dict):
- obj = jwk
- else:
- raise ValueError
- except ValueError:
- raise InvalidKeyError("Key is not valid JSON") from None
- if obj.get("kty") != "OKP":
- raise InvalidKeyError("Not an Octet Key Pair")
- curve = obj.get("crv")
- if curve != "Ed25519" and curve != "Ed448":
- raise InvalidKeyError(f"Invalid curve: {curve}")
- if "x" not in obj:
- raise InvalidKeyError('OKP should have "x" parameter')
- x = base64url_decode(obj.get("x"))
- try:
- if "d" not in obj:
- if curve == "Ed25519":
- return Ed25519PublicKey.from_public_bytes(x)
- return Ed448PublicKey.from_public_bytes(x)
- d = base64url_decode(obj.get("d"))
- if curve == "Ed25519":
- return Ed25519PrivateKey.from_private_bytes(d)
- return Ed448PrivateKey.from_private_bytes(d)
- except ValueError as err:
- raise InvalidKeyError("Invalid key parameter") from err
|