_patching.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. # Source code: https://github.com/hamdanal/rich-argparse
  2. # MIT license: Copyright (c) Ali Hamdan <ali.hamdan.dev@gmail.com>
  3. # for internal use only
  4. from __future__ import annotations
  5. from rich_argparse._argparse import RichHelpFormatter
  6. def patch_default_formatter_class(
  7. cls=None, /, *, formatter_class=RichHelpFormatter, method_name="__init__"
  8. ):
  9. """Patch the default `formatter_class` parameter of an argument parser constructor.
  10. Parameters
  11. ----------
  12. cls : (type, optional)
  13. The class to patch. If not provided, a decorator is returned.
  14. formatter_class : (type, optional)
  15. The new formatter class to use. Defaults to ``RichHelpFormatter``.
  16. method_name : (str, optional)
  17. The method name to patch. Defaults to ``__init__``.
  18. Examples
  19. --------
  20. Can be used as a normal function to patch an existing class::
  21. # Patch the default formatter class of `argparse.ArgumentParser`
  22. patch_default_formatter_class(argparse.ArgumentParser)
  23. # Patch the default formatter class of django commands
  24. from django.core.management.base import BaseCommand, DjangoHelpFormatter
  25. class DjangoRichHelpFormatter(DjangoHelpFormatter, RichHelpFormatter): ...
  26. patch_default_formatter_class(
  27. BaseCommand, formatter_class=DjangoRichHelpFormatter, method_name="create_parser"
  28. )
  29. Or as a decorator to patch a new class::
  30. @patch_default_formatter_class
  31. class MyArgumentParser(argparse.ArgumentParser):
  32. pass
  33. @patch_default_formatter_class(formatter_class=RawDescriptionRichHelpFormatter)
  34. class MyOtherArgumentParser(argparse.ArgumentParser):
  35. pass
  36. """
  37. import functools
  38. def decorator(cls, /):
  39. method = getattr(cls, method_name)
  40. if not callable(method):
  41. raise TypeError(f"'{cls.__name__}.{method_name}' is not callable")
  42. @functools.wraps(method)
  43. def wrapper(*args, **kwargs):
  44. kwargs.setdefault("formatter_class", formatter_class)
  45. return method(*args, **kwargs)
  46. setattr(cls, method_name, wrapper)
  47. return cls
  48. if cls is None:
  49. return decorator
  50. return decorator(cls)