about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/alembic/runtime/migration.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/alembic/runtime/migration.py')
-rw-r--r--.venv/lib/python3.12/site-packages/alembic/runtime/migration.py1391
1 files changed, 1391 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/alembic/runtime/migration.py b/.venv/lib/python3.12/site-packages/alembic/runtime/migration.py
new file mode 100644
index 00000000..ac431a62
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/alembic/runtime/migration.py
@@ -0,0 +1,1391 @@
+# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
+# mypy: no-warn-return-any, allow-any-generics
+
+from __future__ import annotations
+
+from contextlib import contextmanager
+from contextlib import nullcontext
+import logging
+import sys
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Collection
+from typing import Dict
+from typing import Iterable
+from typing import Iterator
+from typing import List
+from typing import Optional
+from typing import Set
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import Union
+
+from sqlalchemy import Column
+from sqlalchemy import literal_column
+from sqlalchemy import select
+from sqlalchemy.engine import Engine
+from sqlalchemy.engine import url as sqla_url
+from sqlalchemy.engine.strategies import MockEngineStrategy
+from typing_extensions import ContextManager
+
+from .. import ddl
+from .. import util
+from ..util import sqla_compat
+from ..util.compat import EncodedIO
+
+if TYPE_CHECKING:
+    from sqlalchemy.engine import Dialect
+    from sqlalchemy.engine import URL
+    from sqlalchemy.engine.base import Connection
+    from sqlalchemy.engine.base import Transaction
+    from sqlalchemy.engine.mock import MockConnection
+    from sqlalchemy.sql import Executable
+
+    from .environment import EnvironmentContext
+    from ..config import Config
+    from ..script.base import Script
+    from ..script.base import ScriptDirectory
+    from ..script.revision import _RevisionOrBase
+    from ..script.revision import Revision
+    from ..script.revision import RevisionMap
+
+log = logging.getLogger(__name__)
+
+
+class _ProxyTransaction:
+    def __init__(self, migration_context: MigrationContext) -> None:
+        self.migration_context = migration_context
+
+    @property
+    def _proxied_transaction(self) -> Optional[Transaction]:
+        return self.migration_context._transaction
+
+    def rollback(self) -> None:
+        t = self._proxied_transaction
+        assert t is not None
+        t.rollback()
+        self.migration_context._transaction = None
+
+    def commit(self) -> None:
+        t = self._proxied_transaction
+        assert t is not None
+        t.commit()
+        self.migration_context._transaction = None
+
+    def __enter__(self) -> _ProxyTransaction:
+        return self
+
+    def __exit__(self, type_: Any, value: Any, traceback: Any) -> None:
+        if self._proxied_transaction is not None:
+            self._proxied_transaction.__exit__(type_, value, traceback)
+            self.migration_context._transaction = None
+
+
+class MigrationContext:
+    """Represent the database state made available to a migration
+    script.
+
+    :class:`.MigrationContext` is the front end to an actual
+    database connection, or alternatively a string output
+    stream given a particular database dialect,
+    from an Alembic perspective.
+
+    When inside the ``env.py`` script, the :class:`.MigrationContext`
+    is available via the
+    :meth:`.EnvironmentContext.get_context` method,
+    which is available at ``alembic.context``::
+
+        # from within env.py script
+        from alembic import context
+
+        migration_context = context.get_context()
+
+    For usage outside of an ``env.py`` script, such as for
+    utility routines that want to check the current version
+    in the database, the :meth:`.MigrationContext.configure`
+    method to create new :class:`.MigrationContext` objects.
+    For example, to get at the current revision in the
+    database using :meth:`.MigrationContext.get_current_revision`::
+
+        # in any application, outside of an env.py script
+        from alembic.migration import MigrationContext
+        from sqlalchemy import create_engine
+
+        engine = create_engine("postgresql://mydatabase")
+        conn = engine.connect()
+
+        context = MigrationContext.configure(conn)
+        current_rev = context.get_current_revision()
+
+    The above context can also be used to produce
+    Alembic migration operations with an :class:`.Operations`
+    instance::
+
+        # in any application, outside of the normal Alembic environment
+        from alembic.operations import Operations
+
+        op = Operations(context)
+        op.alter_column("mytable", "somecolumn", nullable=True)
+
+    """
+
+    def __init__(
+        self,
+        dialect: Dialect,
+        connection: Optional[Connection],
+        opts: Dict[str, Any],
+        environment_context: Optional[EnvironmentContext] = None,
+    ) -> None:
+        self.environment_context = environment_context
+        self.opts = opts
+        self.dialect = dialect
+        self.script: Optional[ScriptDirectory] = opts.get("script")
+        as_sql: bool = opts.get("as_sql", False)
+        transactional_ddl = opts.get("transactional_ddl")
+        self._transaction_per_migration = opts.get(
+            "transaction_per_migration", False
+        )
+        self.on_version_apply_callbacks = opts.get("on_version_apply", ())
+        self._transaction: Optional[Transaction] = None
+
+        if as_sql:
+            self.connection = cast(
+                Optional["Connection"], self._stdout_connection(connection)
+            )
+            assert self.connection is not None
+            self._in_external_transaction = False
+        else:
+            self.connection = connection
+            self._in_external_transaction = (
+                sqla_compat._get_connection_in_transaction(connection)
+            )
+
+        self._migrations_fn: Optional[
+            Callable[..., Iterable[RevisionStep]]
+        ] = opts.get("fn")
+        self.as_sql = as_sql
+
+        self.purge = opts.get("purge", False)
+
+        if "output_encoding" in opts:
+            self.output_buffer = EncodedIO(
+                opts.get("output_buffer")
+                or sys.stdout,  # type:ignore[arg-type]
+                opts["output_encoding"],
+            )
+        else:
+            self.output_buffer = opts.get("output_buffer", sys.stdout)
+
+        self._user_compare_type = opts.get("compare_type", True)
+        self._user_compare_server_default = opts.get(
+            "compare_server_default", False
+        )
+        self.version_table = version_table = opts.get(
+            "version_table", "alembic_version"
+        )
+        self.version_table_schema = version_table_schema = opts.get(
+            "version_table_schema", None
+        )
+
+        self._start_from_rev: Optional[str] = opts.get("starting_rev")
+        self.impl = ddl.DefaultImpl.get_by_dialect(dialect)(
+            dialect,
+            self.connection,
+            self.as_sql,
+            transactional_ddl,
+            self.output_buffer,
+            opts,
+        )
+
+        self._version = self.impl.version_table_impl(
+            version_table=version_table,
+            version_table_schema=version_table_schema,
+            version_table_pk=opts.get("version_table_pk", True),
+        )
+
+        log.info("Context impl %s.", self.impl.__class__.__name__)
+        if self.as_sql:
+            log.info("Generating static SQL")
+        log.info(
+            "Will assume %s DDL.",
+            (
+                "transactional"
+                if self.impl.transactional_ddl
+                else "non-transactional"
+            ),
+        )
+
+    @classmethod
+    def configure(
+        cls,
+        connection: Optional[Connection] = None,
+        url: Optional[Union[str, URL]] = None,
+        dialect_name: Optional[str] = None,
+        dialect: Optional[Dialect] = None,
+        environment_context: Optional[EnvironmentContext] = None,
+        dialect_opts: Optional[Dict[str, str]] = None,
+        opts: Optional[Any] = None,
+    ) -> MigrationContext:
+        """Create a new :class:`.MigrationContext`.
+
+        This is a factory method usually called
+        by :meth:`.EnvironmentContext.configure`.
+
+        :param connection: a :class:`~sqlalchemy.engine.Connection`
+         to use for SQL execution in "online" mode.  When present,
+         is also used to determine the type of dialect in use.
+        :param url: a string database url, or a
+         :class:`sqlalchemy.engine.url.URL` object.
+         The type of dialect to be used will be derived from this if
+         ``connection`` is not passed.
+        :param dialect_name: string name of a dialect, such as
+         "postgresql", "mssql", etc.  The type of dialect to be used will be
+         derived from this if ``connection`` and ``url`` are not passed.
+        :param opts: dictionary of options.  Most other options
+         accepted by :meth:`.EnvironmentContext.configure` are passed via
+         this dictionary.
+
+        """
+        if opts is None:
+            opts = {}
+        if dialect_opts is None:
+            dialect_opts = {}
+
+        if connection:
+            if isinstance(connection, Engine):
+                raise util.CommandError(
+                    "'connection' argument to configure() is expected "
+                    "to be a sqlalchemy.engine.Connection instance, "
+                    "got %r" % connection,
+                )
+
+            dialect = connection.dialect
+        elif url:
+            url_obj = sqla_url.make_url(url)
+            dialect = url_obj.get_dialect()(**dialect_opts)
+        elif dialect_name:
+            url_obj = sqla_url.make_url("%s://" % dialect_name)
+            dialect = url_obj.get_dialect()(**dialect_opts)
+        elif not dialect:
+            raise Exception("Connection, url, or dialect_name is required.")
+        assert dialect is not None
+        return MigrationContext(dialect, connection, opts, environment_context)
+
+    @contextmanager
+    def autocommit_block(self) -> Iterator[None]:
+        """Enter an "autocommit" block, for databases that support AUTOCOMMIT
+        isolation levels.
+
+        This special directive is intended to support the occasional database
+        DDL or system operation that specifically has to be run outside of
+        any kind of transaction block.   The PostgreSQL database platform
+        is the most common target for this style of operation, as many
+        of its DDL operations must be run outside of transaction blocks, even
+        though the database overall supports transactional DDL.
+
+        The method is used as a context manager within a migration script, by
+        calling on :meth:`.Operations.get_context` to retrieve the
+        :class:`.MigrationContext`, then invoking
+        :meth:`.MigrationContext.autocommit_block` using the ``with:``
+        statement::
+
+            def upgrade():
+                with op.get_context().autocommit_block():
+                    op.execute("ALTER TYPE mood ADD VALUE 'soso'")
+
+        Above, a PostgreSQL "ALTER TYPE..ADD VALUE" directive is emitted,
+        which must be run outside of a transaction block at the database level.
+        The :meth:`.MigrationContext.autocommit_block` method makes use of the
+        SQLAlchemy ``AUTOCOMMIT`` isolation level setting, which against the
+        psycogp2 DBAPI corresponds to the ``connection.autocommit`` setting,
+        to ensure that the database driver is not inside of a DBAPI level
+        transaction block.
+
+        .. warning::
+
+            As is necessary, **the database transaction preceding the block is
+            unconditionally committed**.  This means that the run of migrations
+            preceding the operation will be committed, before the overall
+            migration operation is complete.
+
+            It is recommended that when an application includes migrations with
+            "autocommit" blocks, that
+            :paramref:`.EnvironmentContext.transaction_per_migration` be used
+            so that the calling environment is tuned to expect short per-file
+            migrations whether or not one of them has an autocommit block.
+
+
+        """
+        _in_connection_transaction = self._in_connection_transaction()
+
+        if self.impl.transactional_ddl and self.as_sql:
+            self.impl.emit_commit()
+
+        elif _in_connection_transaction:
+            assert self._transaction is not None
+
+            self._transaction.commit()
+            self._transaction = None
+
+        if not self.as_sql:
+            assert self.connection is not None
+            current_level = self.connection.get_isolation_level()
+            base_connection = self.connection
+
+            # in 1.3 and 1.4 non-future mode, the connection gets switched
+            # out.  we can use the base connection with the new mode
+            # except that it will not know it's in "autocommit" and will
+            # emit deprecation warnings when an autocommit action takes
+            # place.
+            self.connection = self.impl.connection = (
+                base_connection.execution_options(isolation_level="AUTOCOMMIT")
+            )
+
+            # sqlalchemy future mode will "autobegin" in any case, so take
+            # control of that "transaction" here
+            fake_trans: Optional[Transaction] = self.connection.begin()
+        else:
+            fake_trans = None
+        try:
+            yield
+        finally:
+            if not self.as_sql:
+                assert self.connection is not None
+                if fake_trans is not None:
+                    fake_trans.commit()
+                self.connection.execution_options(
+                    isolation_level=current_level
+                )
+                self.connection = self.impl.connection = base_connection
+
+            if self.impl.transactional_ddl and self.as_sql:
+                self.impl.emit_begin()
+
+            elif _in_connection_transaction:
+                assert self.connection is not None
+                self._transaction = self.connection.begin()
+
+    def begin_transaction(
+        self, _per_migration: bool = False
+    ) -> Union[_ProxyTransaction, ContextManager[None, Optional[bool]]]:
+        """Begin a logical transaction for migration operations.
+
+        This method is used within an ``env.py`` script to demarcate where
+        the outer "transaction" for a series of migrations begins.  Example::
+
+            def run_migrations_online():
+                connectable = create_engine(...)
+
+                with connectable.connect() as connection:
+                    context.configure(
+                        connection=connection, target_metadata=target_metadata
+                    )
+
+                    with context.begin_transaction():
+                        context.run_migrations()
+
+        Above, :meth:`.MigrationContext.begin_transaction` is used to demarcate
+        where the outer logical transaction occurs around the
+        :meth:`.MigrationContext.run_migrations` operation.
+
+        A "Logical" transaction means that the operation may or may not
+        correspond to a real database transaction.   If the target database
+        supports transactional DDL (or
+        :paramref:`.EnvironmentContext.configure.transactional_ddl` is true),
+        the :paramref:`.EnvironmentContext.configure.transaction_per_migration`
+        flag is not set, and the migration is against a real database
+        connection (as opposed to using "offline" ``--sql`` mode), a real
+        transaction will be started.   If ``--sql`` mode is in effect, the
+        operation would instead correspond to a string such as "BEGIN" being
+        emitted to the string output.
+
+        The returned object is a Python context manager that should only be
+        used in the context of a ``with:`` statement as indicated above.
+        The object has no other guaranteed API features present.
+
+        .. seealso::
+
+            :meth:`.MigrationContext.autocommit_block`
+
+        """
+
+        if self._in_external_transaction:
+            return nullcontext()
+
+        if self.impl.transactional_ddl:
+            transaction_now = _per_migration == self._transaction_per_migration
+        else:
+            transaction_now = _per_migration is True
+
+        if not transaction_now:
+            return nullcontext()
+
+        elif not self.impl.transactional_ddl:
+            assert _per_migration
+
+            if self.as_sql:
+                return nullcontext()
+            else:
+                # track our own notion of a "transaction block", which must be
+                # committed when complete.   Don't rely upon whether or not the
+                # SQLAlchemy connection reports as "in transaction"; this
+                # because SQLAlchemy future connection features autobegin
+                # behavior, so it may already be in a transaction from our
+                # emitting of queries like "has_version_table", etc. While we
+                # could track these operations as well, that leaves open the
+                # possibility of new operations or other things happening in
+                # the user environment that still may be triggering
+                # "autobegin".
+
+                in_transaction = self._transaction is not None
+
+                if in_transaction:
+                    return nullcontext()
+                else:
+                    assert self.connection is not None
+                    self._transaction = (
+                        sqla_compat._safe_begin_connection_transaction(
+                            self.connection
+                        )
+                    )
+                    return _ProxyTransaction(self)
+        elif self.as_sql:
+
+            @contextmanager
+            def begin_commit():
+                self.impl.emit_begin()
+                yield
+                self.impl.emit_commit()
+
+            return begin_commit()
+        else:
+            assert self.connection is not None
+            self._transaction = sqla_compat._safe_begin_connection_transaction(
+                self.connection
+            )
+            return _ProxyTransaction(self)
+
+    def get_current_revision(self) -> Optional[str]:
+        """Return the current revision, usually that which is present
+        in the ``alembic_version`` table in the database.
+
+        This method intends to be used only for a migration stream that
+        does not contain unmerged branches in the target database;
+        if there are multiple branches present, an exception is raised.
+        The :meth:`.MigrationContext.get_current_heads` should be preferred
+        over this method going forward in order to be compatible with
+        branch migration support.
+
+        If this :class:`.MigrationContext` was configured in "offline"
+        mode, that is with ``as_sql=True``, the ``starting_rev``
+        parameter is returned instead, if any.
+
+        """
+        heads = self.get_current_heads()
+        if len(heads) == 0:
+            return None
+        elif len(heads) > 1:
+            raise util.CommandError(
+                "Version table '%s' has more than one head present; "
+                "please use get_current_heads()" % self.version_table
+            )
+        else:
+            return heads[0]
+
+    def get_current_heads(self) -> Tuple[str, ...]:
+        """Return a tuple of the current 'head versions' that are represented
+        in the target database.
+
+        For a migration stream without branches, this will be a single
+        value, synonymous with that of
+        :meth:`.MigrationContext.get_current_revision`.   However when multiple
+        unmerged branches exist within the target database, the returned tuple
+        will contain a value for each head.
+
+        If this :class:`.MigrationContext` was configured in "offline"
+        mode, that is with ``as_sql=True``, the ``starting_rev``
+        parameter is returned in a one-length tuple.
+
+        If no version table is present, or if there are no revisions
+        present, an empty tuple is returned.
+
+        """
+        if self.as_sql:
+            start_from_rev: Any = self._start_from_rev
+            if start_from_rev == "base":
+                start_from_rev = None
+            elif start_from_rev is not None and self.script:
+                start_from_rev = [
+                    self.script.get_revision(sfr).revision
+                    for sfr in util.to_list(start_from_rev)
+                    if sfr not in (None, "base")
+                ]
+            return util.to_tuple(start_from_rev, default=())
+        else:
+            if self._start_from_rev:
+                raise util.CommandError(
+                    "Can't specify current_rev to context "
+                    "when using a database connection"
+                )
+            if not self._has_version_table():
+                return ()
+        assert self.connection is not None
+        return tuple(
+            row[0]
+            for row in self.connection.execute(
+                select(self._version.c.version_num)
+            )
+        )
+
+    def _ensure_version_table(self, purge: bool = False) -> None:
+        with sqla_compat._ensure_scope_for_ddl(self.connection):
+            assert self.connection is not None
+            self._version.create(self.connection, checkfirst=True)
+            if purge:
+                assert self.connection is not None
+                self.connection.execute(self._version.delete())
+
+    def _has_version_table(self) -> bool:
+        assert self.connection is not None
+        return sqla_compat._connectable_has_table(
+            self.connection, self.version_table, self.version_table_schema
+        )
+
+    def stamp(self, script_directory: ScriptDirectory, revision: str) -> None:
+        """Stamp the version table with a specific revision.
+
+        This method calculates those branches to which the given revision
+        can apply, and updates those branches as though they were migrated
+        towards that revision (either up or down).  If no current branches
+        include the revision, it is added as a new branch head.
+
+        """
+        heads = self.get_current_heads()
+        if not self.as_sql and not heads:
+            self._ensure_version_table()
+        head_maintainer = HeadMaintainer(self, heads)
+        for step in script_directory._stamp_revs(revision, heads):
+            head_maintainer.update_to_step(step)
+
+    def run_migrations(self, **kw: Any) -> None:
+        r"""Run the migration scripts established for this
+        :class:`.MigrationContext`, if any.
+
+        The commands in :mod:`alembic.command` will set up a function
+        that is ultimately passed to the :class:`.MigrationContext`
+        as the ``fn`` argument.  This function represents the "work"
+        that will be done when :meth:`.MigrationContext.run_migrations`
+        is called, typically from within the ``env.py`` script of the
+        migration environment.  The "work function" then provides an iterable
+        of version callables and other version information which
+        in the case of the ``upgrade`` or ``downgrade`` commands are the
+        list of version scripts to invoke.  Other commands yield nothing,
+        in the case that a command wants to run some other operation
+        against the database such as the ``current`` or ``stamp`` commands.
+
+        :param \**kw: keyword arguments here will be passed to each
+         migration callable, that is the ``upgrade()`` or ``downgrade()``
+         method within revision scripts.
+
+        """
+        self.impl.start_migrations()
+
+        heads: Tuple[str, ...]
+        if self.purge:
+            if self.as_sql:
+                raise util.CommandError("Can't use --purge with --sql mode")
+            self._ensure_version_table(purge=True)
+            heads = ()
+        else:
+            heads = self.get_current_heads()
+
+            dont_mutate = self.opts.get("dont_mutate", False)
+
+            if not self.as_sql and not heads and not dont_mutate:
+                self._ensure_version_table()
+
+        head_maintainer = HeadMaintainer(self, heads)
+
+        assert self._migrations_fn is not None
+        for step in self._migrations_fn(heads, self):
+            with self.begin_transaction(_per_migration=True):
+                if self.as_sql and not head_maintainer.heads:
+                    # for offline mode, include a CREATE TABLE from
+                    # the base
+                    assert self.connection is not None
+                    self._version.create(self.connection)
+                log.info("Running %s", step)
+                if self.as_sql:
+                    self.impl.static_output(
+                        "-- Running %s" % (step.short_log,)
+                    )
+                step.migration_fn(**kw)
+
+                # previously, we wouldn't stamp per migration
+                # if we were in a transaction, however given the more
+                # complex model that involves any number of inserts
+                # and row-targeted updates and deletes, it's simpler for now
+                # just to run the operations on every version
+                head_maintainer.update_to_step(step)
+                for callback in self.on_version_apply_callbacks:
+                    callback(
+                        ctx=self,
+                        step=step.info,
+                        heads=set(head_maintainer.heads),
+                        run_args=kw,
+                    )
+
+        if self.as_sql and not head_maintainer.heads:
+            assert self.connection is not None
+            self._version.drop(self.connection)
+
+    def _in_connection_transaction(self) -> bool:
+        try:
+            meth = self.connection.in_transaction  # type:ignore[union-attr]
+        except AttributeError:
+            return False
+        else:
+            return meth()
+
+    def execute(
+        self,
+        sql: Union[Executable, str],
+        execution_options: Optional[Dict[str, Any]] = None,
+    ) -> None:
+        """Execute a SQL construct or string statement.
+
+        The underlying execution mechanics are used, that is
+        if this is "offline mode" the SQL is written to the
+        output buffer, otherwise the SQL is emitted on
+        the current SQLAlchemy connection.
+
+        """
+        self.impl._exec(sql, execution_options)
+
+    def _stdout_connection(
+        self, connection: Optional[Connection]
+    ) -> MockConnection:
+        def dump(construct, *multiparams, **params):
+            self.impl._exec(construct)
+
+        return MockEngineStrategy.MockConnection(self.dialect, dump)
+
+    @property
+    def bind(self) -> Optional[Connection]:
+        """Return the current "bind".
+
+        In online mode, this is an instance of
+        :class:`sqlalchemy.engine.Connection`, and is suitable
+        for ad-hoc execution of any kind of usage described
+        in SQLAlchemy Core documentation as well as
+        for usage with the :meth:`sqlalchemy.schema.Table.create`
+        and :meth:`sqlalchemy.schema.MetaData.create_all` methods
+        of :class:`~sqlalchemy.schema.Table`,
+        :class:`~sqlalchemy.schema.MetaData`.
+
+        Note that when "standard output" mode is enabled,
+        this bind will be a "mock" connection handler that cannot
+        return results and is only appropriate for a very limited
+        subset of commands.
+
+        """
+        return self.connection
+
+    @property
+    def config(self) -> Optional[Config]:
+        """Return the :class:`.Config` used by the current environment,
+        if any."""
+
+        if self.environment_context:
+            return self.environment_context.config
+        else:
+            return None
+
+    def _compare_type(
+        self, inspector_column: Column[Any], metadata_column: Column
+    ) -> bool:
+        if self._user_compare_type is False:
+            return False
+
+        if callable(self._user_compare_type):
+            user_value = self._user_compare_type(
+                self,
+                inspector_column,
+                metadata_column,
+                inspector_column.type,
+                metadata_column.type,
+            )
+            if user_value is not None:
+                return user_value
+
+        return self.impl.compare_type(inspector_column, metadata_column)
+
+    def _compare_server_default(
+        self,
+        inspector_column: Column[Any],
+        metadata_column: Column[Any],
+        rendered_metadata_default: Optional[str],
+        rendered_column_default: Optional[str],
+    ) -> bool:
+        if self._user_compare_server_default is False:
+            return False
+
+        if callable(self._user_compare_server_default):
+            user_value = self._user_compare_server_default(
+                self,
+                inspector_column,
+                metadata_column,
+                rendered_column_default,
+                metadata_column.server_default,
+                rendered_metadata_default,
+            )
+            if user_value is not None:
+                return user_value
+
+        return self.impl.compare_server_default(
+            inspector_column,
+            metadata_column,
+            rendered_metadata_default,
+            rendered_column_default,
+        )
+
+
+class HeadMaintainer:
+    def __init__(self, context: MigrationContext, heads: Any) -> None:
+        self.context = context
+        self.heads = set(heads)
+
+    def _insert_version(self, version: str) -> None:
+        assert version not in self.heads
+        self.heads.add(version)
+
+        self.context.impl._exec(
+            self.context._version.insert().values(
+                version_num=literal_column("'%s'" % version)
+            )
+        )
+
+    def _delete_version(self, version: str) -> None:
+        self.heads.remove(version)
+
+        ret = self.context.impl._exec(
+            self.context._version.delete().where(
+                self.context._version.c.version_num
+                == literal_column("'%s'" % version)
+            )
+        )
+
+        if (
+            not self.context.as_sql
+            and self.context.dialect.supports_sane_rowcount
+            and ret is not None
+            and ret.rowcount != 1
+        ):
+            raise util.CommandError(
+                "Online migration expected to match one "
+                "row when deleting '%s' in '%s'; "
+                "%d found"
+                % (version, self.context.version_table, ret.rowcount)
+            )
+
+    def _update_version(self, from_: str, to_: str) -> None:
+        assert to_ not in self.heads
+        self.heads.remove(from_)
+        self.heads.add(to_)
+
+        ret = self.context.impl._exec(
+            self.context._version.update()
+            .values(version_num=literal_column("'%s'" % to_))
+            .where(
+                self.context._version.c.version_num
+                == literal_column("'%s'" % from_)
+            )
+        )
+
+        if (
+            not self.context.as_sql
+            and self.context.dialect.supports_sane_rowcount
+            and ret is not None
+            and ret.rowcount != 1
+        ):
+            raise util.CommandError(
+                "Online migration expected to match one "
+                "row when updating '%s' to '%s' in '%s'; "
+                "%d found"
+                % (from_, to_, self.context.version_table, ret.rowcount)
+            )
+
+    def update_to_step(self, step: Union[RevisionStep, StampStep]) -> None:
+        if step.should_delete_branch(self.heads):
+            vers = step.delete_version_num
+            log.debug("branch delete %s", vers)
+            self._delete_version(vers)
+        elif step.should_create_branch(self.heads):
+            vers = step.insert_version_num
+            log.debug("new branch insert %s", vers)
+            self._insert_version(vers)
+        elif step.should_merge_branches(self.heads):
+            # delete revs, update from rev, update to rev
+            (
+                delete_revs,
+                update_from_rev,
+                update_to_rev,
+            ) = step.merge_branch_idents(self.heads)
+            log.debug(
+                "merge, delete %s, update %s to %s",
+                delete_revs,
+                update_from_rev,
+                update_to_rev,
+            )
+            for delrev in delete_revs:
+                self._delete_version(delrev)
+            self._update_version(update_from_rev, update_to_rev)
+        elif step.should_unmerge_branches(self.heads):
+            (
+                update_from_rev,
+                update_to_rev,
+                insert_revs,
+            ) = step.unmerge_branch_idents(self.heads)
+            log.debug(
+                "unmerge, insert %s, update %s to %s",
+                insert_revs,
+                update_from_rev,
+                update_to_rev,
+            )
+            for insrev in insert_revs:
+                self._insert_version(insrev)
+            self._update_version(update_from_rev, update_to_rev)
+        else:
+            from_, to_ = step.update_version_num(self.heads)
+            log.debug("update %s to %s", from_, to_)
+            self._update_version(from_, to_)
+
+
+class MigrationInfo:
+    """Exposes information about a migration step to a callback listener.
+
+    The :class:`.MigrationInfo` object is available exclusively for the
+    benefit of the :paramref:`.EnvironmentContext.on_version_apply`
+    callback hook.
+
+    """
+
+    is_upgrade: bool
+    """True/False: indicates whether this operation ascends or descends the
+    version tree."""
+
+    is_stamp: bool
+    """True/False: indicates whether this operation is a stamp (i.e. whether
+    it results in any actual database operations)."""
+
+    up_revision_id: Optional[str]
+    """Version string corresponding to :attr:`.Revision.revision`.
+
+    In the case of a stamp operation, it is advised to use the
+    :attr:`.MigrationInfo.up_revision_ids` tuple as a stamp operation can
+    make a single movement from one or more branches down to a single
+    branchpoint, in which case there will be multiple "up" revisions.
+
+    .. seealso::
+
+        :attr:`.MigrationInfo.up_revision_ids`
+
+    """
+
+    up_revision_ids: Tuple[str, ...]
+    """Tuple of version strings corresponding to :attr:`.Revision.revision`.
+
+    In the majority of cases, this tuple will be a single value, synonymous
+    with the scalar value of :attr:`.MigrationInfo.up_revision_id`.
+    It can be multiple revision identifiers only in the case of an
+    ``alembic stamp`` operation which is moving downwards from multiple
+    branches down to their common branch point.
+
+    """
+
+    down_revision_ids: Tuple[str, ...]
+    """Tuple of strings representing the base revisions of this migration step.
+
+    If empty, this represents a root revision; otherwise, the first item
+    corresponds to :attr:`.Revision.down_revision`, and the rest are inferred
+    from dependencies.
+    """
+
+    revision_map: RevisionMap
+    """The revision map inside of which this operation occurs."""
+
+    def __init__(
+        self,
+        revision_map: RevisionMap,
+        is_upgrade: bool,
+        is_stamp: bool,
+        up_revisions: Union[str, Tuple[str, ...]],
+        down_revisions: Union[str, Tuple[str, ...]],
+    ) -> None:
+        self.revision_map = revision_map
+        self.is_upgrade = is_upgrade
+        self.is_stamp = is_stamp
+        self.up_revision_ids = util.to_tuple(up_revisions, default=())
+        if self.up_revision_ids:
+            self.up_revision_id = self.up_revision_ids[0]
+        else:
+            # this should never be the case with
+            # "upgrade", "downgrade", or "stamp" as we are always
+            # measuring movement in terms of at least one upgrade version
+            self.up_revision_id = None
+        self.down_revision_ids = util.to_tuple(down_revisions, default=())
+
+    @property
+    def is_migration(self) -> bool:
+        """True/False: indicates whether this operation is a migration.
+
+        At present this is true if and only the migration is not a stamp.
+        If other operation types are added in the future, both this attribute
+        and :attr:`~.MigrationInfo.is_stamp` will be false.
+        """
+        return not self.is_stamp
+
+    @property
+    def source_revision_ids(self) -> Tuple[str, ...]:
+        """Active revisions before this migration step is applied."""
+        return (
+            self.down_revision_ids if self.is_upgrade else self.up_revision_ids
+        )
+
+    @property
+    def destination_revision_ids(self) -> Tuple[str, ...]:
+        """Active revisions after this migration step is applied."""
+        return (
+            self.up_revision_ids if self.is_upgrade else self.down_revision_ids
+        )
+
+    @property
+    def up_revision(self) -> Optional[Revision]:
+        """Get :attr:`~.MigrationInfo.up_revision_id` as
+        a :class:`.Revision`.
+
+        """
+        return self.revision_map.get_revision(self.up_revision_id)
+
+    @property
+    def up_revisions(self) -> Tuple[Optional[_RevisionOrBase], ...]:
+        """Get :attr:`~.MigrationInfo.up_revision_ids` as a
+        :class:`.Revision`."""
+        return self.revision_map.get_revisions(self.up_revision_ids)
+
+    @property
+    def down_revisions(self) -> Tuple[Optional[_RevisionOrBase], ...]:
+        """Get :attr:`~.MigrationInfo.down_revision_ids` as a tuple of
+        :class:`Revisions <.Revision>`."""
+        return self.revision_map.get_revisions(self.down_revision_ids)
+
+    @property
+    def source_revisions(self) -> Tuple[Optional[_RevisionOrBase], ...]:
+        """Get :attr:`~MigrationInfo.source_revision_ids` as a tuple of
+        :class:`Revisions <.Revision>`."""
+        return self.revision_map.get_revisions(self.source_revision_ids)
+
+    @property
+    def destination_revisions(self) -> Tuple[Optional[_RevisionOrBase], ...]:
+        """Get :attr:`~MigrationInfo.destination_revision_ids` as a tuple of
+        :class:`Revisions <.Revision>`."""
+        return self.revision_map.get_revisions(self.destination_revision_ids)
+
+
+class MigrationStep:
+    from_revisions_no_deps: Tuple[str, ...]
+    to_revisions_no_deps: Tuple[str, ...]
+    is_upgrade: bool
+    migration_fn: Any
+
+    if TYPE_CHECKING:
+
+        @property
+        def doc(self) -> Optional[str]: ...
+
+    @property
+    def name(self) -> str:
+        return self.migration_fn.__name__
+
+    @classmethod
+    def upgrade_from_script(
+        cls, revision_map: RevisionMap, script: Script
+    ) -> RevisionStep:
+        return RevisionStep(revision_map, script, True)
+
+    @classmethod
+    def downgrade_from_script(
+        cls, revision_map: RevisionMap, script: Script
+    ) -> RevisionStep:
+        return RevisionStep(revision_map, script, False)
+
+    @property
+    def is_downgrade(self) -> bool:
+        return not self.is_upgrade
+
+    @property
+    def short_log(self) -> str:
+        return "%s %s -> %s" % (
+            self.name,
+            util.format_as_comma(self.from_revisions_no_deps),
+            util.format_as_comma(self.to_revisions_no_deps),
+        )
+
+    def __str__(self):
+        if self.doc:
+            return "%s %s -> %s, %s" % (
+                self.name,
+                util.format_as_comma(self.from_revisions_no_deps),
+                util.format_as_comma(self.to_revisions_no_deps),
+                self.doc,
+            )
+        else:
+            return self.short_log
+
+
+class RevisionStep(MigrationStep):
+    def __init__(
+        self, revision_map: RevisionMap, revision: Script, is_upgrade: bool
+    ) -> None:
+        self.revision_map = revision_map
+        self.revision = revision
+        self.is_upgrade = is_upgrade
+        if is_upgrade:
+            self.migration_fn = revision.module.upgrade
+        else:
+            self.migration_fn = revision.module.downgrade
+
+    def __repr__(self):
+        return "RevisionStep(%r, is_upgrade=%r)" % (
+            self.revision.revision,
+            self.is_upgrade,
+        )
+
+    def __eq__(self, other: object) -> bool:
+        return (
+            isinstance(other, RevisionStep)
+            and other.revision == self.revision
+            and self.is_upgrade == other.is_upgrade
+        )
+
+    @property
+    def doc(self) -> Optional[str]:
+        return self.revision.doc
+
+    @property
+    def from_revisions(self) -> Tuple[str, ...]:
+        if self.is_upgrade:
+            return self.revision._normalized_down_revisions
+        else:
+            return (self.revision.revision,)
+
+    @property
+    def from_revisions_no_deps(  # type:ignore[override]
+        self,
+    ) -> Tuple[str, ...]:
+        if self.is_upgrade:
+            return self.revision._versioned_down_revisions
+        else:
+            return (self.revision.revision,)
+
+    @property
+    def to_revisions(self) -> Tuple[str, ...]:
+        if self.is_upgrade:
+            return (self.revision.revision,)
+        else:
+            return self.revision._normalized_down_revisions
+
+    @property
+    def to_revisions_no_deps(  # type:ignore[override]
+        self,
+    ) -> Tuple[str, ...]:
+        if self.is_upgrade:
+            return (self.revision.revision,)
+        else:
+            return self.revision._versioned_down_revisions
+
+    @property
+    def _has_scalar_down_revision(self) -> bool:
+        return len(self.revision._normalized_down_revisions) == 1
+
+    def should_delete_branch(self, heads: Set[str]) -> bool:
+        """A delete is when we are a. in a downgrade and b.
+        we are going to the "base" or we are going to a version that
+        is implied as a dependency on another version that is remaining.
+
+        """
+        if not self.is_downgrade:
+            return False
+
+        if self.revision.revision not in heads:
+            return False
+
+        downrevs = self.revision._normalized_down_revisions
+
+        if not downrevs:
+            # is a base
+            return True
+        else:
+            # determine what the ultimate "to_revisions" for an
+            # unmerge would be.  If there are none, then we're a delete.
+            to_revisions = self._unmerge_to_revisions(heads)
+            return not to_revisions
+
+    def merge_branch_idents(
+        self, heads: Set[str]
+    ) -> Tuple[List[str], str, str]:
+        other_heads = set(heads).difference(self.from_revisions)
+
+        if other_heads:
+            ancestors = {
+                r.revision
+                for r in self.revision_map._get_ancestor_nodes(
+                    self.revision_map.get_revisions(other_heads), check=False
+                )
+            }
+            from_revisions = list(
+                set(self.from_revisions).difference(ancestors)
+            )
+        else:
+            from_revisions = list(self.from_revisions)
+
+        return (
+            # delete revs, update from rev, update to rev
+            list(from_revisions[0:-1]),
+            from_revisions[-1],
+            self.to_revisions[0],
+        )
+
+    def _unmerge_to_revisions(self, heads: Set[str]) -> Tuple[str, ...]:
+        other_heads = set(heads).difference([self.revision.revision])
+        if other_heads:
+            ancestors = {
+                r.revision
+                for r in self.revision_map._get_ancestor_nodes(
+                    self.revision_map.get_revisions(other_heads), check=False
+                )
+            }
+            return tuple(set(self.to_revisions).difference(ancestors))
+        else:
+            # for each revision we plan to return, compute its ancestors
+            # (excluding self), and remove those from the final output since
+            # they are already accounted for.
+            ancestors = {
+                r.revision
+                for to_revision in self.to_revisions
+                for r in self.revision_map._get_ancestor_nodes(
+                    self.revision_map.get_revisions(to_revision), check=False
+                )
+                if r.revision != to_revision
+            }
+            return tuple(set(self.to_revisions).difference(ancestors))
+
+    def unmerge_branch_idents(
+        self, heads: Set[str]
+    ) -> Tuple[str, str, Tuple[str, ...]]:
+        to_revisions = self._unmerge_to_revisions(heads)
+
+        return (
+            # update from rev, update to rev, insert revs
+            self.from_revisions[0],
+            to_revisions[-1],
+            to_revisions[0:-1],
+        )
+
+    def should_create_branch(self, heads: Set[str]) -> bool:
+        if not self.is_upgrade:
+            return False
+
+        downrevs = self.revision._normalized_down_revisions
+
+        if not downrevs:
+            # is a base
+            return True
+        else:
+            # none of our downrevs are present, so...
+            # we have to insert our version.   This is true whether
+            # or not there is only one downrev, or multiple (in the latter
+            # case, we're a merge point.)
+            if not heads.intersection(downrevs):
+                return True
+            else:
+                return False
+
+    def should_merge_branches(self, heads: Set[str]) -> bool:
+        if not self.is_upgrade:
+            return False
+
+        downrevs = self.revision._normalized_down_revisions
+
+        if len(downrevs) > 1 and len(heads.intersection(downrevs)) > 1:
+            return True
+
+        return False
+
+    def should_unmerge_branches(self, heads: Set[str]) -> bool:
+        if not self.is_downgrade:
+            return False
+
+        downrevs = self.revision._normalized_down_revisions
+
+        if self.revision.revision in heads and len(downrevs) > 1:
+            return True
+
+        return False
+
+    def update_version_num(self, heads: Set[str]) -> Tuple[str, str]:
+        if not self._has_scalar_down_revision:
+            downrev = heads.intersection(
+                self.revision._normalized_down_revisions
+            )
+            assert (
+                len(downrev) == 1
+            ), "Can't do an UPDATE because downrevision is ambiguous"
+            down_revision = list(downrev)[0]
+        else:
+            down_revision = self.revision._normalized_down_revisions[0]
+
+        if self.is_upgrade:
+            return down_revision, self.revision.revision
+        else:
+            return self.revision.revision, down_revision
+
+    @property
+    def delete_version_num(self) -> str:
+        return self.revision.revision
+
+    @property
+    def insert_version_num(self) -> str:
+        return self.revision.revision
+
+    @property
+    def info(self) -> MigrationInfo:
+        return MigrationInfo(
+            revision_map=self.revision_map,
+            up_revisions=self.revision.revision,
+            down_revisions=self.revision._normalized_down_revisions,
+            is_upgrade=self.is_upgrade,
+            is_stamp=False,
+        )
+
+
+class StampStep(MigrationStep):
+    def __init__(
+        self,
+        from_: Optional[Union[str, Collection[str]]],
+        to_: Optional[Union[str, Collection[str]]],
+        is_upgrade: bool,
+        branch_move: bool,
+        revision_map: Optional[RevisionMap] = None,
+    ) -> None:
+        self.from_: Tuple[str, ...] = util.to_tuple(from_, default=())
+        self.to_: Tuple[str, ...] = util.to_tuple(to_, default=())
+        self.is_upgrade = is_upgrade
+        self.branch_move = branch_move
+        self.migration_fn = self.stamp_revision
+        self.revision_map = revision_map
+
+    doc: Optional[str] = None
+
+    def stamp_revision(self, **kw: Any) -> None:
+        return None
+
+    def __eq__(self, other):
+        return (
+            isinstance(other, StampStep)
+            and other.from_revisions == self.from_revisions
+            and other.to_revisions == self.to_revisions
+            and other.branch_move == self.branch_move
+            and self.is_upgrade == other.is_upgrade
+        )
+
+    @property
+    def from_revisions(self):
+        return self.from_
+
+    @property
+    def to_revisions(self) -> Tuple[str, ...]:
+        return self.to_
+
+    @property
+    def from_revisions_no_deps(  # type:ignore[override]
+        self,
+    ) -> Tuple[str, ...]:
+        return self.from_
+
+    @property
+    def to_revisions_no_deps(  # type:ignore[override]
+        self,
+    ) -> Tuple[str, ...]:
+        return self.to_
+
+    @property
+    def delete_version_num(self) -> str:
+        assert len(self.from_) == 1
+        return self.from_[0]
+
+    @property
+    def insert_version_num(self) -> str:
+        assert len(self.to_) == 1
+        return self.to_[0]
+
+    def update_version_num(self, heads: Set[str]) -> Tuple[str, str]:
+        assert len(self.from_) == 1
+        assert len(self.to_) == 1
+        return self.from_[0], self.to_[0]
+
+    def merge_branch_idents(
+        self, heads: Union[Set[str], List[str]]
+    ) -> Union[Tuple[List[Any], str, str], Tuple[List[str], str, str]]:
+        return (
+            # delete revs, update from rev, update to rev
+            list(self.from_[0:-1]),
+            self.from_[-1],
+            self.to_[0],
+        )
+
+    def unmerge_branch_idents(
+        self, heads: Set[str]
+    ) -> Tuple[str, str, List[str]]:
+        return (
+            # update from rev, update to rev, insert revs
+            self.from_[0],
+            self.to_[-1],
+            list(self.to_[0:-1]),
+        )
+
+    def should_delete_branch(self, heads: Set[str]) -> bool:
+        # TODO: we probably need to look for self.to_ inside of heads,
+        # in a similar manner as should_create_branch, however we have
+        # no tests for this yet (stamp downgrades w/ branches)
+        return self.is_downgrade and self.branch_move
+
+    def should_create_branch(self, heads: Set[str]) -> Union[Set[str], bool]:
+        return (
+            self.is_upgrade
+            and (self.branch_move or set(self.from_).difference(heads))
+            and set(self.to_).difference(heads)
+        )
+
+    def should_merge_branches(self, heads: Set[str]) -> bool:
+        return len(self.from_) > 1
+
+    def should_unmerge_branches(self, heads: Set[str]) -> bool:
+        return len(self.to_) > 1
+
+    @property
+    def info(self) -> MigrationInfo:
+        up, down = (
+            (self.to_, self.from_)
+            if self.is_upgrade
+            else (self.from_, self.to_)
+        )
+        assert self.revision_map is not None
+        return MigrationInfo(
+            revision_map=self.revision_map,
+            up_revisions=up,
+            down_revisions=down,
+            is_upgrade=self.is_upgrade,
+            is_stamp=True,
+        )