api.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650
  1. from __future__ import annotations
  2. import contextlib
  3. from typing import Any
  4. from typing import Dict
  5. from typing import Iterator
  6. from typing import List
  7. from typing import Optional
  8. from typing import Sequence
  9. from typing import Set
  10. from typing import TYPE_CHECKING
  11. from typing import Union
  12. from sqlalchemy import inspect
  13. from . import compare
  14. from . import render
  15. from .. import util
  16. from ..operations import ops
  17. from ..util import sqla_compat
  18. """Provide the 'autogenerate' feature which can produce migration operations
  19. automatically."""
  20. if TYPE_CHECKING:
  21. from sqlalchemy.engine import Connection
  22. from sqlalchemy.engine import Dialect
  23. from sqlalchemy.engine import Inspector
  24. from sqlalchemy.sql.schema import MetaData
  25. from sqlalchemy.sql.schema import SchemaItem
  26. from sqlalchemy.sql.schema import Table
  27. from ..config import Config
  28. from ..operations.ops import DowngradeOps
  29. from ..operations.ops import MigrationScript
  30. from ..operations.ops import UpgradeOps
  31. from ..runtime.environment import NameFilterParentNames
  32. from ..runtime.environment import NameFilterType
  33. from ..runtime.environment import ProcessRevisionDirectiveFn
  34. from ..runtime.environment import RenderItemFn
  35. from ..runtime.migration import MigrationContext
  36. from ..script.base import Script
  37. from ..script.base import ScriptDirectory
  38. from ..script.revision import _GetRevArg
  39. def compare_metadata(context: MigrationContext, metadata: MetaData) -> Any:
  40. """Compare a database schema to that given in a
  41. :class:`~sqlalchemy.schema.MetaData` instance.
  42. The database connection is presented in the context
  43. of a :class:`.MigrationContext` object, which
  44. provides database connectivity as well as optional
  45. comparison functions to use for datatypes and
  46. server defaults - see the "autogenerate" arguments
  47. at :meth:`.EnvironmentContext.configure`
  48. for details on these.
  49. The return format is a list of "diff" directives,
  50. each representing individual differences::
  51. from alembic.migration import MigrationContext
  52. from alembic.autogenerate import compare_metadata
  53. from sqlalchemy import (
  54. create_engine,
  55. MetaData,
  56. Column,
  57. Integer,
  58. String,
  59. Table,
  60. text,
  61. )
  62. import pprint
  63. engine = create_engine("sqlite://")
  64. with engine.begin() as conn:
  65. conn.execute(
  66. text(
  67. '''
  68. create table foo (
  69. id integer not null primary key,
  70. old_data varchar,
  71. x integer
  72. )
  73. '''
  74. )
  75. )
  76. conn.execute(text("create table bar (data varchar)"))
  77. metadata = MetaData()
  78. Table(
  79. "foo",
  80. metadata,
  81. Column("id", Integer, primary_key=True),
  82. Column("data", Integer),
  83. Column("x", Integer, nullable=False),
  84. )
  85. Table("bat", metadata, Column("info", String))
  86. mc = MigrationContext.configure(engine.connect())
  87. diff = compare_metadata(mc, metadata)
  88. pprint.pprint(diff, indent=2, width=20)
  89. Output::
  90. [
  91. (
  92. "add_table",
  93. Table(
  94. "bat",
  95. MetaData(),
  96. Column("info", String(), table=<bat>),
  97. schema=None,
  98. ),
  99. ),
  100. (
  101. "remove_table",
  102. Table(
  103. "bar",
  104. MetaData(),
  105. Column("data", VARCHAR(), table=<bar>),
  106. schema=None,
  107. ),
  108. ),
  109. (
  110. "add_column",
  111. None,
  112. "foo",
  113. Column("data", Integer(), table=<foo>),
  114. ),
  115. [
  116. (
  117. "modify_nullable",
  118. None,
  119. "foo",
  120. "x",
  121. {
  122. "existing_comment": None,
  123. "existing_server_default": False,
  124. "existing_type": INTEGER(),
  125. },
  126. True,
  127. False,
  128. )
  129. ],
  130. (
  131. "remove_column",
  132. None,
  133. "foo",
  134. Column("old_data", VARCHAR(), table=<foo>),
  135. ),
  136. ]
  137. :param context: a :class:`.MigrationContext`
  138. instance.
  139. :param metadata: a :class:`~sqlalchemy.schema.MetaData`
  140. instance.
  141. .. seealso::
  142. :func:`.produce_migrations` - produces a :class:`.MigrationScript`
  143. structure based on metadata comparison.
  144. """
  145. migration_script = produce_migrations(context, metadata)
  146. assert migration_script.upgrade_ops is not None
  147. return migration_script.upgrade_ops.as_diffs()
  148. def produce_migrations(
  149. context: MigrationContext, metadata: MetaData
  150. ) -> MigrationScript:
  151. """Produce a :class:`.MigrationScript` structure based on schema
  152. comparison.
  153. This function does essentially what :func:`.compare_metadata` does,
  154. but then runs the resulting list of diffs to produce the full
  155. :class:`.MigrationScript` object. For an example of what this looks like,
  156. see the example in :ref:`customizing_revision`.
  157. .. seealso::
  158. :func:`.compare_metadata` - returns more fundamental "diff"
  159. data from comparing a schema.
  160. """
  161. autogen_context = AutogenContext(context, metadata=metadata)
  162. migration_script = ops.MigrationScript(
  163. rev_id=None,
  164. upgrade_ops=ops.UpgradeOps([]),
  165. downgrade_ops=ops.DowngradeOps([]),
  166. )
  167. compare._populate_migration_script(autogen_context, migration_script)
  168. return migration_script
  169. def render_python_code(
  170. up_or_down_op: Union[UpgradeOps, DowngradeOps],
  171. sqlalchemy_module_prefix: str = "sa.",
  172. alembic_module_prefix: str = "op.",
  173. render_as_batch: bool = False,
  174. imports: Sequence[str] = (),
  175. render_item: Optional[RenderItemFn] = None,
  176. migration_context: Optional[MigrationContext] = None,
  177. user_module_prefix: Optional[str] = None,
  178. ) -> str:
  179. """Render Python code given an :class:`.UpgradeOps` or
  180. :class:`.DowngradeOps` object.
  181. This is a convenience function that can be used to test the
  182. autogenerate output of a user-defined :class:`.MigrationScript` structure.
  183. :param up_or_down_op: :class:`.UpgradeOps` or :class:`.DowngradeOps` object
  184. :param sqlalchemy_module_prefix: module prefix for SQLAlchemy objects
  185. :param alembic_module_prefix: module prefix for Alembic constructs
  186. :param render_as_batch: use "batch operations" style for rendering
  187. :param imports: sequence of import symbols to add
  188. :param render_item: callable to render items
  189. :param migration_context: optional :class:`.MigrationContext`
  190. :param user_module_prefix: optional string prefix for user-defined types
  191. .. versionadded:: 1.11.0
  192. """
  193. opts = {
  194. "sqlalchemy_module_prefix": sqlalchemy_module_prefix,
  195. "alembic_module_prefix": alembic_module_prefix,
  196. "render_item": render_item,
  197. "render_as_batch": render_as_batch,
  198. "user_module_prefix": user_module_prefix,
  199. }
  200. if migration_context is None:
  201. from ..runtime.migration import MigrationContext
  202. from sqlalchemy.engine.default import DefaultDialect
  203. migration_context = MigrationContext.configure(
  204. dialect=DefaultDialect()
  205. )
  206. autogen_context = AutogenContext(migration_context, opts=opts)
  207. autogen_context.imports = set(imports)
  208. return render._indent(
  209. render._render_cmd_body(up_or_down_op, autogen_context)
  210. )
  211. def _render_migration_diffs(
  212. context: MigrationContext, template_args: Dict[Any, Any]
  213. ) -> None:
  214. """legacy, used by test_autogen_composition at the moment"""
  215. autogen_context = AutogenContext(context)
  216. upgrade_ops = ops.UpgradeOps([])
  217. compare._produce_net_changes(autogen_context, upgrade_ops)
  218. migration_script = ops.MigrationScript(
  219. rev_id=None,
  220. upgrade_ops=upgrade_ops,
  221. downgrade_ops=upgrade_ops.reverse(),
  222. )
  223. render._render_python_into_templatevars(
  224. autogen_context, migration_script, template_args
  225. )
  226. class AutogenContext:
  227. """Maintains configuration and state that's specific to an
  228. autogenerate operation."""
  229. metadata: Union[MetaData, Sequence[MetaData], None] = None
  230. """The :class:`~sqlalchemy.schema.MetaData` object
  231. representing the destination.
  232. This object is the one that is passed within ``env.py``
  233. to the :paramref:`.EnvironmentContext.configure.target_metadata`
  234. parameter. It represents the structure of :class:`.Table` and other
  235. objects as stated in the current database model, and represents the
  236. destination structure for the database being examined.
  237. While the :class:`~sqlalchemy.schema.MetaData` object is primarily
  238. known as a collection of :class:`~sqlalchemy.schema.Table` objects,
  239. it also has an :attr:`~sqlalchemy.schema.MetaData.info` dictionary
  240. that may be used by end-user schemes to store additional schema-level
  241. objects that are to be compared in custom autogeneration schemes.
  242. """
  243. connection: Optional[Connection] = None
  244. """The :class:`~sqlalchemy.engine.base.Connection` object currently
  245. connected to the database backend being compared.
  246. This is obtained from the :attr:`.MigrationContext.bind` and is
  247. ultimately set up in the ``env.py`` script.
  248. """
  249. dialect: Optional[Dialect] = None
  250. """The :class:`~sqlalchemy.engine.Dialect` object currently in use.
  251. This is normally obtained from the
  252. :attr:`~sqlalchemy.engine.base.Connection.dialect` attribute.
  253. """
  254. imports: Set[str] = None # type: ignore[assignment]
  255. """A ``set()`` which contains string Python import directives.
  256. The directives are to be rendered into the ``${imports}`` section
  257. of a script template. The set is normally empty and can be modified
  258. within hooks such as the
  259. :paramref:`.EnvironmentContext.configure.render_item` hook.
  260. .. seealso::
  261. :ref:`autogen_render_types`
  262. """
  263. migration_context: MigrationContext = None # type: ignore[assignment]
  264. """The :class:`.MigrationContext` established by the ``env.py`` script."""
  265. def __init__(
  266. self,
  267. migration_context: MigrationContext,
  268. metadata: Union[MetaData, Sequence[MetaData], None] = None,
  269. opts: Optional[Dict[str, Any]] = None,
  270. autogenerate: bool = True,
  271. ) -> None:
  272. if (
  273. autogenerate
  274. and migration_context is not None
  275. and migration_context.as_sql
  276. ):
  277. raise util.CommandError(
  278. "autogenerate can't use as_sql=True as it prevents querying "
  279. "the database for schema information"
  280. )
  281. if opts is None:
  282. opts = migration_context.opts
  283. self.metadata = metadata = (
  284. opts.get("target_metadata", None) if metadata is None else metadata
  285. )
  286. if (
  287. autogenerate
  288. and metadata is None
  289. and migration_context is not None
  290. and migration_context.script is not None
  291. ):
  292. raise util.CommandError(
  293. "Can't proceed with --autogenerate option; environment "
  294. "script %s does not provide "
  295. "a MetaData object or sequence of objects to the context."
  296. % (migration_context.script.env_py_location)
  297. )
  298. include_object = opts.get("include_object", None)
  299. include_name = opts.get("include_name", None)
  300. object_filters = []
  301. name_filters = []
  302. if include_object:
  303. object_filters.append(include_object)
  304. if include_name:
  305. name_filters.append(include_name)
  306. self._object_filters = object_filters
  307. self._name_filters = name_filters
  308. self.migration_context = migration_context
  309. if self.migration_context is not None:
  310. self.connection = self.migration_context.bind
  311. self.dialect = self.migration_context.dialect
  312. self.imports = set()
  313. self.opts: Dict[str, Any] = opts
  314. self._has_batch: bool = False
  315. @util.memoized_property
  316. def inspector(self) -> Inspector:
  317. if self.connection is None:
  318. raise TypeError(
  319. "can't return inspector as this "
  320. "AutogenContext has no database connection"
  321. )
  322. return inspect(self.connection)
  323. @contextlib.contextmanager
  324. def _within_batch(self) -> Iterator[None]:
  325. self._has_batch = True
  326. yield
  327. self._has_batch = False
  328. def run_name_filters(
  329. self,
  330. name: Optional[str],
  331. type_: NameFilterType,
  332. parent_names: NameFilterParentNames,
  333. ) -> bool:
  334. """Run the context's name filters and return True if the targets
  335. should be part of the autogenerate operation.
  336. This method should be run for every kind of name encountered within the
  337. reflection side of an autogenerate operation, giving the environment
  338. the chance to filter what names should be reflected as database
  339. objects. The filters here are produced directly via the
  340. :paramref:`.EnvironmentContext.configure.include_name` parameter.
  341. """
  342. if "schema_name" in parent_names:
  343. if type_ == "table":
  344. table_name = name
  345. else:
  346. table_name = parent_names.get("table_name", None)
  347. if table_name:
  348. schema_name = parent_names["schema_name"]
  349. if schema_name:
  350. parent_names["schema_qualified_table_name"] = "%s.%s" % (
  351. schema_name,
  352. table_name,
  353. )
  354. else:
  355. parent_names["schema_qualified_table_name"] = table_name
  356. for fn in self._name_filters:
  357. if not fn(name, type_, parent_names):
  358. return False
  359. else:
  360. return True
  361. def run_object_filters(
  362. self,
  363. object_: SchemaItem,
  364. name: sqla_compat._ConstraintName,
  365. type_: NameFilterType,
  366. reflected: bool,
  367. compare_to: Optional[SchemaItem],
  368. ) -> bool:
  369. """Run the context's object filters and return True if the targets
  370. should be part of the autogenerate operation.
  371. This method should be run for every kind of object encountered within
  372. an autogenerate operation, giving the environment the chance
  373. to filter what objects should be included in the comparison.
  374. The filters here are produced directly via the
  375. :paramref:`.EnvironmentContext.configure.include_object` parameter.
  376. """
  377. for fn in self._object_filters:
  378. if not fn(object_, name, type_, reflected, compare_to):
  379. return False
  380. else:
  381. return True
  382. run_filters = run_object_filters
  383. @util.memoized_property
  384. def sorted_tables(self) -> List[Table]:
  385. """Return an aggregate of the :attr:`.MetaData.sorted_tables`
  386. collection(s).
  387. For a sequence of :class:`.MetaData` objects, this
  388. concatenates the :attr:`.MetaData.sorted_tables` collection
  389. for each individual :class:`.MetaData` in the order of the
  390. sequence. It does **not** collate the sorted tables collections.
  391. """
  392. result = []
  393. for m in util.to_list(self.metadata):
  394. result.extend(m.sorted_tables)
  395. return result
  396. @util.memoized_property
  397. def table_key_to_table(self) -> Dict[str, Table]:
  398. """Return an aggregate of the :attr:`.MetaData.tables` dictionaries.
  399. The :attr:`.MetaData.tables` collection is a dictionary of table key
  400. to :class:`.Table`; this method aggregates the dictionary across
  401. multiple :class:`.MetaData` objects into one dictionary.
  402. Duplicate table keys are **not** supported; if two :class:`.MetaData`
  403. objects contain the same table key, an exception is raised.
  404. """
  405. result: Dict[str, Table] = {}
  406. for m in util.to_list(self.metadata):
  407. intersect = set(result).intersection(set(m.tables))
  408. if intersect:
  409. raise ValueError(
  410. "Duplicate table keys across multiple "
  411. "MetaData objects: %s"
  412. % (", ".join('"%s"' % key for key in sorted(intersect)))
  413. )
  414. result.update(m.tables)
  415. return result
  416. class RevisionContext:
  417. """Maintains configuration and state that's specific to a revision
  418. file generation operation."""
  419. generated_revisions: List[MigrationScript]
  420. process_revision_directives: Optional[ProcessRevisionDirectiveFn]
  421. def __init__(
  422. self,
  423. config: Config,
  424. script_directory: ScriptDirectory,
  425. command_args: Dict[str, Any],
  426. process_revision_directives: Optional[
  427. ProcessRevisionDirectiveFn
  428. ] = None,
  429. ) -> None:
  430. self.config = config
  431. self.script_directory = script_directory
  432. self.command_args = command_args
  433. self.process_revision_directives = process_revision_directives
  434. self.template_args = {
  435. "config": config # Let templates use config for
  436. # e.g. multiple databases
  437. }
  438. self.generated_revisions = [self._default_revision()]
  439. def _to_script(
  440. self, migration_script: MigrationScript
  441. ) -> Optional[Script]:
  442. template_args: Dict[str, Any] = self.template_args.copy()
  443. if getattr(migration_script, "_needs_render", False):
  444. autogen_context = self._last_autogen_context
  445. # clear out existing imports if we are doing multiple
  446. # renders
  447. autogen_context.imports = set()
  448. if migration_script.imports:
  449. autogen_context.imports.update(migration_script.imports)
  450. render._render_python_into_templatevars(
  451. autogen_context, migration_script, template_args
  452. )
  453. assert migration_script.rev_id is not None
  454. return self.script_directory.generate_revision(
  455. migration_script.rev_id,
  456. migration_script.message,
  457. refresh=True,
  458. head=migration_script.head,
  459. splice=migration_script.splice,
  460. branch_labels=migration_script.branch_label,
  461. version_path=migration_script.version_path,
  462. depends_on=migration_script.depends_on,
  463. **template_args,
  464. )
  465. def run_autogenerate(
  466. self, rev: _GetRevArg, migration_context: MigrationContext
  467. ) -> None:
  468. self._run_environment(rev, migration_context, True)
  469. def run_no_autogenerate(
  470. self, rev: _GetRevArg, migration_context: MigrationContext
  471. ) -> None:
  472. self._run_environment(rev, migration_context, False)
  473. def _run_environment(
  474. self,
  475. rev: _GetRevArg,
  476. migration_context: MigrationContext,
  477. autogenerate: bool,
  478. ) -> None:
  479. if autogenerate:
  480. if self.command_args["sql"]:
  481. raise util.CommandError(
  482. "Using --sql with --autogenerate does not make any sense"
  483. )
  484. if set(self.script_directory.get_revisions(rev)) != set(
  485. self.script_directory.get_revisions("heads")
  486. ):
  487. raise util.CommandError("Target database is not up to date.")
  488. upgrade_token = migration_context.opts["upgrade_token"]
  489. downgrade_token = migration_context.opts["downgrade_token"]
  490. migration_script = self.generated_revisions[-1]
  491. if not getattr(migration_script, "_needs_render", False):
  492. migration_script.upgrade_ops_list[-1].upgrade_token = upgrade_token
  493. migration_script.downgrade_ops_list[-1].downgrade_token = (
  494. downgrade_token
  495. )
  496. migration_script._needs_render = True
  497. else:
  498. migration_script._upgrade_ops.append(
  499. ops.UpgradeOps([], upgrade_token=upgrade_token)
  500. )
  501. migration_script._downgrade_ops.append(
  502. ops.DowngradeOps([], downgrade_token=downgrade_token)
  503. )
  504. autogen_context = AutogenContext(
  505. migration_context, autogenerate=autogenerate
  506. )
  507. self._last_autogen_context: AutogenContext = autogen_context
  508. if autogenerate:
  509. compare._populate_migration_script(
  510. autogen_context, migration_script
  511. )
  512. if self.process_revision_directives:
  513. self.process_revision_directives(
  514. migration_context, rev, self.generated_revisions
  515. )
  516. hook = migration_context.opts["process_revision_directives"]
  517. if hook:
  518. hook(migration_context, rev, self.generated_revisions)
  519. for migration_script in self.generated_revisions:
  520. migration_script._needs_render = True
  521. def _default_revision(self) -> MigrationScript:
  522. command_args: Dict[str, Any] = self.command_args
  523. op = ops.MigrationScript(
  524. rev_id=command_args["rev_id"] or util.rev_id(),
  525. message=command_args["message"],
  526. upgrade_ops=ops.UpgradeOps([]),
  527. downgrade_ops=ops.DowngradeOps([]),
  528. head=command_args["head"],
  529. splice=command_args["splice"],
  530. branch_label=command_args["branch_label"],
  531. version_path=command_args["version_path"],
  532. depends_on=command_args["depends_on"],
  533. )
  534. return op
  535. def generate_scripts(self) -> Iterator[Optional[Script]]:
  536. for generated_revision in self.generated_revisions:
  537. yield self._to_script(generated_revision)