about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/alembic/operations/batch.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/alembic/operations/batch.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/alembic/operations/batch.py')
-rw-r--r--.venv/lib/python3.12/site-packages/alembic/operations/batch.py718
1 files changed, 718 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/alembic/operations/batch.py b/.venv/lib/python3.12/site-packages/alembic/operations/batch.py
new file mode 100644
index 00000000..fe183e9c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/alembic/operations/batch.py
@@ -0,0 +1,718 @@
+# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
+# mypy: no-warn-return-any, allow-any-generics
+
+from __future__ import annotations
+
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import Union
+
+from sqlalchemy import CheckConstraint
+from sqlalchemy import Column
+from sqlalchemy import ForeignKeyConstraint
+from sqlalchemy import Index
+from sqlalchemy import MetaData
+from sqlalchemy import PrimaryKeyConstraint
+from sqlalchemy import schema as sql_schema
+from sqlalchemy import select
+from sqlalchemy import Table
+from sqlalchemy import types as sqltypes
+from sqlalchemy.sql.schema import SchemaEventTarget
+from sqlalchemy.util import OrderedDict
+from sqlalchemy.util import topological
+
+from ..util import exc
+from ..util.sqla_compat import _columns_for_constraint
+from ..util.sqla_compat import _copy
+from ..util.sqla_compat import _copy_expression
+from ..util.sqla_compat import _ensure_scope_for_ddl
+from ..util.sqla_compat import _fk_is_self_referential
+from ..util.sqla_compat import _idx_table_bound_expressions
+from ..util.sqla_compat import _is_type_bound
+from ..util.sqla_compat import _remove_column_from_collection
+from ..util.sqla_compat import _resolve_for_variant
+from ..util.sqla_compat import constraint_name_defined
+from ..util.sqla_compat import constraint_name_string
+
+if TYPE_CHECKING:
+    from typing import Literal
+
+    from sqlalchemy.engine import Dialect
+    from sqlalchemy.sql.elements import ColumnClause
+    from sqlalchemy.sql.elements import quoted_name
+    from sqlalchemy.sql.functions import Function
+    from sqlalchemy.sql.schema import Constraint
+    from sqlalchemy.sql.type_api import TypeEngine
+
+    from ..ddl.impl import DefaultImpl
+
+
+class BatchOperationsImpl:
+    def __init__(
+        self,
+        operations,
+        table_name,
+        schema,
+        recreate,
+        copy_from,
+        table_args,
+        table_kwargs,
+        reflect_args,
+        reflect_kwargs,
+        naming_convention,
+        partial_reordering,
+    ):
+        self.operations = operations
+        self.table_name = table_name
+        self.schema = schema
+        if recreate not in ("auto", "always", "never"):
+            raise ValueError(
+                "recreate may be one of 'auto', 'always', or 'never'."
+            )
+        self.recreate = recreate
+        self.copy_from = copy_from
+        self.table_args = table_args
+        self.table_kwargs = dict(table_kwargs)
+        self.reflect_args = reflect_args
+        self.reflect_kwargs = dict(reflect_kwargs)
+        self.reflect_kwargs.setdefault(
+            "listeners", list(self.reflect_kwargs.get("listeners", ()))
+        )
+        self.reflect_kwargs["listeners"].append(
+            ("column_reflect", operations.impl.autogen_column_reflect)
+        )
+        self.naming_convention = naming_convention
+        self.partial_reordering = partial_reordering
+        self.batch = []
+
+    @property
+    def dialect(self) -> Dialect:
+        return self.operations.impl.dialect
+
+    @property
+    def impl(self) -> DefaultImpl:
+        return self.operations.impl
+
+    def _should_recreate(self) -> bool:
+        if self.recreate == "auto":
+            return self.operations.impl.requires_recreate_in_batch(self)
+        elif self.recreate == "always":
+            return True
+        else:
+            return False
+
+    def flush(self) -> None:
+        should_recreate = self._should_recreate()
+
+        with _ensure_scope_for_ddl(self.impl.connection):
+            if not should_recreate:
+                for opname, arg, kw in self.batch:
+                    fn = getattr(self.operations.impl, opname)
+                    fn(*arg, **kw)
+            else:
+                if self.naming_convention:
+                    m1 = MetaData(naming_convention=self.naming_convention)
+                else:
+                    m1 = MetaData()
+
+                if self.copy_from is not None:
+                    existing_table = self.copy_from
+                    reflected = False
+                else:
+                    if self.operations.migration_context.as_sql:
+                        raise exc.CommandError(
+                            f"This operation cannot proceed in --sql mode; "
+                            f"batch mode with dialect "
+                            f"{self.operations.migration_context.dialect.name} "  # noqa: E501
+                            f"requires a live database connection with which "
+                            f'to reflect the table "{self.table_name}". '
+                            f"To generate a batch SQL migration script using "
+                            "table "
+                            '"move and copy", a complete Table object '
+                            f'should be passed to the "copy_from" argument '
+                            "of the batch_alter_table() method so that table "
+                            "reflection can be skipped."
+                        )
+
+                    existing_table = Table(
+                        self.table_name,
+                        m1,
+                        schema=self.schema,
+                        autoload_with=self.operations.get_bind(),
+                        *self.reflect_args,
+                        **self.reflect_kwargs,
+                    )
+                    reflected = True
+
+                batch_impl = ApplyBatchImpl(
+                    self.impl,
+                    existing_table,
+                    self.table_args,
+                    self.table_kwargs,
+                    reflected,
+                    partial_reordering=self.partial_reordering,
+                )
+                for opname, arg, kw in self.batch:
+                    fn = getattr(batch_impl, opname)
+                    fn(*arg, **kw)
+
+                batch_impl._create(self.impl)
+
+    def alter_column(self, *arg, **kw) -> None:
+        self.batch.append(("alter_column", arg, kw))
+
+    def add_column(self, *arg, **kw) -> None:
+        if (
+            "insert_before" in kw or "insert_after" in kw
+        ) and not self._should_recreate():
+            raise exc.CommandError(
+                "Can't specify insert_before or insert_after when using "
+                "ALTER; please specify recreate='always'"
+            )
+        self.batch.append(("add_column", arg, kw))
+
+    def drop_column(self, *arg, **kw) -> None:
+        self.batch.append(("drop_column", arg, kw))
+
+    def add_constraint(self, const: Constraint) -> None:
+        self.batch.append(("add_constraint", (const,), {}))
+
+    def drop_constraint(self, const: Constraint) -> None:
+        self.batch.append(("drop_constraint", (const,), {}))
+
+    def rename_table(self, *arg, **kw):
+        self.batch.append(("rename_table", arg, kw))
+
+    def create_index(self, idx: Index, **kw: Any) -> None:
+        self.batch.append(("create_index", (idx,), kw))
+
+    def drop_index(self, idx: Index, **kw: Any) -> None:
+        self.batch.append(("drop_index", (idx,), kw))
+
+    def create_table_comment(self, table):
+        self.batch.append(("create_table_comment", (table,), {}))
+
+    def drop_table_comment(self, table):
+        self.batch.append(("drop_table_comment", (table,), {}))
+
+    def create_table(self, table):
+        raise NotImplementedError("Can't create table in batch mode")
+
+    def drop_table(self, table):
+        raise NotImplementedError("Can't drop table in batch mode")
+
+    def create_column_comment(self, column):
+        self.batch.append(("create_column_comment", (column,), {}))
+
+
+class ApplyBatchImpl:
+    def __init__(
+        self,
+        impl: DefaultImpl,
+        table: Table,
+        table_args: tuple,
+        table_kwargs: Dict[str, Any],
+        reflected: bool,
+        partial_reordering: tuple = (),
+    ) -> None:
+        self.impl = impl
+        self.table = table  # this is a Table object
+        self.table_args = table_args
+        self.table_kwargs = table_kwargs
+        self.temp_table_name = self._calc_temp_name(table.name)
+        self.new_table: Optional[Table] = None
+
+        self.partial_reordering = partial_reordering  # tuple of tuples
+        self.add_col_ordering: Tuple[
+            Tuple[str, str], ...
+        ] = ()  # tuple of tuples
+
+        self.column_transfers = OrderedDict(
+            (c.name, {"expr": c}) for c in self.table.c
+        )
+        self.existing_ordering = list(self.column_transfers)
+
+        self.reflected = reflected
+        self._grab_table_elements()
+
+    @classmethod
+    def _calc_temp_name(cls, tablename: Union[quoted_name, str]) -> str:
+        return ("_alembic_tmp_%s" % tablename)[0:50]
+
+    def _grab_table_elements(self) -> None:
+        schema = self.table.schema
+        self.columns: Dict[str, Column[Any]] = OrderedDict()
+        for c in self.table.c:
+            c_copy = _copy(c, schema=schema)
+            c_copy.unique = c_copy.index = False
+            # ensure that the type object was copied,
+            # as we may need to modify it in-place
+            if isinstance(c.type, SchemaEventTarget):
+                assert c_copy.type is not c.type
+            self.columns[c.name] = c_copy
+        self.named_constraints: Dict[str, Constraint] = {}
+        self.unnamed_constraints = []
+        self.col_named_constraints = {}
+        self.indexes: Dict[str, Index] = {}
+        self.new_indexes: Dict[str, Index] = {}
+
+        for const in self.table.constraints:
+            if _is_type_bound(const):
+                continue
+            elif (
+                self.reflected
+                and isinstance(const, CheckConstraint)
+                and not const.name
+            ):
+                # TODO: we are skipping unnamed reflected CheckConstraint
+                # because
+                # we have no way to determine _is_type_bound() for these.
+                pass
+            elif constraint_name_string(const.name):
+                self.named_constraints[const.name] = const
+            else:
+                self.unnamed_constraints.append(const)
+
+        if not self.reflected:
+            for col in self.table.c:
+                for const in col.constraints:
+                    if const.name:
+                        self.col_named_constraints[const.name] = (col, const)
+
+        for idx in self.table.indexes:
+            self.indexes[idx.name] = idx  # type: ignore[index]
+
+        for k in self.table.kwargs:
+            self.table_kwargs.setdefault(k, self.table.kwargs[k])
+
+    def _adjust_self_columns_for_partial_reordering(self) -> None:
+        pairs = set()
+
+        col_by_idx = list(self.columns)
+
+        if self.partial_reordering:
+            for tuple_ in self.partial_reordering:
+                for index, elem in enumerate(tuple_):
+                    if index > 0:
+                        pairs.add((tuple_[index - 1], elem))
+        else:
+            for index, elem in enumerate(self.existing_ordering):
+                if index > 0:
+                    pairs.add((col_by_idx[index - 1], elem))
+
+        pairs.update(self.add_col_ordering)
+
+        # this can happen if some columns were dropped and not removed
+        # from existing_ordering.  this should be prevented already, but
+        # conservatively making sure this didn't happen
+        pairs_list = [p for p in pairs if p[0] != p[1]]
+
+        sorted_ = list(
+            topological.sort(pairs_list, col_by_idx, deterministic_order=True)
+        )
+        self.columns = OrderedDict((k, self.columns[k]) for k in sorted_)
+        self.column_transfers = OrderedDict(
+            (k, self.column_transfers[k]) for k in sorted_
+        )
+
+    def _transfer_elements_to_new_table(self) -> None:
+        assert self.new_table is None, "Can only create new table once"
+
+        m = MetaData()
+        schema = self.table.schema
+
+        if self.partial_reordering or self.add_col_ordering:
+            self._adjust_self_columns_for_partial_reordering()
+
+        self.new_table = new_table = Table(
+            self.temp_table_name,
+            m,
+            *(list(self.columns.values()) + list(self.table_args)),
+            schema=schema,
+            **self.table_kwargs,
+        )
+
+        for const in (
+            list(self.named_constraints.values()) + self.unnamed_constraints
+        ):
+            const_columns = {c.key for c in _columns_for_constraint(const)}
+
+            if not const_columns.issubset(self.column_transfers):
+                continue
+
+            const_copy: Constraint
+            if isinstance(const, ForeignKeyConstraint):
+                if _fk_is_self_referential(const):
+                    # for self-referential constraint, refer to the
+                    # *original* table name, and not _alembic_batch_temp.
+                    # This is consistent with how we're handling
+                    # FK constraints from other tables; we assume SQLite
+                    # no foreign keys just keeps the names unchanged, so
+                    # when we rename back, they match again.
+                    const_copy = _copy(
+                        const, schema=schema, target_table=self.table
+                    )
+                else:
+                    # "target_table" for ForeignKeyConstraint.copy() is
+                    # only used if the FK is detected as being
+                    # self-referential, which we are handling above.
+                    const_copy = _copy(const, schema=schema)
+            else:
+                const_copy = _copy(
+                    const, schema=schema, target_table=new_table
+                )
+            if isinstance(const, ForeignKeyConstraint):
+                self._setup_referent(m, const)
+            new_table.append_constraint(const_copy)
+
+    def _gather_indexes_from_both_tables(self) -> List[Index]:
+        assert self.new_table is not None
+        idx: List[Index] = []
+
+        for idx_existing in self.indexes.values():
+            # this is a lift-and-move from Table.to_metadata
+
+            if idx_existing._column_flag:
+                continue
+
+            idx_copy = Index(
+                idx_existing.name,
+                unique=idx_existing.unique,
+                *[
+                    _copy_expression(expr, self.new_table)
+                    for expr in _idx_table_bound_expressions(idx_existing)
+                ],
+                _table=self.new_table,
+                **idx_existing.kwargs,
+            )
+            idx.append(idx_copy)
+
+        for index in self.new_indexes.values():
+            idx.append(
+                Index(
+                    index.name,
+                    unique=index.unique,
+                    *[self.new_table.c[col] for col in index.columns.keys()],
+                    **index.kwargs,
+                )
+            )
+        return idx
+
+    def _setup_referent(
+        self, metadata: MetaData, constraint: ForeignKeyConstraint
+    ) -> None:
+        spec = constraint.elements[0]._get_colspec()
+        parts = spec.split(".")
+        tname = parts[-2]
+        if len(parts) == 3:
+            referent_schema = parts[0]
+        else:
+            referent_schema = None
+
+        if tname != self.temp_table_name:
+            key = sql_schema._get_table_key(tname, referent_schema)
+
+            def colspec(elem: Any):
+                return elem._get_colspec()
+
+            if key in metadata.tables:
+                t = metadata.tables[key]
+                for elem in constraint.elements:
+                    colname = colspec(elem).split(".")[-1]
+                    if colname not in t.c:
+                        t.append_column(Column(colname, sqltypes.NULLTYPE))
+            else:
+                Table(
+                    tname,
+                    metadata,
+                    *[
+                        Column(n, sqltypes.NULLTYPE)
+                        for n in [
+                            colspec(elem).split(".")[-1]
+                            for elem in constraint.elements
+                        ]
+                    ],
+                    schema=referent_schema,
+                )
+
+    def _create(self, op_impl: DefaultImpl) -> None:
+        self._transfer_elements_to_new_table()
+
+        op_impl.prep_table_for_batch(self, self.table)
+        assert self.new_table is not None
+        op_impl.create_table(self.new_table)
+
+        try:
+            op_impl._exec(
+                self.new_table.insert()
+                .inline()
+                .from_select(
+                    list(
+                        k
+                        for k, transfer in self.column_transfers.items()
+                        if "expr" in transfer
+                    ),
+                    select(
+                        *[
+                            transfer["expr"]
+                            for transfer in self.column_transfers.values()
+                            if "expr" in transfer
+                        ]
+                    ),
+                )
+            )
+            op_impl.drop_table(self.table)
+        except:
+            op_impl.drop_table(self.new_table)
+            raise
+        else:
+            op_impl.rename_table(
+                self.temp_table_name, self.table.name, schema=self.table.schema
+            )
+            self.new_table.name = self.table.name
+            try:
+                for idx in self._gather_indexes_from_both_tables():
+                    op_impl.create_index(idx)
+            finally:
+                self.new_table.name = self.temp_table_name
+
+    def alter_column(
+        self,
+        table_name: str,
+        column_name: str,
+        nullable: Optional[bool] = None,
+        server_default: Optional[Union[Function[Any], str, bool]] = False,
+        name: Optional[str] = None,
+        type_: Optional[TypeEngine] = None,
+        autoincrement: Optional[Union[bool, Literal["auto"]]] = None,
+        comment: Union[str, Literal[False]] = False,
+        **kw,
+    ) -> None:
+        existing = self.columns[column_name]
+        existing_transfer: Dict[str, Any] = self.column_transfers[column_name]
+        if name is not None and name != column_name:
+            # note that we don't change '.key' - we keep referring
+            # to the renamed column by its old key in _create().  neat!
+            existing.name = name
+            existing_transfer["name"] = name
+
+            existing_type = kw.get("existing_type", None)
+            if existing_type:
+                resolved_existing_type = _resolve_for_variant(
+                    kw["existing_type"], self.impl.dialect
+                )
+
+                # pop named constraints for Boolean/Enum for rename
+                if (
+                    isinstance(resolved_existing_type, SchemaEventTarget)
+                    and resolved_existing_type.name  # type:ignore[attr-defined]  # noqa E501
+                ):
+                    self.named_constraints.pop(
+                        resolved_existing_type.name,  # type:ignore[attr-defined]  # noqa E501
+                        None,
+                    )
+
+        if type_ is not None:
+            type_ = sqltypes.to_instance(type_)
+            # old type is being discarded so turn off eventing
+            # rules. Alternatively we can
+            # erase the events set up by this type, but this is simpler.
+            # we also ignore the drop_constraint that will come here from
+            # Operations.implementation_for(alter_column)
+
+            if isinstance(existing.type, SchemaEventTarget):
+                existing.type._create_events = (  # type:ignore[attr-defined]
+                    existing.type.create_constraint  # type:ignore[attr-defined] # noqa
+                ) = False
+
+            self.impl.cast_for_batch_migrate(
+                existing, existing_transfer, type_
+            )
+
+            existing.type = type_
+
+            # we *dont* however set events for the new type, because
+            # alter_column is invoked from
+            # Operations.implementation_for(alter_column) which already
+            # will emit an add_constraint()
+
+        if nullable is not None:
+            existing.nullable = nullable
+        if server_default is not False:
+            if server_default is None:
+                existing.server_default = None
+            else:
+                sql_schema.DefaultClause(
+                    server_default  # type: ignore[arg-type]
+                )._set_parent(existing)
+        if autoincrement is not None:
+            existing.autoincrement = bool(autoincrement)
+
+        if comment is not False:
+            existing.comment = comment
+
+    def _setup_dependencies_for_add_column(
+        self,
+        colname: str,
+        insert_before: Optional[str],
+        insert_after: Optional[str],
+    ) -> None:
+        index_cols = self.existing_ordering
+        col_indexes = {name: i for i, name in enumerate(index_cols)}
+
+        if not self.partial_reordering:
+            if insert_after:
+                if not insert_before:
+                    if insert_after in col_indexes:
+                        # insert after an existing column
+                        idx = col_indexes[insert_after] + 1
+                        if idx < len(index_cols):
+                            insert_before = index_cols[idx]
+                    else:
+                        # insert after a column that is also new
+                        insert_before = dict(self.add_col_ordering)[
+                            insert_after
+                        ]
+            if insert_before:
+                if not insert_after:
+                    if insert_before in col_indexes:
+                        # insert before an existing column
+                        idx = col_indexes[insert_before] - 1
+                        if idx >= 0:
+                            insert_after = index_cols[idx]
+                    else:
+                        # insert before a column that is also new
+                        insert_after = {
+                            b: a for a, b in self.add_col_ordering
+                        }[insert_before]
+
+        if insert_before:
+            self.add_col_ordering += ((colname, insert_before),)
+        if insert_after:
+            self.add_col_ordering += ((insert_after, colname),)
+
+        if (
+            not self.partial_reordering
+            and not insert_before
+            and not insert_after
+            and col_indexes
+        ):
+            self.add_col_ordering += ((index_cols[-1], colname),)
+
+    def add_column(
+        self,
+        table_name: str,
+        column: Column[Any],
+        insert_before: Optional[str] = None,
+        insert_after: Optional[str] = None,
+        **kw,
+    ) -> None:
+        self._setup_dependencies_for_add_column(
+            column.name, insert_before, insert_after
+        )
+        # we copy the column because operations.add_column()
+        # gives us a Column that is part of a Table already.
+        self.columns[column.name] = _copy(column, schema=self.table.schema)
+        self.column_transfers[column.name] = {}
+
+    def drop_column(
+        self,
+        table_name: str,
+        column: Union[ColumnClause[Any], Column[Any]],
+        **kw,
+    ) -> None:
+        if column.name in self.table.primary_key.columns:
+            _remove_column_from_collection(
+                self.table.primary_key.columns, column
+            )
+        del self.columns[column.name]
+        del self.column_transfers[column.name]
+        self.existing_ordering.remove(column.name)
+
+        # pop named constraints for Boolean/Enum for rename
+        if (
+            "existing_type" in kw
+            and isinstance(kw["existing_type"], SchemaEventTarget)
+            and kw["existing_type"].name  # type:ignore[attr-defined]
+        ):
+            self.named_constraints.pop(
+                kw["existing_type"].name, None  # type:ignore[attr-defined]
+            )
+
+    def create_column_comment(self, column):
+        """the batch table creation function will issue create_column_comment
+        on the real "impl" as part of the create table process.
+
+        That is, the Column object will have the comment on it already,
+        so when it is received by add_column() it will be a normal part of
+        the CREATE TABLE and doesn't need an extra step here.
+
+        """
+
+    def create_table_comment(self, table):
+        """the batch table creation function will issue create_table_comment
+        on the real "impl" as part of the create table process.
+
+        """
+
+    def drop_table_comment(self, table):
+        """the batch table creation function will issue drop_table_comment
+        on the real "impl" as part of the create table process.
+
+        """
+
+    def add_constraint(self, const: Constraint) -> None:
+        if not constraint_name_defined(const.name):
+            raise ValueError("Constraint must have a name")
+        if isinstance(const, sql_schema.PrimaryKeyConstraint):
+            if self.table.primary_key in self.unnamed_constraints:
+                self.unnamed_constraints.remove(self.table.primary_key)
+
+        if constraint_name_string(const.name):
+            self.named_constraints[const.name] = const
+        else:
+            self.unnamed_constraints.append(const)
+
+    def drop_constraint(self, const: Constraint) -> None:
+        if not const.name:
+            raise ValueError("Constraint must have a name")
+        try:
+            if const.name in self.col_named_constraints:
+                col, const = self.col_named_constraints.pop(const.name)
+
+                for col_const in list(self.columns[col.name].constraints):
+                    if col_const.name == const.name:
+                        self.columns[col.name].constraints.remove(col_const)
+            elif constraint_name_string(const.name):
+                const = self.named_constraints.pop(const.name)
+            elif const in self.unnamed_constraints:
+                self.unnamed_constraints.remove(const)
+
+        except KeyError:
+            if _is_type_bound(const):
+                # type-bound constraints are only included in the new
+                # table via their type object in any case, so ignore the
+                # drop_constraint() that comes here via the
+                # Operations.implementation_for(alter_column)
+                return
+            raise ValueError("No such constraint: '%s'" % const.name)
+        else:
+            if isinstance(const, PrimaryKeyConstraint):
+                for col in const.columns:
+                    self.columns[col.name].primary_key = False
+
+    def create_index(self, idx: Index) -> None:
+        self.new_indexes[idx.name] = idx  # type: ignore[index]
+
+    def drop_index(self, idx: Index) -> None:
+        try:
+            del self.indexes[idx.name]  # type: ignore[arg-type]
+        except KeyError:
+            raise ValueError("No such index: '%s'" % idx.name)
+
+    def rename_table(self, *arg, **kw):
+        raise NotImplementedError("TODO")