manager.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. from __future__ import annotations
  2. import itertools
  3. import logging
  4. from collections.abc import Iterable
  5. import flask
  6. from ordered_set import OrderedSet
  7. from .constants import ExemptionScope
  8. from .util import get_qualified_name
  9. from .wrappers import Limit, LimitGroup
  10. class LimitManager:
  11. def __init__(
  12. self,
  13. application_limits: list[LimitGroup],
  14. default_limits: list[LimitGroup],
  15. decorated_limits: dict[str, OrderedSet[LimitGroup]],
  16. blueprint_limits: dict[str, OrderedSet[LimitGroup]],
  17. route_exemptions: dict[str, ExemptionScope],
  18. blueprint_exemptions: dict[str, ExemptionScope],
  19. ) -> None:
  20. self._application_limits = application_limits
  21. self._default_limits = default_limits
  22. self._decorated_limits = decorated_limits
  23. self._blueprint_limits = blueprint_limits
  24. self._route_exemptions = route_exemptions
  25. self._blueprint_exemptions = blueprint_exemptions
  26. self._endpoint_hints: dict[str, OrderedSet[str]] = {}
  27. self._logger = logging.getLogger("flask-limiter")
  28. @property
  29. def application_limits(self) -> list[Limit]:
  30. return list(itertools.chain(*self._application_limits))
  31. @property
  32. def default_limits(self) -> list[Limit]:
  33. return list(itertools.chain(*self._default_limits))
  34. def set_application_limits(self, limits: list[LimitGroup]) -> None:
  35. self._application_limits = limits
  36. def set_default_limits(self, limits: list[LimitGroup]) -> None:
  37. self._default_limits = limits
  38. def add_decorated_limit(
  39. self, route: str, limit: LimitGroup | None, override: bool = False
  40. ) -> None:
  41. if limit:
  42. if not override:
  43. self._decorated_limits.setdefault(route, OrderedSet()).add(limit)
  44. else:
  45. self._decorated_limits[route] = OrderedSet([limit])
  46. def add_blueprint_limit(self, blueprint: str, limit: LimitGroup | None) -> None:
  47. if limit:
  48. self._blueprint_limits.setdefault(blueprint, OrderedSet()).add(limit)
  49. def add_route_exemption(self, route: str, scope: ExemptionScope) -> None:
  50. self._route_exemptions[route] = scope
  51. def add_blueprint_exemption(self, blueprint: str, scope: ExemptionScope) -> None:
  52. self._blueprint_exemptions[blueprint] = scope
  53. def add_endpoint_hint(self, endpoint: str, callable: str) -> None:
  54. self._endpoint_hints.setdefault(endpoint, OrderedSet()).add(callable)
  55. def has_hints(self, endpoint: str) -> bool:
  56. return bool(self._endpoint_hints.get(endpoint))
  57. def resolve_limits(
  58. self,
  59. app: flask.Flask,
  60. endpoint: str | None = None,
  61. blueprint: str | None = None,
  62. callable_name: str | None = None,
  63. in_middleware: bool = False,
  64. marked_for_limiting: bool = False,
  65. ) -> tuple[list[Limit], ...]:
  66. before_request_context = in_middleware and marked_for_limiting
  67. decorated_limits = []
  68. hinted_limits = []
  69. if endpoint:
  70. if not in_middleware:
  71. if not callable_name:
  72. view_func = app.view_functions.get(endpoint, None)
  73. name = get_qualified_name(view_func) if view_func else ""
  74. else:
  75. name = callable_name
  76. decorated_limits.extend(self.decorated_limits(name))
  77. for hint in self._endpoint_hints.get(endpoint, OrderedSet()):
  78. hinted_limits.extend(self.decorated_limits(hint))
  79. if blueprint:
  80. if not before_request_context and (
  81. not decorated_limits
  82. or all(not limit.override_defaults for limit in decorated_limits)
  83. ):
  84. decorated_limits.extend(self.blueprint_limits(app, blueprint))
  85. exemption_scope = self.exemption_scope(app, endpoint, blueprint)
  86. all_limits = (
  87. self.application_limits
  88. if in_middleware and not (exemption_scope & ExemptionScope.APPLICATION)
  89. else []
  90. )
  91. # all_limits += decorated_limits
  92. explicit_limits_exempt = all(limit.method_exempt for limit in decorated_limits)
  93. # all the decorated limits explicitly declared
  94. # that they don't override the defaults - so, they should
  95. # be included.
  96. combined_defaults = all(
  97. not limit.override_defaults for limit in decorated_limits
  98. )
  99. # previous requests to this endpoint have exercised decorated
  100. # rate limits on callables that are not view functions. check
  101. # if all of them declared that they don't override defaults
  102. # and if so include the default limits.
  103. hinted_limits_request_defaults = (
  104. all(not limit.override_defaults for limit in hinted_limits)
  105. if hinted_limits
  106. else False
  107. )
  108. if (
  109. (explicit_limits_exempt or combined_defaults)
  110. and (
  111. not (before_request_context or exemption_scope & ExemptionScope.DEFAULT)
  112. )
  113. ) or hinted_limits_request_defaults:
  114. all_limits += self.default_limits
  115. return all_limits, decorated_limits
  116. def exemption_scope(
  117. self, app: flask.Flask, endpoint: str | None, blueprint: str | None
  118. ) -> ExemptionScope:
  119. view_func = app.view_functions.get(endpoint or "", None)
  120. name = get_qualified_name(view_func) if view_func else ""
  121. route_exemption_scope = self._route_exemptions.get(name, ExemptionScope.NONE)
  122. blueprint_instance = app.blueprints.get(blueprint) if blueprint else None
  123. if not blueprint_instance:
  124. return route_exemption_scope
  125. else:
  126. assert blueprint
  127. (
  128. blueprint_exemption_scope,
  129. ancestor_exemption_scopes,
  130. ) = self._blueprint_exemption_scope(app, blueprint)
  131. if (
  132. blueprint_exemption_scope
  133. & ~(ExemptionScope.DEFAULT | ExemptionScope.APPLICATION)
  134. or ancestor_exemption_scopes
  135. ):
  136. for exemption in ancestor_exemption_scopes.values():
  137. blueprint_exemption_scope |= exemption
  138. return route_exemption_scope | blueprint_exemption_scope
  139. def decorated_limits(self, callable_name: str) -> list[Limit]:
  140. limits = []
  141. if not self._route_exemptions.get(callable_name, ExemptionScope.NONE):
  142. if callable_name in self._decorated_limits:
  143. for group in self._decorated_limits[callable_name]:
  144. try:
  145. for limit in group:
  146. limits.append(limit)
  147. except ValueError as e:
  148. self._logger.error(
  149. f"failed to load ratelimit for function {callable_name}: {e}",
  150. )
  151. return limits
  152. def blueprint_limits(self, app: flask.Flask, blueprint: str) -> list[Limit]:
  153. limits: list[Limit] = []
  154. blueprint_instance = app.blueprints.get(blueprint) if blueprint else None
  155. if blueprint_instance:
  156. blueprint_name = blueprint_instance.name
  157. blueprint_ancestory = set(blueprint.split(".") if blueprint else [])
  158. self_exemption, ancestor_exemptions = self._blueprint_exemption_scope(
  159. app, blueprint
  160. )
  161. if not (
  162. self_exemption & ~(ExemptionScope.DEFAULT | ExemptionScope.APPLICATION)
  163. ):
  164. blueprint_self_limits = self._blueprint_limits.get(
  165. blueprint_name, OrderedSet()
  166. )
  167. blueprint_limits: Iterable[LimitGroup] = (
  168. itertools.chain(
  169. *(
  170. self._blueprint_limits.get(member, [])
  171. for member in blueprint_ancestory.intersection(
  172. self._blueprint_limits
  173. ).difference(ancestor_exemptions)
  174. )
  175. )
  176. if not (
  177. blueprint_self_limits
  178. and all(
  179. limit.override_defaults for limit in blueprint_self_limits
  180. )
  181. )
  182. and not self._blueprint_exemptions.get(
  183. blueprint_name, ExemptionScope.NONE
  184. )
  185. & ExemptionScope.ANCESTORS
  186. else blueprint_self_limits
  187. )
  188. if blueprint_limits:
  189. for limit_group in blueprint_limits:
  190. try:
  191. limits.extend(
  192. [
  193. Limit(
  194. limit.limit,
  195. limit.key_func,
  196. limit.scope,
  197. limit.per_method,
  198. limit.methods,
  199. limit.error_message,
  200. limit.exempt_when,
  201. limit.override_defaults,
  202. limit.deduct_when,
  203. limit.on_breach,
  204. limit.cost,
  205. limit.shared,
  206. )
  207. for limit in limit_group
  208. ]
  209. )
  210. except ValueError as e:
  211. self._logger.error(
  212. f"failed to load ratelimit for blueprint {blueprint_name}: {e}",
  213. )
  214. return limits
  215. def _blueprint_exemption_scope(
  216. self, app: flask.Flask, blueprint_name: str
  217. ) -> tuple[ExemptionScope, dict[str, ExemptionScope]]:
  218. name = app.blueprints[blueprint_name].name
  219. exemption = self._blueprint_exemptions.get(name, ExemptionScope.NONE) & ~(
  220. ExemptionScope.ANCESTORS
  221. )
  222. ancestory = set(blueprint_name.split("."))
  223. ancestor_exemption = {
  224. k
  225. for k, f in self._blueprint_exemptions.items()
  226. if f & ExemptionScope.DESCENDENTS
  227. }.intersection(ancestory)
  228. return exemption, {
  229. k: self._blueprint_exemptions.get(k, ExemptionScope.NONE)
  230. for k in ancestor_exemption
  231. }