| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650 | from __future__ import annotationsimport contextlibfrom typing import Anyfrom typing import Dictfrom typing import Iteratorfrom typing import Listfrom typing import Optionalfrom typing import Sequencefrom typing import Setfrom typing import TYPE_CHECKINGfrom typing import Unionfrom sqlalchemy import inspectfrom . import comparefrom . import renderfrom .. import utilfrom ..operations import opsfrom ..util import sqla_compat"""Provide the 'autogenerate' feature which can produce migration operationsautomatically."""if TYPE_CHECKING:    from sqlalchemy.engine import Connection    from sqlalchemy.engine import Dialect    from sqlalchemy.engine import Inspector    from sqlalchemy.sql.schema import MetaData    from sqlalchemy.sql.schema import SchemaItem    from sqlalchemy.sql.schema import Table    from ..config import Config    from ..operations.ops import DowngradeOps    from ..operations.ops import MigrationScript    from ..operations.ops import UpgradeOps    from ..runtime.environment import NameFilterParentNames    from ..runtime.environment import NameFilterType    from ..runtime.environment import ProcessRevisionDirectiveFn    from ..runtime.environment import RenderItemFn    from ..runtime.migration import MigrationContext    from ..script.base import Script    from ..script.base import ScriptDirectory    from ..script.revision import _GetRevArgdef compare_metadata(context: MigrationContext, metadata: MetaData) -> Any:    """Compare a database schema to that given in a    :class:`~sqlalchemy.schema.MetaData` instance.    The database connection is presented in the context    of a :class:`.MigrationContext` object, which    provides database connectivity as well as optional    comparison functions to use for datatypes and    server defaults - see the "autogenerate" arguments    at :meth:`.EnvironmentContext.configure`    for details on these.    The return format is a list of "diff" directives,    each representing individual differences::        from alembic.migration import MigrationContext        from alembic.autogenerate import compare_metadata        from sqlalchemy import (            create_engine,            MetaData,            Column,            Integer,            String,            Table,            text,        )        import pprint        engine = create_engine("sqlite://")        with engine.begin() as conn:            conn.execute(                text(                    '''                        create table foo (                            id integer not null primary key,                            old_data varchar,                            x integer                        )                    '''                )            )            conn.execute(text("create table bar (data varchar)"))        metadata = MetaData()        Table(            "foo",            metadata,            Column("id", Integer, primary_key=True),            Column("data", Integer),            Column("x", Integer, nullable=False),        )        Table("bat", metadata, Column("info", String))        mc = MigrationContext.configure(engine.connect())        diff = compare_metadata(mc, metadata)        pprint.pprint(diff, indent=2, width=20)    Output::        [            (                "add_table",                Table(                    "bat",                    MetaData(),                    Column("info", String(), table=<bat>),                    schema=None,                ),            ),            (                "remove_table",                Table(                    "bar",                    MetaData(),                    Column("data", VARCHAR(), table=<bar>),                    schema=None,                ),            ),            (                "add_column",                None,                "foo",                Column("data", Integer(), table=<foo>),            ),            [                (                    "modify_nullable",                    None,                    "foo",                    "x",                    {                        "existing_comment": None,                        "existing_server_default": False,                        "existing_type": INTEGER(),                    },                    True,                    False,                )            ],            (                "remove_column",                None,                "foo",                Column("old_data", VARCHAR(), table=<foo>),            ),        ]    :param context: a :class:`.MigrationContext`     instance.    :param metadata: a :class:`~sqlalchemy.schema.MetaData`     instance.    .. seealso::        :func:`.produce_migrations` - produces a :class:`.MigrationScript`        structure based on metadata comparison.    """    migration_script = produce_migrations(context, metadata)    assert migration_script.upgrade_ops is not None    return migration_script.upgrade_ops.as_diffs()def produce_migrations(    context: MigrationContext, metadata: MetaData) -> MigrationScript:    """Produce a :class:`.MigrationScript` structure based on schema    comparison.    This function does essentially what :func:`.compare_metadata` does,    but then runs the resulting list of diffs to produce the full    :class:`.MigrationScript` object.   For an example of what this looks like,    see the example in :ref:`customizing_revision`.    .. seealso::        :func:`.compare_metadata` - returns more fundamental "diff"        data from comparing a schema.    """    autogen_context = AutogenContext(context, metadata=metadata)    migration_script = ops.MigrationScript(        rev_id=None,        upgrade_ops=ops.UpgradeOps([]),        downgrade_ops=ops.DowngradeOps([]),    )    compare._populate_migration_script(autogen_context, migration_script)    return migration_scriptdef render_python_code(    up_or_down_op: Union[UpgradeOps, DowngradeOps],    sqlalchemy_module_prefix: str = "sa.",    alembic_module_prefix: str = "op.",    render_as_batch: bool = False,    imports: Sequence[str] = (),    render_item: Optional[RenderItemFn] = None,    migration_context: Optional[MigrationContext] = None,    user_module_prefix: Optional[str] = None,) -> str:    """Render Python code given an :class:`.UpgradeOps` or    :class:`.DowngradeOps` object.    This is a convenience function that can be used to test the    autogenerate output of a user-defined :class:`.MigrationScript` structure.    :param up_or_down_op: :class:`.UpgradeOps` or :class:`.DowngradeOps` object    :param sqlalchemy_module_prefix: module prefix for SQLAlchemy objects    :param alembic_module_prefix: module prefix for Alembic constructs    :param render_as_batch: use "batch operations" style for rendering    :param imports: sequence of import symbols to add    :param render_item: callable to render items    :param migration_context: optional :class:`.MigrationContext`    :param user_module_prefix: optional string prefix for user-defined types     .. versionadded:: 1.11.0    """    opts = {        "sqlalchemy_module_prefix": sqlalchemy_module_prefix,        "alembic_module_prefix": alembic_module_prefix,        "render_item": render_item,        "render_as_batch": render_as_batch,        "user_module_prefix": user_module_prefix,    }    if migration_context is None:        from ..runtime.migration import MigrationContext        from sqlalchemy.engine.default import DefaultDialect        migration_context = MigrationContext.configure(            dialect=DefaultDialect()        )    autogen_context = AutogenContext(migration_context, opts=opts)    autogen_context.imports = set(imports)    return render._indent(        render._render_cmd_body(up_or_down_op, autogen_context)    )def _render_migration_diffs(    context: MigrationContext, template_args: Dict[Any, Any]) -> None:    """legacy, used by test_autogen_composition at the moment"""    autogen_context = AutogenContext(context)    upgrade_ops = ops.UpgradeOps([])    compare._produce_net_changes(autogen_context, upgrade_ops)    migration_script = ops.MigrationScript(        rev_id=None,        upgrade_ops=upgrade_ops,        downgrade_ops=upgrade_ops.reverse(),    )    render._render_python_into_templatevars(        autogen_context, migration_script, template_args    )class AutogenContext:    """Maintains configuration and state that's specific to an    autogenerate operation."""    metadata: Union[MetaData, Sequence[MetaData], None] = None    """The :class:`~sqlalchemy.schema.MetaData` object    representing the destination.    This object is the one that is passed within ``env.py``    to the :paramref:`.EnvironmentContext.configure.target_metadata`    parameter.  It represents the structure of :class:`.Table` and other    objects as stated in the current database model, and represents the    destination structure for the database being examined.    While the :class:`~sqlalchemy.schema.MetaData` object is primarily    known as a collection of :class:`~sqlalchemy.schema.Table` objects,    it also has an :attr:`~sqlalchemy.schema.MetaData.info` dictionary    that may be used by end-user schemes to store additional schema-level    objects that are to be compared in custom autogeneration schemes.    """    connection: Optional[Connection] = None    """The :class:`~sqlalchemy.engine.base.Connection` object currently    connected to the database backend being compared.    This is obtained from the :attr:`.MigrationContext.bind` and is    ultimately set up in the ``env.py`` script.    """    dialect: Optional[Dialect] = None    """The :class:`~sqlalchemy.engine.Dialect` object currently in use.    This is normally obtained from the    :attr:`~sqlalchemy.engine.base.Connection.dialect` attribute.    """    imports: Set[str] = None  # type: ignore[assignment]    """A ``set()`` which contains string Python import directives.    The directives are to be rendered into the ``${imports}`` section    of a script template.  The set is normally empty and can be modified    within hooks such as the    :paramref:`.EnvironmentContext.configure.render_item` hook.    .. seealso::        :ref:`autogen_render_types`    """    migration_context: MigrationContext = None  # type: ignore[assignment]    """The :class:`.MigrationContext` established by the ``env.py`` script."""    def __init__(        self,        migration_context: MigrationContext,        metadata: Union[MetaData, Sequence[MetaData], None] = None,        opts: Optional[Dict[str, Any]] = None,        autogenerate: bool = True,    ) -> None:        if (            autogenerate            and migration_context is not None            and migration_context.as_sql        ):            raise util.CommandError(                "autogenerate can't use as_sql=True as it prevents querying "                "the database for schema information"            )        if opts is None:            opts = migration_context.opts        self.metadata = metadata = (            opts.get("target_metadata", None) if metadata is None else metadata        )        if (            autogenerate            and metadata is None            and migration_context is not None            and migration_context.script is not None        ):            raise util.CommandError(                "Can't proceed with --autogenerate option; environment "                "script %s does not provide "                "a MetaData object or sequence of objects to the context."                % (migration_context.script.env_py_location)            )        include_object = opts.get("include_object", None)        include_name = opts.get("include_name", None)        object_filters = []        name_filters = []        if include_object:            object_filters.append(include_object)        if include_name:            name_filters.append(include_name)        self._object_filters = object_filters        self._name_filters = name_filters        self.migration_context = migration_context        if self.migration_context is not None:            self.connection = self.migration_context.bind            self.dialect = self.migration_context.dialect        self.imports = set()        self.opts: Dict[str, Any] = opts        self._has_batch: bool = False    @util.memoized_property    def inspector(self) -> Inspector:        if self.connection is None:            raise TypeError(                "can't return inspector as this "                "AutogenContext has no database connection"            )        return inspect(self.connection)    @contextlib.contextmanager    def _within_batch(self) -> Iterator[None]:        self._has_batch = True        yield        self._has_batch = False    def run_name_filters(        self,        name: Optional[str],        type_: NameFilterType,        parent_names: NameFilterParentNames,    ) -> bool:        """Run the context's name filters and return True if the targets        should be part of the autogenerate operation.        This method should be run for every kind of name encountered within the        reflection side of an autogenerate operation, giving the environment        the chance to filter what names should be reflected as database        objects.  The filters here are produced directly via the        :paramref:`.EnvironmentContext.configure.include_name` parameter.        """        if "schema_name" in parent_names:            if type_ == "table":                table_name = name            else:                table_name = parent_names.get("table_name", None)            if table_name:                schema_name = parent_names["schema_name"]                if schema_name:                    parent_names["schema_qualified_table_name"] = "%s.%s" % (                        schema_name,                        table_name,                    )                else:                    parent_names["schema_qualified_table_name"] = table_name        for fn in self._name_filters:            if not fn(name, type_, parent_names):                return False        else:            return True    def run_object_filters(        self,        object_: SchemaItem,        name: sqla_compat._ConstraintName,        type_: NameFilterType,        reflected: bool,        compare_to: Optional[SchemaItem],    ) -> bool:        """Run the context's object filters and return True if the targets        should be part of the autogenerate operation.        This method should be run for every kind of object encountered within        an autogenerate operation, giving the environment the chance        to filter what objects should be included in the comparison.        The filters here are produced directly via the        :paramref:`.EnvironmentContext.configure.include_object` parameter.        """        for fn in self._object_filters:            if not fn(object_, name, type_, reflected, compare_to):                return False        else:            return True    run_filters = run_object_filters    @util.memoized_property    def sorted_tables(self) -> List[Table]:        """Return an aggregate of the :attr:`.MetaData.sorted_tables`        collection(s).        For a sequence of :class:`.MetaData` objects, this        concatenates the :attr:`.MetaData.sorted_tables` collection        for each individual :class:`.MetaData`  in the order of the        sequence.  It does **not** collate the sorted tables collections.        """        result = []        for m in util.to_list(self.metadata):            result.extend(m.sorted_tables)        return result    @util.memoized_property    def table_key_to_table(self) -> Dict[str, Table]:        """Return an aggregate  of the :attr:`.MetaData.tables` dictionaries.        The :attr:`.MetaData.tables` collection is a dictionary of table key        to :class:`.Table`; this method aggregates the dictionary across        multiple :class:`.MetaData` objects into one dictionary.        Duplicate table keys are **not** supported; if two :class:`.MetaData`        objects contain the same table key, an exception is raised.        """        result: Dict[str, Table] = {}        for m in util.to_list(self.metadata):            intersect = set(result).intersection(set(m.tables))            if intersect:                raise ValueError(                    "Duplicate table keys across multiple "                    "MetaData objects: %s"                    % (", ".join('"%s"' % key for key in sorted(intersect)))                )            result.update(m.tables)        return resultclass RevisionContext:    """Maintains configuration and state that's specific to a revision    file generation operation."""    generated_revisions: List[MigrationScript]    process_revision_directives: Optional[ProcessRevisionDirectiveFn]    def __init__(        self,        config: Config,        script_directory: ScriptDirectory,        command_args: Dict[str, Any],        process_revision_directives: Optional[            ProcessRevisionDirectiveFn        ] = None,    ) -> None:        self.config = config        self.script_directory = script_directory        self.command_args = command_args        self.process_revision_directives = process_revision_directives        self.template_args = {            "config": config  # Let templates use config for            # e.g. multiple databases        }        self.generated_revisions = [self._default_revision()]    def _to_script(        self, migration_script: MigrationScript    ) -> Optional[Script]:        template_args: Dict[str, Any] = self.template_args.copy()        if getattr(migration_script, "_needs_render", False):            autogen_context = self._last_autogen_context            # clear out existing imports if we are doing multiple            # renders            autogen_context.imports = set()            if migration_script.imports:                autogen_context.imports.update(migration_script.imports)            render._render_python_into_templatevars(                autogen_context, migration_script, template_args            )        assert migration_script.rev_id is not None        return self.script_directory.generate_revision(            migration_script.rev_id,            migration_script.message,            refresh=True,            head=migration_script.head,            splice=migration_script.splice,            branch_labels=migration_script.branch_label,            version_path=migration_script.version_path,            depends_on=migration_script.depends_on,            **template_args,        )    def run_autogenerate(        self, rev: _GetRevArg, migration_context: MigrationContext    ) -> None:        self._run_environment(rev, migration_context, True)    def run_no_autogenerate(        self, rev: _GetRevArg, migration_context: MigrationContext    ) -> None:        self._run_environment(rev, migration_context, False)    def _run_environment(        self,        rev: _GetRevArg,        migration_context: MigrationContext,        autogenerate: bool,    ) -> None:        if autogenerate:            if self.command_args["sql"]:                raise util.CommandError(                    "Using --sql with --autogenerate does not make any sense"                )            if set(self.script_directory.get_revisions(rev)) != set(                self.script_directory.get_revisions("heads")            ):                raise util.CommandError("Target database is not up to date.")        upgrade_token = migration_context.opts["upgrade_token"]        downgrade_token = migration_context.opts["downgrade_token"]        migration_script = self.generated_revisions[-1]        if not getattr(migration_script, "_needs_render", False):            migration_script.upgrade_ops_list[-1].upgrade_token = upgrade_token            migration_script.downgrade_ops_list[-1].downgrade_token = (                downgrade_token            )            migration_script._needs_render = True        else:            migration_script._upgrade_ops.append(                ops.UpgradeOps([], upgrade_token=upgrade_token)            )            migration_script._downgrade_ops.append(                ops.DowngradeOps([], downgrade_token=downgrade_token)            )        autogen_context = AutogenContext(            migration_context, autogenerate=autogenerate        )        self._last_autogen_context: AutogenContext = autogen_context        if autogenerate:            compare._populate_migration_script(                autogen_context, migration_script            )        if self.process_revision_directives:            self.process_revision_directives(                migration_context, rev, self.generated_revisions            )        hook = migration_context.opts["process_revision_directives"]        if hook:            hook(migration_context, rev, self.generated_revisions)        for migration_script in self.generated_revisions:            migration_script._needs_render = True    def _default_revision(self) -> MigrationScript:        command_args: Dict[str, Any] = self.command_args        op = ops.MigrationScript(            rev_id=command_args["rev_id"] or util.rev_id(),            message=command_args["message"],            upgrade_ops=ops.UpgradeOps([]),            downgrade_ops=ops.DowngradeOps([]),            head=command_args["head"],            splice=command_args["splice"],            branch_label=command_args["branch_label"],            version_path=command_args["version_path"],            depends_on=command_args["depends_on"],        )        return op    def generate_scripts(self) -> Iterator[Optional[Script]]:        for generated_revision in self.generated_revisions:            yield self._to_script(generated_revision)
 |