rewriter.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. from __future__ import annotations
  2. from typing import Any
  3. from typing import Callable
  4. from typing import Iterator
  5. from typing import List
  6. from typing import Tuple
  7. from typing import Type
  8. from typing import TYPE_CHECKING
  9. from typing import Union
  10. from .. import util
  11. from ..operations import ops
  12. if TYPE_CHECKING:
  13. from ..operations.ops import AddColumnOp
  14. from ..operations.ops import AlterColumnOp
  15. from ..operations.ops import CreateTableOp
  16. from ..operations.ops import DowngradeOps
  17. from ..operations.ops import MigrateOperation
  18. from ..operations.ops import MigrationScript
  19. from ..operations.ops import ModifyTableOps
  20. from ..operations.ops import OpContainer
  21. from ..operations.ops import UpgradeOps
  22. from ..runtime.migration import MigrationContext
  23. from ..script.revision import _GetRevArg
  24. ProcessRevisionDirectiveFn = Callable[
  25. ["MigrationContext", "_GetRevArg", List["MigrationScript"]], None
  26. ]
  27. class Rewriter:
  28. """A helper object that allows easy 'rewriting' of ops streams.
  29. The :class:`.Rewriter` object is intended to be passed along
  30. to the
  31. :paramref:`.EnvironmentContext.configure.process_revision_directives`
  32. parameter in an ``env.py`` script. Once constructed, any number
  33. of "rewrites" functions can be associated with it, which will be given
  34. the opportunity to modify the structure without having to have explicit
  35. knowledge of the overall structure.
  36. The function is passed the :class:`.MigrationContext` object and
  37. ``revision`` tuple that are passed to the :paramref:`.Environment
  38. Context.configure.process_revision_directives` function normally,
  39. and the third argument is an individual directive of the type
  40. noted in the decorator. The function has the choice of returning
  41. a single op directive, which normally can be the directive that
  42. was actually passed, or a new directive to replace it, or a list
  43. of zero or more directives to replace it.
  44. .. seealso::
  45. :ref:`autogen_rewriter` - usage example
  46. """
  47. _traverse = util.Dispatcher()
  48. _chained: Tuple[Union[ProcessRevisionDirectiveFn, Rewriter], ...] = ()
  49. def __init__(self) -> None:
  50. self.dispatch = util.Dispatcher()
  51. def chain(
  52. self,
  53. other: Union[
  54. ProcessRevisionDirectiveFn,
  55. Rewriter,
  56. ],
  57. ) -> Rewriter:
  58. """Produce a "chain" of this :class:`.Rewriter` to another.
  59. This allows two or more rewriters to operate serially on a stream,
  60. e.g.::
  61. writer1 = autogenerate.Rewriter()
  62. writer2 = autogenerate.Rewriter()
  63. @writer1.rewrites(ops.AddColumnOp)
  64. def add_column_nullable(context, revision, op):
  65. op.column.nullable = True
  66. return op
  67. @writer2.rewrites(ops.AddColumnOp)
  68. def add_column_idx(context, revision, op):
  69. idx_op = ops.CreateIndexOp(
  70. "ixc", op.table_name, [op.column.name]
  71. )
  72. return [op, idx_op]
  73. writer = writer1.chain(writer2)
  74. :param other: a :class:`.Rewriter` instance
  75. :return: a new :class:`.Rewriter` that will run the operations
  76. of this writer, then the "other" writer, in succession.
  77. """
  78. wr = self.__class__.__new__(self.__class__)
  79. wr.__dict__.update(self.__dict__)
  80. wr._chained += (other,)
  81. return wr
  82. def rewrites(
  83. self,
  84. operator: Union[
  85. Type[AddColumnOp],
  86. Type[MigrateOperation],
  87. Type[AlterColumnOp],
  88. Type[CreateTableOp],
  89. Type[ModifyTableOps],
  90. ],
  91. ) -> Callable[..., Any]:
  92. """Register a function as rewriter for a given type.
  93. The function should receive three arguments, which are
  94. the :class:`.MigrationContext`, a ``revision`` tuple, and
  95. an op directive of the type indicated. E.g.::
  96. @writer1.rewrites(ops.AddColumnOp)
  97. def add_column_nullable(context, revision, op):
  98. op.column.nullable = True
  99. return op
  100. """
  101. return self.dispatch.dispatch_for(operator)
  102. def _rewrite(
  103. self,
  104. context: MigrationContext,
  105. revision: _GetRevArg,
  106. directive: MigrateOperation,
  107. ) -> Iterator[MigrateOperation]:
  108. try:
  109. _rewriter = self.dispatch.dispatch(directive)
  110. except ValueError:
  111. _rewriter = None
  112. yield directive
  113. else:
  114. if self in directive._mutations:
  115. yield directive
  116. else:
  117. for r_directive in util.to_list(
  118. _rewriter(context, revision, directive), []
  119. ):
  120. r_directive._mutations = r_directive._mutations.union(
  121. [self]
  122. )
  123. yield r_directive
  124. def __call__(
  125. self,
  126. context: MigrationContext,
  127. revision: _GetRevArg,
  128. directives: List[MigrationScript],
  129. ) -> None:
  130. self.process_revision_directives(context, revision, directives)
  131. for process_revision_directives in self._chained:
  132. process_revision_directives(context, revision, directives)
  133. @_traverse.dispatch_for(ops.MigrationScript)
  134. def _traverse_script(
  135. self,
  136. context: MigrationContext,
  137. revision: _GetRevArg,
  138. directive: MigrationScript,
  139. ) -> None:
  140. upgrade_ops_list: List[UpgradeOps] = []
  141. for upgrade_ops in directive.upgrade_ops_list:
  142. ret = self._traverse_for(context, revision, upgrade_ops)
  143. if len(ret) != 1:
  144. raise ValueError(
  145. "Can only return single object for UpgradeOps traverse"
  146. )
  147. upgrade_ops_list.append(ret[0])
  148. directive.upgrade_ops = upgrade_ops_list # type: ignore
  149. downgrade_ops_list: List[DowngradeOps] = []
  150. for downgrade_ops in directive.downgrade_ops_list:
  151. ret = self._traverse_for(context, revision, downgrade_ops)
  152. if len(ret) != 1:
  153. raise ValueError(
  154. "Can only return single object for DowngradeOps traverse"
  155. )
  156. downgrade_ops_list.append(ret[0])
  157. directive.downgrade_ops = downgrade_ops_list # type: ignore
  158. @_traverse.dispatch_for(ops.OpContainer)
  159. def _traverse_op_container(
  160. self,
  161. context: MigrationContext,
  162. revision: _GetRevArg,
  163. directive: OpContainer,
  164. ) -> None:
  165. self._traverse_list(context, revision, directive.ops)
  166. @_traverse.dispatch_for(ops.MigrateOperation)
  167. def _traverse_any_directive(
  168. self,
  169. context: MigrationContext,
  170. revision: _GetRevArg,
  171. directive: MigrateOperation,
  172. ) -> None:
  173. pass
  174. def _traverse_for(
  175. self,
  176. context: MigrationContext,
  177. revision: _GetRevArg,
  178. directive: MigrateOperation,
  179. ) -> Any:
  180. directives = list(self._rewrite(context, revision, directive))
  181. for directive in directives:
  182. traverser = self._traverse.dispatch(directive)
  183. traverser(self, context, revision, directive)
  184. return directives
  185. def _traverse_list(
  186. self,
  187. context: MigrationContext,
  188. revision: _GetRevArg,
  189. directives: Any,
  190. ) -> None:
  191. dest = []
  192. for directive in directives:
  193. dest.extend(self._traverse_for(context, revision, directive))
  194. directives[:] = dest
  195. def process_revision_directives(
  196. self,
  197. context: MigrationContext,
  198. revision: _GetRevArg,
  199. directives: List[MigrationScript],
  200. ) -> None:
  201. self._traverse_list(context, revision, directives)