jwt_manager.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564
  1. import datetime
  2. from typing import Any
  3. from typing import Callable
  4. from typing import Optional
  5. import jwt
  6. from flask import Flask
  7. from jwt import DecodeError
  8. from jwt import ExpiredSignatureError
  9. from jwt import InvalidAudienceError
  10. from jwt import InvalidIssuerError
  11. from jwt import InvalidTokenError
  12. from jwt import MissingRequiredClaimError
  13. from flask_jwt_extended.config import config
  14. from flask_jwt_extended.default_callbacks import default_additional_claims_callback
  15. from flask_jwt_extended.default_callbacks import default_blocklist_callback
  16. from flask_jwt_extended.default_callbacks import default_decode_key_callback
  17. from flask_jwt_extended.default_callbacks import default_encode_key_callback
  18. from flask_jwt_extended.default_callbacks import default_expired_token_callback
  19. from flask_jwt_extended.default_callbacks import default_invalid_token_callback
  20. from flask_jwt_extended.default_callbacks import default_jwt_headers_callback
  21. from flask_jwt_extended.default_callbacks import default_needs_fresh_token_callback
  22. from flask_jwt_extended.default_callbacks import default_revoked_token_callback
  23. from flask_jwt_extended.default_callbacks import default_token_verification_callback
  24. from flask_jwt_extended.default_callbacks import (
  25. default_token_verification_failed_callback,
  26. )
  27. from flask_jwt_extended.default_callbacks import default_unauthorized_callback
  28. from flask_jwt_extended.default_callbacks import default_user_identity_callback
  29. from flask_jwt_extended.default_callbacks import default_user_lookup_error_callback
  30. from flask_jwt_extended.exceptions import CSRFError
  31. from flask_jwt_extended.exceptions import FreshTokenRequired
  32. from flask_jwt_extended.exceptions import InvalidHeaderError
  33. from flask_jwt_extended.exceptions import InvalidQueryParamError
  34. from flask_jwt_extended.exceptions import JWTDecodeError
  35. from flask_jwt_extended.exceptions import NoAuthorizationError
  36. from flask_jwt_extended.exceptions import RevokedTokenError
  37. from flask_jwt_extended.exceptions import UserClaimsVerificationError
  38. from flask_jwt_extended.exceptions import UserLookupError
  39. from flask_jwt_extended.exceptions import WrongTokenError
  40. from flask_jwt_extended.tokens import _decode_jwt
  41. from flask_jwt_extended.tokens import _encode_jwt
  42. from flask_jwt_extended.typing import ExpiresDelta
  43. from flask_jwt_extended.typing import Fresh
  44. from flask_jwt_extended.utils import current_user_context_processor
  45. class JWTManager(object):
  46. """
  47. An object used to hold JWT settings and callback functions for the
  48. Flask-JWT-Extended extension.
  49. Instances of :class:`JWTManager` are *not* bound to specific apps, so
  50. you can create one in the main body of your code and then bind it
  51. to your app in a factory function.
  52. """
  53. def __init__(
  54. self, app: Optional[Flask] = None, add_context_processor: bool = False
  55. ) -> None:
  56. """
  57. Create the JWTManager instance. You can either pass a flask application
  58. in directly here to register this extension with the flask app, or
  59. call init_app after creating this object (in a factory pattern).
  60. :param app:
  61. The Flask Application object
  62. :param add_context_processor:
  63. Controls if `current_user` is should be added to flasks template
  64. context (and thus be available for use in Jinja templates). Defaults
  65. to ``False``.
  66. """
  67. # Register the default error handler callback methods. These can be
  68. # overridden with the appropriate loader decorators
  69. self._decode_key_callback = default_decode_key_callback
  70. self._encode_key_callback = default_encode_key_callback
  71. self._expired_token_callback = default_expired_token_callback
  72. self._invalid_token_callback = default_invalid_token_callback
  73. self._jwt_additional_header_callback = default_jwt_headers_callback
  74. self._needs_fresh_token_callback = default_needs_fresh_token_callback
  75. self._revoked_token_callback = default_revoked_token_callback
  76. self._token_in_blocklist_callback = default_blocklist_callback
  77. self._token_verification_callback = default_token_verification_callback
  78. self._unauthorized_callback = default_unauthorized_callback
  79. self._user_claims_callback = default_additional_claims_callback
  80. self._user_identity_callback = default_user_identity_callback
  81. self._user_lookup_callback: Optional[Callable] = None
  82. self._user_lookup_error_callback = default_user_lookup_error_callback
  83. self._token_verification_failed_callback = (
  84. default_token_verification_failed_callback
  85. )
  86. # Register this extension with the flask app now (if it is provided)
  87. if app is not None:
  88. self.init_app(app, add_context_processor)
  89. def init_app(self, app: Flask, add_context_processor: bool = False) -> None:
  90. """
  91. Register this extension with the flask app.
  92. :param app:
  93. The Flask Application object
  94. :param add_context_processor:
  95. Controls if `current_user` is should be added to flasks template
  96. context (and thus be available for use in Jinja templates). Defaults
  97. to ``False``.
  98. """
  99. # Save this so we can use it later in the extension
  100. if not hasattr(app, "extensions"): # pragma: no cover
  101. app.extensions = {}
  102. app.extensions["flask-jwt-extended"] = self
  103. if add_context_processor:
  104. app.context_processor(current_user_context_processor)
  105. # Set all the default configurations for this extension
  106. self._set_default_configuration_options(app)
  107. self._set_error_handler_callbacks(app)
  108. def _set_error_handler_callbacks(self, app: Flask) -> None:
  109. @app.errorhandler(CSRFError)
  110. def handle_csrf_error(e):
  111. return self._unauthorized_callback(str(e))
  112. @app.errorhandler(DecodeError)
  113. def handle_decode_error(e):
  114. return self._invalid_token_callback(str(e))
  115. @app.errorhandler(ExpiredSignatureError)
  116. def handle_expired_error(e):
  117. return self._expired_token_callback(e.jwt_header, e.jwt_data)
  118. @app.errorhandler(FreshTokenRequired)
  119. def handle_fresh_token_required(e):
  120. return self._needs_fresh_token_callback(e.jwt_header, e.jwt_data)
  121. @app.errorhandler(MissingRequiredClaimError)
  122. def handle_missing_required_claim_error(e):
  123. return self._invalid_token_callback(str(e))
  124. @app.errorhandler(InvalidAudienceError)
  125. def handle_invalid_audience_error(e):
  126. return self._invalid_token_callback(str(e))
  127. @app.errorhandler(InvalidIssuerError)
  128. def handle_invalid_issuer_error(e):
  129. return self._invalid_token_callback(str(e))
  130. @app.errorhandler(InvalidHeaderError)
  131. def handle_invalid_header_error(e):
  132. return self._invalid_token_callback(str(e))
  133. @app.errorhandler(InvalidTokenError)
  134. def handle_invalid_token_error(e):
  135. return self._invalid_token_callback(str(e))
  136. @app.errorhandler(JWTDecodeError)
  137. def handle_jwt_decode_error(e):
  138. return self._invalid_token_callback(str(e))
  139. @app.errorhandler(NoAuthorizationError)
  140. def handle_auth_error(e):
  141. return self._unauthorized_callback(str(e))
  142. @app.errorhandler(InvalidQueryParamError)
  143. def handle_invalid_query_param_error(e):
  144. return self._invalid_token_callback(str(e))
  145. @app.errorhandler(RevokedTokenError)
  146. def handle_revoked_token_error(e):
  147. return self._revoked_token_callback(e.jwt_header, e.jwt_data)
  148. @app.errorhandler(UserClaimsVerificationError)
  149. def handle_failed_token_verification(e):
  150. return self._token_verification_failed_callback(e.jwt_header, e.jwt_data)
  151. @app.errorhandler(UserLookupError)
  152. def handler_user_lookup_error(e):
  153. return self._user_lookup_error_callback(e.jwt_header, e.jwt_data)
  154. @app.errorhandler(WrongTokenError)
  155. def handle_wrong_token_error(e):
  156. return self._invalid_token_callback(str(e))
  157. @staticmethod
  158. def _set_default_configuration_options(app: Flask) -> None:
  159. app.config.setdefault(
  160. "JWT_ACCESS_TOKEN_EXPIRES", datetime.timedelta(minutes=15)
  161. )
  162. app.config.setdefault("JWT_ACCESS_COOKIE_NAME", "access_token_cookie")
  163. app.config.setdefault("JWT_ACCESS_COOKIE_PATH", "/")
  164. app.config.setdefault("JWT_ACCESS_CSRF_COOKIE_NAME", "csrf_access_token")
  165. app.config.setdefault("JWT_ACCESS_CSRF_COOKIE_PATH", "/")
  166. app.config.setdefault("JWT_ACCESS_CSRF_FIELD_NAME", "csrf_token")
  167. app.config.setdefault("JWT_ACCESS_CSRF_HEADER_NAME", "X-CSRF-TOKEN")
  168. app.config.setdefault("JWT_ALGORITHM", "HS256")
  169. app.config.setdefault("JWT_COOKIE_CSRF_PROTECT", True)
  170. app.config.setdefault("JWT_COOKIE_DOMAIN", None)
  171. app.config.setdefault("JWT_COOKIE_SAMESITE", None)
  172. app.config.setdefault("JWT_COOKIE_SECURE", False)
  173. app.config.setdefault("JWT_CSRF_CHECK_FORM", False)
  174. app.config.setdefault("JWT_CSRF_IN_COOKIES", True)
  175. app.config.setdefault("JWT_CSRF_METHODS", ["POST", "PUT", "PATCH", "DELETE"])
  176. app.config.setdefault("JWT_DECODE_ALGORITHMS", None)
  177. app.config.setdefault("JWT_DECODE_AUDIENCE", None)
  178. app.config.setdefault("JWT_DECODE_ISSUER", None)
  179. app.config.setdefault("JWT_DECODE_LEEWAY", 0)
  180. app.config.setdefault("JWT_ENCODE_AUDIENCE", None)
  181. app.config.setdefault("JWT_ENCODE_ISSUER", None)
  182. app.config.setdefault("JWT_ERROR_MESSAGE_KEY", "msg")
  183. app.config.setdefault("JWT_HEADER_NAME", "Authorization")
  184. app.config.setdefault("JWT_HEADER_TYPE", "Bearer")
  185. app.config.setdefault("JWT_IDENTITY_CLAIM", "sub")
  186. app.config.setdefault("JWT_JSON_KEY", "access_token")
  187. app.config.setdefault("JWT_PRIVATE_KEY", None)
  188. app.config.setdefault("JWT_PUBLIC_KEY", None)
  189. app.config.setdefault("JWT_QUERY_STRING_NAME", "jwt")
  190. app.config.setdefault("JWT_QUERY_STRING_VALUE_PREFIX", "")
  191. app.config.setdefault("JWT_REFRESH_COOKIE_NAME", "refresh_token_cookie")
  192. app.config.setdefault("JWT_REFRESH_COOKIE_PATH", "/")
  193. app.config.setdefault("JWT_REFRESH_CSRF_COOKIE_NAME", "csrf_refresh_token")
  194. app.config.setdefault("JWT_REFRESH_CSRF_COOKIE_PATH", "/")
  195. app.config.setdefault("JWT_REFRESH_CSRF_FIELD_NAME", "csrf_token")
  196. app.config.setdefault("JWT_REFRESH_CSRF_HEADER_NAME", "X-CSRF-TOKEN")
  197. app.config.setdefault("JWT_REFRESH_JSON_KEY", "refresh_token")
  198. app.config.setdefault("JWT_REFRESH_TOKEN_EXPIRES", datetime.timedelta(days=30))
  199. app.config.setdefault("JWT_SECRET_KEY", None)
  200. app.config.setdefault("JWT_SESSION_COOKIE", True)
  201. app.config.setdefault("JWT_TOKEN_LOCATION", ("headers",))
  202. app.config.setdefault("JWT_VERIFY_SUB", True)
  203. app.config.setdefault("JWT_ENCODE_NBF", True)
  204. def additional_claims_loader(self, callback: Callable) -> Callable:
  205. """
  206. This decorator sets the callback function used to add additional claims
  207. when creating a JWT. The claims returned by this function will be merged
  208. with any claims passed in via the ``additional_claims`` argument to
  209. :func:`~flask_jwt_extended.create_access_token` or
  210. :func:`~flask_jwt_extended.create_refresh_token`.
  211. The decorated function must take **one** argument.
  212. The argument is the identity that was used when creating a JWT.
  213. The decorated function must return a dictionary of claims to add to the JWT.
  214. """
  215. self._user_claims_callback = callback
  216. return callback
  217. def additional_headers_loader(self, callback: Callable) -> Callable:
  218. """
  219. This decorator sets the callback function used to add additional headers
  220. when creating a JWT. The headers returned by this function will be merged
  221. with any headers passed in via the ``additional_headers`` argument to
  222. :func:`~flask_jwt_extended.create_access_token` or
  223. :func:`~flask_jwt_extended.create_refresh_token`.
  224. The decorated function must take **one** argument.
  225. The argument is the identity that was used when creating a JWT.
  226. The decorated function must return a dictionary of headers to add to the JWT.
  227. """
  228. self._jwt_additional_header_callback = callback
  229. return callback
  230. def decode_key_loader(self, callback: Callable) -> Callable:
  231. """
  232. This decorator sets the callback function for dynamically setting the JWT
  233. decode key based on the **UNVERIFIED** contents of the token. Think
  234. carefully before using this functionality, in most cases you probably
  235. don't need it.
  236. The decorated function must take **two** arguments.
  237. The first argument is a dictionary containing the header data of the
  238. unverified JWT.
  239. The second argument is a dictionary containing the payload data of the
  240. unverified JWT.
  241. The decorated function must return a *string* that is used to decode and
  242. verify the token.
  243. """
  244. self._decode_key_callback = callback
  245. return callback
  246. def encode_key_loader(self, callback: Callable) -> Callable:
  247. """
  248. This decorator sets the callback function for dynamically setting the JWT
  249. encode key based on the tokens identity. Think carefully before using this
  250. functionality, in most cases you probably don't need it.
  251. The decorated function must take **one** argument.
  252. The argument is the identity used to create this JWT.
  253. The decorated function must return a *string* which is the secrete key used to
  254. encode the JWT.
  255. """
  256. self._encode_key_callback = callback
  257. return callback
  258. def expired_token_loader(self, callback: Callable) -> Callable:
  259. """
  260. This decorator sets the callback function for returning a custom
  261. response when an expired JWT is encountered.
  262. The decorated function must take **two** arguments.
  263. The first argument is a dictionary containing the header data of the JWT.
  264. The second argument is a dictionary containing the payload data of the JWT.
  265. The decorated function must return a Flask Response.
  266. """
  267. self._expired_token_callback = callback
  268. return callback
  269. def invalid_token_loader(self, callback: Callable) -> Callable:
  270. """
  271. This decorator sets the callback function for returning a custom
  272. response when an invalid JWT is encountered.
  273. This decorator sets the callback function that will be used if an
  274. invalid JWT attempts to access a protected endpoint.
  275. The decorated function must take **one** argument.
  276. The argument is a string which contains the reason why a token is invalid.
  277. The decorated function must return a Flask Response.
  278. """
  279. self._invalid_token_callback = callback
  280. return callback
  281. def needs_fresh_token_loader(self, callback: Callable) -> Callable:
  282. """
  283. This decorator sets the callback function for returning a custom
  284. response when a valid and non-fresh token is used on an endpoint
  285. that is marked as ``fresh=True``.
  286. The decorated function must take **two** arguments.
  287. The first argument is a dictionary containing the header data of the JWT.
  288. The second argument is a dictionary containing the payload data of the JWT.
  289. The decorated function must return a Flask Response.
  290. """
  291. self._needs_fresh_token_callback = callback
  292. return callback
  293. def revoked_token_loader(self, callback: Callable) -> Callable:
  294. """
  295. This decorator sets the callback function for returning a custom
  296. response when a revoked token is encountered.
  297. The decorated function must take **two** arguments.
  298. The first argument is a dictionary containing the header data of the JWT.
  299. The second argument is a dictionary containing the payload data of the JWT.
  300. The decorated function must return a Flask Response.
  301. """
  302. self._revoked_token_callback = callback
  303. return callback
  304. def token_in_blocklist_loader(self, callback: Callable) -> Callable:
  305. """
  306. This decorator sets the callback function used to check if a JWT has
  307. been revoked.
  308. The decorated function must take **two** arguments.
  309. The first argument is a dictionary containing the header data of the JWT.
  310. The second argument is a dictionary containing the payload data of the JWT.
  311. The decorated function must be return ``True`` if the token has been
  312. revoked, ``False`` otherwise.
  313. """
  314. self._token_in_blocklist_callback = callback
  315. return callback
  316. def token_verification_failed_loader(self, callback: Callable) -> Callable:
  317. """
  318. This decorator sets the callback function used to return a custom
  319. response when the claims verification check fails.
  320. The decorated function must take **two** arguments.
  321. The first argument is a dictionary containing the header data of the JWT.
  322. The second argument is a dictionary containing the payload data of the JWT.
  323. The decorated function must return a Flask Response.
  324. """
  325. self._token_verification_failed_callback = callback
  326. return callback
  327. def token_verification_loader(self, callback: Callable) -> Callable:
  328. """
  329. This decorator sets the callback function used for custom verification
  330. of a valid JWT.
  331. The decorated function must take **two** arguments.
  332. The first argument is a dictionary containing the header data of the JWT.
  333. The second argument is a dictionary containing the payload data of the JWT.
  334. The decorated function must return ``True`` if the token is valid, or
  335. ``False`` otherwise.
  336. """
  337. self._token_verification_callback = callback
  338. return callback
  339. def unauthorized_loader(self, callback: Callable) -> Callable:
  340. """
  341. This decorator sets the callback function used to return a custom
  342. response when no JWT is present.
  343. The decorated function must take **one** argument.
  344. The argument is a string that explains why the JWT could not be found.
  345. The decorated function must return a Flask Response.
  346. """
  347. self._unauthorized_callback = callback
  348. return callback
  349. def user_identity_loader(self, callback: Callable) -> Callable:
  350. """
  351. This decorator sets the callback function used to convert an identity to
  352. a string when creating JWTs. This is useful for using objects (such as
  353. SQLAlchemy instances) as the identity when creating your tokens.
  354. The decorated function must take **one** argument.
  355. The argument is the identity that was used when creating a JWT.
  356. The decorated function must return a string.
  357. """
  358. self._user_identity_callback = callback
  359. return callback
  360. def user_lookup_loader(self, callback: Callable) -> Callable:
  361. """
  362. This decorator sets the callback function used to convert a JWT into
  363. a python object that can be used in a protected endpoint. This is useful
  364. for automatically loading a SQLAlchemy instance based on the contents
  365. of the JWT.
  366. The object returned from this function can be accessed via
  367. :attr:`~flask_jwt_extended.current_user` or
  368. :meth:`~flask_jwt_extended.get_current_user`
  369. The decorated function must take **two** arguments.
  370. The first argument is a dictionary containing the header data of the JWT.
  371. The second argument is a dictionary containing the payload data of the JWT.
  372. The decorated function can return any python object, which can then be
  373. accessed in a protected endpoint. If an object cannot be loaded, for
  374. example if a user has been deleted from your database, ``None`` must be
  375. returned to indicate that an error occurred loading the user.
  376. """
  377. self._user_lookup_callback = callback
  378. return callback
  379. def user_lookup_error_loader(self, callback: Callable) -> Callable:
  380. """
  381. This decorator sets the callback function used to return a custom
  382. response when loading a user via
  383. :meth:`~flask_jwt_extended.JWTManager.user_lookup_loader` fails.
  384. The decorated function must take **two** arguments.
  385. The first argument is a dictionary containing the header data of the JWT.
  386. The second argument is a dictionary containing the payload data of the JWT.
  387. The decorated function must return a Flask Response.
  388. """
  389. self._user_lookup_error_callback = callback
  390. return callback
  391. def _encode_jwt_from_config(
  392. self,
  393. identity: Any,
  394. token_type: str,
  395. claims=None,
  396. fresh: Fresh = False,
  397. expires_delta: Optional[ExpiresDelta] = None,
  398. headers=None,
  399. ) -> str:
  400. header_overrides = self._jwt_additional_header_callback(identity)
  401. if headers is not None:
  402. header_overrides.update(headers)
  403. claim_overrides = self._user_claims_callback(identity)
  404. if claims is not None:
  405. claim_overrides.update(claims)
  406. if expires_delta is None:
  407. if token_type == "access":
  408. expires_delta = config.access_expires
  409. else:
  410. expires_delta = config.refresh_expires
  411. return _encode_jwt(
  412. algorithm=config.algorithm,
  413. audience=config.encode_audience,
  414. claim_overrides=claim_overrides,
  415. csrf=config.cookie_csrf_protect,
  416. expires_delta=expires_delta,
  417. fresh=fresh,
  418. header_overrides=header_overrides,
  419. identity=self._user_identity_callback(identity),
  420. identity_claim_key=config.identity_claim_key,
  421. issuer=config.encode_issuer,
  422. json_encoder=config.json_encoder,
  423. secret=self._encode_key_callback(identity),
  424. token_type=token_type,
  425. nbf=config.encode_nbf,
  426. )
  427. def _decode_jwt_from_config(
  428. self, encoded_token: str, csrf_value=None, allow_expired: bool = False
  429. ) -> dict:
  430. unverified_claims = jwt.decode(
  431. encoded_token,
  432. algorithms=config.decode_algorithms,
  433. options={"verify_signature": False},
  434. )
  435. unverified_headers = jwt.get_unverified_header(encoded_token)
  436. secret = self._decode_key_callback(unverified_headers, unverified_claims)
  437. kwargs = {
  438. "algorithms": config.decode_algorithms,
  439. "audience": config.decode_audience,
  440. "csrf_value": csrf_value,
  441. "encoded_token": encoded_token,
  442. "identity_claim_key": config.identity_claim_key,
  443. "issuer": config.decode_issuer,
  444. "leeway": config.leeway,
  445. "secret": secret,
  446. "verify_aud": config.decode_audience is not None,
  447. "verify_sub": config.verify_sub,
  448. }
  449. try:
  450. return _decode_jwt(**kwargs, allow_expired=allow_expired)
  451. except ExpiredSignatureError as e:
  452. # TODO: If we ever do another breaking change, don't raise this pyjwt
  453. # error directly, instead raise a custom error of ours from this
  454. # error.
  455. e.jwt_header = unverified_headers # type: ignore
  456. e.jwt_data = _decode_jwt(**kwargs, allow_expired=True) # type: ignore
  457. raise