ruler.py 9.0 KB


  1. """
  2. class Ruler
  3. Helper class, used by [[MarkdownIt#core]], [[MarkdownIt#block]] and
  4. [[MarkdownIt#inline]] to manage sequences of functions (rules):
  5. - keep rules in defined order
  6. - assign the name to each rule
  7. - enable/disable rules
  8. - add/replace rules
  9. - allow assign rules to additional named chains (in the same)
  10. - caching lists of active rules
  11. You will not need use this class directly until write plugins. For simple
  12. rules control use [[MarkdownIt.disable]], [[MarkdownIt.enable]] and
  13. [[MarkdownIt.use]].
  14. """
  15. from __future__ import annotations
  16. from collections.abc import Iterable
  17. from dataclasses import dataclass, field
  18. from typing import TYPE_CHECKING, Generic, TypedDict, TypeVar
  19. import warnings
  20. from markdown_it._compat import DATACLASS_KWARGS
  21. from .utils import EnvType
  22. if TYPE_CHECKING:
  23. from markdown_it import MarkdownIt
  24. class StateBase:
  25. def __init__(self, src: str, md: MarkdownIt, env: EnvType):
  26. self.src = src
  27. self.env = env
  28. self.md = md
  29. @property
  30. def src(self) -> str:
  31. return self._src
  32. @src.setter
  33. def src(self, value: str) -> None:
  34. self._src = value
  35. self._srcCharCode: tuple[int, ...] | None = None
  36. @property
  37. def srcCharCode(self) -> tuple[int, ...]:
  38. warnings.warn(
  39. "StateBase.srcCharCode is deprecated. Use StateBase.src instead.",
  40. DeprecationWarning,
  41. stacklevel=2,
  42. )
  43. if self._srcCharCode is None:
  44. self._srcCharCode = tuple(ord(c) for c in self._src)
  45. return self._srcCharCode
  46. class RuleOptionsType(TypedDict, total=False):
  47. alt: list[str]
  48. RuleFuncTv = TypeVar("RuleFuncTv")
  49. """A rule function, whose signature is dependent on the state type."""
  50. @dataclass(**DATACLASS_KWARGS)
  51. class Rule(Generic[RuleFuncTv]):
  52. name: str
  53. enabled: bool
  54. fn: RuleFuncTv = field(repr=False)
  55. alt: list[str]
  56. class Ruler(Generic[RuleFuncTv]):
  57. def __init__(self) -> None:
  58. # List of added rules.
  59. self.__rules__: list[Rule[RuleFuncTv]] = []
  60. # Cached rule chains.
  61. # First level - chain name, '' for default.
  62. # Second level - diginal anchor for fast filtering by charcodes.
  63. self.__cache__: dict[str, list[RuleFuncTv]] | None = None
  64. def __find__(self, name: str) -> int:
  65. """Find rule index by name"""
  66. for i, rule in enumerate(self.__rules__):
  67. if rule.name == name:
  68. return i
  69. return -1
  70. def __compile__(self) -> None:
  71. """Build rules lookup cache"""
  72. chains = {""}
  73. # collect unique names
  74. for rule in self.__rules__:
  75. if not rule.enabled:
  76. continue
  77. for name in rule.alt:
  78. chains.add(name)
  79. self.__cache__ = {}
  80. for chain in chains:
  81. self.__cache__[chain] = []
  82. for rule in self.__rules__:
  83. if not rule.enabled:
  84. continue
  85. if chain and (chain not in rule.alt):
  86. continue
  87. self.__cache__[chain].append(rule.fn)
  88. def at(
  89. self, ruleName: str, fn: RuleFuncTv, options: RuleOptionsType | None = None
  90. ) -> None:
  91. """Replace rule by name with new function & options.
  92. :param ruleName: rule name to replace.
  93. :param fn: new rule function.
  94. :param options: new rule options (not mandatory).
  95. :raises: KeyError if name not found
  96. """
  97. index = self.__find__(ruleName)
  98. options = options or {}
  99. if index == -1:
  100. raise KeyError(f"Parser rule not found: {ruleName}")
  101. self.__rules__[index].fn = fn
  102. self.__rules__[index].alt = options.get("alt", [])
  103. self.__cache__ = None
  104. def before(
  105. self,
  106. beforeName: str,
  107. ruleName: str,
  108. fn: RuleFuncTv,
  109. options: RuleOptionsType | None = None,
  110. ) -> None:
  111. """Add new rule to chain before one with given name.
  112. :param beforeName: new rule will be added before this one.
  113. :param ruleName: new rule will be added before this one.
  114. :param fn: new rule function.
  115. :param options: new rule options (not mandatory).
  116. :raises: KeyError if name not found
  117. """
  118. index = self.__find__(beforeName)
  119. options = options or {}
  120. if index == -1:
  121. raise KeyError(f"Parser rule not found: {beforeName}")
  122. self.__rules__.insert(
  123. index, Rule[RuleFuncTv](ruleName, True, fn, options.get("alt", []))
  124. )
  125. self.__cache__ = None
  126. def after(
  127. self,
  128. afterName: str,
  129. ruleName: str,
  130. fn: RuleFuncTv,
  131. options: RuleOptionsType | None = None,
  132. ) -> None:
  133. """Add new rule to chain after one with given name.
  134. :param afterName: new rule will be added after this one.
  135. :param ruleName: new rule will be added after this one.
  136. :param fn: new rule function.
  137. :param options: new rule options (not mandatory).
  138. :raises: KeyError if name not found
  139. """
  140. index = self.__find__(afterName)
  141. options = options or {}
  142. if index == -1:
  143. raise KeyError(f"Parser rule not found: {afterName}")
  144. self.__rules__.insert(
  145. index + 1, Rule[RuleFuncTv](ruleName, True, fn, options.get("alt", []))
  146. )
  147. self.__cache__ = None
  148. def push(
  149. self, ruleName: str, fn: RuleFuncTv, options: RuleOptionsType | None = None
  150. ) -> None:
  151. """Push new rule to the end of chain.
  152. :param ruleName: new rule will be added to the end of chain.
  153. :param fn: new rule function.
  154. :param options: new rule options (not mandatory).
  155. """
  156. self.__rules__.append(
  157. Rule[RuleFuncTv](ruleName, True, fn, (options or {}).get("alt", []))
  158. )
  159. self.__cache__ = None
  160. def enable(
  161. self, names: str | Iterable[str], ignoreInvalid: bool = False
  162. ) -> list[str]:
  163. """Enable rules with given names.
  164. :param names: name or list of rule names to enable.
  165. :param ignoreInvalid: ignore errors when rule not found
  166. :raises: KeyError if name not found and not ignoreInvalid
  167. :return: list of found rule names
  168. """
  169. if isinstance(names, str):
  170. names = [names]
  171. result: list[str] = []
  172. for name in names:
  173. idx = self.__find__(name)
  174. if (idx < 0) and ignoreInvalid:
  175. continue
  176. if (idx < 0) and not ignoreInvalid:
  177. raise KeyError(f"Rules manager: invalid rule name {name}")
  178. self.__rules__[idx].enabled = True
  179. result.append(name)
  180. self.__cache__ = None
  181. return result
  182. def enableOnly(
  183. self, names: str | Iterable[str], ignoreInvalid: bool = False
  184. ) -> list[str]:
  185. """Enable rules with given names, and disable everything else.
  186. :param names: name or list of rule names to enable.
  187. :param ignoreInvalid: ignore errors when rule not found
  188. :raises: KeyError if name not found and not ignoreInvalid
  189. :return: list of found rule names
  190. """
  191. if isinstance(names, str):
  192. names = [names]
  193. for rule in self.__rules__:
  194. rule.enabled = False
  195. return self.enable(names, ignoreInvalid)
  196. def disable(
  197. self, names: str | Iterable[str], ignoreInvalid: bool = False
  198. ) -> list[str]:
  199. """Disable rules with given names.
  200. :param names: name or list of rule names to enable.
  201. :param ignoreInvalid: ignore errors when rule not found
  202. :raises: KeyError if name not found and not ignoreInvalid
  203. :return: list of found rule names
  204. """
  205. if isinstance(names, str):
  206. names = [names]
  207. result = []
  208. for name in names:
  209. idx = self.__find__(name)
  210. if (idx < 0) and ignoreInvalid:
  211. continue
  212. if (idx < 0) and not ignoreInvalid:
  213. raise KeyError(f"Rules manager: invalid rule name {name}")
  214. self.__rules__[idx].enabled = False
  215. result.append(name)
  216. self.__cache__ = None
  217. return result
  218. def getRules(self, chainName: str = "") -> list[RuleFuncTv]:
  219. """Return array of active functions (rules) for given chain name.
  220. It analyzes rules configuration, compiles caches if not exists and returns result.
  221. Default chain name is `''` (empty string). It can't be skipped.
  222. That's done intentionally, to keep signature monomorphic for high speed.
  223. """
  224. if self.__cache__ is None:
  225. self.__compile__()
  226. assert self.__cache__ is not None
  227. # Chain can be empty, if rules disabled. But we still have to return Array.
  228. return self.__cache__.get(chainName, []) or []
  229. def get_all_rules(self) -> list[str]:
  230. """Return all available rule names."""
  231. return [r.name for r in self.__rules__]
  232. def get_active_rules(self) -> list[str]:
  233. """Return the active rule names."""
  234. return [r.name for r in self.__rules__ if r.enabled]