diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/alembic/ddl')
9 files changed, 3755 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/alembic/ddl/__init__.py b/.venv/lib/python3.12/site-packages/alembic/ddl/__init__.py new file mode 100644 index 00000000..f2f72b3d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/alembic/ddl/__init__.py @@ -0,0 +1,6 @@ +from . import mssql +from . import mysql +from . import oracle +from . import postgresql +from . import sqlite +from .impl import DefaultImpl as DefaultImpl diff --git a/.venv/lib/python3.12/site-packages/alembic/ddl/_autogen.py b/.venv/lib/python3.12/site-packages/alembic/ddl/_autogen.py new file mode 100644 index 00000000..74715b18 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/alembic/ddl/_autogen.py @@ -0,0 +1,329 @@ +# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls +# mypy: no-warn-return-any, allow-any-generics + +from __future__ import annotations + +from typing import Any +from typing import ClassVar +from typing import Dict +from typing import Generic +from typing import NamedTuple +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from sqlalchemy.sql.schema import Constraint +from sqlalchemy.sql.schema import ForeignKeyConstraint +from sqlalchemy.sql.schema import Index +from sqlalchemy.sql.schema import UniqueConstraint +from typing_extensions import TypeGuard + +from .. import util +from ..util import sqla_compat + +if TYPE_CHECKING: + from typing import Literal + + from alembic.autogenerate.api import AutogenContext + from alembic.ddl.impl import DefaultImpl + +CompareConstraintType = Union[Constraint, Index] + +_C = TypeVar("_C", bound=CompareConstraintType) + +_clsreg: Dict[str, Type[_constraint_sig]] = {} + + +class ComparisonResult(NamedTuple): + status: Literal["equal", "different", "skip"] + message: str + + @property + def is_equal(self) -> bool: + return self.status == "equal" + + @property + def is_different(self) -> bool: + return self.status == "different" + + @property + def is_skip(self) -> bool: + return self.status == "skip" + + @classmethod + def Equal(cls) -> ComparisonResult: + """the constraints are equal.""" + return cls("equal", "The two constraints are equal") + + @classmethod + def Different(cls, reason: Union[str, Sequence[str]]) -> ComparisonResult: + """the constraints are different for the provided reason(s).""" + return cls("different", ", ".join(util.to_list(reason))) + + @classmethod + def Skip(cls, reason: Union[str, Sequence[str]]) -> ComparisonResult: + """the constraint cannot be compared for the provided reason(s). + + The message is logged, but the constraints will be otherwise + considered equal, meaning that no migration command will be + generated. + """ + return cls("skip", ", ".join(util.to_list(reason))) + + +class _constraint_sig(Generic[_C]): + const: _C + + _sig: Tuple[Any, ...] + name: Optional[sqla_compat._ConstraintNameDefined] + + impl: DefaultImpl + + _is_index: ClassVar[bool] = False + _is_fk: ClassVar[bool] = False + _is_uq: ClassVar[bool] = False + + _is_metadata: bool + + def __init_subclass__(cls) -> None: + cls._register() + + @classmethod + def _register(cls): + raise NotImplementedError() + + def __init__( + self, is_metadata: bool, impl: DefaultImpl, const: _C + ) -> None: + raise NotImplementedError() + + def compare_to_reflected( + self, other: _constraint_sig[Any] + ) -> ComparisonResult: + assert self.impl is other.impl + assert self._is_metadata + assert not other._is_metadata + + return self._compare_to_reflected(other) + + def _compare_to_reflected( + self, other: _constraint_sig[_C] + ) -> ComparisonResult: + raise NotImplementedError() + + @classmethod + def from_constraint( + cls, is_metadata: bool, impl: DefaultImpl, constraint: _C + ) -> _constraint_sig[_C]: + # these could be cached by constraint/impl, however, if the + # constraint is modified in place, then the sig is wrong. the mysql + # impl currently does this, and if we fixed that we can't be sure + # someone else might do it too, so play it safe. + sig = _clsreg[constraint.__visit_name__](is_metadata, impl, constraint) + return sig + + def md_name_to_sql_name(self, context: AutogenContext) -> Optional[str]: + return sqla_compat._get_constraint_final_name( + self.const, context.dialect + ) + + @util.memoized_property + def is_named(self): + return sqla_compat._constraint_is_named(self.const, self.impl.dialect) + + @util.memoized_property + def unnamed(self) -> Tuple[Any, ...]: + return self._sig + + @util.memoized_property + def unnamed_no_options(self) -> Tuple[Any, ...]: + raise NotImplementedError() + + @util.memoized_property + def _full_sig(self) -> Tuple[Any, ...]: + return (self.name,) + self.unnamed + + def __eq__(self, other) -> bool: + return self._full_sig == other._full_sig + + def __ne__(self, other) -> bool: + return self._full_sig != other._full_sig + + def __hash__(self) -> int: + return hash(self._full_sig) + + +class _uq_constraint_sig(_constraint_sig[UniqueConstraint]): + _is_uq = True + + @classmethod + def _register(cls) -> None: + _clsreg["unique_constraint"] = cls + + is_unique = True + + def __init__( + self, + is_metadata: bool, + impl: DefaultImpl, + const: UniqueConstraint, + ) -> None: + self.impl = impl + self.const = const + self.name = sqla_compat.constraint_name_or_none(const.name) + self._sig = tuple(sorted([col.name for col in const.columns])) + self._is_metadata = is_metadata + + @property + def column_names(self) -> Tuple[str, ...]: + return tuple([col.name for col in self.const.columns]) + + def _compare_to_reflected( + self, other: _constraint_sig[_C] + ) -> ComparisonResult: + assert self._is_metadata + metadata_obj = self + conn_obj = other + + assert is_uq_sig(conn_obj) + return self.impl.compare_unique_constraint( + metadata_obj.const, conn_obj.const + ) + + +class _ix_constraint_sig(_constraint_sig[Index]): + _is_index = True + + name: sqla_compat._ConstraintName + + @classmethod + def _register(cls) -> None: + _clsreg["index"] = cls + + def __init__( + self, is_metadata: bool, impl: DefaultImpl, const: Index + ) -> None: + self.impl = impl + self.const = const + self.name = const.name + self.is_unique = bool(const.unique) + self._is_metadata = is_metadata + + def _compare_to_reflected( + self, other: _constraint_sig[_C] + ) -> ComparisonResult: + assert self._is_metadata + metadata_obj = self + conn_obj = other + + assert is_index_sig(conn_obj) + return self.impl.compare_indexes(metadata_obj.const, conn_obj.const) + + @util.memoized_property + def has_expressions(self): + return sqla_compat.is_expression_index(self.const) + + @util.memoized_property + def column_names(self) -> Tuple[str, ...]: + return tuple([col.name for col in self.const.columns]) + + @util.memoized_property + def column_names_optional(self) -> Tuple[Optional[str], ...]: + return tuple( + [getattr(col, "name", None) for col in self.const.expressions] + ) + + @util.memoized_property + def is_named(self): + return True + + @util.memoized_property + def unnamed(self): + return (self.is_unique,) + self.column_names_optional + + +class _fk_constraint_sig(_constraint_sig[ForeignKeyConstraint]): + _is_fk = True + + @classmethod + def _register(cls) -> None: + _clsreg["foreign_key_constraint"] = cls + + def __init__( + self, + is_metadata: bool, + impl: DefaultImpl, + const: ForeignKeyConstraint, + ) -> None: + self._is_metadata = is_metadata + + self.impl = impl + self.const = const + + self.name = sqla_compat.constraint_name_or_none(const.name) + + ( + self.source_schema, + self.source_table, + self.source_columns, + self.target_schema, + self.target_table, + self.target_columns, + onupdate, + ondelete, + deferrable, + initially, + ) = sqla_compat._fk_spec(const) + + self._sig: Tuple[Any, ...] = ( + self.source_schema, + self.source_table, + tuple(self.source_columns), + self.target_schema, + self.target_table, + tuple(self.target_columns), + ) + ( + ( + (None if onupdate.lower() == "no action" else onupdate.lower()) + if onupdate + else None + ), + ( + (None if ondelete.lower() == "no action" else ondelete.lower()) + if ondelete + else None + ), + # convert initially + deferrable into one three-state value + ( + "initially_deferrable" + if initially and initially.lower() == "deferred" + else "deferrable" if deferrable else "not deferrable" + ), + ) + + @util.memoized_property + def unnamed_no_options(self): + return ( + self.source_schema, + self.source_table, + tuple(self.source_columns), + self.target_schema, + self.target_table, + tuple(self.target_columns), + ) + + +def is_index_sig(sig: _constraint_sig) -> TypeGuard[_ix_constraint_sig]: + return sig._is_index + + +def is_uq_sig(sig: _constraint_sig) -> TypeGuard[_uq_constraint_sig]: + return sig._is_uq + + +def is_fk_sig(sig: _constraint_sig) -> TypeGuard[_fk_constraint_sig]: + return sig._is_fk diff --git a/.venv/lib/python3.12/site-packages/alembic/ddl/base.py b/.venv/lib/python3.12/site-packages/alembic/ddl/base.py new file mode 100644 index 00000000..bd55c56d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/alembic/ddl/base.py @@ -0,0 +1,336 @@ +# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls +# mypy: no-warn-return-any, allow-any-generics + +from __future__ import annotations + +import functools +from typing import Optional +from typing import TYPE_CHECKING +from typing import Union + +from sqlalchemy import exc +from sqlalchemy import Integer +from sqlalchemy import types as sqltypes +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.schema import Column +from sqlalchemy.schema import DDLElement +from sqlalchemy.sql.elements import quoted_name + +from ..util.sqla_compat import _columns_for_constraint # noqa +from ..util.sqla_compat import _find_columns # noqa +from ..util.sqla_compat import _fk_spec # noqa +from ..util.sqla_compat import _is_type_bound # noqa +from ..util.sqla_compat import _table_for_constraint # noqa + +if TYPE_CHECKING: + from typing import Any + + from sqlalchemy import Computed + from sqlalchemy import Identity + from sqlalchemy.sql.compiler import Compiled + from sqlalchemy.sql.compiler import DDLCompiler + from sqlalchemy.sql.elements import TextClause + from sqlalchemy.sql.functions import Function + from sqlalchemy.sql.schema import FetchedValue + from sqlalchemy.sql.type_api import TypeEngine + + from .impl import DefaultImpl + +_ServerDefault = Union["TextClause", "FetchedValue", "Function[Any]", str] + + +class AlterTable(DDLElement): + """Represent an ALTER TABLE statement. + + Only the string name and optional schema name of the table + is required, not a full Table object. + + """ + + def __init__( + self, + table_name: str, + schema: Optional[Union[quoted_name, str]] = None, + ) -> None: + self.table_name = table_name + self.schema = schema + + +class RenameTable(AlterTable): + def __init__( + self, + old_table_name: str, + new_table_name: Union[quoted_name, str], + schema: Optional[Union[quoted_name, str]] = None, + ) -> None: + super().__init__(old_table_name, schema=schema) + self.new_table_name = new_table_name + + +class AlterColumn(AlterTable): + def __init__( + self, + name: str, + column_name: str, + schema: Optional[str] = None, + existing_type: Optional[TypeEngine] = None, + existing_nullable: Optional[bool] = None, + existing_server_default: Optional[_ServerDefault] = None, + existing_comment: Optional[str] = None, + ) -> None: + super().__init__(name, schema=schema) + self.column_name = column_name + self.existing_type = ( + sqltypes.to_instance(existing_type) + if existing_type is not None + else None + ) + self.existing_nullable = existing_nullable + self.existing_server_default = existing_server_default + self.existing_comment = existing_comment + + +class ColumnNullable(AlterColumn): + def __init__( + self, name: str, column_name: str, nullable: bool, **kw + ) -> None: + super().__init__(name, column_name, **kw) + self.nullable = nullable + + +class ColumnType(AlterColumn): + def __init__( + self, name: str, column_name: str, type_: TypeEngine, **kw + ) -> None: + super().__init__(name, column_name, **kw) + self.type_ = sqltypes.to_instance(type_) + + +class ColumnName(AlterColumn): + def __init__( + self, name: str, column_name: str, newname: str, **kw + ) -> None: + super().__init__(name, column_name, **kw) + self.newname = newname + + +class ColumnDefault(AlterColumn): + def __init__( + self, + name: str, + column_name: str, + default: Optional[_ServerDefault], + **kw, + ) -> None: + super().__init__(name, column_name, **kw) + self.default = default + + +class ComputedColumnDefault(AlterColumn): + def __init__( + self, name: str, column_name: str, default: Optional[Computed], **kw + ) -> None: + super().__init__(name, column_name, **kw) + self.default = default + + +class IdentityColumnDefault(AlterColumn): + def __init__( + self, + name: str, + column_name: str, + default: Optional[Identity], + impl: DefaultImpl, + **kw, + ) -> None: + super().__init__(name, column_name, **kw) + self.default = default + self.impl = impl + + +class AddColumn(AlterTable): + def __init__( + self, + name: str, + column: Column[Any], + schema: Optional[Union[quoted_name, str]] = None, + ) -> None: + super().__init__(name, schema=schema) + self.column = column + + +class DropColumn(AlterTable): + def __init__( + self, name: str, column: Column[Any], schema: Optional[str] = None + ) -> None: + super().__init__(name, schema=schema) + self.column = column + + +class ColumnComment(AlterColumn): + def __init__( + self, name: str, column_name: str, comment: Optional[str], **kw + ) -> None: + super().__init__(name, column_name, **kw) + self.comment = comment + + +@compiles(RenameTable) +def visit_rename_table( + element: RenameTable, compiler: DDLCompiler, **kw +) -> str: + return "%s RENAME TO %s" % ( + alter_table(compiler, element.table_name, element.schema), + format_table_name(compiler, element.new_table_name, element.schema), + ) + + +@compiles(AddColumn) +def visit_add_column(element: AddColumn, compiler: DDLCompiler, **kw) -> str: + return "%s %s" % ( + alter_table(compiler, element.table_name, element.schema), + add_column(compiler, element.column, **kw), + ) + + +@compiles(DropColumn) +def visit_drop_column(element: DropColumn, compiler: DDLCompiler, **kw) -> str: + return "%s %s" % ( + alter_table(compiler, element.table_name, element.schema), + drop_column(compiler, element.column.name, **kw), + ) + + +@compiles(ColumnNullable) +def visit_column_nullable( + element: ColumnNullable, compiler: DDLCompiler, **kw +) -> str: + return "%s %s %s" % ( + alter_table(compiler, element.table_name, element.schema), + alter_column(compiler, element.column_name), + "DROP NOT NULL" if element.nullable else "SET NOT NULL", + ) + + +@compiles(ColumnType) +def visit_column_type(element: ColumnType, compiler: DDLCompiler, **kw) -> str: + return "%s %s %s" % ( + alter_table(compiler, element.table_name, element.schema), + alter_column(compiler, element.column_name), + "TYPE %s" % format_type(compiler, element.type_), + ) + + +@compiles(ColumnName) +def visit_column_name(element: ColumnName, compiler: DDLCompiler, **kw) -> str: + return "%s RENAME %s TO %s" % ( + alter_table(compiler, element.table_name, element.schema), + format_column_name(compiler, element.column_name), + format_column_name(compiler, element.newname), + ) + + +@compiles(ColumnDefault) +def visit_column_default( + element: ColumnDefault, compiler: DDLCompiler, **kw +) -> str: + return "%s %s %s" % ( + alter_table(compiler, element.table_name, element.schema), + alter_column(compiler, element.column_name), + ( + "SET DEFAULT %s" % format_server_default(compiler, element.default) + if element.default is not None + else "DROP DEFAULT" + ), + ) + + +@compiles(ComputedColumnDefault) +def visit_computed_column( + element: ComputedColumnDefault, compiler: DDLCompiler, **kw +): + raise exc.CompileError( + 'Adding or removing a "computed" construct, e.g. GENERATED ' + "ALWAYS AS, to or from an existing column is not supported." + ) + + +@compiles(IdentityColumnDefault) +def visit_identity_column( + element: IdentityColumnDefault, compiler: DDLCompiler, **kw +): + raise exc.CompileError( + 'Adding, removing or modifying an "identity" construct, ' + "e.g. GENERATED AS IDENTITY, to or from an existing " + "column is not supported in this dialect." + ) + + +def quote_dotted( + name: Union[quoted_name, str], quote: functools.partial +) -> Union[quoted_name, str]: + """quote the elements of a dotted name""" + + if isinstance(name, quoted_name): + return quote(name) + result = ".".join([quote(x) for x in name.split(".")]) + return result + + +def format_table_name( + compiler: Compiled, + name: Union[quoted_name, str], + schema: Optional[Union[quoted_name, str]], +) -> Union[quoted_name, str]: + quote = functools.partial(compiler.preparer.quote) + if schema: + return quote_dotted(schema, quote) + "." + quote(name) + else: + return quote(name) + + +def format_column_name( + compiler: DDLCompiler, name: Optional[Union[quoted_name, str]] +) -> Union[quoted_name, str]: + return compiler.preparer.quote(name) # type: ignore[arg-type] + + +def format_server_default( + compiler: DDLCompiler, + default: Optional[_ServerDefault], +) -> str: + return compiler.get_column_default_string( + Column("x", Integer, server_default=default) + ) + + +def format_type(compiler: DDLCompiler, type_: TypeEngine) -> str: + return compiler.dialect.type_compiler.process(type_) + + +def alter_table( + compiler: DDLCompiler, + name: str, + schema: Optional[str], +) -> str: + return "ALTER TABLE %s" % format_table_name(compiler, name, schema) + + +def drop_column(compiler: DDLCompiler, name: str, **kw) -> str: + return "DROP COLUMN %s" % format_column_name(compiler, name) + + +def alter_column(compiler: DDLCompiler, name: str) -> str: + return "ALTER COLUMN %s" % format_column_name(compiler, name) + + +def add_column(compiler: DDLCompiler, column: Column[Any], **kw) -> str: + text = "ADD COLUMN %s" % compiler.get_column_specification(column, **kw) + + const = " ".join( + compiler.process(constraint) for constraint in column.constraints + ) + if const: + text += " " + const + + return text diff --git a/.venv/lib/python3.12/site-packages/alembic/ddl/impl.py b/.venv/lib/python3.12/site-packages/alembic/ddl/impl.py new file mode 100644 index 00000000..c116fcfa --- /dev/null +++ b/.venv/lib/python3.12/site-packages/alembic/ddl/impl.py @@ -0,0 +1,885 @@ +# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls +# mypy: no-warn-return-any, allow-any-generics + +from __future__ import annotations + +import logging +import re +from typing import Any +from typing import Callable +from typing import Dict +from typing import Iterable +from typing import List +from typing import Mapping +from typing import NamedTuple +from typing import Optional +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import Union + +from sqlalchemy import cast +from sqlalchemy import Column +from sqlalchemy import MetaData +from sqlalchemy import PrimaryKeyConstraint +from sqlalchemy import schema +from sqlalchemy import String +from sqlalchemy import Table +from sqlalchemy import text + +from . import _autogen +from . import base +from ._autogen import _constraint_sig as _constraint_sig +from ._autogen import ComparisonResult as ComparisonResult +from .. import util +from ..util import sqla_compat + +if TYPE_CHECKING: + from typing import Literal + from typing import TextIO + + from sqlalchemy.engine import Connection + from sqlalchemy.engine import Dialect + from sqlalchemy.engine.cursor import CursorResult + from sqlalchemy.engine.reflection import Inspector + from sqlalchemy.sql import ClauseElement + from sqlalchemy.sql import Executable + from sqlalchemy.sql.elements import ColumnElement + from sqlalchemy.sql.elements import quoted_name + from sqlalchemy.sql.schema import Constraint + from sqlalchemy.sql.schema import ForeignKeyConstraint + from sqlalchemy.sql.schema import Index + from sqlalchemy.sql.schema import UniqueConstraint + from sqlalchemy.sql.selectable import TableClause + from sqlalchemy.sql.type_api import TypeEngine + + from .base import _ServerDefault + from ..autogenerate.api import AutogenContext + from ..operations.batch import ApplyBatchImpl + from ..operations.batch import BatchOperationsImpl + +log = logging.getLogger(__name__) + + +class ImplMeta(type): + def __init__( + cls, + classname: str, + bases: Tuple[Type[DefaultImpl]], + dict_: Dict[str, Any], + ): + newtype = type.__init__(cls, classname, bases, dict_) + if "__dialect__" in dict_: + _impls[dict_["__dialect__"]] = cls # type: ignore[assignment] + return newtype + + +_impls: Dict[str, Type[DefaultImpl]] = {} + + +class DefaultImpl(metaclass=ImplMeta): + """Provide the entrypoint for major migration operations, + including database-specific behavioral variances. + + While individual SQL/DDL constructs already provide + for database-specific implementations, variances here + allow for entirely different sequences of operations + to take place for a particular migration, such as + SQL Server's special 'IDENTITY INSERT' step for + bulk inserts. + + """ + + __dialect__ = "default" + + transactional_ddl = False + command_terminator = ";" + type_synonyms: Tuple[Set[str], ...] = ({"NUMERIC", "DECIMAL"},) + type_arg_extract: Sequence[str] = () + # These attributes are deprecated in SQLAlchemy via #10247. They need to + # be ignored to support older version that did not use dialect kwargs. + # They only apply to Oracle and are replaced by oracle_order, + # oracle_on_null + identity_attrs_ignore: Tuple[str, ...] = ("order", "on_null") + + def __init__( + self, + dialect: Dialect, + connection: Optional[Connection], + as_sql: bool, + transactional_ddl: Optional[bool], + output_buffer: Optional[TextIO], + context_opts: Dict[str, Any], + ) -> None: + self.dialect = dialect + self.connection = connection + self.as_sql = as_sql + self.literal_binds = context_opts.get("literal_binds", False) + + self.output_buffer = output_buffer + self.memo: dict = {} + self.context_opts = context_opts + if transactional_ddl is not None: + self.transactional_ddl = transactional_ddl + + if self.literal_binds: + if not self.as_sql: + raise util.CommandError( + "Can't use literal_binds setting without as_sql mode" + ) + + @classmethod + def get_by_dialect(cls, dialect: Dialect) -> Type[DefaultImpl]: + return _impls[dialect.name] + + def static_output(self, text: str) -> None: + assert self.output_buffer is not None + self.output_buffer.write(text + "\n\n") + self.output_buffer.flush() + + def version_table_impl( + self, + *, + version_table: str, + version_table_schema: Optional[str], + version_table_pk: bool, + **kw: Any, + ) -> Table: + """Generate a :class:`.Table` object which will be used as the + structure for the Alembic version table. + + Third party dialects may override this hook to provide an alternate + structure for this :class:`.Table`; requirements are only that it + be named based on the ``version_table`` parameter and contains + at least a single string-holding column named ``version_num``. + + .. versionadded:: 1.14 + + """ + vt = Table( + version_table, + MetaData(), + Column("version_num", String(32), nullable=False), + schema=version_table_schema, + ) + if version_table_pk: + vt.append_constraint( + PrimaryKeyConstraint( + "version_num", name=f"{version_table}_pkc" + ) + ) + + return vt + + def requires_recreate_in_batch( + self, batch_op: BatchOperationsImpl + ) -> bool: + """Return True if the given :class:`.BatchOperationsImpl` + would need the table to be recreated and copied in order to + proceed. + + Normally, only returns True on SQLite when operations other + than add_column are present. + + """ + return False + + def prep_table_for_batch( + self, batch_impl: ApplyBatchImpl, table: Table + ) -> None: + """perform any operations needed on a table before a new + one is created to replace it in batch mode. + + the PG dialect uses this to drop constraints on the table + before the new one uses those same names. + + """ + + @property + def bind(self) -> Optional[Connection]: + return self.connection + + def _exec( + self, + construct: Union[Executable, str], + execution_options: Optional[Mapping[str, Any]] = None, + multiparams: Optional[Sequence[Mapping[str, Any]]] = None, + params: Mapping[str, Any] = util.immutabledict(), + ) -> Optional[CursorResult]: + if isinstance(construct, str): + construct = text(construct) + if self.as_sql: + if multiparams is not None or params: + raise TypeError("SQL parameters not allowed with as_sql") + + compile_kw: dict[str, Any] + if self.literal_binds and not isinstance( + construct, schema.DDLElement + ): + compile_kw = dict(compile_kwargs={"literal_binds": True}) + else: + compile_kw = {} + + if TYPE_CHECKING: + assert isinstance(construct, ClauseElement) + compiled = construct.compile(dialect=self.dialect, **compile_kw) + self.static_output( + str(compiled).replace("\t", " ").strip() + + self.command_terminator + ) + return None + else: + conn = self.connection + assert conn is not None + if execution_options: + conn = conn.execution_options(**execution_options) + + if params and multiparams is not None: + raise TypeError( + "Can't send params and multiparams at the same time" + ) + + if multiparams: + return conn.execute(construct, multiparams) + else: + return conn.execute(construct, params) + + def execute( + self, + sql: Union[Executable, str], + execution_options: Optional[dict[str, Any]] = None, + ) -> None: + self._exec(sql, execution_options) + + def alter_column( + self, + table_name: str, + column_name: str, + nullable: Optional[bool] = None, + server_default: Union[_ServerDefault, Literal[False]] = False, + name: Optional[str] = None, + type_: Optional[TypeEngine] = None, + schema: Optional[str] = None, + autoincrement: Optional[bool] = None, + comment: Optional[Union[str, Literal[False]]] = False, + existing_comment: Optional[str] = None, + existing_type: Optional[TypeEngine] = None, + existing_server_default: Optional[_ServerDefault] = None, + existing_nullable: Optional[bool] = None, + existing_autoincrement: Optional[bool] = None, + **kw: Any, + ) -> None: + if autoincrement is not None or existing_autoincrement is not None: + util.warn( + "autoincrement and existing_autoincrement " + "only make sense for MySQL", + stacklevel=3, + ) + if nullable is not None: + self._exec( + base.ColumnNullable( + table_name, + column_name, + nullable, + schema=schema, + existing_type=existing_type, + existing_server_default=existing_server_default, + existing_nullable=existing_nullable, + existing_comment=existing_comment, + ) + ) + if server_default is not False: + kw = {} + cls_: Type[ + Union[ + base.ComputedColumnDefault, + base.IdentityColumnDefault, + base.ColumnDefault, + ] + ] + if sqla_compat._server_default_is_computed( + server_default, existing_server_default + ): + cls_ = base.ComputedColumnDefault + elif sqla_compat._server_default_is_identity( + server_default, existing_server_default + ): + cls_ = base.IdentityColumnDefault + kw["impl"] = self + else: + cls_ = base.ColumnDefault + self._exec( + cls_( + table_name, + column_name, + server_default, # type:ignore[arg-type] + schema=schema, + existing_type=existing_type, + existing_server_default=existing_server_default, + existing_nullable=existing_nullable, + existing_comment=existing_comment, + **kw, + ) + ) + if type_ is not None: + self._exec( + base.ColumnType( + table_name, + column_name, + type_, + schema=schema, + existing_type=existing_type, + existing_server_default=existing_server_default, + existing_nullable=existing_nullable, + existing_comment=existing_comment, + ) + ) + + if comment is not False: + self._exec( + base.ColumnComment( + table_name, + column_name, + comment, + schema=schema, + existing_type=existing_type, + existing_server_default=existing_server_default, + existing_nullable=existing_nullable, + existing_comment=existing_comment, + ) + ) + + # do the new name last ;) + if name is not None: + self._exec( + base.ColumnName( + table_name, + column_name, + name, + schema=schema, + existing_type=existing_type, + existing_server_default=existing_server_default, + existing_nullable=existing_nullable, + ) + ) + + def add_column( + self, + table_name: str, + column: Column[Any], + schema: Optional[Union[str, quoted_name]] = None, + ) -> None: + self._exec(base.AddColumn(table_name, column, schema=schema)) + + def drop_column( + self, + table_name: str, + column: Column[Any], + schema: Optional[str] = None, + **kw, + ) -> None: + self._exec(base.DropColumn(table_name, column, schema=schema)) + + def add_constraint(self, const: Any) -> None: + if const._create_rule is None or const._create_rule(self): + self._exec(schema.AddConstraint(const)) + + def drop_constraint(self, const: Constraint) -> None: + self._exec(schema.DropConstraint(const)) + + def rename_table( + self, + old_table_name: str, + new_table_name: Union[str, quoted_name], + schema: Optional[Union[str, quoted_name]] = None, + ) -> None: + self._exec( + base.RenameTable(old_table_name, new_table_name, schema=schema) + ) + + def create_table(self, table: Table, **kw: Any) -> None: + table.dispatch.before_create( + table, self.connection, checkfirst=False, _ddl_runner=self + ) + self._exec(schema.CreateTable(table, **kw)) + table.dispatch.after_create( + table, self.connection, checkfirst=False, _ddl_runner=self + ) + for index in table.indexes: + self._exec(schema.CreateIndex(index)) + + with_comment = ( + self.dialect.supports_comments and not self.dialect.inline_comments + ) + comment = table.comment + if comment and with_comment: + self.create_table_comment(table) + + for column in table.columns: + comment = column.comment + if comment and with_comment: + self.create_column_comment(column) + + def drop_table(self, table: Table, **kw: Any) -> None: + table.dispatch.before_drop( + table, self.connection, checkfirst=False, _ddl_runner=self + ) + self._exec(schema.DropTable(table, **kw)) + table.dispatch.after_drop( + table, self.connection, checkfirst=False, _ddl_runner=self + ) + + def create_index(self, index: Index, **kw: Any) -> None: + self._exec(schema.CreateIndex(index, **kw)) + + def create_table_comment(self, table: Table) -> None: + self._exec(schema.SetTableComment(table)) + + def drop_table_comment(self, table: Table) -> None: + self._exec(schema.DropTableComment(table)) + + def create_column_comment(self, column: ColumnElement[Any]) -> None: + self._exec(schema.SetColumnComment(column)) + + def drop_index(self, index: Index, **kw: Any) -> None: + self._exec(schema.DropIndex(index, **kw)) + + def bulk_insert( + self, + table: Union[TableClause, Table], + rows: List[dict], + multiinsert: bool = True, + ) -> None: + if not isinstance(rows, list): + raise TypeError("List expected") + elif rows and not isinstance(rows[0], dict): + raise TypeError("List of dictionaries expected") + if self.as_sql: + for row in rows: + self._exec( + table.insert() + .inline() + .values( + **{ + k: ( + sqla_compat._literal_bindparam( + k, v, type_=table.c[k].type + ) + if not isinstance( + v, sqla_compat._literal_bindparam + ) + else v + ) + for k, v in row.items() + } + ) + ) + else: + if rows: + if multiinsert: + self._exec(table.insert().inline(), multiparams=rows) + else: + for row in rows: + self._exec(table.insert().inline().values(**row)) + + def _tokenize_column_type(self, column: Column) -> Params: + definition: str + definition = self.dialect.type_compiler.process(column.type).lower() + + # tokenize the SQLAlchemy-generated version of a type, so that + # the two can be compared. + # + # examples: + # NUMERIC(10, 5) + # TIMESTAMP WITH TIMEZONE + # INTEGER UNSIGNED + # INTEGER (10) UNSIGNED + # INTEGER(10) UNSIGNED + # varchar character set utf8 + # + + tokens: List[str] = re.findall(r"[\w\-_]+|\(.+?\)", definition) + + term_tokens: List[str] = [] + paren_term = None + + for token in tokens: + if re.match(r"^\(.*\)$", token): + paren_term = token + else: + term_tokens.append(token) + + params = Params(term_tokens[0], term_tokens[1:], [], {}) + + if paren_term: + term: str + for term in re.findall("[^(),]+", paren_term): + if "=" in term: + key, val = term.split("=") + params.kwargs[key.strip()] = val.strip() + else: + params.args.append(term.strip()) + + return params + + def _column_types_match( + self, inspector_params: Params, metadata_params: Params + ) -> bool: + if inspector_params.token0 == metadata_params.token0: + return True + + synonyms = [{t.lower() for t in batch} for batch in self.type_synonyms] + inspector_all_terms = " ".join( + [inspector_params.token0] + inspector_params.tokens + ) + metadata_all_terms = " ".join( + [metadata_params.token0] + metadata_params.tokens + ) + + for batch in synonyms: + if {inspector_all_terms, metadata_all_terms}.issubset(batch) or { + inspector_params.token0, + metadata_params.token0, + }.issubset(batch): + return True + return False + + def _column_args_match( + self, inspected_params: Params, meta_params: Params + ) -> bool: + """We want to compare column parameters. However, we only want + to compare parameters that are set. If they both have `collation`, + we want to make sure they are the same. However, if only one + specifies it, dont flag it for being less specific + """ + + if ( + len(meta_params.tokens) == len(inspected_params.tokens) + and meta_params.tokens != inspected_params.tokens + ): + return False + + if ( + len(meta_params.args) == len(inspected_params.args) + and meta_params.args != inspected_params.args + ): + return False + + insp = " ".join(inspected_params.tokens).lower() + meta = " ".join(meta_params.tokens).lower() + + for reg in self.type_arg_extract: + mi = re.search(reg, insp) + mm = re.search(reg, meta) + + if mi and mm and mi.group(1) != mm.group(1): + return False + + return True + + def compare_type( + self, inspector_column: Column[Any], metadata_column: Column + ) -> bool: + """Returns True if there ARE differences between the types of the two + columns. Takes impl.type_synonyms into account between retrospected + and metadata types + """ + inspector_params = self._tokenize_column_type(inspector_column) + metadata_params = self._tokenize_column_type(metadata_column) + + if not self._column_types_match(inspector_params, metadata_params): + return True + if not self._column_args_match(inspector_params, metadata_params): + return True + return False + + def compare_server_default( + self, + inspector_column, + metadata_column, + rendered_metadata_default, + rendered_inspector_default, + ): + return rendered_inspector_default != rendered_metadata_default + + def correct_for_autogen_constraints( + self, + conn_uniques: Set[UniqueConstraint], + conn_indexes: Set[Index], + metadata_unique_constraints: Set[UniqueConstraint], + metadata_indexes: Set[Index], + ) -> None: + pass + + def cast_for_batch_migrate(self, existing, existing_transfer, new_type): + if existing.type._type_affinity is not new_type._type_affinity: + existing_transfer["expr"] = cast( + existing_transfer["expr"], new_type + ) + + def render_ddl_sql_expr( + self, expr: ClauseElement, is_server_default: bool = False, **kw: Any + ) -> str: + """Render a SQL expression that is typically a server default, + index expression, etc. + + """ + + compile_kw = {"literal_binds": True, "include_table": False} + + return str( + expr.compile(dialect=self.dialect, compile_kwargs=compile_kw) + ) + + def _compat_autogen_column_reflect(self, inspector: Inspector) -> Callable: + return self.autogen_column_reflect + + def correct_for_autogen_foreignkeys( + self, + conn_fks: Set[ForeignKeyConstraint], + metadata_fks: Set[ForeignKeyConstraint], + ) -> None: + pass + + def autogen_column_reflect(self, inspector, table, column_info): + """A hook that is attached to the 'column_reflect' event for when + a Table is reflected from the database during the autogenerate + process. + + Dialects can elect to modify the information gathered here. + + """ + + def start_migrations(self) -> None: + """A hook called when :meth:`.EnvironmentContext.run_migrations` + is called. + + Implementations can set up per-migration-run state here. + + """ + + def emit_begin(self) -> None: + """Emit the string ``BEGIN``, or the backend-specific + equivalent, on the current connection context. + + This is used in offline mode and typically + via :meth:`.EnvironmentContext.begin_transaction`. + + """ + self.static_output("BEGIN" + self.command_terminator) + + def emit_commit(self) -> None: + """Emit the string ``COMMIT``, or the backend-specific + equivalent, on the current connection context. + + This is used in offline mode and typically + via :meth:`.EnvironmentContext.begin_transaction`. + + """ + self.static_output("COMMIT" + self.command_terminator) + + def render_type( + self, type_obj: TypeEngine, autogen_context: AutogenContext + ) -> Union[str, Literal[False]]: + return False + + def _compare_identity_default(self, metadata_identity, inspector_identity): + # ignored contains the attributes that were not considered + # because assumed to their default values in the db. + diff, ignored = _compare_identity_options( + metadata_identity, + inspector_identity, + schema.Identity(), + skip={"always"}, + ) + + meta_always = getattr(metadata_identity, "always", None) + inspector_always = getattr(inspector_identity, "always", None) + # None and False are the same in this comparison + if bool(meta_always) != bool(inspector_always): + diff.add("always") + + diff.difference_update(self.identity_attrs_ignore) + + # returns 3 values: + return ( + # different identity attributes + diff, + # ignored identity attributes + ignored, + # if the two identity should be considered different + bool(diff) or bool(metadata_identity) != bool(inspector_identity), + ) + + def _compare_index_unique( + self, metadata_index: Index, reflected_index: Index + ) -> Optional[str]: + conn_unique = bool(reflected_index.unique) + meta_unique = bool(metadata_index.unique) + if conn_unique != meta_unique: + return f"unique={conn_unique} to unique={meta_unique}" + else: + return None + + def _create_metadata_constraint_sig( + self, constraint: _autogen._C, **opts: Any + ) -> _constraint_sig[_autogen._C]: + return _constraint_sig.from_constraint(True, self, constraint, **opts) + + def _create_reflected_constraint_sig( + self, constraint: _autogen._C, **opts: Any + ) -> _constraint_sig[_autogen._C]: + return _constraint_sig.from_constraint(False, self, constraint, **opts) + + def compare_indexes( + self, + metadata_index: Index, + reflected_index: Index, + ) -> ComparisonResult: + """Compare two indexes by comparing the signature generated by + ``create_index_sig``. + + This method returns a ``ComparisonResult``. + """ + msg: List[str] = [] + unique_msg = self._compare_index_unique( + metadata_index, reflected_index + ) + if unique_msg: + msg.append(unique_msg) + m_sig = self._create_metadata_constraint_sig(metadata_index) + r_sig = self._create_reflected_constraint_sig(reflected_index) + + assert _autogen.is_index_sig(m_sig) + assert _autogen.is_index_sig(r_sig) + + # The assumption is that the index have no expression + for sig in m_sig, r_sig: + if sig.has_expressions: + log.warning( + "Generating approximate signature for index %s. " + "The dialect " + "implementation should either skip expression indexes " + "or provide a custom implementation.", + sig.const, + ) + + if m_sig.column_names != r_sig.column_names: + msg.append( + f"expression {r_sig.column_names} to {m_sig.column_names}" + ) + + if msg: + return ComparisonResult.Different(msg) + else: + return ComparisonResult.Equal() + + def compare_unique_constraint( + self, + metadata_constraint: UniqueConstraint, + reflected_constraint: UniqueConstraint, + ) -> ComparisonResult: + """Compare two unique constraints by comparing the two signatures. + + The arguments are two tuples that contain the unique constraint and + the signatures generated by ``create_unique_constraint_sig``. + + This method returns a ``ComparisonResult``. + """ + metadata_tup = self._create_metadata_constraint_sig( + metadata_constraint + ) + reflected_tup = self._create_reflected_constraint_sig( + reflected_constraint + ) + + meta_sig = metadata_tup.unnamed + conn_sig = reflected_tup.unnamed + if conn_sig != meta_sig: + return ComparisonResult.Different( + f"expression {conn_sig} to {meta_sig}" + ) + else: + return ComparisonResult.Equal() + + def _skip_functional_indexes(self, metadata_indexes, conn_indexes): + conn_indexes_by_name = {c.name: c for c in conn_indexes} + + for idx in list(metadata_indexes): + if idx.name in conn_indexes_by_name: + continue + iex = sqla_compat.is_expression_index(idx) + if iex: + util.warn( + "autogenerate skipping metadata-specified " + "expression-based index " + f"{idx.name!r}; dialect {self.__dialect__!r} under " + f"SQLAlchemy {sqla_compat.sqlalchemy_version} can't " + "reflect these indexes so they can't be compared" + ) + metadata_indexes.discard(idx) + + def adjust_reflected_dialect_options( + self, reflected_object: Dict[str, Any], kind: str + ) -> Dict[str, Any]: + return reflected_object.get("dialect_options", {}) + + +class Params(NamedTuple): + token0: str + tokens: List[str] + args: List[str] + kwargs: Dict[str, str] + + +def _compare_identity_options( + metadata_io: Union[schema.Identity, schema.Sequence, None], + inspector_io: Union[schema.Identity, schema.Sequence, None], + default_io: Union[schema.Identity, schema.Sequence], + skip: Set[str], +): + # this can be used for identity or sequence compare. + # default_io is an instance of IdentityOption with all attributes to the + # default value. + meta_d = sqla_compat._get_identity_options_dict(metadata_io) + insp_d = sqla_compat._get_identity_options_dict(inspector_io) + + diff = set() + ignored_attr = set() + + def check_dicts( + meta_dict: Mapping[str, Any], + insp_dict: Mapping[str, Any], + default_dict: Mapping[str, Any], + attrs: Iterable[str], + ): + for attr in set(attrs).difference(skip): + meta_value = meta_dict.get(attr) + insp_value = insp_dict.get(attr) + if insp_value != meta_value: + default_value = default_dict.get(attr) + if meta_value == default_value: + ignored_attr.add(attr) + else: + diff.add(attr) + + check_dicts( + meta_d, + insp_d, + sqla_compat._get_identity_options_dict(default_io), + set(meta_d).union(insp_d), + ) + if sqla_compat.identity_has_dialect_kwargs: + assert hasattr(default_io, "dialect_kwargs") + # use only the dialect kwargs in inspector_io since metadata_io + # can have options for many backends + check_dicts( + getattr(metadata_io, "dialect_kwargs", {}), + getattr(inspector_io, "dialect_kwargs", {}), + default_io.dialect_kwargs, + getattr(inspector_io, "dialect_kwargs", {}), + ) + + return diff, ignored_attr diff --git a/.venv/lib/python3.12/site-packages/alembic/ddl/mssql.py b/.venv/lib/python3.12/site-packages/alembic/ddl/mssql.py new file mode 100644 index 00000000..baa43d5e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/alembic/ddl/mssql.py @@ -0,0 +1,419 @@ +# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls +# mypy: no-warn-return-any, allow-any-generics + +from __future__ import annotations + +import re +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import TYPE_CHECKING +from typing import Union + +from sqlalchemy import types as sqltypes +from sqlalchemy.schema import Column +from sqlalchemy.schema import CreateIndex +from sqlalchemy.sql.base import Executable +from sqlalchemy.sql.elements import ClauseElement + +from .base import AddColumn +from .base import alter_column +from .base import alter_table +from .base import ColumnDefault +from .base import ColumnName +from .base import ColumnNullable +from .base import ColumnType +from .base import format_column_name +from .base import format_server_default +from .base import format_table_name +from .base import format_type +from .base import RenameTable +from .impl import DefaultImpl +from .. import util +from ..util import sqla_compat +from ..util.sqla_compat import compiles + +if TYPE_CHECKING: + from typing import Literal + + from sqlalchemy.dialects.mssql.base import MSDDLCompiler + from sqlalchemy.dialects.mssql.base import MSSQLCompiler + from sqlalchemy.engine.cursor import CursorResult + from sqlalchemy.sql.schema import Index + from sqlalchemy.sql.schema import Table + from sqlalchemy.sql.selectable import TableClause + from sqlalchemy.sql.type_api import TypeEngine + + from .base import _ServerDefault + + +class MSSQLImpl(DefaultImpl): + __dialect__ = "mssql" + transactional_ddl = True + batch_separator = "GO" + + type_synonyms = DefaultImpl.type_synonyms + ({"VARCHAR", "NVARCHAR"},) + identity_attrs_ignore = DefaultImpl.identity_attrs_ignore + ( + "minvalue", + "maxvalue", + "nominvalue", + "nomaxvalue", + "cycle", + "cache", + ) + + def __init__(self, *arg, **kw) -> None: + super().__init__(*arg, **kw) + self.batch_separator = self.context_opts.get( + "mssql_batch_separator", self.batch_separator + ) + + def _exec(self, construct: Any, *args, **kw) -> Optional[CursorResult]: + result = super()._exec(construct, *args, **kw) + if self.as_sql and self.batch_separator: + self.static_output(self.batch_separator) + return result + + def emit_begin(self) -> None: + self.static_output("BEGIN TRANSACTION" + self.command_terminator) + + def emit_commit(self) -> None: + super().emit_commit() + if self.as_sql and self.batch_separator: + self.static_output(self.batch_separator) + + def alter_column( # type:ignore[override] + self, + table_name: str, + column_name: str, + nullable: Optional[bool] = None, + server_default: Optional[ + Union[_ServerDefault, Literal[False]] + ] = False, + name: Optional[str] = None, + type_: Optional[TypeEngine] = None, + schema: Optional[str] = None, + existing_type: Optional[TypeEngine] = None, + existing_server_default: Optional[_ServerDefault] = None, + existing_nullable: Optional[bool] = None, + **kw: Any, + ) -> None: + if nullable is not None: + if type_ is not None: + # the NULL/NOT NULL alter will handle + # the type alteration + existing_type = type_ + type_ = None + elif existing_type is None: + raise util.CommandError( + "MS-SQL ALTER COLUMN operations " + "with NULL or NOT NULL require the " + "existing_type or a new type_ be passed." + ) + elif existing_nullable is not None and type_ is not None: + nullable = existing_nullable + + # the NULL/NOT NULL alter will handle + # the type alteration + existing_type = type_ + type_ = None + + elif type_ is not None: + util.warn( + "MS-SQL ALTER COLUMN operations that specify type_= " + "should also specify a nullable= or " + "existing_nullable= argument to avoid implicit conversion " + "of NOT NULL columns to NULL." + ) + + used_default = False + if sqla_compat._server_default_is_identity( + server_default, existing_server_default + ) or sqla_compat._server_default_is_computed( + server_default, existing_server_default + ): + used_default = True + kw["server_default"] = server_default + kw["existing_server_default"] = existing_server_default + + super().alter_column( + table_name, + column_name, + nullable=nullable, + type_=type_, + schema=schema, + existing_type=existing_type, + existing_nullable=existing_nullable, + **kw, + ) + + if server_default is not False and used_default is False: + if existing_server_default is not False or server_default is None: + self._exec( + _ExecDropConstraint( + table_name, + column_name, + "sys.default_constraints", + schema, + ) + ) + if server_default is not None: + super().alter_column( + table_name, + column_name, + schema=schema, + server_default=server_default, + ) + + if name is not None: + super().alter_column( + table_name, column_name, schema=schema, name=name + ) + + def create_index(self, index: Index, **kw: Any) -> None: + # this likely defaults to None if not present, so get() + # should normally not return the default value. being + # defensive in any case + mssql_include = index.kwargs.get("mssql_include", None) or () + assert index.table is not None + for col in mssql_include: + if col not in index.table.c: + index.table.append_column(Column(col, sqltypes.NullType)) + self._exec(CreateIndex(index, **kw)) + + def bulk_insert( # type:ignore[override] + self, table: Union[TableClause, Table], rows: List[dict], **kw: Any + ) -> None: + if self.as_sql: + self._exec( + "SET IDENTITY_INSERT %s ON" + % self.dialect.identifier_preparer.format_table(table) + ) + super().bulk_insert(table, rows, **kw) + self._exec( + "SET IDENTITY_INSERT %s OFF" + % self.dialect.identifier_preparer.format_table(table) + ) + else: + super().bulk_insert(table, rows, **kw) + + def drop_column( + self, + table_name: str, + column: Column[Any], + schema: Optional[str] = None, + **kw, + ) -> None: + drop_default = kw.pop("mssql_drop_default", False) + if drop_default: + self._exec( + _ExecDropConstraint( + table_name, column, "sys.default_constraints", schema + ) + ) + drop_check = kw.pop("mssql_drop_check", False) + if drop_check: + self._exec( + _ExecDropConstraint( + table_name, column, "sys.check_constraints", schema + ) + ) + drop_fks = kw.pop("mssql_drop_foreign_key", False) + if drop_fks: + self._exec(_ExecDropFKConstraint(table_name, column, schema)) + super().drop_column(table_name, column, schema=schema, **kw) + + def compare_server_default( + self, + inspector_column, + metadata_column, + rendered_metadata_default, + rendered_inspector_default, + ): + if rendered_metadata_default is not None: + rendered_metadata_default = re.sub( + r"[\(\) \"\']", "", rendered_metadata_default + ) + + if rendered_inspector_default is not None: + # SQL Server collapses whitespace and adds arbitrary parenthesis + # within expressions. our only option is collapse all of it + + rendered_inspector_default = re.sub( + r"[\(\) \"\']", "", rendered_inspector_default + ) + + return rendered_inspector_default != rendered_metadata_default + + def _compare_identity_default(self, metadata_identity, inspector_identity): + diff, ignored, is_alter = super()._compare_identity_default( + metadata_identity, inspector_identity + ) + + if ( + metadata_identity is None + and inspector_identity is not None + and not diff + and inspector_identity.column is not None + and inspector_identity.column.primary_key + ): + # mssql reflect primary keys with autoincrement as identity + # columns. if no different attributes are present ignore them + is_alter = False + + return diff, ignored, is_alter + + def adjust_reflected_dialect_options( + self, reflected_object: Dict[str, Any], kind: str + ) -> Dict[str, Any]: + options: Dict[str, Any] + options = reflected_object.get("dialect_options", {}).copy() + if not options.get("mssql_include"): + options.pop("mssql_include", None) + if not options.get("mssql_clustered"): + options.pop("mssql_clustered", None) + return options + + +class _ExecDropConstraint(Executable, ClauseElement): + inherit_cache = False + + def __init__( + self, + tname: str, + colname: Union[Column[Any], str], + type_: str, + schema: Optional[str], + ) -> None: + self.tname = tname + self.colname = colname + self.type_ = type_ + self.schema = schema + + +class _ExecDropFKConstraint(Executable, ClauseElement): + inherit_cache = False + + def __init__( + self, tname: str, colname: Column[Any], schema: Optional[str] + ) -> None: + self.tname = tname + self.colname = colname + self.schema = schema + + +@compiles(_ExecDropConstraint, "mssql") +def _exec_drop_col_constraint( + element: _ExecDropConstraint, compiler: MSSQLCompiler, **kw +) -> str: + schema, tname, colname, type_ = ( + element.schema, + element.tname, + element.colname, + element.type_, + ) + # from http://www.mssqltips.com/sqlservertip/1425/\ + # working-with-default-constraints-in-sql-server/ + return """declare @const_name varchar(256) +select @const_name = QUOTENAME([name]) from %(type)s +where parent_object_id = object_id('%(schema_dot)s%(tname)s') +and col_name(parent_object_id, parent_column_id) = '%(colname)s' +exec('alter table %(tname_quoted)s drop constraint ' + @const_name)""" % { + "type": type_, + "tname": tname, + "colname": colname, + "tname_quoted": format_table_name(compiler, tname, schema), + "schema_dot": schema + "." if schema else "", + } + + +@compiles(_ExecDropFKConstraint, "mssql") +def _exec_drop_col_fk_constraint( + element: _ExecDropFKConstraint, compiler: MSSQLCompiler, **kw +) -> str: + schema, tname, colname = element.schema, element.tname, element.colname + + return """declare @const_name varchar(256) +select @const_name = QUOTENAME([name]) from +sys.foreign_keys fk join sys.foreign_key_columns fkc +on fk.object_id=fkc.constraint_object_id +where fkc.parent_object_id = object_id('%(schema_dot)s%(tname)s') +and col_name(fkc.parent_object_id, fkc.parent_column_id) = '%(colname)s' +exec('alter table %(tname_quoted)s drop constraint ' + @const_name)""" % { + "tname": tname, + "colname": colname, + "tname_quoted": format_table_name(compiler, tname, schema), + "schema_dot": schema + "." if schema else "", + } + + +@compiles(AddColumn, "mssql") +def visit_add_column(element: AddColumn, compiler: MSDDLCompiler, **kw) -> str: + return "%s %s" % ( + alter_table(compiler, element.table_name, element.schema), + mssql_add_column(compiler, element.column, **kw), + ) + + +def mssql_add_column( + compiler: MSDDLCompiler, column: Column[Any], **kw +) -> str: + return "ADD %s" % compiler.get_column_specification(column, **kw) + + +@compiles(ColumnNullable, "mssql") +def visit_column_nullable( + element: ColumnNullable, compiler: MSDDLCompiler, **kw +) -> str: + return "%s %s %s %s" % ( + alter_table(compiler, element.table_name, element.schema), + alter_column(compiler, element.column_name), + format_type(compiler, element.existing_type), # type: ignore[arg-type] + "NULL" if element.nullable else "NOT NULL", + ) + + +@compiles(ColumnDefault, "mssql") +def visit_column_default( + element: ColumnDefault, compiler: MSDDLCompiler, **kw +) -> str: + # TODO: there can also be a named constraint + # with ADD CONSTRAINT here + return "%s ADD DEFAULT %s FOR %s" % ( + alter_table(compiler, element.table_name, element.schema), + format_server_default(compiler, element.default), + format_column_name(compiler, element.column_name), + ) + + +@compiles(ColumnName, "mssql") +def visit_rename_column( + element: ColumnName, compiler: MSDDLCompiler, **kw +) -> str: + return "EXEC sp_rename '%s.%s', %s, 'COLUMN'" % ( + format_table_name(compiler, element.table_name, element.schema), + format_column_name(compiler, element.column_name), + format_column_name(compiler, element.newname), + ) + + +@compiles(ColumnType, "mssql") +def visit_column_type( + element: ColumnType, compiler: MSDDLCompiler, **kw +) -> str: + return "%s %s %s" % ( + alter_table(compiler, element.table_name, element.schema), + alter_column(compiler, element.column_name), + format_type(compiler, element.type_), + ) + + +@compiles(RenameTable, "mssql") +def visit_rename_table( + element: RenameTable, compiler: MSDDLCompiler, **kw +) -> str: + return "EXEC sp_rename '%s', %s" % ( + format_table_name(compiler, element.table_name, element.schema), + format_table_name(compiler, element.new_table_name, None), + ) diff --git a/.venv/lib/python3.12/site-packages/alembic/ddl/mysql.py b/.venv/lib/python3.12/site-packages/alembic/ddl/mysql.py new file mode 100644 index 00000000..c7b3905c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/alembic/ddl/mysql.py @@ -0,0 +1,491 @@ +# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls +# mypy: no-warn-return-any, allow-any-generics + +from __future__ import annotations + +import re +from typing import Any +from typing import Optional +from typing import TYPE_CHECKING +from typing import Union + +from sqlalchemy import schema +from sqlalchemy import types as sqltypes + +from .base import alter_table +from .base import AlterColumn +from .base import ColumnDefault +from .base import ColumnName +from .base import ColumnNullable +from .base import ColumnType +from .base import format_column_name +from .base import format_server_default +from .impl import DefaultImpl +from .. import util +from ..util import sqla_compat +from ..util.sqla_compat import _is_type_bound +from ..util.sqla_compat import compiles + +if TYPE_CHECKING: + from typing import Literal + + from sqlalchemy.dialects.mysql.base import MySQLDDLCompiler + from sqlalchemy.sql.ddl import DropConstraint + from sqlalchemy.sql.schema import Constraint + from sqlalchemy.sql.type_api import TypeEngine + + from .base import _ServerDefault + + +class MySQLImpl(DefaultImpl): + __dialect__ = "mysql" + + transactional_ddl = False + type_synonyms = DefaultImpl.type_synonyms + ( + {"BOOL", "TINYINT"}, + {"JSON", "LONGTEXT"}, + ) + type_arg_extract = [r"character set ([\w\-_]+)", r"collate ([\w\-_]+)"] + + def alter_column( # type:ignore[override] + self, + table_name: str, + column_name: str, + nullable: Optional[bool] = None, + server_default: Union[_ServerDefault, Literal[False]] = False, + name: Optional[str] = None, + type_: Optional[TypeEngine] = None, + schema: Optional[str] = None, + existing_type: Optional[TypeEngine] = None, + existing_server_default: Optional[_ServerDefault] = None, + existing_nullable: Optional[bool] = None, + autoincrement: Optional[bool] = None, + existing_autoincrement: Optional[bool] = None, + comment: Optional[Union[str, Literal[False]]] = False, + existing_comment: Optional[str] = None, + **kw: Any, + ) -> None: + if sqla_compat._server_default_is_identity( + server_default, existing_server_default + ) or sqla_compat._server_default_is_computed( + server_default, existing_server_default + ): + # modifying computed or identity columns is not supported + # the default will raise + super().alter_column( + table_name, + column_name, + nullable=nullable, + type_=type_, + schema=schema, + existing_type=existing_type, + existing_nullable=existing_nullable, + server_default=server_default, + existing_server_default=existing_server_default, + **kw, + ) + if name is not None or self._is_mysql_allowed_functional_default( + type_ if type_ is not None else existing_type, server_default + ): + self._exec( + MySQLChangeColumn( + table_name, + column_name, + schema=schema, + newname=name if name is not None else column_name, + nullable=( + nullable + if nullable is not None + else ( + existing_nullable + if existing_nullable is not None + else True + ) + ), + type_=type_ if type_ is not None else existing_type, + default=( + server_default + if server_default is not False + else existing_server_default + ), + autoincrement=( + autoincrement + if autoincrement is not None + else existing_autoincrement + ), + comment=( + comment if comment is not False else existing_comment + ), + ) + ) + elif ( + nullable is not None + or type_ is not None + or autoincrement is not None + or comment is not False + ): + self._exec( + MySQLModifyColumn( + table_name, + column_name, + schema=schema, + newname=name if name is not None else column_name, + nullable=( + nullable + if nullable is not None + else ( + existing_nullable + if existing_nullable is not None + else True + ) + ), + type_=type_ if type_ is not None else existing_type, + default=( + server_default + if server_default is not False + else existing_server_default + ), + autoincrement=( + autoincrement + if autoincrement is not None + else existing_autoincrement + ), + comment=( + comment if comment is not False else existing_comment + ), + ) + ) + elif server_default is not False: + self._exec( + MySQLAlterDefault( + table_name, column_name, server_default, schema=schema + ) + ) + + def drop_constraint( + self, + const: Constraint, + ) -> None: + if isinstance(const, schema.CheckConstraint) and _is_type_bound(const): + return + + super().drop_constraint(const) + + def _is_mysql_allowed_functional_default( + self, + type_: Optional[TypeEngine], + server_default: Union[_ServerDefault, Literal[False]], + ) -> bool: + return ( + type_ is not None + and type_._type_affinity is sqltypes.DateTime + and server_default is not None + ) + + def compare_server_default( + self, + inspector_column, + metadata_column, + rendered_metadata_default, + rendered_inspector_default, + ): + # partially a workaround for SQLAlchemy issue #3023; if the + # column were created without "NOT NULL", MySQL may have added + # an implicit default of '0' which we need to skip + # TODO: this is not really covered anymore ? + if ( + metadata_column.type._type_affinity is sqltypes.Integer + and inspector_column.primary_key + and not inspector_column.autoincrement + and not rendered_metadata_default + and rendered_inspector_default == "'0'" + ): + return False + elif ( + rendered_inspector_default + and inspector_column.type._type_affinity is sqltypes.Integer + ): + rendered_inspector_default = ( + re.sub(r"^'|'$", "", rendered_inspector_default) + if rendered_inspector_default is not None + else None + ) + return rendered_inspector_default != rendered_metadata_default + elif ( + rendered_metadata_default + and metadata_column.type._type_affinity is sqltypes.String + ): + metadata_default = re.sub(r"^'|'$", "", rendered_metadata_default) + return rendered_inspector_default != f"'{metadata_default}'" + elif rendered_inspector_default and rendered_metadata_default: + # adjust for "function()" vs. "FUNCTION" as can occur particularly + # for the CURRENT_TIMESTAMP function on newer MariaDB versions + + # SQLAlchemy MySQL dialect bundles ON UPDATE into the server + # default; adjust for this possibly being present. + onupdate_ins = re.match( + r"(.*) (on update.*?)(?:\(\))?$", + rendered_inspector_default.lower(), + ) + onupdate_met = re.match( + r"(.*) (on update.*?)(?:\(\))?$", + rendered_metadata_default.lower(), + ) + + if onupdate_ins: + if not onupdate_met: + return True + elif onupdate_ins.group(2) != onupdate_met.group(2): + return True + + rendered_inspector_default = onupdate_ins.group(1) + rendered_metadata_default = onupdate_met.group(1) + + return re.sub( + r"(.*?)(?:\(\))?$", r"\1", rendered_inspector_default.lower() + ) != re.sub( + r"(.*?)(?:\(\))?$", r"\1", rendered_metadata_default.lower() + ) + else: + return rendered_inspector_default != rendered_metadata_default + + def correct_for_autogen_constraints( + self, + conn_unique_constraints, + conn_indexes, + metadata_unique_constraints, + metadata_indexes, + ): + # TODO: if SQLA 1.0, make use of "duplicates_index" + # metadata + removed = set() + for idx in list(conn_indexes): + if idx.unique: + continue + # MySQL puts implicit indexes on FK columns, even if + # composite and even if MyISAM, so can't check this too easily. + # the name of the index may be the column name or it may + # be the name of the FK constraint. + for col in idx.columns: + if idx.name == col.name: + conn_indexes.remove(idx) + removed.add(idx.name) + break + for fk in col.foreign_keys: + if fk.name == idx.name: + conn_indexes.remove(idx) + removed.add(idx.name) + break + if idx.name in removed: + break + + # then remove indexes from the "metadata_indexes" + # that we've removed from reflected, otherwise they come out + # as adds (see #202) + for idx in list(metadata_indexes): + if idx.name in removed: + metadata_indexes.remove(idx) + + def correct_for_autogen_foreignkeys(self, conn_fks, metadata_fks): + conn_fk_by_sig = { + self._create_reflected_constraint_sig(fk).unnamed_no_options: fk + for fk in conn_fks + } + metadata_fk_by_sig = { + self._create_metadata_constraint_sig(fk).unnamed_no_options: fk + for fk in metadata_fks + } + + for sig in set(conn_fk_by_sig).intersection(metadata_fk_by_sig): + mdfk = metadata_fk_by_sig[sig] + cnfk = conn_fk_by_sig[sig] + # MySQL considers RESTRICT to be the default and doesn't + # report on it. if the model has explicit RESTRICT and + # the conn FK has None, set it to RESTRICT + if ( + mdfk.ondelete is not None + and mdfk.ondelete.lower() == "restrict" + and cnfk.ondelete is None + ): + cnfk.ondelete = "RESTRICT" + if ( + mdfk.onupdate is not None + and mdfk.onupdate.lower() == "restrict" + and cnfk.onupdate is None + ): + cnfk.onupdate = "RESTRICT" + + +class MariaDBImpl(MySQLImpl): + __dialect__ = "mariadb" + + +class MySQLAlterDefault(AlterColumn): + def __init__( + self, + name: str, + column_name: str, + default: _ServerDefault, + schema: Optional[str] = None, + ) -> None: + super(AlterColumn, self).__init__(name, schema=schema) + self.column_name = column_name + self.default = default + + +class MySQLChangeColumn(AlterColumn): + def __init__( + self, + name: str, + column_name: str, + schema: Optional[str] = None, + newname: Optional[str] = None, + type_: Optional[TypeEngine] = None, + nullable: Optional[bool] = None, + default: Optional[Union[_ServerDefault, Literal[False]]] = False, + autoincrement: Optional[bool] = None, + comment: Optional[Union[str, Literal[False]]] = False, + ) -> None: + super(AlterColumn, self).__init__(name, schema=schema) + self.column_name = column_name + self.nullable = nullable + self.newname = newname + self.default = default + self.autoincrement = autoincrement + self.comment = comment + if type_ is None: + raise util.CommandError( + "All MySQL CHANGE/MODIFY COLUMN operations " + "require the existing type." + ) + + self.type_ = sqltypes.to_instance(type_) + + +class MySQLModifyColumn(MySQLChangeColumn): + pass + + +@compiles(ColumnNullable, "mysql", "mariadb") +@compiles(ColumnName, "mysql", "mariadb") +@compiles(ColumnDefault, "mysql", "mariadb") +@compiles(ColumnType, "mysql", "mariadb") +def _mysql_doesnt_support_individual(element, compiler, **kw): + raise NotImplementedError( + "Individual alter column constructs not supported by MySQL" + ) + + +@compiles(MySQLAlterDefault, "mysql", "mariadb") +def _mysql_alter_default( + element: MySQLAlterDefault, compiler: MySQLDDLCompiler, **kw +) -> str: + return "%s ALTER COLUMN %s %s" % ( + alter_table(compiler, element.table_name, element.schema), + format_column_name(compiler, element.column_name), + ( + "SET DEFAULT %s" % format_server_default(compiler, element.default) + if element.default is not None + else "DROP DEFAULT" + ), + ) + + +@compiles(MySQLModifyColumn, "mysql", "mariadb") +def _mysql_modify_column( + element: MySQLModifyColumn, compiler: MySQLDDLCompiler, **kw +) -> str: + return "%s MODIFY %s %s" % ( + alter_table(compiler, element.table_name, element.schema), + format_column_name(compiler, element.column_name), + _mysql_colspec( + compiler, + nullable=element.nullable, + server_default=element.default, + type_=element.type_, + autoincrement=element.autoincrement, + comment=element.comment, + ), + ) + + +@compiles(MySQLChangeColumn, "mysql", "mariadb") +def _mysql_change_column( + element: MySQLChangeColumn, compiler: MySQLDDLCompiler, **kw +) -> str: + return "%s CHANGE %s %s %s" % ( + alter_table(compiler, element.table_name, element.schema), + format_column_name(compiler, element.column_name), + format_column_name(compiler, element.newname), + _mysql_colspec( + compiler, + nullable=element.nullable, + server_default=element.default, + type_=element.type_, + autoincrement=element.autoincrement, + comment=element.comment, + ), + ) + + +def _mysql_colspec( + compiler: MySQLDDLCompiler, + nullable: Optional[bool], + server_default: Optional[Union[_ServerDefault, Literal[False]]], + type_: TypeEngine, + autoincrement: Optional[bool], + comment: Optional[Union[str, Literal[False]]], +) -> str: + spec = "%s %s" % ( + compiler.dialect.type_compiler.process(type_), + "NULL" if nullable else "NOT NULL", + ) + if autoincrement: + spec += " AUTO_INCREMENT" + if server_default is not False and server_default is not None: + spec += " DEFAULT %s" % format_server_default(compiler, server_default) + if comment: + spec += " COMMENT %s" % compiler.sql_compiler.render_literal_value( + comment, sqltypes.String() + ) + + return spec + + +@compiles(schema.DropConstraint, "mysql", "mariadb") +def _mysql_drop_constraint( + element: DropConstraint, compiler: MySQLDDLCompiler, **kw +) -> str: + """Redefine SQLAlchemy's drop constraint to + raise errors for invalid constraint type.""" + + constraint = element.element + if isinstance( + constraint, + ( + schema.ForeignKeyConstraint, + schema.PrimaryKeyConstraint, + schema.UniqueConstraint, + ), + ): + assert not kw + return compiler.visit_drop_constraint(element) + elif isinstance(constraint, schema.CheckConstraint): + # note that SQLAlchemy as of 1.2 does not yet support + # DROP CONSTRAINT for MySQL/MariaDB, so we implement fully + # here. + if compiler.dialect.is_mariadb: # type: ignore[attr-defined] + return "ALTER TABLE %s DROP CONSTRAINT %s" % ( + compiler.preparer.format_table(constraint.table), + compiler.preparer.format_constraint(constraint), + ) + else: + return "ALTER TABLE %s DROP CHECK %s" % ( + compiler.preparer.format_table(constraint.table), + compiler.preparer.format_constraint(constraint), + ) + else: + raise NotImplementedError( + "No generic 'DROP CONSTRAINT' in MySQL - " + "please specify constraint type" + ) diff --git a/.venv/lib/python3.12/site-packages/alembic/ddl/oracle.py b/.venv/lib/python3.12/site-packages/alembic/ddl/oracle.py new file mode 100644 index 00000000..eac99124 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/alembic/ddl/oracle.py @@ -0,0 +1,202 @@ +# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls +# mypy: no-warn-return-any, allow-any-generics + +from __future__ import annotations + +import re +from typing import Any +from typing import Optional +from typing import TYPE_CHECKING + +from sqlalchemy.sql import sqltypes + +from .base import AddColumn +from .base import alter_table +from .base import ColumnComment +from .base import ColumnDefault +from .base import ColumnName +from .base import ColumnNullable +from .base import ColumnType +from .base import format_column_name +from .base import format_server_default +from .base import format_table_name +from .base import format_type +from .base import IdentityColumnDefault +from .base import RenameTable +from .impl import DefaultImpl +from ..util.sqla_compat import compiles + +if TYPE_CHECKING: + from sqlalchemy.dialects.oracle.base import OracleDDLCompiler + from sqlalchemy.engine.cursor import CursorResult + from sqlalchemy.sql.schema import Column + + +class OracleImpl(DefaultImpl): + __dialect__ = "oracle" + transactional_ddl = False + batch_separator = "/" + command_terminator = "" + type_synonyms = DefaultImpl.type_synonyms + ( + {"VARCHAR", "VARCHAR2"}, + {"BIGINT", "INTEGER", "SMALLINT", "DECIMAL", "NUMERIC", "NUMBER"}, + {"DOUBLE", "FLOAT", "DOUBLE_PRECISION"}, + ) + identity_attrs_ignore = () + + def __init__(self, *arg, **kw) -> None: + super().__init__(*arg, **kw) + self.batch_separator = self.context_opts.get( + "oracle_batch_separator", self.batch_separator + ) + + def _exec(self, construct: Any, *args, **kw) -> Optional[CursorResult]: + result = super()._exec(construct, *args, **kw) + if self.as_sql and self.batch_separator: + self.static_output(self.batch_separator) + return result + + def compare_server_default( + self, + inspector_column, + metadata_column, + rendered_metadata_default, + rendered_inspector_default, + ): + if rendered_metadata_default is not None: + rendered_metadata_default = re.sub( + r"^\((.+)\)$", r"\1", rendered_metadata_default + ) + + rendered_metadata_default = re.sub( + r"^\"?'(.+)'\"?$", r"\1", rendered_metadata_default + ) + + if rendered_inspector_default is not None: + rendered_inspector_default = re.sub( + r"^\((.+)\)$", r"\1", rendered_inspector_default + ) + + rendered_inspector_default = re.sub( + r"^\"?'(.+)'\"?$", r"\1", rendered_inspector_default + ) + + rendered_inspector_default = rendered_inspector_default.strip() + return rendered_inspector_default != rendered_metadata_default + + def emit_begin(self) -> None: + self._exec("SET TRANSACTION READ WRITE") + + def emit_commit(self) -> None: + self._exec("COMMIT") + + +@compiles(AddColumn, "oracle") +def visit_add_column( + element: AddColumn, compiler: OracleDDLCompiler, **kw +) -> str: + return "%s %s" % ( + alter_table(compiler, element.table_name, element.schema), + add_column(compiler, element.column, **kw), + ) + + +@compiles(ColumnNullable, "oracle") +def visit_column_nullable( + element: ColumnNullable, compiler: OracleDDLCompiler, **kw +) -> str: + return "%s %s %s" % ( + alter_table(compiler, element.table_name, element.schema), + alter_column(compiler, element.column_name), + "NULL" if element.nullable else "NOT NULL", + ) + + +@compiles(ColumnType, "oracle") +def visit_column_type( + element: ColumnType, compiler: OracleDDLCompiler, **kw +) -> str: + return "%s %s %s" % ( + alter_table(compiler, element.table_name, element.schema), + alter_column(compiler, element.column_name), + "%s" % format_type(compiler, element.type_), + ) + + +@compiles(ColumnName, "oracle") +def visit_column_name( + element: ColumnName, compiler: OracleDDLCompiler, **kw +) -> str: + return "%s RENAME COLUMN %s TO %s" % ( + alter_table(compiler, element.table_name, element.schema), + format_column_name(compiler, element.column_name), + format_column_name(compiler, element.newname), + ) + + +@compiles(ColumnDefault, "oracle") +def visit_column_default( + element: ColumnDefault, compiler: OracleDDLCompiler, **kw +) -> str: + return "%s %s %s" % ( + alter_table(compiler, element.table_name, element.schema), + alter_column(compiler, element.column_name), + ( + "DEFAULT %s" % format_server_default(compiler, element.default) + if element.default is not None + else "DEFAULT NULL" + ), + ) + + +@compiles(ColumnComment, "oracle") +def visit_column_comment( + element: ColumnComment, compiler: OracleDDLCompiler, **kw +) -> str: + ddl = "COMMENT ON COLUMN {table_name}.{column_name} IS {comment}" + + comment = compiler.sql_compiler.render_literal_value( + (element.comment if element.comment is not None else ""), + sqltypes.String(), + ) + + return ddl.format( + table_name=element.table_name, + column_name=element.column_name, + comment=comment, + ) + + +@compiles(RenameTable, "oracle") +def visit_rename_table( + element: RenameTable, compiler: OracleDDLCompiler, **kw +) -> str: + return "%s RENAME TO %s" % ( + alter_table(compiler, element.table_name, element.schema), + format_table_name(compiler, element.new_table_name, None), + ) + + +def alter_column(compiler: OracleDDLCompiler, name: str) -> str: + return "MODIFY %s" % format_column_name(compiler, name) + + +def add_column(compiler: OracleDDLCompiler, column: Column[Any], **kw) -> str: + return "ADD %s" % compiler.get_column_specification(column, **kw) + + +@compiles(IdentityColumnDefault, "oracle") +def visit_identity_column( + element: IdentityColumnDefault, compiler: OracleDDLCompiler, **kw +): + text = "%s %s " % ( + alter_table(compiler, element.table_name, element.schema), + alter_column(compiler, element.column_name), + ) + if element.default is None: + # drop identity + text += "DROP IDENTITY" + return text + else: + text += compiler.visit_identity_column(element.default) + return text diff --git a/.venv/lib/python3.12/site-packages/alembic/ddl/postgresql.py b/.venv/lib/python3.12/site-packages/alembic/ddl/postgresql.py new file mode 100644 index 00000000..7cd8d35b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/alembic/ddl/postgresql.py @@ -0,0 +1,850 @@ +# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls +# mypy: no-warn-return-any, allow-any-generics + +from __future__ import annotations + +import logging +import re +from typing import Any +from typing import cast +from typing import Dict +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import TYPE_CHECKING +from typing import Union + +from sqlalchemy import Column +from sqlalchemy import Float +from sqlalchemy import Identity +from sqlalchemy import literal_column +from sqlalchemy import Numeric +from sqlalchemy import select +from sqlalchemy import text +from sqlalchemy import types as sqltypes +from sqlalchemy.dialects.postgresql import BIGINT +from sqlalchemy.dialects.postgresql import ExcludeConstraint +from sqlalchemy.dialects.postgresql import INTEGER +from sqlalchemy.schema import CreateIndex +from sqlalchemy.sql.elements import ColumnClause +from sqlalchemy.sql.elements import TextClause +from sqlalchemy.sql.functions import FunctionElement +from sqlalchemy.types import NULLTYPE + +from .base import alter_column +from .base import alter_table +from .base import AlterColumn +from .base import ColumnComment +from .base import format_column_name +from .base import format_table_name +from .base import format_type +from .base import IdentityColumnDefault +from .base import RenameTable +from .impl import ComparisonResult +from .impl import DefaultImpl +from .. import util +from ..autogenerate import render +from ..operations import ops +from ..operations import schemaobj +from ..operations.base import BatchOperations +from ..operations.base import Operations +from ..util import sqla_compat +from ..util.sqla_compat import compiles + +if TYPE_CHECKING: + from typing import Literal + + from sqlalchemy import Index + from sqlalchemy import UniqueConstraint + from sqlalchemy.dialects.postgresql.array import ARRAY + from sqlalchemy.dialects.postgresql.base import PGDDLCompiler + from sqlalchemy.dialects.postgresql.hstore import HSTORE + from sqlalchemy.dialects.postgresql.json import JSON + from sqlalchemy.dialects.postgresql.json import JSONB + from sqlalchemy.sql.elements import ClauseElement + from sqlalchemy.sql.elements import ColumnElement + from sqlalchemy.sql.elements import quoted_name + from sqlalchemy.sql.schema import MetaData + from sqlalchemy.sql.schema import Table + from sqlalchemy.sql.type_api import TypeEngine + + from .base import _ServerDefault + from ..autogenerate.api import AutogenContext + from ..autogenerate.render import _f_name + from ..runtime.migration import MigrationContext + + +log = logging.getLogger(__name__) + + +class PostgresqlImpl(DefaultImpl): + __dialect__ = "postgresql" + transactional_ddl = True + type_synonyms = DefaultImpl.type_synonyms + ( + {"FLOAT", "DOUBLE PRECISION"}, + ) + + def create_index(self, index: Index, **kw: Any) -> None: + # this likely defaults to None if not present, so get() + # should normally not return the default value. being + # defensive in any case + postgresql_include = index.kwargs.get("postgresql_include", None) or () + for col in postgresql_include: + if col not in index.table.c: # type: ignore[union-attr] + index.table.append_column( # type: ignore[union-attr] + Column(col, sqltypes.NullType) + ) + self._exec(CreateIndex(index, **kw)) + + def prep_table_for_batch(self, batch_impl, table): + for constraint in table.constraints: + if ( + constraint.name is not None + and constraint.name in batch_impl.named_constraints + ): + self.drop_constraint(constraint) + + def compare_server_default( + self, + inspector_column, + metadata_column, + rendered_metadata_default, + rendered_inspector_default, + ): + # don't do defaults for SERIAL columns + if ( + metadata_column.primary_key + and metadata_column is metadata_column.table._autoincrement_column + ): + return False + + conn_col_default = rendered_inspector_default + + defaults_equal = conn_col_default == rendered_metadata_default + if defaults_equal: + return False + + if None in ( + conn_col_default, + rendered_metadata_default, + metadata_column.server_default, + ): + return not defaults_equal + + metadata_default = metadata_column.server_default.arg + + if isinstance(metadata_default, str): + if not isinstance(inspector_column.type, (Numeric, Float)): + metadata_default = re.sub(r"^'|'$", "", metadata_default) + metadata_default = f"'{metadata_default}'" + + metadata_default = literal_column(metadata_default) + + # run a real compare against the server + conn = self.connection + assert conn is not None + return not conn.scalar( + select(literal_column(conn_col_default) == metadata_default) + ) + + def alter_column( # type:ignore[override] + self, + table_name: str, + column_name: str, + nullable: Optional[bool] = None, + server_default: Union[_ServerDefault, Literal[False]] = False, + name: Optional[str] = None, + type_: Optional[TypeEngine] = None, + schema: Optional[str] = None, + autoincrement: Optional[bool] = None, + existing_type: Optional[TypeEngine] = None, + existing_server_default: Optional[_ServerDefault] = None, + existing_nullable: Optional[bool] = None, + existing_autoincrement: Optional[bool] = None, + **kw: Any, + ) -> None: + using = kw.pop("postgresql_using", None) + + if using is not None and type_ is None: + raise util.CommandError( + "postgresql_using must be used with the type_ parameter" + ) + + if type_ is not None: + self._exec( + PostgresqlColumnType( + table_name, + column_name, + type_, + schema=schema, + using=using, + existing_type=existing_type, + existing_server_default=existing_server_default, + existing_nullable=existing_nullable, + ) + ) + + super().alter_column( + table_name, + column_name, + nullable=nullable, + server_default=server_default, + name=name, + schema=schema, + autoincrement=autoincrement, + existing_type=existing_type, + existing_server_default=existing_server_default, + existing_nullable=existing_nullable, + existing_autoincrement=existing_autoincrement, + **kw, + ) + + def autogen_column_reflect(self, inspector, table, column_info): + if column_info.get("default") and isinstance( + column_info["type"], (INTEGER, BIGINT) + ): + seq_match = re.match( + r"nextval\('(.+?)'::regclass\)", column_info["default"] + ) + if seq_match: + info = sqla_compat._exec_on_inspector( + inspector, + text( + "select c.relname, a.attname " + "from pg_class as c join " + "pg_depend d on d.objid=c.oid and " + "d.classid='pg_class'::regclass and " + "d.refclassid='pg_class'::regclass " + "join pg_class t on t.oid=d.refobjid " + "join pg_attribute a on a.attrelid=t.oid and " + "a.attnum=d.refobjsubid " + "where c.relkind='S' and " + "c.oid=cast(:seqname as regclass)" + ), + seqname=seq_match.group(1), + ).first() + if info: + seqname, colname = info + if colname == column_info["name"]: + log.info( + "Detected sequence named '%s' as " + "owned by integer column '%s(%s)', " + "assuming SERIAL and omitting", + seqname, + table.name, + colname, + ) + # sequence, and the owner is this column, + # its a SERIAL - whack it! + del column_info["default"] + + def correct_for_autogen_constraints( + self, + conn_unique_constraints, + conn_indexes, + metadata_unique_constraints, + metadata_indexes, + ): + doubled_constraints = { + index + for index in conn_indexes + if index.info.get("duplicates_constraint") + } + + for ix in doubled_constraints: + conn_indexes.remove(ix) + + if not sqla_compat.sqla_2: + self._skip_functional_indexes(metadata_indexes, conn_indexes) + + # pg behavior regarding modifiers + # | # | compiled sql | returned sql | regexp. group is removed | + # | - | ---------------- | -----------------| ------------------------ | + # | 1 | nulls first | nulls first | - | + # | 2 | nulls last | | (?<! desc)( nulls last)$ | + # | 3 | asc | | ( asc)$ | + # | 4 | asc nulls first | nulls first | ( asc) nulls first$ | + # | 5 | asc nulls last | | ( asc nulls last)$ | + # | 6 | desc | desc | - | + # | 7 | desc nulls first | desc | desc( nulls first)$ | + # | 8 | desc nulls last | desc nulls last | - | + _default_modifiers_re = ( # order of case 2 and 5 matters + re.compile("( asc nulls last)$"), # case 5 + re.compile("(?<! desc)( nulls last)$"), # case 2 + re.compile("( asc)$"), # case 3 + re.compile("( asc) nulls first$"), # case 4 + re.compile(" desc( nulls first)$"), # case 7 + ) + + def _cleanup_index_expr(self, index: Index, expr: str) -> str: + expr = expr.lower().replace('"', "").replace("'", "") + if index.table is not None: + # should not be needed, since include_table=False is in compile + expr = expr.replace(f"{index.table.name.lower()}.", "") + + if "::" in expr: + # strip :: cast. types can have spaces in them + expr = re.sub(r"(::[\w ]+\w)", "", expr) + + while expr and expr[0] == "(" and expr[-1] == ")": + expr = expr[1:-1] + + # NOTE: when parsing the connection expression this cleanup could + # be skipped + for rs in self._default_modifiers_re: + if match := rs.search(expr): + start, end = match.span(1) + expr = expr[:start] + expr[end:] + break + + while expr and expr[0] == "(" and expr[-1] == ")": + expr = expr[1:-1] + + # strip casts + cast_re = re.compile(r"cast\s*\(") + if cast_re.match(expr): + expr = cast_re.sub("", expr) + # remove the as type + expr = re.sub(r"as\s+[^)]+\)", "", expr) + # remove spaces + expr = expr.replace(" ", "") + return expr + + def _dialect_options( + self, item: Union[Index, UniqueConstraint] + ) -> Tuple[Any, ...]: + # only the positive case is returned by sqlalchemy reflection so + # None and False are threated the same + if item.dialect_kwargs.get("postgresql_nulls_not_distinct"): + return ("nulls_not_distinct",) + return () + + def compare_indexes( + self, + metadata_index: Index, + reflected_index: Index, + ) -> ComparisonResult: + msg = [] + unique_msg = self._compare_index_unique( + metadata_index, reflected_index + ) + if unique_msg: + msg.append(unique_msg) + m_exprs = metadata_index.expressions + r_exprs = reflected_index.expressions + if len(m_exprs) != len(r_exprs): + msg.append(f"expression number {len(r_exprs)} to {len(m_exprs)}") + if msg: + # no point going further, return early + return ComparisonResult.Different(msg) + skip = [] + for pos, (m_e, r_e) in enumerate(zip(m_exprs, r_exprs), 1): + m_compile = self._compile_element(m_e) + m_text = self._cleanup_index_expr(metadata_index, m_compile) + # print(f"META ORIG: {m_compile!r} CLEANUP: {m_text!r}") + r_compile = self._compile_element(r_e) + r_text = self._cleanup_index_expr(metadata_index, r_compile) + # print(f"CONN ORIG: {r_compile!r} CLEANUP: {r_text!r}") + if m_text == r_text: + continue # expressions these are equal + elif m_compile.strip().endswith("_ops") and ( + " " in m_compile or ")" in m_compile # is an expression + ): + skip.append( + f"expression #{pos} {m_compile!r} detected " + "as including operator clause." + ) + util.warn( + f"Expression #{pos} {m_compile!r} in index " + f"{reflected_index.name!r} detected to include " + "an operator clause. Expression compare cannot proceed. " + "Please move the operator clause to the " + "``postgresql_ops`` dict to enable proper compare " + "of the index expressions: " + "https://docs.sqlalchemy.org/en/latest/dialects/postgresql.html#operator-classes", # noqa: E501 + ) + else: + msg.append(f"expression #{pos} {r_compile!r} to {m_compile!r}") + + m_options = self._dialect_options(metadata_index) + r_options = self._dialect_options(reflected_index) + if m_options != r_options: + msg.extend(f"options {r_options} to {m_options}") + + if msg: + return ComparisonResult.Different(msg) + elif skip: + # if there are other changes detected don't skip the index + return ComparisonResult.Skip(skip) + else: + return ComparisonResult.Equal() + + def compare_unique_constraint( + self, + metadata_constraint: UniqueConstraint, + reflected_constraint: UniqueConstraint, + ) -> ComparisonResult: + metadata_tup = self._create_metadata_constraint_sig( + metadata_constraint + ) + reflected_tup = self._create_reflected_constraint_sig( + reflected_constraint + ) + + meta_sig = metadata_tup.unnamed + conn_sig = reflected_tup.unnamed + if conn_sig != meta_sig: + return ComparisonResult.Different( + f"expression {conn_sig} to {meta_sig}" + ) + + metadata_do = self._dialect_options(metadata_tup.const) + conn_do = self._dialect_options(reflected_tup.const) + if metadata_do != conn_do: + return ComparisonResult.Different( + f"expression {conn_do} to {metadata_do}" + ) + + return ComparisonResult.Equal() + + def adjust_reflected_dialect_options( + self, reflected_options: Dict[str, Any], kind: str + ) -> Dict[str, Any]: + options: Dict[str, Any] + options = reflected_options.get("dialect_options", {}).copy() + if not options.get("postgresql_include"): + options.pop("postgresql_include", None) + return options + + def _compile_element(self, element: Union[ClauseElement, str]) -> str: + if isinstance(element, str): + return element + return element.compile( + dialect=self.dialect, + compile_kwargs={"literal_binds": True, "include_table": False}, + ).string + + def render_ddl_sql_expr( + self, + expr: ClauseElement, + is_server_default: bool = False, + is_index: bool = False, + **kw: Any, + ) -> str: + """Render a SQL expression that is typically a server default, + index expression, etc. + + """ + + # apply self_group to index expressions; + # see https://github.com/sqlalchemy/sqlalchemy/blob/ + # 82fa95cfce070fab401d020c6e6e4a6a96cc2578/ + # lib/sqlalchemy/dialects/postgresql/base.py#L2261 + if is_index and not isinstance(expr, ColumnClause): + expr = expr.self_group() + + return super().render_ddl_sql_expr( + expr, is_server_default=is_server_default, is_index=is_index, **kw + ) + + def render_type( + self, type_: TypeEngine, autogen_context: AutogenContext + ) -> Union[str, Literal[False]]: + mod = type(type_).__module__ + if not mod.startswith("sqlalchemy.dialects.postgresql"): + return False + + if hasattr(self, "_render_%s_type" % type_.__visit_name__): + meth = getattr(self, "_render_%s_type" % type_.__visit_name__) + return meth(type_, autogen_context) + + return False + + def _render_HSTORE_type( + self, type_: HSTORE, autogen_context: AutogenContext + ) -> str: + return cast( + str, + render._render_type_w_subtype( + type_, autogen_context, "text_type", r"(.+?\(.*text_type=)" + ), + ) + + def _render_ARRAY_type( + self, type_: ARRAY, autogen_context: AutogenContext + ) -> str: + return cast( + str, + render._render_type_w_subtype( + type_, autogen_context, "item_type", r"(.+?\()" + ), + ) + + def _render_JSON_type( + self, type_: JSON, autogen_context: AutogenContext + ) -> str: + return cast( + str, + render._render_type_w_subtype( + type_, autogen_context, "astext_type", r"(.+?\(.*astext_type=)" + ), + ) + + def _render_JSONB_type( + self, type_: JSONB, autogen_context: AutogenContext + ) -> str: + return cast( + str, + render._render_type_w_subtype( + type_, autogen_context, "astext_type", r"(.+?\(.*astext_type=)" + ), + ) + + +class PostgresqlColumnType(AlterColumn): + def __init__( + self, name: str, column_name: str, type_: TypeEngine, **kw + ) -> None: + using = kw.pop("using", None) + super().__init__(name, column_name, **kw) + self.type_ = sqltypes.to_instance(type_) + self.using = using + + +@compiles(RenameTable, "postgresql") +def visit_rename_table( + element: RenameTable, compiler: PGDDLCompiler, **kw +) -> str: + return "%s RENAME TO %s" % ( + alter_table(compiler, element.table_name, element.schema), + format_table_name(compiler, element.new_table_name, None), + ) + + +@compiles(PostgresqlColumnType, "postgresql") +def visit_column_type( + element: PostgresqlColumnType, compiler: PGDDLCompiler, **kw +) -> str: + return "%s %s %s %s" % ( + alter_table(compiler, element.table_name, element.schema), + alter_column(compiler, element.column_name), + "TYPE %s" % format_type(compiler, element.type_), + "USING %s" % element.using if element.using else "", + ) + + +@compiles(ColumnComment, "postgresql") +def visit_column_comment( + element: ColumnComment, compiler: PGDDLCompiler, **kw +) -> str: + ddl = "COMMENT ON COLUMN {table_name}.{column_name} IS {comment}" + comment = ( + compiler.sql_compiler.render_literal_value( + element.comment, sqltypes.String() + ) + if element.comment is not None + else "NULL" + ) + + return ddl.format( + table_name=format_table_name( + compiler, element.table_name, element.schema + ), + column_name=format_column_name(compiler, element.column_name), + comment=comment, + ) + + +@compiles(IdentityColumnDefault, "postgresql") +def visit_identity_column( + element: IdentityColumnDefault, compiler: PGDDLCompiler, **kw +): + text = "%s %s " % ( + alter_table(compiler, element.table_name, element.schema), + alter_column(compiler, element.column_name), + ) + if element.default is None: + # drop identity + text += "DROP IDENTITY" + return text + elif element.existing_server_default is None: + # add identity options + text += "ADD " + text += compiler.visit_identity_column(element.default) + return text + else: + # alter identity + diff, _, _ = element.impl._compare_identity_default( + element.default, element.existing_server_default + ) + identity = element.default + for attr in sorted(diff): + if attr == "always": + text += "SET GENERATED %s " % ( + "ALWAYS" if identity.always else "BY DEFAULT" + ) + else: + text += "SET %s " % compiler.get_identity_options( + Identity(**{attr: getattr(identity, attr)}) + ) + return text + + +@Operations.register_operation("create_exclude_constraint") +@BatchOperations.register_operation( + "create_exclude_constraint", "batch_create_exclude_constraint" +) +@ops.AddConstraintOp.register_add_constraint("exclude_constraint") +class CreateExcludeConstraintOp(ops.AddConstraintOp): + """Represent a create exclude constraint operation.""" + + constraint_type = "exclude" + + def __init__( + self, + constraint_name: sqla_compat._ConstraintName, + table_name: Union[str, quoted_name], + elements: Union[ + Sequence[Tuple[str, str]], + Sequence[Tuple[ColumnClause[Any], str]], + ], + where: Optional[Union[ColumnElement[bool], str]] = None, + schema: Optional[str] = None, + _orig_constraint: Optional[ExcludeConstraint] = None, + **kw, + ) -> None: + self.constraint_name = constraint_name + self.table_name = table_name + self.elements = elements + self.where = where + self.schema = schema + self._orig_constraint = _orig_constraint + self.kw = kw + + @classmethod + def from_constraint( # type:ignore[override] + cls, constraint: ExcludeConstraint + ) -> CreateExcludeConstraintOp: + constraint_table = sqla_compat._table_for_constraint(constraint) + return cls( + constraint.name, + constraint_table.name, + [ # type: ignore + (expr, op) for expr, name, op in constraint._render_exprs + ], + where=cast("ColumnElement[bool] | None", constraint.where), + schema=constraint_table.schema, + _orig_constraint=constraint, + deferrable=constraint.deferrable, + initially=constraint.initially, + using=constraint.using, + ) + + def to_constraint( + self, migration_context: Optional[MigrationContext] = None + ) -> ExcludeConstraint: + if self._orig_constraint is not None: + return self._orig_constraint + schema_obj = schemaobj.SchemaObjects(migration_context) + t = schema_obj.table(self.table_name, schema=self.schema) + excl = ExcludeConstraint( + *self.elements, + name=self.constraint_name, + where=self.where, + **self.kw, + ) + for ( + expr, + name, + oper, + ) in excl._render_exprs: + t.append_column(Column(name, NULLTYPE)) + t.append_constraint(excl) + return excl + + @classmethod + def create_exclude_constraint( + cls, + operations: Operations, + constraint_name: str, + table_name: str, + *elements: Any, + **kw: Any, + ) -> Optional[Table]: + """Issue an alter to create an EXCLUDE constraint using the + current migration context. + + .. note:: This method is Postgresql specific, and additionally + requires at least SQLAlchemy 1.0. + + e.g.:: + + from alembic import op + + op.create_exclude_constraint( + "user_excl", + "user", + ("period", "&&"), + ("group", "="), + where=("group != 'some group'"), + ) + + Note that the expressions work the same way as that of + the ``ExcludeConstraint`` object itself; if plain strings are + passed, quoting rules must be applied manually. + + :param name: Name of the constraint. + :param table_name: String name of the source table. + :param elements: exclude conditions. + :param where: SQL expression or SQL string with optional WHERE + clause. + :param deferrable: optional bool. If set, emit DEFERRABLE or + NOT DEFERRABLE when issuing DDL for this constraint. + :param initially: optional string. If set, emit INITIALLY <value> + when issuing DDL for this constraint. + :param schema: Optional schema name to operate within. + + """ + op = cls(constraint_name, table_name, elements, **kw) + return operations.invoke(op) + + @classmethod + def batch_create_exclude_constraint( + cls, + operations: BatchOperations, + constraint_name: str, + *elements: Any, + **kw: Any, + ) -> Optional[Table]: + """Issue a "create exclude constraint" instruction using the + current batch migration context. + + .. note:: This method is Postgresql specific, and additionally + requires at least SQLAlchemy 1.0. + + .. seealso:: + + :meth:`.Operations.create_exclude_constraint` + + """ + kw["schema"] = operations.impl.schema + op = cls(constraint_name, operations.impl.table_name, elements, **kw) + return operations.invoke(op) + + +@render.renderers.dispatch_for(CreateExcludeConstraintOp) +def _add_exclude_constraint( + autogen_context: AutogenContext, op: CreateExcludeConstraintOp +) -> str: + return _exclude_constraint(op.to_constraint(), autogen_context, alter=True) + + +@render._constraint_renderers.dispatch_for(ExcludeConstraint) +def _render_inline_exclude_constraint( + constraint: ExcludeConstraint, + autogen_context: AutogenContext, + namespace_metadata: MetaData, +) -> str: + rendered = render._user_defined_render( + "exclude", constraint, autogen_context + ) + if rendered is not False: + return rendered + + return _exclude_constraint(constraint, autogen_context, False) + + +def _postgresql_autogenerate_prefix(autogen_context: AutogenContext) -> str: + imports = autogen_context.imports + if imports is not None: + imports.add("from sqlalchemy.dialects import postgresql") + return "postgresql." + + +def _exclude_constraint( + constraint: ExcludeConstraint, + autogen_context: AutogenContext, + alter: bool, +) -> str: + opts: List[Tuple[str, Union[quoted_name, str, _f_name, None]]] = [] + + has_batch = autogen_context._has_batch + + if constraint.deferrable: + opts.append(("deferrable", str(constraint.deferrable))) + if constraint.initially: + opts.append(("initially", str(constraint.initially))) + if constraint.using: + opts.append(("using", str(constraint.using))) + if not has_batch and alter and constraint.table.schema: + opts.append(("schema", render._ident(constraint.table.schema))) + if not alter and constraint.name: + opts.append( + ("name", render._render_gen_name(autogen_context, constraint.name)) + ) + + def do_expr_where_opts(): + args = [ + "(%s, %r)" + % ( + _render_potential_column( + sqltext, # type:ignore[arg-type] + autogen_context, + ), + opstring, + ) + for sqltext, name, opstring in constraint._render_exprs + ] + if constraint.where is not None: + args.append( + "where=%s" + % render._render_potential_expr( + constraint.where, autogen_context + ) + ) + args.extend(["%s=%r" % (k, v) for k, v in opts]) + return args + + if alter: + args = [ + repr(render._render_gen_name(autogen_context, constraint.name)) + ] + if not has_batch: + args += [repr(render._ident(constraint.table.name))] + args.extend(do_expr_where_opts()) + return "%(prefix)screate_exclude_constraint(%(args)s)" % { + "prefix": render._alembic_autogenerate_prefix(autogen_context), + "args": ", ".join(args), + } + else: + args = do_expr_where_opts() + return "%(prefix)sExcludeConstraint(%(args)s)" % { + "prefix": _postgresql_autogenerate_prefix(autogen_context), + "args": ", ".join(args), + } + + +def _render_potential_column( + value: Union[ + ColumnClause[Any], Column[Any], TextClause, FunctionElement[Any] + ], + autogen_context: AutogenContext, +) -> str: + if isinstance(value, ColumnClause): + if value.is_literal: + # like literal_column("int8range(from, to)") in ExcludeConstraint + template = "%(prefix)sliteral_column(%(name)r)" + else: + template = "%(prefix)scolumn(%(name)r)" + + return template % { + "prefix": render._sqlalchemy_autogenerate_prefix(autogen_context), + "name": value.name, + } + else: + return render._render_potential_expr( + value, + autogen_context, + wrap_in_element=isinstance(value, (TextClause, FunctionElement)), + ) diff --git a/.venv/lib/python3.12/site-packages/alembic/ddl/sqlite.py b/.venv/lib/python3.12/site-packages/alembic/ddl/sqlite.py new file mode 100644 index 00000000..7c6fb20c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/alembic/ddl/sqlite.py @@ -0,0 +1,237 @@ +# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls +# mypy: no-warn-return-any, allow-any-generics + +from __future__ import annotations + +import re +from typing import Any +from typing import Dict +from typing import Optional +from typing import TYPE_CHECKING +from typing import Union + +from sqlalchemy import cast +from sqlalchemy import Computed +from sqlalchemy import JSON +from sqlalchemy import schema +from sqlalchemy import sql + +from .base import alter_table +from .base import ColumnName +from .base import format_column_name +from .base import format_table_name +from .base import RenameTable +from .impl import DefaultImpl +from .. import util +from ..util.sqla_compat import compiles + +if TYPE_CHECKING: + from sqlalchemy.engine.reflection import Inspector + from sqlalchemy.sql.compiler import DDLCompiler + from sqlalchemy.sql.elements import Cast + from sqlalchemy.sql.elements import ClauseElement + from sqlalchemy.sql.schema import Column + from sqlalchemy.sql.schema import Constraint + from sqlalchemy.sql.schema import Table + from sqlalchemy.sql.type_api import TypeEngine + + from ..operations.batch import BatchOperationsImpl + + +class SQLiteImpl(DefaultImpl): + __dialect__ = "sqlite" + + transactional_ddl = False + """SQLite supports transactional DDL, but pysqlite does not: + see: http://bugs.python.org/issue10740 + """ + + def requires_recreate_in_batch( + self, batch_op: BatchOperationsImpl + ) -> bool: + """Return True if the given :class:`.BatchOperationsImpl` + would need the table to be recreated and copied in order to + proceed. + + Normally, only returns True on SQLite when operations other + than add_column are present. + + """ + for op in batch_op.batch: + if op[0] == "add_column": + col = op[1][1] + if isinstance( + col.server_default, schema.DefaultClause + ) and isinstance(col.server_default.arg, sql.ClauseElement): + return True + elif ( + isinstance(col.server_default, Computed) + and col.server_default.persisted + ): + return True + elif op[0] not in ("create_index", "drop_index"): + return True + else: + return False + + def add_constraint(self, const: Constraint): + # attempt to distinguish between an + # auto-gen constraint and an explicit one + if const._create_rule is None: + raise NotImplementedError( + "No support for ALTER of constraints in SQLite dialect. " + "Please refer to the batch mode feature which allows for " + "SQLite migrations using a copy-and-move strategy." + ) + elif const._create_rule(self): + util.warn( + "Skipping unsupported ALTER for " + "creation of implicit constraint. " + "Please refer to the batch mode feature which allows for " + "SQLite migrations using a copy-and-move strategy." + ) + + def drop_constraint(self, const: Constraint): + if const._create_rule is None: + raise NotImplementedError( + "No support for ALTER of constraints in SQLite dialect. " + "Please refer to the batch mode feature which allows for " + "SQLite migrations using a copy-and-move strategy." + ) + + def compare_server_default( + self, + inspector_column: Column[Any], + metadata_column: Column[Any], + rendered_metadata_default: Optional[str], + rendered_inspector_default: Optional[str], + ) -> bool: + if rendered_metadata_default is not None: + rendered_metadata_default = re.sub( + r"^\((.+)\)$", r"\1", rendered_metadata_default + ) + + rendered_metadata_default = re.sub( + r"^\"?'(.+)'\"?$", r"\1", rendered_metadata_default + ) + + if rendered_inspector_default is not None: + rendered_inspector_default = re.sub( + r"^\((.+)\)$", r"\1", rendered_inspector_default + ) + + rendered_inspector_default = re.sub( + r"^\"?'(.+)'\"?$", r"\1", rendered_inspector_default + ) + + return rendered_inspector_default != rendered_metadata_default + + def _guess_if_default_is_unparenthesized_sql_expr( + self, expr: Optional[str] + ) -> bool: + """Determine if a server default is a SQL expression or a constant. + + There are too many assertions that expect server defaults to round-trip + identically without parenthesis added so we will add parens only in + very specific cases. + + """ + if not expr: + return False + elif re.match(r"^[0-9\.]$", expr): + return False + elif re.match(r"^'.+'$", expr): + return False + elif re.match(r"^\(.+\)$", expr): + return False + else: + return True + + def autogen_column_reflect( + self, + inspector: Inspector, + table: Table, + column_info: Dict[str, Any], + ) -> None: + # SQLite expression defaults require parenthesis when sent + # as DDL + if self._guess_if_default_is_unparenthesized_sql_expr( + column_info.get("default", None) + ): + column_info["default"] = "(%s)" % (column_info["default"],) + + def render_ddl_sql_expr( + self, expr: ClauseElement, is_server_default: bool = False, **kw + ) -> str: + # SQLite expression defaults require parenthesis when sent + # as DDL + str_expr = super().render_ddl_sql_expr( + expr, is_server_default=is_server_default, **kw + ) + + if ( + is_server_default + and self._guess_if_default_is_unparenthesized_sql_expr(str_expr) + ): + str_expr = "(%s)" % (str_expr,) + return str_expr + + def cast_for_batch_migrate( + self, + existing: Column[Any], + existing_transfer: Dict[str, Union[TypeEngine, Cast]], + new_type: TypeEngine, + ) -> None: + if ( + existing.type._type_affinity is not new_type._type_affinity + and not isinstance(new_type, JSON) + ): + existing_transfer["expr"] = cast( + existing_transfer["expr"], new_type + ) + + def correct_for_autogen_constraints( + self, + conn_unique_constraints, + conn_indexes, + metadata_unique_constraints, + metadata_indexes, + ): + self._skip_functional_indexes(metadata_indexes, conn_indexes) + + +@compiles(RenameTable, "sqlite") +def visit_rename_table( + element: RenameTable, compiler: DDLCompiler, **kw +) -> str: + return "%s RENAME TO %s" % ( + alter_table(compiler, element.table_name, element.schema), + format_table_name(compiler, element.new_table_name, None), + ) + + +@compiles(ColumnName, "sqlite") +def visit_column_name(element: ColumnName, compiler: DDLCompiler, **kw) -> str: + return "%s RENAME COLUMN %s TO %s" % ( + alter_table(compiler, element.table_name, element.schema), + format_column_name(compiler, element.column_name), + format_column_name(compiler, element.newname), + ) + + +# @compiles(AddColumn, 'sqlite') +# def visit_add_column(element, compiler, **kw): +# return "%s %s" % ( +# alter_table(compiler, element.table_name, element.schema), +# add_column(compiler, element.column, **kw) +# ) + + +# def add_column(compiler, column, **kw): +# text = "ADD COLUMN %s" % compiler.get_column_specification(column, **kw) +# need to modify SQLAlchemy so that the CHECK associated with a Boolean +# or Enum gets placed as part of the column constraints, not the Table +# see ticket 98 +# for const in column.constraints: +# text += compiler.process(AddConstraint(const)) +# return text |