123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433 |
- from __future__ import annotations
- import json
- import warnings
- from calendar import timegm
- from collections.abc import Iterable, Sequence
- from datetime import datetime, timedelta, timezone
- from typing import TYPE_CHECKING, Any
- from . import api_jws
- from .exceptions import (
- DecodeError,
- ExpiredSignatureError,
- ImmatureSignatureError,
- InvalidAudienceError,
- InvalidIssuedAtError,
- InvalidIssuerError,
- InvalidJTIError,
- InvalidSubjectError,
- MissingRequiredClaimError,
- )
- from .warnings import RemovedInPyjwt3Warning
- if TYPE_CHECKING:
- from .algorithms import AllowedPrivateKeys, AllowedPublicKeys
- from .api_jwk import PyJWK
- class PyJWT:
- def __init__(self, options: dict[str, Any] | None = None) -> None:
- if options is None:
- options = {}
- self.options: dict[str, Any] = {**self._get_default_options(), **options}
- @staticmethod
- def _get_default_options() -> dict[str, bool | list[str]]:
- return {
- "verify_signature": True,
- "verify_exp": True,
- "verify_nbf": True,
- "verify_iat": True,
- "verify_aud": True,
- "verify_iss": True,
- "verify_sub": True,
- "verify_jti": True,
- "require": [],
- }
- def encode(
- self,
- payload: dict[str, Any],
- key: AllowedPrivateKeys | PyJWK | str | bytes,
- algorithm: str | None = None,
- headers: dict[str, Any] | None = None,
- json_encoder: type[json.JSONEncoder] | None = None,
- sort_headers: bool = True,
- ) -> str:
- # Check that we get a dict
- if not isinstance(payload, dict):
- raise TypeError(
- "Expecting a dict object, as JWT only supports "
- "JSON objects as payloads."
- )
- # Payload
- payload = payload.copy()
- for time_claim in ["exp", "iat", "nbf"]:
- # Convert datetime to a intDate value in known time-format claims
- if isinstance(payload.get(time_claim), datetime):
- payload[time_claim] = timegm(payload[time_claim].utctimetuple())
- json_payload = self._encode_payload(
- payload,
- headers=headers,
- json_encoder=json_encoder,
- )
- return api_jws.encode(
- json_payload,
- key,
- algorithm,
- headers,
- json_encoder,
- sort_headers=sort_headers,
- )
- def _encode_payload(
- self,
- payload: dict[str, Any],
- headers: dict[str, Any] | None = None,
- json_encoder: type[json.JSONEncoder] | None = None,
- ) -> bytes:
- """
- Encode a given payload to the bytes to be signed.
- This method is intended to be overridden by subclasses that need to
- encode the payload in a different way, e.g. compress the payload.
- """
- return json.dumps(
- payload,
- separators=(",", ":"),
- cls=json_encoder,
- ).encode("utf-8")
- def decode_complete(
- self,
- jwt: str | bytes,
- key: AllowedPublicKeys | PyJWK | str | bytes = "",
- algorithms: Sequence[str] | None = None,
- options: dict[str, Any] | None = None,
- # deprecated arg, remove in pyjwt3
- verify: bool | None = None,
- # could be used as passthrough to api_jws, consider removal in pyjwt3
- detached_payload: bytes | None = None,
- # passthrough arguments to _validate_claims
- # consider putting in options
- audience: str | Iterable[str] | None = None,
- issuer: str | Sequence[str] | None = None,
- subject: str | None = None,
- leeway: float | timedelta = 0,
- # kwargs
- **kwargs: Any,
- ) -> dict[str, Any]:
- if kwargs:
- warnings.warn(
- "passing additional kwargs to decode_complete() is deprecated "
- "and will be removed in pyjwt version 3. "
- f"Unsupported kwargs: {tuple(kwargs.keys())}",
- RemovedInPyjwt3Warning,
- stacklevel=2,
- )
- options = dict(options or {}) # shallow-copy or initialize an empty dict
- options.setdefault("verify_signature", True)
- # If the user has set the legacy `verify` argument, and it doesn't match
- # what the relevant `options` entry for the argument is, inform the user
- # that they're likely making a mistake.
- if verify is not None and verify != options["verify_signature"]:
- warnings.warn(
- "The `verify` argument to `decode` does nothing in PyJWT 2.0 and newer. "
- "The equivalent is setting `verify_signature` to False in the `options` dictionary. "
- "This invocation has a mismatch between the kwarg and the option entry.",
- category=DeprecationWarning,
- stacklevel=2,
- )
- if not options["verify_signature"]:
- options.setdefault("verify_exp", False)
- options.setdefault("verify_nbf", False)
- options.setdefault("verify_iat", False)
- options.setdefault("verify_aud", False)
- options.setdefault("verify_iss", False)
- options.setdefault("verify_sub", False)
- options.setdefault("verify_jti", False)
- decoded = api_jws.decode_complete(
- jwt,
- key=key,
- algorithms=algorithms,
- options=options,
- detached_payload=detached_payload,
- )
- payload = self._decode_payload(decoded)
- merged_options = {**self.options, **options}
- self._validate_claims(
- payload,
- merged_options,
- audience=audience,
- issuer=issuer,
- leeway=leeway,
- subject=subject,
- )
- decoded["payload"] = payload
- return decoded
- def _decode_payload(self, decoded: dict[str, Any]) -> Any:
- """
- Decode the payload from a JWS dictionary (payload, signature, header).
- This method is intended to be overridden by subclasses that need to
- decode the payload in a different way, e.g. decompress compressed
- payloads.
- """
- try:
- payload = json.loads(decoded["payload"])
- except ValueError as e:
- raise DecodeError(f"Invalid payload string: {e}") from e
- if not isinstance(payload, dict):
- raise DecodeError("Invalid payload string: must be a json object")
- return payload
- def decode(
- self,
- jwt: str | bytes,
- key: AllowedPublicKeys | PyJWK | str | bytes = "",
- algorithms: Sequence[str] | None = None,
- options: dict[str, Any] | None = None,
- # deprecated arg, remove in pyjwt3
- verify: bool | None = None,
- # could be used as passthrough to api_jws, consider removal in pyjwt3
- detached_payload: bytes | None = None,
- # passthrough arguments to _validate_claims
- # consider putting in options
- audience: str | Iterable[str] | None = None,
- subject: str | None = None,
- issuer: str | Sequence[str] | None = None,
- leeway: float | timedelta = 0,
- # kwargs
- **kwargs: Any,
- ) -> Any:
- if kwargs:
- warnings.warn(
- "passing additional kwargs to decode() is deprecated "
- "and will be removed in pyjwt version 3. "
- f"Unsupported kwargs: {tuple(kwargs.keys())}",
- RemovedInPyjwt3Warning,
- stacklevel=2,
- )
- decoded = self.decode_complete(
- jwt,
- key,
- algorithms,
- options,
- verify=verify,
- detached_payload=detached_payload,
- audience=audience,
- subject=subject,
- issuer=issuer,
- leeway=leeway,
- )
- return decoded["payload"]
- def _validate_claims(
- self,
- payload: dict[str, Any],
- options: dict[str, Any],
- audience=None,
- issuer=None,
- subject: str | None = None,
- leeway: float | timedelta = 0,
- ) -> None:
- if isinstance(leeway, timedelta):
- leeway = leeway.total_seconds()
- if audience is not None and not isinstance(audience, (str, Iterable)):
- raise TypeError("audience must be a string, iterable or None")
- self._validate_required_claims(payload, options)
- now = datetime.now(tz=timezone.utc).timestamp()
- if "iat" in payload and options["verify_iat"]:
- self._validate_iat(payload, now, leeway)
- if "nbf" in payload and options["verify_nbf"]:
- self._validate_nbf(payload, now, leeway)
- if "exp" in payload and options["verify_exp"]:
- self._validate_exp(payload, now, leeway)
- if options["verify_iss"]:
- self._validate_iss(payload, issuer)
- if options["verify_aud"]:
- self._validate_aud(
- payload, audience, strict=options.get("strict_aud", False)
- )
- if options["verify_sub"]:
- self._validate_sub(payload, subject)
- if options["verify_jti"]:
- self._validate_jti(payload)
- def _validate_required_claims(
- self,
- payload: dict[str, Any],
- options: dict[str, Any],
- ) -> None:
- for claim in options["require"]:
- if payload.get(claim) is None:
- raise MissingRequiredClaimError(claim)
- def _validate_sub(self, payload: dict[str, Any], subject=None) -> None:
- """
- Checks whether "sub" if in the payload is valid ot not.
- This is an Optional claim
- :param payload(dict): The payload which needs to be validated
- :param subject(str): The subject of the token
- """
- if "sub" not in payload:
- return
- if not isinstance(payload["sub"], str):
- raise InvalidSubjectError("Subject must be a string")
- if subject is not None:
- if payload.get("sub") != subject:
- raise InvalidSubjectError("Invalid subject")
- def _validate_jti(self, payload: dict[str, Any]) -> None:
- """
- Checks whether "jti" if in the payload is valid ot not
- This is an Optional claim
- :param payload(dict): The payload which needs to be validated
- """
- if "jti" not in payload:
- return
- if not isinstance(payload.get("jti"), str):
- raise InvalidJTIError("JWT ID must be a string")
- def _validate_iat(
- self,
- payload: dict[str, Any],
- now: float,
- leeway: float,
- ) -> None:
- try:
- iat = int(payload["iat"])
- except ValueError:
- raise InvalidIssuedAtError(
- "Issued At claim (iat) must be an integer."
- ) from None
- if iat > (now + leeway):
- raise ImmatureSignatureError("The token is not yet valid (iat)")
- def _validate_nbf(
- self,
- payload: dict[str, Any],
- now: float,
- leeway: float,
- ) -> None:
- try:
- nbf = int(payload["nbf"])
- except ValueError:
- raise DecodeError("Not Before claim (nbf) must be an integer.") from None
- if nbf > (now + leeway):
- raise ImmatureSignatureError("The token is not yet valid (nbf)")
- def _validate_exp(
- self,
- payload: dict[str, Any],
- now: float,
- leeway: float,
- ) -> None:
- try:
- exp = int(payload["exp"])
- except ValueError:
- raise DecodeError(
- "Expiration Time claim (exp) must be an integer."
- ) from None
- if exp <= (now - leeway):
- raise ExpiredSignatureError("Signature has expired")
- def _validate_aud(
- self,
- payload: dict[str, Any],
- audience: str | Iterable[str] | None,
- *,
- strict: bool = False,
- ) -> None:
- if audience is None:
- if "aud" not in payload or not payload["aud"]:
- return
- # Application did not specify an audience, but
- # the token has the 'aud' claim
- raise InvalidAudienceError("Invalid audience")
- if "aud" not in payload or not payload["aud"]:
- # Application specified an audience, but it could not be
- # verified since the token does not contain a claim.
- raise MissingRequiredClaimError("aud")
- audience_claims = payload["aud"]
- # In strict mode, we forbid list matching: the supplied audience
- # must be a string, and it must exactly match the audience claim.
- if strict:
- # Only a single audience is allowed in strict mode.
- if not isinstance(audience, str):
- raise InvalidAudienceError("Invalid audience (strict)")
- # Only a single audience claim is allowed in strict mode.
- if not isinstance(audience_claims, str):
- raise InvalidAudienceError("Invalid claim format in token (strict)")
- if audience != audience_claims:
- raise InvalidAudienceError("Audience doesn't match (strict)")
- return
- if isinstance(audience_claims, str):
- audience_claims = [audience_claims]
- if not isinstance(audience_claims, list):
- raise InvalidAudienceError("Invalid claim format in token")
- if any(not isinstance(c, str) for c in audience_claims):
- raise InvalidAudienceError("Invalid claim format in token")
- if isinstance(audience, str):
- audience = [audience]
- if all(aud not in audience_claims for aud in audience):
- raise InvalidAudienceError("Audience doesn't match")
- def _validate_iss(self, payload: dict[str, Any], issuer: Any) -> None:
- if issuer is None:
- return
- if "iss" not in payload:
- raise MissingRequiredClaimError("iss")
- if isinstance(issuer, str):
- if payload["iss"] != issuer:
- raise InvalidIssuerError("Invalid issuer")
- else:
- if payload["iss"] not in issuer:
- raise InvalidIssuerError("Invalid issuer")
- _jwt_global_obj = PyJWT()
- encode = _jwt_global_obj.encode
- decode_complete = _jwt_global_obj.decode_complete
- decode = _jwt_global_obj.decode
|