tokens.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. import uuid
  2. from datetime import datetime
  3. from datetime import timedelta
  4. from datetime import timezone
  5. from hmac import compare_digest
  6. from json import JSONEncoder
  7. from typing import Any
  8. from typing import Iterable
  9. from typing import List
  10. from typing import Type
  11. from typing import Union
  12. import jwt
  13. from flask_jwt_extended.exceptions import CSRFError
  14. from flask_jwt_extended.exceptions import JWTDecodeError
  15. from flask_jwt_extended.typing import ExpiresDelta
  16. from flask_jwt_extended.typing import Fresh
  17. def _encode_jwt(
  18. algorithm: str,
  19. audience: Union[str, Iterable[str]],
  20. claim_overrides: dict,
  21. csrf: bool,
  22. expires_delta: ExpiresDelta,
  23. fresh: Fresh,
  24. header_overrides: dict,
  25. identity: Any,
  26. identity_claim_key: str,
  27. issuer: str,
  28. json_encoder: Type[JSONEncoder],
  29. secret: str,
  30. token_type: str,
  31. nbf: bool,
  32. ) -> str:
  33. now = datetime.now(timezone.utc)
  34. if isinstance(fresh, timedelta):
  35. fresh = datetime.timestamp(now + fresh)
  36. token_data = {
  37. "fresh": fresh,
  38. "iat": now,
  39. "jti": str(uuid.uuid4()),
  40. "type": token_type,
  41. identity_claim_key: identity,
  42. }
  43. if nbf:
  44. token_data["nbf"] = now
  45. if csrf:
  46. token_data["csrf"] = str(uuid.uuid4())
  47. if audience:
  48. token_data["aud"] = audience
  49. if issuer:
  50. token_data["iss"] = issuer
  51. if expires_delta:
  52. token_data["exp"] = now + expires_delta
  53. if claim_overrides:
  54. token_data.update(claim_overrides)
  55. return jwt.encode(
  56. token_data,
  57. secret,
  58. algorithm,
  59. json_encoder=json_encoder, # type: ignore
  60. headers=header_overrides,
  61. )
  62. def _decode_jwt(
  63. algorithms: List,
  64. allow_expired: bool,
  65. audience: Union[str, Iterable[str]],
  66. csrf_value: str,
  67. encoded_token: str,
  68. identity_claim_key: str,
  69. issuer: str,
  70. leeway: int,
  71. secret: str,
  72. verify_aud: bool,
  73. verify_sub: bool,
  74. ) -> dict:
  75. options = {"verify_aud": verify_aud, "verify_sub": verify_sub}
  76. if allow_expired:
  77. options["verify_exp"] = False
  78. # This call verifies the ext, iat, and nbf claims
  79. # This optionally verifies the exp and aud claims if enabled
  80. decoded_token = jwt.decode(
  81. encoded_token,
  82. secret,
  83. algorithms=algorithms,
  84. audience=audience,
  85. issuer=issuer,
  86. leeway=leeway,
  87. options=options,
  88. )
  89. # Make sure that any custom claims we expect in the token are present
  90. if identity_claim_key not in decoded_token:
  91. raise JWTDecodeError("Missing claim: {}".format(identity_claim_key))
  92. if "type" not in decoded_token:
  93. decoded_token["type"] = "access"
  94. if "fresh" not in decoded_token:
  95. decoded_token["fresh"] = False
  96. if "jti" not in decoded_token:
  97. decoded_token["jti"] = None
  98. if csrf_value:
  99. if "csrf" not in decoded_token:
  100. raise JWTDecodeError("Missing claim: csrf")
  101. if not compare_digest(decoded_token["csrf"], csrf_value):
  102. raise CSRFError("CSRF double submit tokens do not match")
  103. return decoded_token