about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/alembic/util/sqla_compat.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/alembic/util/sqla_compat.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/alembic/util/sqla_compat.py')
-rw-r--r--.venv/lib/python3.12/site-packages/alembic/util/sqla_compat.py497
1 files changed, 497 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/alembic/util/sqla_compat.py b/.venv/lib/python3.12/site-packages/alembic/util/sqla_compat.py
new file mode 100644
index 00000000..a427d3c8
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/alembic/util/sqla_compat.py
@@ -0,0 +1,497 @@
+# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
+# mypy: no-warn-return-any, allow-any-generics
+
+from __future__ import annotations
+
+import contextlib
+import re
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import Iterable
+from typing import Iterator
+from typing import Optional
+from typing import Protocol
+from typing import Set
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import TypeVar
+from typing import Union
+
+from sqlalchemy import __version__
+from sqlalchemy import schema
+from sqlalchemy import sql
+from sqlalchemy import types as sqltypes
+from sqlalchemy.schema import CheckConstraint
+from sqlalchemy.schema import Column
+from sqlalchemy.schema import ForeignKeyConstraint
+from sqlalchemy.sql import visitors
+from sqlalchemy.sql.base import DialectKWArgs
+from sqlalchemy.sql.elements import BindParameter
+from sqlalchemy.sql.elements import ColumnClause
+from sqlalchemy.sql.elements import TextClause
+from sqlalchemy.sql.elements import UnaryExpression
+from sqlalchemy.sql.visitors import traverse
+from typing_extensions import TypeGuard
+
+if True:
+    from sqlalchemy.sql.naming import _NONE_NAME as _NONE_NAME  # type: ignore[attr-defined] # noqa: E501
+
+if TYPE_CHECKING:
+    from sqlalchemy import ClauseElement
+    from sqlalchemy import Identity
+    from sqlalchemy import Index
+    from sqlalchemy import Table
+    from sqlalchemy.engine import Connection
+    from sqlalchemy.engine import Dialect
+    from sqlalchemy.engine import Transaction
+    from sqlalchemy.sql.base import ColumnCollection
+    from sqlalchemy.sql.compiler import SQLCompiler
+    from sqlalchemy.sql.elements import ColumnElement
+    from sqlalchemy.sql.schema import Constraint
+    from sqlalchemy.sql.schema import SchemaItem
+
+_CE = TypeVar("_CE", bound=Union["ColumnElement[Any]", "SchemaItem"])
+
+
+class _CompilerProtocol(Protocol):
+    def __call__(self, element: Any, compiler: Any, **kw: Any) -> str: ...
+
+
+def _safe_int(value: str) -> Union[int, str]:
+    try:
+        return int(value)
+    except:
+        return value
+
+
+_vers = tuple(
+    [_safe_int(x) for x in re.findall(r"(\d+|[abc]\d)", __version__)]
+)
+# https://docs.sqlalchemy.org/en/latest/changelog/changelog_14.html#change-0c6e0cc67dfe6fac5164720e57ef307d
+sqla_14_18 = _vers >= (1, 4, 18)
+sqla_14_26 = _vers >= (1, 4, 26)
+sqla_2 = _vers >= (2,)
+sqlalchemy_version = __version__
+
+if TYPE_CHECKING:
+
+    def compiles(
+        element: Type[ClauseElement], *dialects: str
+    ) -> Callable[[_CompilerProtocol], _CompilerProtocol]: ...
+
+else:
+    from sqlalchemy.ext.compiler import compiles
+
+
+identity_has_dialect_kwargs = issubclass(schema.Identity, DialectKWArgs)
+
+
+def _get_identity_options_dict(
+    identity: Union[Identity, schema.Sequence, None],
+    dialect_kwargs: bool = False,
+) -> Dict[str, Any]:
+    if identity is None:
+        return {}
+    elif identity_has_dialect_kwargs:
+        assert hasattr(identity, "_as_dict")
+        as_dict = identity._as_dict()
+        if dialect_kwargs:
+            assert isinstance(identity, DialectKWArgs)
+            as_dict.update(identity.dialect_kwargs)
+    else:
+        as_dict = {}
+        if isinstance(identity, schema.Identity):
+            # always=None means something different than always=False
+            as_dict["always"] = identity.always
+            if identity.on_null is not None:
+                as_dict["on_null"] = identity.on_null
+        # attributes common to Identity and Sequence
+        attrs = (
+            "start",
+            "increment",
+            "minvalue",
+            "maxvalue",
+            "nominvalue",
+            "nomaxvalue",
+            "cycle",
+            "cache",
+            "order",
+        )
+        as_dict.update(
+            {
+                key: getattr(identity, key, None)
+                for key in attrs
+                if getattr(identity, key, None) is not None
+            }
+        )
+    return as_dict
+
+
+if sqla_2:
+    from sqlalchemy.sql.base import _NoneName
+else:
+    from sqlalchemy.util import symbol as _NoneName  # type: ignore[assignment]
+
+
+_ConstraintName = Union[None, str, _NoneName]
+_ConstraintNameDefined = Union[str, _NoneName]
+
+
+def constraint_name_defined(
+    name: _ConstraintName,
+) -> TypeGuard[_ConstraintNameDefined]:
+    return name is _NONE_NAME or isinstance(name, (str, _NoneName))
+
+
+def constraint_name_string(name: _ConstraintName) -> TypeGuard[str]:
+    return isinstance(name, str)
+
+
+def constraint_name_or_none(name: _ConstraintName) -> Optional[str]:
+    return name if constraint_name_string(name) else None
+
+
+AUTOINCREMENT_DEFAULT = "auto"
+
+
+@contextlib.contextmanager
+def _ensure_scope_for_ddl(
+    connection: Optional[Connection],
+) -> Iterator[None]:
+    try:
+        in_transaction = connection.in_transaction  # type: ignore[union-attr]
+    except AttributeError:
+        # catch for MockConnection, None
+        in_transaction = None
+        pass
+
+    # yield outside the catch
+    if in_transaction is None:
+        yield
+    else:
+        if not in_transaction():
+            assert connection is not None
+            with connection.begin():
+                yield
+        else:
+            yield
+
+
+def _safe_begin_connection_transaction(
+    connection: Connection,
+) -> Transaction:
+    transaction = connection.get_transaction()
+    if transaction:
+        return transaction
+    else:
+        return connection.begin()
+
+
+def _safe_commit_connection_transaction(
+    connection: Connection,
+) -> None:
+    transaction = connection.get_transaction()
+    if transaction:
+        transaction.commit()
+
+
+def _safe_rollback_connection_transaction(
+    connection: Connection,
+) -> None:
+    transaction = connection.get_transaction()
+    if transaction:
+        transaction.rollback()
+
+
+def _get_connection_in_transaction(connection: Optional[Connection]) -> bool:
+    try:
+        in_transaction = connection.in_transaction  # type: ignore
+    except AttributeError:
+        # catch for MockConnection
+        return False
+    else:
+        return in_transaction()
+
+
+def _idx_table_bound_expressions(idx: Index) -> Iterable[ColumnElement[Any]]:
+    return idx.expressions  # type: ignore
+
+
+def _copy(schema_item: _CE, **kw) -> _CE:
+    if hasattr(schema_item, "_copy"):
+        return schema_item._copy(**kw)
+    else:
+        return schema_item.copy(**kw)  # type: ignore[union-attr]
+
+
+def _connectable_has_table(
+    connectable: Connection, tablename: str, schemaname: Union[str, None]
+) -> bool:
+    return connectable.dialect.has_table(connectable, tablename, schemaname)
+
+
+def _exec_on_inspector(inspector, statement, **params):
+    with inspector._operation_context() as conn:
+        return conn.execute(statement, params)
+
+
+def _nullability_might_be_unset(metadata_column):
+    from sqlalchemy.sql import schema
+
+    return metadata_column._user_defined_nullable is schema.NULL_UNSPECIFIED
+
+
+def _server_default_is_computed(*server_default) -> bool:
+    return any(isinstance(sd, schema.Computed) for sd in server_default)
+
+
+def _server_default_is_identity(*server_default) -> bool:
+    return any(isinstance(sd, schema.Identity) for sd in server_default)
+
+
+def _table_for_constraint(constraint: Constraint) -> Table:
+    if isinstance(constraint, ForeignKeyConstraint):
+        table = constraint.parent
+        assert table is not None
+        return table  # type: ignore[return-value]
+    else:
+        return constraint.table
+
+
+def _columns_for_constraint(constraint):
+    if isinstance(constraint, ForeignKeyConstraint):
+        return [fk.parent for fk in constraint.elements]
+    elif isinstance(constraint, CheckConstraint):
+        return _find_columns(constraint.sqltext)
+    else:
+        return list(constraint.columns)
+
+
+def _resolve_for_variant(type_, dialect):
+    if _type_has_variants(type_):
+        base_type, mapping = _get_variant_mapping(type_)
+        return mapping.get(dialect.name, base_type)
+    else:
+        return type_
+
+
+if hasattr(sqltypes.TypeEngine, "_variant_mapping"):  # 2.0
+
+    def _type_has_variants(type_):
+        return bool(type_._variant_mapping)
+
+    def _get_variant_mapping(type_):
+        return type_, type_._variant_mapping
+
+else:
+
+    def _type_has_variants(type_):
+        return type(type_) is sqltypes.Variant
+
+    def _get_variant_mapping(type_):
+        return type_.impl, type_.mapping
+
+
+def _fk_spec(constraint: ForeignKeyConstraint) -> Any:
+    if TYPE_CHECKING:
+        assert constraint.columns is not None
+        assert constraint.elements is not None
+        assert isinstance(constraint.parent, Table)
+
+    source_columns = [
+        constraint.columns[key].name for key in constraint.column_keys
+    ]
+
+    source_table = constraint.parent.name
+    source_schema = constraint.parent.schema
+    target_schema = constraint.elements[0].column.table.schema
+    target_table = constraint.elements[0].column.table.name
+    target_columns = [element.column.name for element in constraint.elements]
+    ondelete = constraint.ondelete
+    onupdate = constraint.onupdate
+    deferrable = constraint.deferrable
+    initially = constraint.initially
+    return (
+        source_schema,
+        source_table,
+        source_columns,
+        target_schema,
+        target_table,
+        target_columns,
+        onupdate,
+        ondelete,
+        deferrable,
+        initially,
+    )
+
+
+def _fk_is_self_referential(constraint: ForeignKeyConstraint) -> bool:
+    spec = constraint.elements[0]._get_colspec()
+    tokens = spec.split(".")
+    tokens.pop(-1)  # colname
+    tablekey = ".".join(tokens)
+    assert constraint.parent is not None
+    return tablekey == constraint.parent.key
+
+
+def _is_type_bound(constraint: Constraint) -> bool:
+    # this deals with SQLAlchemy #3260, don't copy CHECK constraints
+    # that will be generated by the type.
+    # new feature added for #3260
+    return constraint._type_bound
+
+
+def _find_columns(clause):
+    """locate Column objects within the given expression."""
+
+    cols: Set[ColumnElement[Any]] = set()
+    traverse(clause, {}, {"column": cols.add})
+    return cols
+
+
+def _remove_column_from_collection(
+    collection: ColumnCollection, column: Union[Column[Any], ColumnClause[Any]]
+) -> None:
+    """remove a column from a ColumnCollection."""
+
+    # workaround for older SQLAlchemy, remove the
+    # same object that's present
+    assert column.key is not None
+    to_remove = collection[column.key]
+
+    # SQLAlchemy 2.0 will use more ReadOnlyColumnCollection
+    # (renamed from ImmutableColumnCollection)
+    if hasattr(collection, "_immutable") or hasattr(collection, "_readonly"):
+        collection._parent.remove(to_remove)
+    else:
+        collection.remove(to_remove)
+
+
+def _textual_index_column(
+    table: Table, text_: Union[str, TextClause, ColumnElement[Any]]
+) -> Union[ColumnElement[Any], Column[Any]]:
+    """a workaround for the Index construct's severe lack of flexibility"""
+    if isinstance(text_, str):
+        c = Column(text_, sqltypes.NULLTYPE)
+        table.append_column(c)
+        return c
+    elif isinstance(text_, TextClause):
+        return _textual_index_element(table, text_)
+    elif isinstance(text_, _textual_index_element):
+        return _textual_index_column(table, text_.text)
+    elif isinstance(text_, sql.ColumnElement):
+        return _copy_expression(text_, table)
+    else:
+        raise ValueError("String or text() construct expected")
+
+
+def _copy_expression(expression: _CE, target_table: Table) -> _CE:
+    def replace(col):
+        if (
+            isinstance(col, Column)
+            and col.table is not None
+            and col.table is not target_table
+        ):
+            if col.name in target_table.c:
+                return target_table.c[col.name]
+            else:
+                c = _copy(col)
+                target_table.append_column(c)
+                return c
+        else:
+            return None
+
+    return visitors.replacement_traverse(  # type: ignore[call-overload]
+        expression, {}, replace
+    )
+
+
+class _textual_index_element(sql.ColumnElement):
+    """Wrap around a sqlalchemy text() construct in such a way that
+    we appear like a column-oriented SQL expression to an Index
+    construct.
+
+    The issue here is that currently the Postgresql dialect, the biggest
+    recipient of functional indexes, keys all the index expressions to
+    the corresponding column expressions when rendering CREATE INDEX,
+    so the Index we create here needs to have a .columns collection that
+    is the same length as the .expressions collection.  Ultimately
+    SQLAlchemy should support text() expressions in indexes.
+
+    See SQLAlchemy issue 3174.
+
+    """
+
+    __visit_name__ = "_textual_idx_element"
+
+    def __init__(self, table: Table, text: TextClause) -> None:
+        self.table = table
+        self.text = text
+        self.key = text.text
+        self.fake_column = schema.Column(self.text.text, sqltypes.NULLTYPE)
+        table.append_column(self.fake_column)
+
+    def get_children(self, **kw):
+        return [self.fake_column]
+
+
+@compiles(_textual_index_element)
+def _render_textual_index_column(
+    element: _textual_index_element, compiler: SQLCompiler, **kw
+) -> str:
+    return compiler.process(element.text, **kw)
+
+
+class _literal_bindparam(BindParameter):
+    pass
+
+
+@compiles(_literal_bindparam)
+def _render_literal_bindparam(
+    element: _literal_bindparam, compiler: SQLCompiler, **kw
+) -> str:
+    return compiler.render_literal_bindparam(element, **kw)
+
+
+def _get_constraint_final_name(
+    constraint: Union[Index, Constraint], dialect: Optional[Dialect]
+) -> Optional[str]:
+    if constraint.name is None:
+        return None
+    assert dialect is not None
+    # for SQLAlchemy 1.4 we would like to have the option to expand
+    # the use of "deferred" names for constraints as well as to have
+    # some flexibility with "None" name and similar; make use of new
+    # SQLAlchemy API to return what would be the final compiled form of
+    # the name for this dialect.
+    return dialect.identifier_preparer.format_constraint(
+        constraint, _alembic_quote=False
+    )
+
+
+def _constraint_is_named(
+    constraint: Union[Constraint, Index], dialect: Optional[Dialect]
+) -> bool:
+    if constraint.name is None:
+        return False
+    assert dialect is not None
+    name = dialect.identifier_preparer.format_constraint(
+        constraint, _alembic_quote=False
+    )
+    return name is not None
+
+
+def is_expression_index(index: Index) -> bool:
+    for expr in index.expressions:
+        if is_expression(expr):
+            return True
+    return False
+
+
+def is_expression(expr: Any) -> bool:
+    while isinstance(expr, UnaryExpression):
+        expr = expr.element
+    if not isinstance(expr, ColumnClause) or expr.is_literal:
+        return True
+    return False