api_jws.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. from __future__ import annotations
  2. import binascii
  3. import json
  4. import warnings
  5. from collections.abc import Sequence
  6. from typing import TYPE_CHECKING, Any
  7. from .algorithms import (
  8. Algorithm,
  9. get_default_algorithms,
  10. has_crypto,
  11. requires_cryptography,
  12. )
  13. from .api_jwk import PyJWK
  14. from .exceptions import (
  15. DecodeError,
  16. InvalidAlgorithmError,
  17. InvalidSignatureError,
  18. InvalidTokenError,
  19. )
  20. from .utils import base64url_decode, base64url_encode
  21. from .warnings import RemovedInPyjwt3Warning
  22. if TYPE_CHECKING:
  23. from .algorithms import AllowedPrivateKeys, AllowedPublicKeys
  24. class PyJWS:
  25. header_typ = "JWT"
  26. def __init__(
  27. self,
  28. algorithms: Sequence[str] | None = None,
  29. options: dict[str, Any] | None = None,
  30. ) -> None:
  31. self._algorithms = get_default_algorithms()
  32. self._valid_algs = (
  33. set(algorithms) if algorithms is not None else set(self._algorithms)
  34. )
  35. # Remove algorithms that aren't on the whitelist
  36. for key in list(self._algorithms.keys()):
  37. if key not in self._valid_algs:
  38. del self._algorithms[key]
  39. if options is None:
  40. options = {}
  41. self.options = {**self._get_default_options(), **options}
  42. @staticmethod
  43. def _get_default_options() -> dict[str, bool]:
  44. return {"verify_signature": True}
  45. def register_algorithm(self, alg_id: str, alg_obj: Algorithm) -> None:
  46. """
  47. Registers a new Algorithm for use when creating and verifying tokens.
  48. """
  49. if alg_id in self._algorithms:
  50. raise ValueError("Algorithm already has a handler.")
  51. if not isinstance(alg_obj, Algorithm):
  52. raise TypeError("Object is not of type `Algorithm`")
  53. self._algorithms[alg_id] = alg_obj
  54. self._valid_algs.add(alg_id)
  55. def unregister_algorithm(self, alg_id: str) -> None:
  56. """
  57. Unregisters an Algorithm for use when creating and verifying tokens
  58. Throws KeyError if algorithm is not registered.
  59. """
  60. if alg_id not in self._algorithms:
  61. raise KeyError(
  62. "The specified algorithm could not be removed"
  63. " because it is not registered."
  64. )
  65. del self._algorithms[alg_id]
  66. self._valid_algs.remove(alg_id)
  67. def get_algorithms(self) -> list[str]:
  68. """
  69. Returns a list of supported values for the 'alg' parameter.
  70. """
  71. return list(self._valid_algs)
  72. def get_algorithm_by_name(self, alg_name: str) -> Algorithm:
  73. """
  74. For a given string name, return the matching Algorithm object.
  75. Example usage:
  76. >>> jws_obj.get_algorithm_by_name("RS256")
  77. """
  78. try:
  79. return self._algorithms[alg_name]
  80. except KeyError as e:
  81. if not has_crypto and alg_name in requires_cryptography:
  82. raise NotImplementedError(
  83. f"Algorithm '{alg_name}' could not be found. Do you have cryptography installed?"
  84. ) from e
  85. raise NotImplementedError("Algorithm not supported") from e
  86. def encode(
  87. self,
  88. payload: bytes,
  89. key: AllowedPrivateKeys | PyJWK | str | bytes,
  90. algorithm: str | None = None,
  91. headers: dict[str, Any] | None = None,
  92. json_encoder: type[json.JSONEncoder] | None = None,
  93. is_payload_detached: bool = False,
  94. sort_headers: bool = True,
  95. ) -> str:
  96. segments = []
  97. # declare a new var to narrow the type for type checkers
  98. if algorithm is None:
  99. if isinstance(key, PyJWK):
  100. algorithm_ = key.algorithm_name
  101. else:
  102. algorithm_ = "HS256"
  103. else:
  104. algorithm_ = algorithm
  105. # Prefer headers values if present to function parameters.
  106. if headers:
  107. headers_alg = headers.get("alg")
  108. if headers_alg:
  109. algorithm_ = headers["alg"]
  110. headers_b64 = headers.get("b64")
  111. if headers_b64 is False:
  112. is_payload_detached = True
  113. # Header
  114. header: dict[str, Any] = {"typ": self.header_typ, "alg": algorithm_}
  115. if headers:
  116. self._validate_headers(headers)
  117. header.update(headers)
  118. if not header["typ"]:
  119. del header["typ"]
  120. if is_payload_detached:
  121. header["b64"] = False
  122. elif "b64" in header:
  123. # True is the standard value for b64, so no need for it
  124. del header["b64"]
  125. json_header = json.dumps(
  126. header, separators=(",", ":"), cls=json_encoder, sort_keys=sort_headers
  127. ).encode()
  128. segments.append(base64url_encode(json_header))
  129. if is_payload_detached:
  130. msg_payload = payload
  131. else:
  132. msg_payload = base64url_encode(payload)
  133. segments.append(msg_payload)
  134. # Segments
  135. signing_input = b".".join(segments)
  136. alg_obj = self.get_algorithm_by_name(algorithm_)
  137. if isinstance(key, PyJWK):
  138. key = key.key
  139. key = alg_obj.prepare_key(key)
  140. signature = alg_obj.sign(signing_input, key)
  141. segments.append(base64url_encode(signature))
  142. # Don't put the payload content inside the encoded token when detached
  143. if is_payload_detached:
  144. segments[1] = b""
  145. encoded_string = b".".join(segments)
  146. return encoded_string.decode("utf-8")
  147. def decode_complete(
  148. self,
  149. jwt: str | bytes,
  150. key: AllowedPublicKeys | PyJWK | str | bytes = "",
  151. algorithms: Sequence[str] | None = None,
  152. options: dict[str, Any] | None = None,
  153. detached_payload: bytes | None = None,
  154. **kwargs,
  155. ) -> dict[str, Any]:
  156. if kwargs:
  157. warnings.warn(
  158. "passing additional kwargs to decode_complete() is deprecated "
  159. "and will be removed in pyjwt version 3. "
  160. f"Unsupported kwargs: {tuple(kwargs.keys())}",
  161. RemovedInPyjwt3Warning,
  162. stacklevel=2,
  163. )
  164. if options is None:
  165. options = {}
  166. merged_options = {**self.options, **options}
  167. verify_signature = merged_options["verify_signature"]
  168. if verify_signature and not algorithms and not isinstance(key, PyJWK):
  169. raise DecodeError(
  170. 'It is required that you pass in a value for the "algorithms" argument when calling decode().'
  171. )
  172. payload, signing_input, header, signature = self._load(jwt)
  173. if header.get("b64", True) is False:
  174. if detached_payload is None:
  175. raise DecodeError(
  176. 'It is required that you pass in a value for the "detached_payload" argument to decode a message having the b64 header set to false.'
  177. )
  178. payload = detached_payload
  179. signing_input = b".".join([signing_input.rsplit(b".", 1)[0], payload])
  180. if verify_signature:
  181. self._verify_signature(signing_input, header, signature, key, algorithms)
  182. return {
  183. "payload": payload,
  184. "header": header,
  185. "signature": signature,
  186. }
  187. def decode(
  188. self,
  189. jwt: str | bytes,
  190. key: AllowedPublicKeys | PyJWK | str | bytes = "",
  191. algorithms: Sequence[str] | None = None,
  192. options: dict[str, Any] | None = None,
  193. detached_payload: bytes | None = None,
  194. **kwargs,
  195. ) -> Any:
  196. if kwargs:
  197. warnings.warn(
  198. "passing additional kwargs to decode() is deprecated "
  199. "and will be removed in pyjwt version 3. "
  200. f"Unsupported kwargs: {tuple(kwargs.keys())}",
  201. RemovedInPyjwt3Warning,
  202. stacklevel=2,
  203. )
  204. decoded = self.decode_complete(
  205. jwt, key, algorithms, options, detached_payload=detached_payload
  206. )
  207. return decoded["payload"]
  208. def get_unverified_header(self, jwt: str | bytes) -> dict[str, Any]:
  209. """Returns back the JWT header parameters as a dict()
  210. Note: The signature is not verified so the header parameters
  211. should not be fully trusted until signature verification is complete
  212. """
  213. headers = self._load(jwt)[2]
  214. self._validate_headers(headers)
  215. return headers
  216. def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict[str, Any], bytes]:
  217. if isinstance(jwt, str):
  218. jwt = jwt.encode("utf-8")
  219. if not isinstance(jwt, bytes):
  220. raise DecodeError(f"Invalid token type. Token must be a {bytes}")
  221. try:
  222. signing_input, crypto_segment = jwt.rsplit(b".", 1)
  223. header_segment, payload_segment = signing_input.split(b".", 1)
  224. except ValueError as err:
  225. raise DecodeError("Not enough segments") from err
  226. try:
  227. header_data = base64url_decode(header_segment)
  228. except (TypeError, binascii.Error) as err:
  229. raise DecodeError("Invalid header padding") from err
  230. try:
  231. header = json.loads(header_data)
  232. except ValueError as e:
  233. raise DecodeError(f"Invalid header string: {e}") from e
  234. if not isinstance(header, dict):
  235. raise DecodeError("Invalid header string: must be a json object")
  236. try:
  237. payload = base64url_decode(payload_segment)
  238. except (TypeError, binascii.Error) as err:
  239. raise DecodeError("Invalid payload padding") from err
  240. try:
  241. signature = base64url_decode(crypto_segment)
  242. except (TypeError, binascii.Error) as err:
  243. raise DecodeError("Invalid crypto padding") from err
  244. return (payload, signing_input, header, signature)
  245. def _verify_signature(
  246. self,
  247. signing_input: bytes,
  248. header: dict[str, Any],
  249. signature: bytes,
  250. key: AllowedPublicKeys | PyJWK | str | bytes = "",
  251. algorithms: Sequence[str] | None = None,
  252. ) -> None:
  253. if algorithms is None and isinstance(key, PyJWK):
  254. algorithms = [key.algorithm_name]
  255. try:
  256. alg = header["alg"]
  257. except KeyError:
  258. raise InvalidAlgorithmError("Algorithm not specified") from None
  259. if not alg or (algorithms is not None and alg not in algorithms):
  260. raise InvalidAlgorithmError("The specified alg value is not allowed")
  261. if isinstance(key, PyJWK):
  262. alg_obj = key.Algorithm
  263. prepared_key = key.key
  264. else:
  265. try:
  266. alg_obj = self.get_algorithm_by_name(alg)
  267. except NotImplementedError as e:
  268. raise InvalidAlgorithmError("Algorithm not supported") from e
  269. prepared_key = alg_obj.prepare_key(key)
  270. if not alg_obj.verify(signing_input, prepared_key, signature):
  271. raise InvalidSignatureError("Signature verification failed")
  272. def _validate_headers(self, headers: dict[str, Any]) -> None:
  273. if "kid" in headers:
  274. self._validate_kid(headers["kid"])
  275. def _validate_kid(self, kid: Any) -> None:
  276. if not isinstance(kid, str):
  277. raise InvalidTokenError("Key ID header parameter must be a string")
  278. _jws_global_obj = PyJWS()
  279. encode = _jws_global_obj.encode
  280. decode_complete = _jws_global_obj.decode_complete
  281. decode = _jws_global_obj.decode
  282. register_algorithm = _jws_global_obj.register_algorithm
  283. unregister_algorithm = _jws_global_obj.unregister_algorithm
  284. get_algorithm_by_name = _jws_global_obj.get_algorithm_by_name
  285. get_unverified_header = _jws_global_obj.get_unverified_header