123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611 |
- from __future__ import annotations
- import itertools
- import time
- from functools import partial
- from typing import Any
- from urllib.parse import urlparse
- import click
- from flask import Flask, current_app
- from flask.cli import with_appcontext
- from limits.strategies import RateLimiter
- from rich.console import Console, group
- from rich.live import Live
- from rich.pretty import Pretty
- from rich.prompt import Confirm
- from rich.table import Table
- from rich.theme import Theme
- from rich.tree import Tree
- from typing_extensions import TypedDict
- from werkzeug.exceptions import MethodNotAllowed, NotFound
- from werkzeug.routing import Rule
- from flask_limiter import Limiter
- from flask_limiter.constants import ConfigVars, ExemptionScope, HeaderNames
- from flask_limiter.typing import Callable, Generator, cast
- from flask_limiter.util import get_qualified_name
- from flask_limiter.wrappers import Limit
- limiter_theme = Theme(
- {
- "success": "bold green",
- "danger": "bold red",
- "error": "bold red",
- "blueprint": "bold red",
- "default": "magenta",
- "callable": "cyan",
- "entity": "magenta",
- "exempt": "bold red",
- "route": "yellow",
- "http": "bold green",
- "option": "bold yellow",
- }
- )
- def render_func(func: Any) -> str | Pretty:
- if callable(func):
- if func.__name__ == "<lambda>":
- return f"[callable]<lambda>({func.__module__})[/callable]"
- return f"[callable]{func.__module__}.{func.__name__}()[/callable]"
- return Pretty(func)
- def render_storage(ext: Limiter) -> Tree:
- render = Tree(ext._storage_uri or "N/A")
- if ext.storage:
- render.add(f"[entity]{ext.storage.__class__.__name__}[/entity]")
- render.add(f"[entity]{ext.storage.storage}[/entity]") # type: ignore
- render.add(Pretty(ext._storage_options or {}))
- health = ext.storage.check()
- if health:
- render.add("[success]OK[/success]")
- else:
- render.add("[error]Error[/error]")
- return render
- def render_strategy(strategy: RateLimiter) -> str:
- return f"[entity]{strategy.__class__.__name__}[/entity]"
- def render_limit_state(
- limiter: Limiter, endpoint: str, limit: Limit, key: str, method: str
- ) -> str:
- args = [key, limit.scope_for(endpoint, method)]
- if not limiter.storage or (limiter.storage and not limiter.storage.check()):
- return ": [error]Storage not available[/error]"
- test = limiter.limiter.test(limit.limit, *args)
- stats = limiter.limiter.get_window_stats(limit.limit, *args)
- if not test:
- return (
- f": [error]Fail[/error] ({stats[1]} out of {limit.limit.amount} remaining)"
- )
- else:
- return f": [success]Pass[/success] ({stats[1]} out of {limit.limit.amount} remaining)"
- def render_limit(limit: Limit, simple: bool = True) -> str:
- render = str(limit.limit)
- if simple:
- return render
- options = []
- if limit.deduct_when:
- options.append(f"deduct_when: {render_func(limit.deduct_when)}")
- if limit.exempt_when:
- options.append(f"exempt_when: {render_func(limit.exempt_when)}")
- if options:
- render = f"{render} [option]{{{', '.join(options)}}}[/option]"
- return render
- def render_limits(
- app: Flask,
- limiter: Limiter,
- limits: tuple[list[Limit], ...],
- endpoint: str | None = None,
- blueprint: str | None = None,
- rule: Rule | None = None,
- exemption_scope: ExemptionScope = ExemptionScope.NONE,
- test: str | None = None,
- method: str = "GET",
- label: str | None = "",
- ) -> Tree:
- _label = None
- if rule and endpoint:
- _label = f"{endpoint}: {rule}"
- label = _label or label or ""
- renderable = Tree(label)
- entries = []
- for limit in limits[0] + limits[1]:
- if endpoint:
- view_func = app.view_functions.get(endpoint, None)
- source = (
- "blueprint"
- if blueprint
- and limit in limiter.limit_manager.blueprint_limits(app, blueprint)
- else (
- "route"
- if limit
- in limiter.limit_manager.decorated_limits(
- get_qualified_name(view_func) if view_func else ""
- )
- else "default"
- )
- )
- else:
- source = "default"
- if limit.per_method and rule and rule.methods:
- for method in rule.methods:
- rendered = render_limit(limit, False)
- entry = f"[{source}]{rendered} [http]({method})[/http][/{source}]"
- if test:
- entry += render_limit_state(
- limiter, endpoint or "", limit, test, method
- )
- entries.append(entry)
- else:
- rendered = render_limit(limit, False)
- entry = f"[{source}]{rendered}[/{source}]"
- if test:
- entry += render_limit_state(
- limiter, endpoint or "", limit, test, method
- )
- entries.append(entry)
- if not entries and exemption_scope:
- renderable.add("[exempt]Exempt[/exempt]")
- else:
- [renderable.add(entry) for entry in entries]
- return renderable
- def get_filtered_endpoint(
- app: Flask,
- console: Console,
- endpoint: str | None,
- path: str | None,
- method: str | None = None,
- ) -> str | None:
- if not (endpoint or path):
- return None
- if endpoint:
- if endpoint in current_app.view_functions:
- return endpoint
- else:
- console.print(f"[red]Error: {endpoint} not found")
- elif path:
- adapter = app.url_map.bind("dev.null")
- parsed = urlparse(path)
- try:
- filter_endpoint, _ = adapter.match(
- parsed.path, method=method, query_args=parsed.query
- )
- return cast(str, filter_endpoint)
- except NotFound:
- console.print(
- f"[error]Error: {path} could not be matched to an endpoint[/error]"
- )
- except MethodNotAllowed:
- assert method
- console.print(
- f"[error]Error: {method.upper()}: {path}"
- " could not be matched to an endpoint[/error]"
- )
- raise SystemExit
- @click.group(help="Flask-Limiter maintenance & utility commmands")
- def cli() -> None:
- pass
- @cli.command(help="View the extension configuration")
- @with_appcontext
- def config() -> None:
- with current_app.test_request_context():
- console = Console(theme=limiter_theme)
- limiters = list(current_app.extensions.get("limiter", set()))
- limiter = limiters and list(limiters)[0]
- if limiter:
- extension_details = Table(title="Flask-Limiter Config")
- extension_details.add_column("Notes")
- extension_details.add_column("Configuration")
- extension_details.add_column("Value")
- extension_details.add_row(
- "Enabled", ConfigVars.ENABLED, Pretty(limiter.enabled)
- )
- extension_details.add_row(
- "Key Function", ConfigVars.KEY_FUNC, render_func(limiter._key_func)
- )
- extension_details.add_row(
- "Key Prefix", ConfigVars.KEY_PREFIX, Pretty(limiter._key_prefix)
- )
- limiter_config = Tree(ConfigVars.STRATEGY)
- limiter_config_values = Tree(render_strategy(limiter.limiter))
- node = limiter_config.add(ConfigVars.STORAGE_URI)
- node.add("Instance")
- node.add("Backend")
- limiter_config.add(ConfigVars.STORAGE_OPTIONS)
- limiter_config.add("Status")
- limiter_config_values.add(render_storage(limiter))
- extension_details.add_row(
- "Rate Limiting Config", limiter_config, limiter_config_values
- )
- if limiter.limit_manager.application_limits:
- extension_details.add_row(
- "Application Limits",
- ConfigVars.APPLICATION_LIMITS,
- Pretty(
- [
- render_limit(limit)
- for limit in limiter.limit_manager.application_limits
- ]
- ),
- )
- extension_details.add_row(
- None,
- ConfigVars.APPLICATION_LIMITS_PER_METHOD,
- Pretty(limiter._application_limits_per_method),
- )
- extension_details.add_row(
- None,
- ConfigVars.APPLICATION_LIMITS_EXEMPT_WHEN,
- render_func(limiter._application_limits_exempt_when),
- )
- extension_details.add_row(
- None,
- ConfigVars.APPLICATION_LIMITS_DEDUCT_WHEN,
- render_func(limiter._application_limits_deduct_when),
- )
- extension_details.add_row(
- None,
- ConfigVars.APPLICATION_LIMITS_COST,
- Pretty(limiter._application_limits_cost),
- )
- else:
- extension_details.add_row(
- "ApplicationLimits Limits",
- ConfigVars.APPLICATION_LIMITS,
- Pretty([]),
- )
- if limiter.limit_manager.default_limits:
- extension_details.add_row(
- "Default Limits",
- ConfigVars.DEFAULT_LIMITS,
- Pretty(
- [
- render_limit(limit)
- for limit in limiter.limit_manager.default_limits
- ]
- ),
- )
- extension_details.add_row(
- None,
- ConfigVars.DEFAULT_LIMITS_PER_METHOD,
- Pretty(limiter._default_limits_per_method),
- )
- extension_details.add_row(
- None,
- ConfigVars.DEFAULT_LIMITS_EXEMPT_WHEN,
- render_func(limiter._default_limits_exempt_when),
- )
- extension_details.add_row(
- None,
- ConfigVars.DEFAULT_LIMITS_DEDUCT_WHEN,
- render_func(limiter._default_limits_deduct_when),
- )
- extension_details.add_row(
- None,
- ConfigVars.DEFAULT_LIMITS_COST,
- render_func(limiter._default_limits_cost),
- )
- else:
- extension_details.add_row(
- "Default Limits", ConfigVars.DEFAULT_LIMITS, Pretty([])
- )
- if limiter._meta_limits:
- extension_details.add_row(
- "Meta Limits",
- ConfigVars.META_LIMITS,
- Pretty(
- [
- render_limit(limit)
- for limit in itertools.chain(*limiter._meta_limits)
- ]
- ),
- )
- if limiter._headers_enabled:
- header_configs = Tree(ConfigVars.HEADERS_ENABLED)
- header_configs.add(ConfigVars.HEADER_RESET)
- header_configs.add(ConfigVars.HEADER_REMAINING)
- header_configs.add(ConfigVars.HEADER_RETRY_AFTER)
- header_configs.add(ConfigVars.HEADER_RETRY_AFTER_VALUE)
- header_values = Tree(Pretty(limiter._headers_enabled))
- header_values.add(Pretty(limiter._header_mapping[HeaderNames.RESET]))
- header_values.add(
- Pretty(limiter._header_mapping[HeaderNames.REMAINING])
- )
- header_values.add(
- Pretty(limiter._header_mapping[HeaderNames.RETRY_AFTER])
- )
- header_values.add(Pretty(limiter._retry_after))
- extension_details.add_row(
- "Header configuration",
- header_configs,
- header_values,
- )
- else:
- extension_details.add_row(
- "Header configuration", ConfigVars.HEADERS_ENABLED, Pretty(False)
- )
- extension_details.add_row(
- "Fail on first breach",
- ConfigVars.FAIL_ON_FIRST_BREACH,
- Pretty(limiter._fail_on_first_breach),
- )
- extension_details.add_row(
- "On breach callback",
- ConfigVars.ON_BREACH,
- render_func(limiter._on_breach),
- )
- console.print(extension_details)
- else:
- console.print(
- f"No Flask-Limiter extension installed on {current_app}",
- style="bold red",
- )
- @cli.command(help="Enumerate details about all routes with rate limits")
- @click.option("--endpoint", default=None, help="Endpoint to filter by")
- @click.option("--path", default=None, help="Path to filter by")
- @click.option("--method", default=None, help="HTTP Method to filter by")
- @click.option("--key", default=None, help="Test the limit")
- @click.option("--watch/--no-watch", default=False, help="Create a live dashboard")
- @with_appcontext
- def limits(
- endpoint: str | None = None,
- path: str | None = None,
- method: str = "GET",
- key: str | None = None,
- watch: bool = False,
- ) -> None:
- with current_app.test_request_context():
- limiters: set[Limiter] = current_app.extensions.get("limiter", set())
- limiter: Limiter | None = list(limiters)[0] if limiters else None
- console = Console(theme=limiter_theme)
- if limiter:
- manager = limiter.limit_manager
- groups: dict[str, list[Callable[..., Tree]]] = {}
- filter_endpoint = get_filtered_endpoint(
- current_app, console, endpoint, path, method
- )
- for rule in sorted(
- current_app.url_map.iter_rules(filter_endpoint), key=lambda r: str(r)
- ):
- rule_endpoint = rule.endpoint
- if rule_endpoint == "static":
- continue
- if len(rule_endpoint.split(".")) > 1:
- bp_fullname = ".".join(rule_endpoint.split(".")[:-1])
- groups.setdefault(bp_fullname, []).append(
- partial(
- render_limits,
- current_app,
- limiter,
- manager.resolve_limits(
- current_app, rule_endpoint, bp_fullname
- ),
- rule_endpoint,
- bp_fullname,
- rule,
- exemption_scope=manager.exemption_scope(
- current_app, rule_endpoint, bp_fullname
- ),
- method=method,
- test=key,
- )
- )
- else:
- groups.setdefault("root", []).append(
- partial(
- render_limits,
- current_app,
- limiter,
- manager.resolve_limits(current_app, rule_endpoint, ""),
- rule_endpoint,
- None,
- rule,
- exemption_scope=manager.exemption_scope(
- current_app, rule_endpoint, None
- ),
- method=method,
- test=key,
- )
- )
- @group()
- def console_renderable() -> Generator: # type: ignore
- if (
- limiter
- and limiter.limit_manager.application_limits
- and not (endpoint or path)
- ):
- yield render_limits(
- current_app,
- limiter,
- (list(itertools.chain(*limiter._meta_limits)), []),
- test=key,
- method=method,
- label="[gold3]Meta Limits[/gold3]",
- )
- yield render_limits(
- current_app,
- limiter,
- (limiter.limit_manager.application_limits, []),
- test=key,
- method=method,
- label="[gold3]Application Limits[/gold3]",
- )
- for name in groups:
- if name == "root":
- group_tree = Tree(f"[gold3]{current_app.name}[/gold3]")
- else:
- group_tree = Tree(f"[blue]{name}[/blue]")
- [group_tree.add(renderable()) for renderable in groups[name]]
- yield group_tree
- if not watch:
- console.print(console_renderable())
- else: # noqa
- with Live(
- console_renderable(),
- console=console,
- refresh_per_second=0.4,
- screen=True,
- ) as live:
- while True:
- try:
- live.update(console_renderable())
- time.sleep(0.4)
- except KeyboardInterrupt:
- break
- else:
- console.print(
- f"No Flask-Limiter extension installed on {current_app}",
- style="bold red",
- )
- @cli.command(help="Clear limits for a specific key")
- @click.option("--endpoint", default=None, help="Endpoint to filter by")
- @click.option("--path", default=None, help="Path to filter by")
- @click.option("--method", default=None, help="HTTP Method to filter by")
- @click.option("--key", default=None, required=True, help="Key to reset the limits for")
- @click.option("-y", is_flag=True, help="Skip prompt for confirmation")
- @with_appcontext
- def clear(
- key: str,
- endpoint: str | None = None,
- path: str | None = None,
- method: str = "GET",
- y: bool = False,
- ) -> None:
- with current_app.test_request_context():
- limiters = list(current_app.extensions.get("limiter", set()))
- limiter: Limiter | None = limiters[0] if limiters else None
- console = Console(theme=limiter_theme)
- if limiter:
- manager = limiter.limit_manager
- filter_endpoint = get_filtered_endpoint(
- current_app, console, endpoint, path, method
- )
- class Details(TypedDict):
- rule: Rule
- limits: tuple[list[Limit], ...]
- rule_limits: dict[str, Details] = {}
- for rule in sorted(
- current_app.url_map.iter_rules(filter_endpoint), key=lambda r: str(r)
- ):
- rule_endpoint = rule.endpoint
- if rule_endpoint == "static":
- continue
- if len(rule_endpoint.split(".")) > 1:
- bp_fullname = ".".join(rule_endpoint.split(".")[:-1])
- rule_limits[rule_endpoint] = Details(
- rule=rule,
- limits=manager.resolve_limits(
- current_app, rule_endpoint, bp_fullname
- ),
- )
- else:
- rule_limits[rule_endpoint] = Details(
- rule=rule,
- limits=manager.resolve_limits(current_app, rule_endpoint, ""),
- )
- application_limits = None
- if not filter_endpoint:
- application_limits = limiter.limit_manager.application_limits
- if not y: # noqa
- if application_limits:
- console.print(
- render_limits(
- current_app,
- limiter,
- (application_limits, []),
- label="Application Limits",
- test=key,
- )
- )
- for endpoint, details in rule_limits.items():
- if details["limits"]:
- console.print(
- render_limits(
- current_app,
- limiter,
- details["limits"],
- endpoint,
- rule=details["rule"],
- test=key,
- )
- )
- if y or Confirm.ask(
- f"Proceed with resetting limits for key: [danger]{key}[/danger]?"
- ):
- if application_limits:
- node = Tree("Application Limits")
- for limit in application_limits:
- limiter.limiter.clear(
- limit.limit,
- key,
- limit.scope_for("", method),
- )
- node.add(f"{render_limit(limit)}: [success]Cleared[/success]")
- console.print(node)
- for endpoint, details in rule_limits.items():
- if details["limits"]:
- node = Tree(endpoint)
- default, decorated = details["limits"]
- for limit in default + decorated:
- if (
- limit.per_method
- and details["rule"]
- and details["rule"].methods
- and not method
- ):
- for rule_method in details["rule"].methods:
- limiter.limiter.clear(
- limit.limit,
- key,
- limit.scope_for(endpoint, rule_method),
- )
- else:
- limiter.limiter.clear(
- limit.limit,
- key,
- limit.scope_for(endpoint, method),
- )
- node.add(
- f"{render_limit(limit)}: [success]Cleared[/success]"
- )
- console.print(node)
- else:
- console.print(
- f"No Flask-Limiter extension installed on {current_app}",
- style="bold red",
- )
- if __name__ == "__main__": # noqa
- cli()
|