commands.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611
  1. from __future__ import annotations
  2. import itertools
  3. import time
  4. from functools import partial
  5. from typing import Any
  6. from urllib.parse import urlparse
  7. import click
  8. from flask import Flask, current_app
  9. from flask.cli import with_appcontext
  10. from limits.strategies import RateLimiter
  11. from rich.console import Console, group
  12. from rich.live import Live
  13. from rich.pretty import Pretty
  14. from rich.prompt import Confirm
  15. from rich.table import Table
  16. from rich.theme import Theme
  17. from rich.tree import Tree
  18. from typing_extensions import TypedDict
  19. from werkzeug.exceptions import MethodNotAllowed, NotFound
  20. from werkzeug.routing import Rule
  21. from flask_limiter import Limiter
  22. from flask_limiter.constants import ConfigVars, ExemptionScope, HeaderNames
  23. from flask_limiter.typing import Callable, Generator, cast
  24. from flask_limiter.util import get_qualified_name
  25. from flask_limiter.wrappers import Limit
  26. limiter_theme = Theme(
  27. {
  28. "success": "bold green",
  29. "danger": "bold red",
  30. "error": "bold red",
  31. "blueprint": "bold red",
  32. "default": "magenta",
  33. "callable": "cyan",
  34. "entity": "magenta",
  35. "exempt": "bold red",
  36. "route": "yellow",
  37. "http": "bold green",
  38. "option": "bold yellow",
  39. }
  40. )
  41. def render_func(func: Any) -> str | Pretty:
  42. if callable(func):
  43. if func.__name__ == "<lambda>":
  44. return f"[callable]<lambda>({func.__module__})[/callable]"
  45. return f"[callable]{func.__module__}.{func.__name__}()[/callable]"
  46. return Pretty(func)
  47. def render_storage(ext: Limiter) -> Tree:
  48. render = Tree(ext._storage_uri or "N/A")
  49. if ext.storage:
  50. render.add(f"[entity]{ext.storage.__class__.__name__}[/entity]")
  51. render.add(f"[entity]{ext.storage.storage}[/entity]") # type: ignore
  52. render.add(Pretty(ext._storage_options or {}))
  53. health = ext.storage.check()
  54. if health:
  55. render.add("[success]OK[/success]")
  56. else:
  57. render.add("[error]Error[/error]")
  58. return render
  59. def render_strategy(strategy: RateLimiter) -> str:
  60. return f"[entity]{strategy.__class__.__name__}[/entity]"
  61. def render_limit_state(
  62. limiter: Limiter, endpoint: str, limit: Limit, key: str, method: str
  63. ) -> str:
  64. args = [key, limit.scope_for(endpoint, method)]
  65. if not limiter.storage or (limiter.storage and not limiter.storage.check()):
  66. return ": [error]Storage not available[/error]"
  67. test = limiter.limiter.test(limit.limit, *args)
  68. stats = limiter.limiter.get_window_stats(limit.limit, *args)
  69. if not test:
  70. return (
  71. f": [error]Fail[/error] ({stats[1]} out of {limit.limit.amount} remaining)"
  72. )
  73. else:
  74. return f": [success]Pass[/success] ({stats[1]} out of {limit.limit.amount} remaining)"
  75. def render_limit(limit: Limit, simple: bool = True) -> str:
  76. render = str(limit.limit)
  77. if simple:
  78. return render
  79. options = []
  80. if limit.deduct_when:
  81. options.append(f"deduct_when: {render_func(limit.deduct_when)}")
  82. if limit.exempt_when:
  83. options.append(f"exempt_when: {render_func(limit.exempt_when)}")
  84. if options:
  85. render = f"{render} [option]{{{', '.join(options)}}}[/option]"
  86. return render
  87. def render_limits(
  88. app: Flask,
  89. limiter: Limiter,
  90. limits: tuple[list[Limit], ...],
  91. endpoint: str | None = None,
  92. blueprint: str | None = None,
  93. rule: Rule | None = None,
  94. exemption_scope: ExemptionScope = ExemptionScope.NONE,
  95. test: str | None = None,
  96. method: str = "GET",
  97. label: str | None = "",
  98. ) -> Tree:
  99. _label = None
  100. if rule and endpoint:
  101. _label = f"{endpoint}: {rule}"
  102. label = _label or label or ""
  103. renderable = Tree(label)
  104. entries = []
  105. for limit in limits[0] + limits[1]:
  106. if endpoint:
  107. view_func = app.view_functions.get(endpoint, None)
  108. source = (
  109. "blueprint"
  110. if blueprint
  111. and limit in limiter.limit_manager.blueprint_limits(app, blueprint)
  112. else (
  113. "route"
  114. if limit
  115. in limiter.limit_manager.decorated_limits(
  116. get_qualified_name(view_func) if view_func else ""
  117. )
  118. else "default"
  119. )
  120. )
  121. else:
  122. source = "default"
  123. if limit.per_method and rule and rule.methods:
  124. for method in rule.methods:
  125. rendered = render_limit(limit, False)
  126. entry = f"[{source}]{rendered} [http]({method})[/http][/{source}]"
  127. if test:
  128. entry += render_limit_state(
  129. limiter, endpoint or "", limit, test, method
  130. )
  131. entries.append(entry)
  132. else:
  133. rendered = render_limit(limit, False)
  134. entry = f"[{source}]{rendered}[/{source}]"
  135. if test:
  136. entry += render_limit_state(
  137. limiter, endpoint or "", limit, test, method
  138. )
  139. entries.append(entry)
  140. if not entries and exemption_scope:
  141. renderable.add("[exempt]Exempt[/exempt]")
  142. else:
  143. [renderable.add(entry) for entry in entries]
  144. return renderable
  145. def get_filtered_endpoint(
  146. app: Flask,
  147. console: Console,
  148. endpoint: str | None,
  149. path: str | None,
  150. method: str | None = None,
  151. ) -> str | None:
  152. if not (endpoint or path):
  153. return None
  154. if endpoint:
  155. if endpoint in current_app.view_functions:
  156. return endpoint
  157. else:
  158. console.print(f"[red]Error: {endpoint} not found")
  159. elif path:
  160. adapter = app.url_map.bind("dev.null")
  161. parsed = urlparse(path)
  162. try:
  163. filter_endpoint, _ = adapter.match(
  164. parsed.path, method=method, query_args=parsed.query
  165. )
  166. return cast(str, filter_endpoint)
  167. except NotFound:
  168. console.print(
  169. f"[error]Error: {path} could not be matched to an endpoint[/error]"
  170. )
  171. except MethodNotAllowed:
  172. assert method
  173. console.print(
  174. f"[error]Error: {method.upper()}: {path}"
  175. " could not be matched to an endpoint[/error]"
  176. )
  177. raise SystemExit
  178. @click.group(help="Flask-Limiter maintenance & utility commmands")
  179. def cli() -> None:
  180. pass
  181. @cli.command(help="View the extension configuration")
  182. @with_appcontext
  183. def config() -> None:
  184. with current_app.test_request_context():
  185. console = Console(theme=limiter_theme)
  186. limiters = list(current_app.extensions.get("limiter", set()))
  187. limiter = limiters and list(limiters)[0]
  188. if limiter:
  189. extension_details = Table(title="Flask-Limiter Config")
  190. extension_details.add_column("Notes")
  191. extension_details.add_column("Configuration")
  192. extension_details.add_column("Value")
  193. extension_details.add_row(
  194. "Enabled", ConfigVars.ENABLED, Pretty(limiter.enabled)
  195. )
  196. extension_details.add_row(
  197. "Key Function", ConfigVars.KEY_FUNC, render_func(limiter._key_func)
  198. )
  199. extension_details.add_row(
  200. "Key Prefix", ConfigVars.KEY_PREFIX, Pretty(limiter._key_prefix)
  201. )
  202. limiter_config = Tree(ConfigVars.STRATEGY)
  203. limiter_config_values = Tree(render_strategy(limiter.limiter))
  204. node = limiter_config.add(ConfigVars.STORAGE_URI)
  205. node.add("Instance")
  206. node.add("Backend")
  207. limiter_config.add(ConfigVars.STORAGE_OPTIONS)
  208. limiter_config.add("Status")
  209. limiter_config_values.add(render_storage(limiter))
  210. extension_details.add_row(
  211. "Rate Limiting Config", limiter_config, limiter_config_values
  212. )
  213. if limiter.limit_manager.application_limits:
  214. extension_details.add_row(
  215. "Application Limits",
  216. ConfigVars.APPLICATION_LIMITS,
  217. Pretty(
  218. [
  219. render_limit(limit)
  220. for limit in limiter.limit_manager.application_limits
  221. ]
  222. ),
  223. )
  224. extension_details.add_row(
  225. None,
  226. ConfigVars.APPLICATION_LIMITS_PER_METHOD,
  227. Pretty(limiter._application_limits_per_method),
  228. )
  229. extension_details.add_row(
  230. None,
  231. ConfigVars.APPLICATION_LIMITS_EXEMPT_WHEN,
  232. render_func(limiter._application_limits_exempt_when),
  233. )
  234. extension_details.add_row(
  235. None,
  236. ConfigVars.APPLICATION_LIMITS_DEDUCT_WHEN,
  237. render_func(limiter._application_limits_deduct_when),
  238. )
  239. extension_details.add_row(
  240. None,
  241. ConfigVars.APPLICATION_LIMITS_COST,
  242. Pretty(limiter._application_limits_cost),
  243. )
  244. else:
  245. extension_details.add_row(
  246. "ApplicationLimits Limits",
  247. ConfigVars.APPLICATION_LIMITS,
  248. Pretty([]),
  249. )
  250. if limiter.limit_manager.default_limits:
  251. extension_details.add_row(
  252. "Default Limits",
  253. ConfigVars.DEFAULT_LIMITS,
  254. Pretty(
  255. [
  256. render_limit(limit)
  257. for limit in limiter.limit_manager.default_limits
  258. ]
  259. ),
  260. )
  261. extension_details.add_row(
  262. None,
  263. ConfigVars.DEFAULT_LIMITS_PER_METHOD,
  264. Pretty(limiter._default_limits_per_method),
  265. )
  266. extension_details.add_row(
  267. None,
  268. ConfigVars.DEFAULT_LIMITS_EXEMPT_WHEN,
  269. render_func(limiter._default_limits_exempt_when),
  270. )
  271. extension_details.add_row(
  272. None,
  273. ConfigVars.DEFAULT_LIMITS_DEDUCT_WHEN,
  274. render_func(limiter._default_limits_deduct_when),
  275. )
  276. extension_details.add_row(
  277. None,
  278. ConfigVars.DEFAULT_LIMITS_COST,
  279. render_func(limiter._default_limits_cost),
  280. )
  281. else:
  282. extension_details.add_row(
  283. "Default Limits", ConfigVars.DEFAULT_LIMITS, Pretty([])
  284. )
  285. if limiter._meta_limits:
  286. extension_details.add_row(
  287. "Meta Limits",
  288. ConfigVars.META_LIMITS,
  289. Pretty(
  290. [
  291. render_limit(limit)
  292. for limit in itertools.chain(*limiter._meta_limits)
  293. ]
  294. ),
  295. )
  296. if limiter._headers_enabled:
  297. header_configs = Tree(ConfigVars.HEADERS_ENABLED)
  298. header_configs.add(ConfigVars.HEADER_RESET)
  299. header_configs.add(ConfigVars.HEADER_REMAINING)
  300. header_configs.add(ConfigVars.HEADER_RETRY_AFTER)
  301. header_configs.add(ConfigVars.HEADER_RETRY_AFTER_VALUE)
  302. header_values = Tree(Pretty(limiter._headers_enabled))
  303. header_values.add(Pretty(limiter._header_mapping[HeaderNames.RESET]))
  304. header_values.add(
  305. Pretty(limiter._header_mapping[HeaderNames.REMAINING])
  306. )
  307. header_values.add(
  308. Pretty(limiter._header_mapping[HeaderNames.RETRY_AFTER])
  309. )
  310. header_values.add(Pretty(limiter._retry_after))
  311. extension_details.add_row(
  312. "Header configuration",
  313. header_configs,
  314. header_values,
  315. )
  316. else:
  317. extension_details.add_row(
  318. "Header configuration", ConfigVars.HEADERS_ENABLED, Pretty(False)
  319. )
  320. extension_details.add_row(
  321. "Fail on first breach",
  322. ConfigVars.FAIL_ON_FIRST_BREACH,
  323. Pretty(limiter._fail_on_first_breach),
  324. )
  325. extension_details.add_row(
  326. "On breach callback",
  327. ConfigVars.ON_BREACH,
  328. render_func(limiter._on_breach),
  329. )
  330. console.print(extension_details)
  331. else:
  332. console.print(
  333. f"No Flask-Limiter extension installed on {current_app}",
  334. style="bold red",
  335. )
  336. @cli.command(help="Enumerate details about all routes with rate limits")
  337. @click.option("--endpoint", default=None, help="Endpoint to filter by")
  338. @click.option("--path", default=None, help="Path to filter by")
  339. @click.option("--method", default=None, help="HTTP Method to filter by")
  340. @click.option("--key", default=None, help="Test the limit")
  341. @click.option("--watch/--no-watch", default=False, help="Create a live dashboard")
  342. @with_appcontext
  343. def limits(
  344. endpoint: str | None = None,
  345. path: str | None = None,
  346. method: str = "GET",
  347. key: str | None = None,
  348. watch: bool = False,
  349. ) -> None:
  350. with current_app.test_request_context():
  351. limiters: set[Limiter] = current_app.extensions.get("limiter", set())
  352. limiter: Limiter | None = list(limiters)[0] if limiters else None
  353. console = Console(theme=limiter_theme)
  354. if limiter:
  355. manager = limiter.limit_manager
  356. groups: dict[str, list[Callable[..., Tree]]] = {}
  357. filter_endpoint = get_filtered_endpoint(
  358. current_app, console, endpoint, path, method
  359. )
  360. for rule in sorted(
  361. current_app.url_map.iter_rules(filter_endpoint), key=lambda r: str(r)
  362. ):
  363. rule_endpoint = rule.endpoint
  364. if rule_endpoint == "static":
  365. continue
  366. if len(rule_endpoint.split(".")) > 1:
  367. bp_fullname = ".".join(rule_endpoint.split(".")[:-1])
  368. groups.setdefault(bp_fullname, []).append(
  369. partial(
  370. render_limits,
  371. current_app,
  372. limiter,
  373. manager.resolve_limits(
  374. current_app, rule_endpoint, bp_fullname
  375. ),
  376. rule_endpoint,
  377. bp_fullname,
  378. rule,
  379. exemption_scope=manager.exemption_scope(
  380. current_app, rule_endpoint, bp_fullname
  381. ),
  382. method=method,
  383. test=key,
  384. )
  385. )
  386. else:
  387. groups.setdefault("root", []).append(
  388. partial(
  389. render_limits,
  390. current_app,
  391. limiter,
  392. manager.resolve_limits(current_app, rule_endpoint, ""),
  393. rule_endpoint,
  394. None,
  395. rule,
  396. exemption_scope=manager.exemption_scope(
  397. current_app, rule_endpoint, None
  398. ),
  399. method=method,
  400. test=key,
  401. )
  402. )
  403. @group()
  404. def console_renderable() -> Generator: # type: ignore
  405. if (
  406. limiter
  407. and limiter.limit_manager.application_limits
  408. and not (endpoint or path)
  409. ):
  410. yield render_limits(
  411. current_app,
  412. limiter,
  413. (list(itertools.chain(*limiter._meta_limits)), []),
  414. test=key,
  415. method=method,
  416. label="[gold3]Meta Limits[/gold3]",
  417. )
  418. yield render_limits(
  419. current_app,
  420. limiter,
  421. (limiter.limit_manager.application_limits, []),
  422. test=key,
  423. method=method,
  424. label="[gold3]Application Limits[/gold3]",
  425. )
  426. for name in groups:
  427. if name == "root":
  428. group_tree = Tree(f"[gold3]{current_app.name}[/gold3]")
  429. else:
  430. group_tree = Tree(f"[blue]{name}[/blue]")
  431. [group_tree.add(renderable()) for renderable in groups[name]]
  432. yield group_tree
  433. if not watch:
  434. console.print(console_renderable())
  435. else: # noqa
  436. with Live(
  437. console_renderable(),
  438. console=console,
  439. refresh_per_second=0.4,
  440. screen=True,
  441. ) as live:
  442. while True:
  443. try:
  444. live.update(console_renderable())
  445. time.sleep(0.4)
  446. except KeyboardInterrupt:
  447. break
  448. else:
  449. console.print(
  450. f"No Flask-Limiter extension installed on {current_app}",
  451. style="bold red",
  452. )
  453. @cli.command(help="Clear limits for a specific key")
  454. @click.option("--endpoint", default=None, help="Endpoint to filter by")
  455. @click.option("--path", default=None, help="Path to filter by")
  456. @click.option("--method", default=None, help="HTTP Method to filter by")
  457. @click.option("--key", default=None, required=True, help="Key to reset the limits for")
  458. @click.option("-y", is_flag=True, help="Skip prompt for confirmation")
  459. @with_appcontext
  460. def clear(
  461. key: str,
  462. endpoint: str | None = None,
  463. path: str | None = None,
  464. method: str = "GET",
  465. y: bool = False,
  466. ) -> None:
  467. with current_app.test_request_context():
  468. limiters = list(current_app.extensions.get("limiter", set()))
  469. limiter: Limiter | None = limiters[0] if limiters else None
  470. console = Console(theme=limiter_theme)
  471. if limiter:
  472. manager = limiter.limit_manager
  473. filter_endpoint = get_filtered_endpoint(
  474. current_app, console, endpoint, path, method
  475. )
  476. class Details(TypedDict):
  477. rule: Rule
  478. limits: tuple[list[Limit], ...]
  479. rule_limits: dict[str, Details] = {}
  480. for rule in sorted(
  481. current_app.url_map.iter_rules(filter_endpoint), key=lambda r: str(r)
  482. ):
  483. rule_endpoint = rule.endpoint
  484. if rule_endpoint == "static":
  485. continue
  486. if len(rule_endpoint.split(".")) > 1:
  487. bp_fullname = ".".join(rule_endpoint.split(".")[:-1])
  488. rule_limits[rule_endpoint] = Details(
  489. rule=rule,
  490. limits=manager.resolve_limits(
  491. current_app, rule_endpoint, bp_fullname
  492. ),
  493. )
  494. else:
  495. rule_limits[rule_endpoint] = Details(
  496. rule=rule,
  497. limits=manager.resolve_limits(current_app, rule_endpoint, ""),
  498. )
  499. application_limits = None
  500. if not filter_endpoint:
  501. application_limits = limiter.limit_manager.application_limits
  502. if not y: # noqa
  503. if application_limits:
  504. console.print(
  505. render_limits(
  506. current_app,
  507. limiter,
  508. (application_limits, []),
  509. label="Application Limits",
  510. test=key,
  511. )
  512. )
  513. for endpoint, details in rule_limits.items():
  514. if details["limits"]:
  515. console.print(
  516. render_limits(
  517. current_app,
  518. limiter,
  519. details["limits"],
  520. endpoint,
  521. rule=details["rule"],
  522. test=key,
  523. )
  524. )
  525. if y or Confirm.ask(
  526. f"Proceed with resetting limits for key: [danger]{key}[/danger]?"
  527. ):
  528. if application_limits:
  529. node = Tree("Application Limits")
  530. for limit in application_limits:
  531. limiter.limiter.clear(
  532. limit.limit,
  533. key,
  534. limit.scope_for("", method),
  535. )
  536. node.add(f"{render_limit(limit)}: [success]Cleared[/success]")
  537. console.print(node)
  538. for endpoint, details in rule_limits.items():
  539. if details["limits"]:
  540. node = Tree(endpoint)
  541. default, decorated = details["limits"]
  542. for limit in default + decorated:
  543. if (
  544. limit.per_method
  545. and details["rule"]
  546. and details["rule"].methods
  547. and not method
  548. ):
  549. for rule_method in details["rule"].methods:
  550. limiter.limiter.clear(
  551. limit.limit,
  552. key,
  553. limit.scope_for(endpoint, rule_method),
  554. )
  555. else:
  556. limiter.limiter.clear(
  557. limit.limit,
  558. key,
  559. limit.scope_for(endpoint, method),
  560. )
  561. node.add(
  562. f"{render_limit(limit)}: [success]Cleared[/success]"
  563. )
  564. console.print(node)
  565. else:
  566. console.print(
  567. f"No Flask-Limiter extension installed on {current_app}",
  568. style="bold red",
  569. )
  570. if __name__ == "__main__": # noqa
  571. cli()