aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/alembic/ddl/_autogen.py
blob: 74715b18a8bfd8b727ee14e8ed3d290de7169d7b (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
# mypy: no-warn-return-any, allow-any-generics

from __future__ import annotations

from typing import Any
from typing import ClassVar
from typing import Dict
from typing import Generic
from typing import NamedTuple
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union

from sqlalchemy.sql.schema import Constraint
from sqlalchemy.sql.schema import ForeignKeyConstraint
from sqlalchemy.sql.schema import Index
from sqlalchemy.sql.schema import UniqueConstraint
from typing_extensions import TypeGuard

from .. import util
from ..util import sqla_compat

if TYPE_CHECKING:
    from typing import Literal

    from alembic.autogenerate.api import AutogenContext
    from alembic.ddl.impl import DefaultImpl

CompareConstraintType = Union[Constraint, Index]

_C = TypeVar("_C", bound=CompareConstraintType)

_clsreg: Dict[str, Type[_constraint_sig]] = {}


class ComparisonResult(NamedTuple):
    status: Literal["equal", "different", "skip"]
    message: str

    @property
    def is_equal(self) -> bool:
        return self.status == "equal"

    @property
    def is_different(self) -> bool:
        return self.status == "different"

    @property
    def is_skip(self) -> bool:
        return self.status == "skip"

    @classmethod
    def Equal(cls) -> ComparisonResult:
        """the constraints are equal."""
        return cls("equal", "The two constraints are equal")

    @classmethod
    def Different(cls, reason: Union[str, Sequence[str]]) -> ComparisonResult:
        """the constraints are different for the provided reason(s)."""
        return cls("different", ", ".join(util.to_list(reason)))

    @classmethod
    def Skip(cls, reason: Union[str, Sequence[str]]) -> ComparisonResult:
        """the constraint cannot be compared for the provided reason(s).

        The message is logged, but the constraints will be otherwise
        considered equal, meaning that no migration command will be
        generated.
        """
        return cls("skip", ", ".join(util.to_list(reason)))


class _constraint_sig(Generic[_C]):
    const: _C

    _sig: Tuple[Any, ...]
    name: Optional[sqla_compat._ConstraintNameDefined]

    impl: DefaultImpl

    _is_index: ClassVar[bool] = False
    _is_fk: ClassVar[bool] = False
    _is_uq: ClassVar[bool] = False

    _is_metadata: bool

    def __init_subclass__(cls) -> None:
        cls._register()

    @classmethod
    def _register(cls):
        raise NotImplementedError()

    def __init__(
        self, is_metadata: bool, impl: DefaultImpl, const: _C
    ) -> None:
        raise NotImplementedError()

    def compare_to_reflected(
        self, other: _constraint_sig[Any]
    ) -> ComparisonResult:
        assert self.impl is other.impl
        assert self._is_metadata
        assert not other._is_metadata

        return self._compare_to_reflected(other)

    def _compare_to_reflected(
        self, other: _constraint_sig[_C]
    ) -> ComparisonResult:
        raise NotImplementedError()

    @classmethod
    def from_constraint(
        cls, is_metadata: bool, impl: DefaultImpl, constraint: _C
    ) -> _constraint_sig[_C]:
        # these could be cached by constraint/impl, however, if the
        # constraint is modified in place, then the sig is wrong.  the mysql
        # impl currently does this, and if we fixed that we can't be sure
        # someone else might do it too, so play it safe.
        sig = _clsreg[constraint.__visit_name__](is_metadata, impl, constraint)
        return sig

    def md_name_to_sql_name(self, context: AutogenContext) -> Optional[str]:
        return sqla_compat._get_constraint_final_name(
            self.const, context.dialect
        )

    @util.memoized_property
    def is_named(self):
        return sqla_compat._constraint_is_named(self.const, self.impl.dialect)

    @util.memoized_property
    def unnamed(self) -> Tuple[Any, ...]:
        return self._sig

    @util.memoized_property
    def unnamed_no_options(self) -> Tuple[Any, ...]:
        raise NotImplementedError()

    @util.memoized_property
    def _full_sig(self) -> Tuple[Any, ...]:
        return (self.name,) + self.unnamed

    def __eq__(self, other) -> bool:
        return self._full_sig == other._full_sig

    def __ne__(self, other) -> bool:
        return self._full_sig != other._full_sig

    def __hash__(self) -> int:
        return hash(self._full_sig)


class _uq_constraint_sig(_constraint_sig[UniqueConstraint]):
    _is_uq = True

    @classmethod
    def _register(cls) -> None:
        _clsreg["unique_constraint"] = cls

    is_unique = True

    def __init__(
        self,
        is_metadata: bool,
        impl: DefaultImpl,
        const: UniqueConstraint,
    ) -> None:
        self.impl = impl
        self.const = const
        self.name = sqla_compat.constraint_name_or_none(const.name)
        self._sig = tuple(sorted([col.name for col in const.columns]))
        self._is_metadata = is_metadata

    @property
    def column_names(self) -> Tuple[str, ...]:
        return tuple([col.name for col in self.const.columns])

    def _compare_to_reflected(
        self, other: _constraint_sig[_C]
    ) -> ComparisonResult:
        assert self._is_metadata
        metadata_obj = self
        conn_obj = other

        assert is_uq_sig(conn_obj)
        return self.impl.compare_unique_constraint(
            metadata_obj.const, conn_obj.const
        )


class _ix_constraint_sig(_constraint_sig[Index]):
    _is_index = True

    name: sqla_compat._ConstraintName

    @classmethod
    def _register(cls) -> None:
        _clsreg["index"] = cls

    def __init__(
        self, is_metadata: bool, impl: DefaultImpl, const: Index
    ) -> None:
        self.impl = impl
        self.const = const
        self.name = const.name
        self.is_unique = bool(const.unique)
        self._is_metadata = is_metadata

    def _compare_to_reflected(
        self, other: _constraint_sig[_C]
    ) -> ComparisonResult:
        assert self._is_metadata
        metadata_obj = self
        conn_obj = other

        assert is_index_sig(conn_obj)
        return self.impl.compare_indexes(metadata_obj.const, conn_obj.const)

    @util.memoized_property
    def has_expressions(self):
        return sqla_compat.is_expression_index(self.const)

    @util.memoized_property
    def column_names(self) -> Tuple[str, ...]:
        return tuple([col.name for col in self.const.columns])

    @util.memoized_property
    def column_names_optional(self) -> Tuple[Optional[str], ...]:
        return tuple(
            [getattr(col, "name", None) for col in self.const.expressions]
        )

    @util.memoized_property
    def is_named(self):
        return True

    @util.memoized_property
    def unnamed(self):
        return (self.is_unique,) + self.column_names_optional


class _fk_constraint_sig(_constraint_sig[ForeignKeyConstraint]):
    _is_fk = True

    @classmethod
    def _register(cls) -> None:
        _clsreg["foreign_key_constraint"] = cls

    def __init__(
        self,
        is_metadata: bool,
        impl: DefaultImpl,
        const: ForeignKeyConstraint,
    ) -> None:
        self._is_metadata = is_metadata

        self.impl = impl
        self.const = const

        self.name = sqla_compat.constraint_name_or_none(const.name)

        (
            self.source_schema,
            self.source_table,
            self.source_columns,
            self.target_schema,
            self.target_table,
            self.target_columns,
            onupdate,
            ondelete,
            deferrable,
            initially,
        ) = sqla_compat._fk_spec(const)

        self._sig: Tuple[Any, ...] = (
            self.source_schema,
            self.source_table,
            tuple(self.source_columns),
            self.target_schema,
            self.target_table,
            tuple(self.target_columns),
        ) + (
            (
                (None if onupdate.lower() == "no action" else onupdate.lower())
                if onupdate
                else None
            ),
            (
                (None if ondelete.lower() == "no action" else ondelete.lower())
                if ondelete
                else None
            ),
            # convert initially + deferrable into one three-state value
            (
                "initially_deferrable"
                if initially and initially.lower() == "deferred"
                else "deferrable" if deferrable else "not deferrable"
            ),
        )

    @util.memoized_property
    def unnamed_no_options(self):
        return (
            self.source_schema,
            self.source_table,
            tuple(self.source_columns),
            self.target_schema,
            self.target_table,
            tuple(self.target_columns),
        )


def is_index_sig(sig: _constraint_sig) -> TypeGuard[_ix_constraint_sig]:
    return sig._is_index


def is_uq_sig(sig: _constraint_sig) -> TypeGuard[_uq_constraint_sig]:
    return sig._is_uq


def is_fk_sig(sig: _constraint_sig) -> TypeGuard[_fk_constraint_sig]:
    return sig._is_fk