api_jwk.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. from __future__ import annotations
  2. import json
  3. import time
  4. from typing import Any
  5. from .algorithms import get_default_algorithms, has_crypto, requires_cryptography
  6. from .exceptions import (
  7. InvalidKeyError,
  8. MissingCryptographyError,
  9. PyJWKError,
  10. PyJWKSetError,
  11. PyJWTError,
  12. )
  13. from .types import JWKDict
  14. class PyJWK:
  15. def __init__(self, jwk_data: JWKDict, algorithm: str | None = None) -> None:
  16. self._algorithms = get_default_algorithms()
  17. self._jwk_data = jwk_data
  18. kty = self._jwk_data.get("kty", None)
  19. if not kty:
  20. raise InvalidKeyError(f"kty is not found: {self._jwk_data}")
  21. if not algorithm and isinstance(self._jwk_data, dict):
  22. algorithm = self._jwk_data.get("alg", None)
  23. if not algorithm:
  24. # Determine alg with kty (and crv).
  25. crv = self._jwk_data.get("crv", None)
  26. if kty == "EC":
  27. if crv == "P-256" or not crv:
  28. algorithm = "ES256"
  29. elif crv == "P-384":
  30. algorithm = "ES384"
  31. elif crv == "P-521":
  32. algorithm = "ES512"
  33. elif crv == "secp256k1":
  34. algorithm = "ES256K"
  35. else:
  36. raise InvalidKeyError(f"Unsupported crv: {crv}")
  37. elif kty == "RSA":
  38. algorithm = "RS256"
  39. elif kty == "oct":
  40. algorithm = "HS256"
  41. elif kty == "OKP":
  42. if not crv:
  43. raise InvalidKeyError(f"crv is not found: {self._jwk_data}")
  44. if crv == "Ed25519":
  45. algorithm = "EdDSA"
  46. else:
  47. raise InvalidKeyError(f"Unsupported crv: {crv}")
  48. else:
  49. raise InvalidKeyError(f"Unsupported kty: {kty}")
  50. if not has_crypto and algorithm in requires_cryptography:
  51. raise MissingCryptographyError(
  52. f"{algorithm} requires 'cryptography' to be installed."
  53. )
  54. self.algorithm_name = algorithm
  55. if algorithm in self._algorithms:
  56. self.Algorithm = self._algorithms[algorithm]
  57. else:
  58. raise PyJWKError(f"Unable to find an algorithm for key: {self._jwk_data}")
  59. self.key = self.Algorithm.from_jwk(self._jwk_data)
  60. @staticmethod
  61. def from_dict(obj: JWKDict, algorithm: str | None = None) -> PyJWK:
  62. return PyJWK(obj, algorithm)
  63. @staticmethod
  64. def from_json(data: str, algorithm: None = None) -> PyJWK:
  65. obj = json.loads(data)
  66. return PyJWK.from_dict(obj, algorithm)
  67. @property
  68. def key_type(self) -> str | None:
  69. return self._jwk_data.get("kty", None)
  70. @property
  71. def key_id(self) -> str | None:
  72. return self._jwk_data.get("kid", None)
  73. @property
  74. def public_key_use(self) -> str | None:
  75. return self._jwk_data.get("use", None)
  76. class PyJWKSet:
  77. def __init__(self, keys: list[JWKDict]) -> None:
  78. self.keys = []
  79. if not keys:
  80. raise PyJWKSetError("The JWK Set did not contain any keys")
  81. if not isinstance(keys, list):
  82. raise PyJWKSetError("Invalid JWK Set value")
  83. for key in keys:
  84. try:
  85. self.keys.append(PyJWK(key))
  86. except PyJWTError as error:
  87. if isinstance(error, MissingCryptographyError):
  88. raise error
  89. # skip unusable keys
  90. continue
  91. if len(self.keys) == 0:
  92. raise PyJWKSetError(
  93. "The JWK Set did not contain any usable keys. Perhaps 'cryptography' is not installed?"
  94. )
  95. @staticmethod
  96. def from_dict(obj: dict[str, Any]) -> PyJWKSet:
  97. keys = obj.get("keys", [])
  98. return PyJWKSet(keys)
  99. @staticmethod
  100. def from_json(data: str) -> PyJWKSet:
  101. obj = json.loads(data)
  102. return PyJWKSet.from_dict(obj)
  103. def __getitem__(self, kid: str) -> PyJWK:
  104. for key in self.keys:
  105. if key.key_id == kid:
  106. return key
  107. raise KeyError(f"keyset has no key for kid: {kid}")
  108. class PyJWTSetWithTimestamp:
  109. def __init__(self, jwk_set: PyJWKSet):
  110. self.jwk_set = jwk_set
  111. self.timestamp = time.monotonic()
  112. def get_jwk_set(self) -> PyJWKSet:
  113. return self.jwk_set
  114. def get_timestamp(self) -> float:
  115. return self.timestamp