api_jwt.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433
  1. from __future__ import annotations
  2. import json
  3. import warnings
  4. from calendar import timegm
  5. from collections.abc import Iterable, Sequence
  6. from datetime import datetime, timedelta, timezone
  7. from typing import TYPE_CHECKING, Any
  8. from . import api_jws
  9. from .exceptions import (
  10. DecodeError,
  11. ExpiredSignatureError,
  12. ImmatureSignatureError,
  13. InvalidAudienceError,
  14. InvalidIssuedAtError,
  15. InvalidIssuerError,
  16. InvalidJTIError,
  17. InvalidSubjectError,
  18. MissingRequiredClaimError,
  19. )
  20. from .warnings import RemovedInPyjwt3Warning
  21. if TYPE_CHECKING:
  22. from .algorithms import AllowedPrivateKeys, AllowedPublicKeys
  23. from .api_jwk import PyJWK
  24. class PyJWT:
  25. def __init__(self, options: dict[str, Any] | None = None) -> None:
  26. if options is None:
  27. options = {}
  28. self.options: dict[str, Any] = {**self._get_default_options(), **options}
  29. @staticmethod
  30. def _get_default_options() -> dict[str, bool | list[str]]:
  31. return {
  32. "verify_signature": True,
  33. "verify_exp": True,
  34. "verify_nbf": True,
  35. "verify_iat": True,
  36. "verify_aud": True,
  37. "verify_iss": True,
  38. "verify_sub": True,
  39. "verify_jti": True,
  40. "require": [],
  41. }
  42. def encode(
  43. self,
  44. payload: dict[str, Any],
  45. key: AllowedPrivateKeys | PyJWK | str | bytes,
  46. algorithm: str | None = None,
  47. headers: dict[str, Any] | None = None,
  48. json_encoder: type[json.JSONEncoder] | None = None,
  49. sort_headers: bool = True,
  50. ) -> str:
  51. # Check that we get a dict
  52. if not isinstance(payload, dict):
  53. raise TypeError(
  54. "Expecting a dict object, as JWT only supports "
  55. "JSON objects as payloads."
  56. )
  57. # Payload
  58. payload = payload.copy()
  59. for time_claim in ["exp", "iat", "nbf"]:
  60. # Convert datetime to a intDate value in known time-format claims
  61. if isinstance(payload.get(time_claim), datetime):
  62. payload[time_claim] = timegm(payload[time_claim].utctimetuple())
  63. json_payload = self._encode_payload(
  64. payload,
  65. headers=headers,
  66. json_encoder=json_encoder,
  67. )
  68. return api_jws.encode(
  69. json_payload,
  70. key,
  71. algorithm,
  72. headers,
  73. json_encoder,
  74. sort_headers=sort_headers,
  75. )
  76. def _encode_payload(
  77. self,
  78. payload: dict[str, Any],
  79. headers: dict[str, Any] | None = None,
  80. json_encoder: type[json.JSONEncoder] | None = None,
  81. ) -> bytes:
  82. """
  83. Encode a given payload to the bytes to be signed.
  84. This method is intended to be overridden by subclasses that need to
  85. encode the payload in a different way, e.g. compress the payload.
  86. """
  87. return json.dumps(
  88. payload,
  89. separators=(",", ":"),
  90. cls=json_encoder,
  91. ).encode("utf-8")
  92. def decode_complete(
  93. self,
  94. jwt: str | bytes,
  95. key: AllowedPublicKeys | PyJWK | str | bytes = "",
  96. algorithms: Sequence[str] | None = None,
  97. options: dict[str, Any] | None = None,
  98. # deprecated arg, remove in pyjwt3
  99. verify: bool | None = None,
  100. # could be used as passthrough to api_jws, consider removal in pyjwt3
  101. detached_payload: bytes | None = None,
  102. # passthrough arguments to _validate_claims
  103. # consider putting in options
  104. audience: str | Iterable[str] | None = None,
  105. issuer: str | Sequence[str] | None = None,
  106. subject: str | None = None,
  107. leeway: float | timedelta = 0,
  108. # kwargs
  109. **kwargs: Any,
  110. ) -> dict[str, Any]:
  111. if kwargs:
  112. warnings.warn(
  113. "passing additional kwargs to decode_complete() is deprecated "
  114. "and will be removed in pyjwt version 3. "
  115. f"Unsupported kwargs: {tuple(kwargs.keys())}",
  116. RemovedInPyjwt3Warning,
  117. stacklevel=2,
  118. )
  119. options = dict(options or {}) # shallow-copy or initialize an empty dict
  120. options.setdefault("verify_signature", True)
  121. # If the user has set the legacy `verify` argument, and it doesn't match
  122. # what the relevant `options` entry for the argument is, inform the user
  123. # that they're likely making a mistake.
  124. if verify is not None and verify != options["verify_signature"]:
  125. warnings.warn(
  126. "The `verify` argument to `decode` does nothing in PyJWT 2.0 and newer. "
  127. "The equivalent is setting `verify_signature` to False in the `options` dictionary. "
  128. "This invocation has a mismatch between the kwarg and the option entry.",
  129. category=DeprecationWarning,
  130. stacklevel=2,
  131. )
  132. if not options["verify_signature"]:
  133. options.setdefault("verify_exp", False)
  134. options.setdefault("verify_nbf", False)
  135. options.setdefault("verify_iat", False)
  136. options.setdefault("verify_aud", False)
  137. options.setdefault("verify_iss", False)
  138. options.setdefault("verify_sub", False)
  139. options.setdefault("verify_jti", False)
  140. decoded = api_jws.decode_complete(
  141. jwt,
  142. key=key,
  143. algorithms=algorithms,
  144. options=options,
  145. detached_payload=detached_payload,
  146. )
  147. payload = self._decode_payload(decoded)
  148. merged_options = {**self.options, **options}
  149. self._validate_claims(
  150. payload,
  151. merged_options,
  152. audience=audience,
  153. issuer=issuer,
  154. leeway=leeway,
  155. subject=subject,
  156. )
  157. decoded["payload"] = payload
  158. return decoded
  159. def _decode_payload(self, decoded: dict[str, Any]) -> Any:
  160. """
  161. Decode the payload from a JWS dictionary (payload, signature, header).
  162. This method is intended to be overridden by subclasses that need to
  163. decode the payload in a different way, e.g. decompress compressed
  164. payloads.
  165. """
  166. try:
  167. payload = json.loads(decoded["payload"])
  168. except ValueError as e:
  169. raise DecodeError(f"Invalid payload string: {e}") from e
  170. if not isinstance(payload, dict):
  171. raise DecodeError("Invalid payload string: must be a json object")
  172. return payload
  173. def decode(
  174. self,
  175. jwt: str | bytes,
  176. key: AllowedPublicKeys | PyJWK | str | bytes = "",
  177. algorithms: Sequence[str] | None = None,
  178. options: dict[str, Any] | None = None,
  179. # deprecated arg, remove in pyjwt3
  180. verify: bool | None = None,
  181. # could be used as passthrough to api_jws, consider removal in pyjwt3
  182. detached_payload: bytes | None = None,
  183. # passthrough arguments to _validate_claims
  184. # consider putting in options
  185. audience: str | Iterable[str] | None = None,
  186. subject: str | None = None,
  187. issuer: str | Sequence[str] | None = None,
  188. leeway: float | timedelta = 0,
  189. # kwargs
  190. **kwargs: Any,
  191. ) -> Any:
  192. if kwargs:
  193. warnings.warn(
  194. "passing additional kwargs to decode() is deprecated "
  195. "and will be removed in pyjwt version 3. "
  196. f"Unsupported kwargs: {tuple(kwargs.keys())}",
  197. RemovedInPyjwt3Warning,
  198. stacklevel=2,
  199. )
  200. decoded = self.decode_complete(
  201. jwt,
  202. key,
  203. algorithms,
  204. options,
  205. verify=verify,
  206. detached_payload=detached_payload,
  207. audience=audience,
  208. subject=subject,
  209. issuer=issuer,
  210. leeway=leeway,
  211. )
  212. return decoded["payload"]
  213. def _validate_claims(
  214. self,
  215. payload: dict[str, Any],
  216. options: dict[str, Any],
  217. audience=None,
  218. issuer=None,
  219. subject: str | None = None,
  220. leeway: float | timedelta = 0,
  221. ) -> None:
  222. if isinstance(leeway, timedelta):
  223. leeway = leeway.total_seconds()
  224. if audience is not None and not isinstance(audience, (str, Iterable)):
  225. raise TypeError("audience must be a string, iterable or None")
  226. self._validate_required_claims(payload, options)
  227. now = datetime.now(tz=timezone.utc).timestamp()
  228. if "iat" in payload and options["verify_iat"]:
  229. self._validate_iat(payload, now, leeway)
  230. if "nbf" in payload and options["verify_nbf"]:
  231. self._validate_nbf(payload, now, leeway)
  232. if "exp" in payload and options["verify_exp"]:
  233. self._validate_exp(payload, now, leeway)
  234. if options["verify_iss"]:
  235. self._validate_iss(payload, issuer)
  236. if options["verify_aud"]:
  237. self._validate_aud(
  238. payload, audience, strict=options.get("strict_aud", False)
  239. )
  240. if options["verify_sub"]:
  241. self._validate_sub(payload, subject)
  242. if options["verify_jti"]:
  243. self._validate_jti(payload)
  244. def _validate_required_claims(
  245. self,
  246. payload: dict[str, Any],
  247. options: dict[str, Any],
  248. ) -> None:
  249. for claim in options["require"]:
  250. if payload.get(claim) is None:
  251. raise MissingRequiredClaimError(claim)
  252. def _validate_sub(self, payload: dict[str, Any], subject=None) -> None:
  253. """
  254. Checks whether "sub" if in the payload is valid ot not.
  255. This is an Optional claim
  256. :param payload(dict): The payload which needs to be validated
  257. :param subject(str): The subject of the token
  258. """
  259. if "sub" not in payload:
  260. return
  261. if not isinstance(payload["sub"], str):
  262. raise InvalidSubjectError("Subject must be a string")
  263. if subject is not None:
  264. if payload.get("sub") != subject:
  265. raise InvalidSubjectError("Invalid subject")
  266. def _validate_jti(self, payload: dict[str, Any]) -> None:
  267. """
  268. Checks whether "jti" if in the payload is valid ot not
  269. This is an Optional claim
  270. :param payload(dict): The payload which needs to be validated
  271. """
  272. if "jti" not in payload:
  273. return
  274. if not isinstance(payload.get("jti"), str):
  275. raise InvalidJTIError("JWT ID must be a string")
  276. def _validate_iat(
  277. self,
  278. payload: dict[str, Any],
  279. now: float,
  280. leeway: float,
  281. ) -> None:
  282. try:
  283. iat = int(payload["iat"])
  284. except ValueError:
  285. raise InvalidIssuedAtError(
  286. "Issued At claim (iat) must be an integer."
  287. ) from None
  288. if iat > (now + leeway):
  289. raise ImmatureSignatureError("The token is not yet valid (iat)")
  290. def _validate_nbf(
  291. self,
  292. payload: dict[str, Any],
  293. now: float,
  294. leeway: float,
  295. ) -> None:
  296. try:
  297. nbf = int(payload["nbf"])
  298. except ValueError:
  299. raise DecodeError("Not Before claim (nbf) must be an integer.") from None
  300. if nbf > (now + leeway):
  301. raise ImmatureSignatureError("The token is not yet valid (nbf)")
  302. def _validate_exp(
  303. self,
  304. payload: dict[str, Any],
  305. now: float,
  306. leeway: float,
  307. ) -> None:
  308. try:
  309. exp = int(payload["exp"])
  310. except ValueError:
  311. raise DecodeError(
  312. "Expiration Time claim (exp) must be an integer."
  313. ) from None
  314. if exp <= (now - leeway):
  315. raise ExpiredSignatureError("Signature has expired")
  316. def _validate_aud(
  317. self,
  318. payload: dict[str, Any],
  319. audience: str | Iterable[str] | None,
  320. *,
  321. strict: bool = False,
  322. ) -> None:
  323. if audience is None:
  324. if "aud" not in payload or not payload["aud"]:
  325. return
  326. # Application did not specify an audience, but
  327. # the token has the 'aud' claim
  328. raise InvalidAudienceError("Invalid audience")
  329. if "aud" not in payload or not payload["aud"]:
  330. # Application specified an audience, but it could not be
  331. # verified since the token does not contain a claim.
  332. raise MissingRequiredClaimError("aud")
  333. audience_claims = payload["aud"]
  334. # In strict mode, we forbid list matching: the supplied audience
  335. # must be a string, and it must exactly match the audience claim.
  336. if strict:
  337. # Only a single audience is allowed in strict mode.
  338. if not isinstance(audience, str):
  339. raise InvalidAudienceError("Invalid audience (strict)")
  340. # Only a single audience claim is allowed in strict mode.
  341. if not isinstance(audience_claims, str):
  342. raise InvalidAudienceError("Invalid claim format in token (strict)")
  343. if audience != audience_claims:
  344. raise InvalidAudienceError("Audience doesn't match (strict)")
  345. return
  346. if isinstance(audience_claims, str):
  347. audience_claims = [audience_claims]
  348. if not isinstance(audience_claims, list):
  349. raise InvalidAudienceError("Invalid claim format in token")
  350. if any(not isinstance(c, str) for c in audience_claims):
  351. raise InvalidAudienceError("Invalid claim format in token")
  352. if isinstance(audience, str):
  353. audience = [audience]
  354. if all(aud not in audience_claims for aud in audience):
  355. raise InvalidAudienceError("Audience doesn't match")
  356. def _validate_iss(self, payload: dict[str, Any], issuer: Any) -> None:
  357. if issuer is None:
  358. return
  359. if "iss" not in payload:
  360. raise MissingRequiredClaimError("iss")
  361. if isinstance(issuer, str):
  362. if payload["iss"] != issuer:
  363. raise InvalidIssuerError("Invalid issuer")
  364. else:
  365. if payload["iss"] not in issuer:
  366. raise InvalidIssuerError("Invalid issuer")
  367. _jwt_global_obj = PyJWT()
  368. encode = _jwt_global_obj.encode
  369. decode_complete = _jwt_global_obj.decode_complete
  370. decode = _jwt_global_obj.decode