about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/sqlalchemy/testing/fixtures/sql.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/sqlalchemy/testing/fixtures/sql.py')
-rw-r--r--.venv/lib/python3.12/site-packages/sqlalchemy/testing/fixtures/sql.py503
1 files changed, 503 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/sqlalchemy/testing/fixtures/sql.py b/.venv/lib/python3.12/site-packages/sqlalchemy/testing/fixtures/sql.py
new file mode 100644
index 00000000..44cf21c2
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/sqlalchemy/testing/fixtures/sql.py
@@ -0,0 +1,503 @@
+# testing/fixtures/sql.py
+# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: ignore-errors
+from __future__ import annotations
+
+import itertools
+import random
+import re
+import sys
+
+import sqlalchemy as sa
+from .base import TestBase
+from .. import config
+from .. import mock
+from ..assertions import eq_
+from ..assertions import ne_
+from ..util import adict
+from ..util import drop_all_tables_from_metadata
+from ... import event
+from ... import util
+from ...schema import sort_tables_and_constraints
+from ...sql import visitors
+from ...sql.elements import ClauseElement
+
+
+class TablesTest(TestBase):
+    # 'once', None
+    run_setup_bind = "once"
+
+    # 'once', 'each', None
+    run_define_tables = "once"
+
+    # 'once', 'each', None
+    run_create_tables = "once"
+
+    # 'once', 'each', None
+    run_inserts = "each"
+
+    # 'each', None
+    run_deletes = "each"
+
+    # 'once', None
+    run_dispose_bind = None
+
+    bind = None
+    _tables_metadata = None
+    tables = None
+    other = None
+    sequences = None
+
+    @config.fixture(autouse=True, scope="class")
+    def _setup_tables_test_class(self):
+        cls = self.__class__
+        cls._init_class()
+
+        cls._setup_once_tables()
+
+        cls._setup_once_inserts()
+
+        yield
+
+        cls._teardown_once_metadata_bind()
+
+    @config.fixture(autouse=True, scope="function")
+    def _setup_tables_test_instance(self):
+        self._setup_each_tables()
+        self._setup_each_inserts()
+
+        yield
+
+        self._teardown_each_tables()
+
+    @property
+    def tables_test_metadata(self):
+        return self._tables_metadata
+
+    @classmethod
+    def _init_class(cls):
+        if cls.run_define_tables == "each":
+            if cls.run_create_tables == "once":
+                cls.run_create_tables = "each"
+            assert cls.run_inserts in ("each", None)
+
+        cls.other = adict()
+        cls.tables = adict()
+        cls.sequences = adict()
+
+        cls.bind = cls.setup_bind()
+        cls._tables_metadata = sa.MetaData()
+
+    @classmethod
+    def _setup_once_inserts(cls):
+        if cls.run_inserts == "once":
+            cls._load_fixtures()
+            with cls.bind.begin() as conn:
+                cls.insert_data(conn)
+
+    @classmethod
+    def _setup_once_tables(cls):
+        if cls.run_define_tables == "once":
+            cls.define_tables(cls._tables_metadata)
+            if cls.run_create_tables == "once":
+                cls._tables_metadata.create_all(cls.bind)
+            cls.tables.update(cls._tables_metadata.tables)
+            cls.sequences.update(cls._tables_metadata._sequences)
+
+    def _setup_each_tables(self):
+        if self.run_define_tables == "each":
+            self.define_tables(self._tables_metadata)
+            if self.run_create_tables == "each":
+                self._tables_metadata.create_all(self.bind)
+            self.tables.update(self._tables_metadata.tables)
+            self.sequences.update(self._tables_metadata._sequences)
+        elif self.run_create_tables == "each":
+            self._tables_metadata.create_all(self.bind)
+
+    def _setup_each_inserts(self):
+        if self.run_inserts == "each":
+            self._load_fixtures()
+            with self.bind.begin() as conn:
+                self.insert_data(conn)
+
+    def _teardown_each_tables(self):
+        if self.run_define_tables == "each":
+            self.tables.clear()
+            if self.run_create_tables == "each":
+                drop_all_tables_from_metadata(self._tables_metadata, self.bind)
+            self._tables_metadata.clear()
+        elif self.run_create_tables == "each":
+            drop_all_tables_from_metadata(self._tables_metadata, self.bind)
+
+        savepoints = getattr(config.requirements, "savepoints", False)
+        if savepoints:
+            savepoints = savepoints.enabled
+
+        # no need to run deletes if tables are recreated on setup
+        if (
+            self.run_define_tables != "each"
+            and self.run_create_tables != "each"
+            and self.run_deletes == "each"
+        ):
+            with self.bind.begin() as conn:
+                for table in reversed(
+                    [
+                        t
+                        for (t, fks) in sort_tables_and_constraints(
+                            self._tables_metadata.tables.values()
+                        )
+                        if t is not None
+                    ]
+                ):
+                    try:
+                        if savepoints:
+                            with conn.begin_nested():
+                                conn.execute(table.delete())
+                        else:
+                            conn.execute(table.delete())
+                    except sa.exc.DBAPIError as ex:
+                        print(
+                            ("Error emptying table %s: %r" % (table, ex)),
+                            file=sys.stderr,
+                        )
+
+    @classmethod
+    def _teardown_once_metadata_bind(cls):
+        if cls.run_create_tables:
+            drop_all_tables_from_metadata(cls._tables_metadata, cls.bind)
+
+        if cls.run_dispose_bind == "once":
+            cls.dispose_bind(cls.bind)
+
+        cls._tables_metadata.bind = None
+
+        if cls.run_setup_bind is not None:
+            cls.bind = None
+
+    @classmethod
+    def setup_bind(cls):
+        return config.db
+
+    @classmethod
+    def dispose_bind(cls, bind):
+        if hasattr(bind, "dispose"):
+            bind.dispose()
+        elif hasattr(bind, "close"):
+            bind.close()
+
+    @classmethod
+    def define_tables(cls, metadata):
+        pass
+
+    @classmethod
+    def fixtures(cls):
+        return {}
+
+    @classmethod
+    def insert_data(cls, connection):
+        pass
+
+    def sql_count_(self, count, fn):
+        self.assert_sql_count(self.bind, fn, count)
+
+    def sql_eq_(self, callable_, statements):
+        self.assert_sql(self.bind, callable_, statements)
+
+    @classmethod
+    def _load_fixtures(cls):
+        """Insert rows as represented by the fixtures() method."""
+        headers, rows = {}, {}
+        for table, data in cls.fixtures().items():
+            if len(data) < 2:
+                continue
+            if isinstance(table, str):
+                table = cls.tables[table]
+            headers[table] = data[0]
+            rows[table] = data[1:]
+        for table, fks in sort_tables_and_constraints(
+            cls._tables_metadata.tables.values()
+        ):
+            if table is None:
+                continue
+            if table not in headers:
+                continue
+            with cls.bind.begin() as conn:
+                conn.execute(
+                    table.insert(),
+                    [
+                        dict(zip(headers[table], column_values))
+                        for column_values in rows[table]
+                    ],
+                )
+
+
+class NoCache:
+    @config.fixture(autouse=True, scope="function")
+    def _disable_cache(self):
+        _cache = config.db._compiled_cache
+        config.db._compiled_cache = None
+        yield
+        config.db._compiled_cache = _cache
+
+
+class RemovesEvents:
+    @util.memoized_property
+    def _event_fns(self):
+        return set()
+
+    def event_listen(self, target, name, fn, **kw):
+        self._event_fns.add((target, name, fn))
+        event.listen(target, name, fn, **kw)
+
+    @config.fixture(autouse=True, scope="function")
+    def _remove_events(self):
+        yield
+        for key in self._event_fns:
+            event.remove(*key)
+
+
+class ComputedReflectionFixtureTest(TablesTest):
+    run_inserts = run_deletes = None
+
+    __backend__ = True
+    __requires__ = ("computed_columns", "table_reflection")
+
+    regexp = re.compile(r"[\[\]\(\)\s`'\"]*")
+
+    def normalize(self, text):
+        return self.regexp.sub("", text).lower()
+
+    @classmethod
+    def define_tables(cls, metadata):
+        from ... import Integer
+        from ... import testing
+        from ...schema import Column
+        from ...schema import Computed
+        from ...schema import Table
+
+        Table(
+            "computed_default_table",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("normal", Integer),
+            Column("computed_col", Integer, Computed("normal + 42")),
+            Column("with_default", Integer, server_default="42"),
+        )
+
+        t = Table(
+            "computed_column_table",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("normal", Integer),
+            Column("computed_no_flag", Integer, Computed("normal + 42")),
+        )
+
+        if testing.requires.schemas.enabled:
+            t2 = Table(
+                "computed_column_table",
+                metadata,
+                Column("id", Integer, primary_key=True),
+                Column("normal", Integer),
+                Column("computed_no_flag", Integer, Computed("normal / 42")),
+                schema=config.test_schema,
+            )
+
+        if testing.requires.computed_columns_virtual.enabled:
+            t.append_column(
+                Column(
+                    "computed_virtual",
+                    Integer,
+                    Computed("normal + 2", persisted=False),
+                )
+            )
+            if testing.requires.schemas.enabled:
+                t2.append_column(
+                    Column(
+                        "computed_virtual",
+                        Integer,
+                        Computed("normal / 2", persisted=False),
+                    )
+                )
+        if testing.requires.computed_columns_stored.enabled:
+            t.append_column(
+                Column(
+                    "computed_stored",
+                    Integer,
+                    Computed("normal - 42", persisted=True),
+                )
+            )
+            if testing.requires.schemas.enabled:
+                t2.append_column(
+                    Column(
+                        "computed_stored",
+                        Integer,
+                        Computed("normal * 42", persisted=True),
+                    )
+                )
+
+
+class CacheKeyFixture:
+    def _compare_equal(self, a, b, compare_values):
+        a_key = a._generate_cache_key()
+        b_key = b._generate_cache_key()
+
+        if a_key is None:
+            assert a._annotations.get("nocache")
+
+            assert b_key is None
+        else:
+            eq_(a_key.key, b_key.key)
+            eq_(hash(a_key.key), hash(b_key.key))
+
+            for a_param, b_param in zip(a_key.bindparams, b_key.bindparams):
+                assert a_param.compare(b_param, compare_values=compare_values)
+        return a_key, b_key
+
+    def _run_cache_key_fixture(self, fixture, compare_values):
+        case_a = fixture()
+        case_b = fixture()
+
+        for a, b in itertools.combinations_with_replacement(
+            range(len(case_a)), 2
+        ):
+            if a == b:
+                a_key, b_key = self._compare_equal(
+                    case_a[a], case_b[b], compare_values
+                )
+                if a_key is None:
+                    continue
+            else:
+                a_key = case_a[a]._generate_cache_key()
+                b_key = case_b[b]._generate_cache_key()
+
+                if a_key is None or b_key is None:
+                    if a_key is None:
+                        assert case_a[a]._annotations.get("nocache")
+                    if b_key is None:
+                        assert case_b[b]._annotations.get("nocache")
+                    continue
+
+                if a_key.key == b_key.key:
+                    for a_param, b_param in zip(
+                        a_key.bindparams, b_key.bindparams
+                    ):
+                        if not a_param.compare(
+                            b_param, compare_values=compare_values
+                        ):
+                            break
+                    else:
+                        # this fails unconditionally since we could not
+                        # find bound parameter values that differed.
+                        # Usually we intended to get two distinct keys here
+                        # so the failure will be more descriptive using the
+                        # ne_() assertion.
+                        ne_(a_key.key, b_key.key)
+                else:
+                    ne_(a_key.key, b_key.key)
+
+            # ClauseElement-specific test to ensure the cache key
+            # collected all the bound parameters that aren't marked
+            # as "literal execute"
+            if isinstance(case_a[a], ClauseElement) and isinstance(
+                case_b[b], ClauseElement
+            ):
+                assert_a_params = []
+                assert_b_params = []
+
+                for elem in visitors.iterate(case_a[a]):
+                    if elem.__visit_name__ == "bindparam":
+                        assert_a_params.append(elem)
+
+                for elem in visitors.iterate(case_b[b]):
+                    if elem.__visit_name__ == "bindparam":
+                        assert_b_params.append(elem)
+
+                # note we're asserting the order of the params as well as
+                # if there are dupes or not.  ordering has to be
+                # deterministic and matches what a traversal would provide.
+                eq_(
+                    sorted(a_key.bindparams, key=lambda b: b.key),
+                    sorted(
+                        util.unique_list(assert_a_params), key=lambda b: b.key
+                    ),
+                )
+                eq_(
+                    sorted(b_key.bindparams, key=lambda b: b.key),
+                    sorted(
+                        util.unique_list(assert_b_params), key=lambda b: b.key
+                    ),
+                )
+
+    def _run_cache_key_equal_fixture(self, fixture, compare_values):
+        case_a = fixture()
+        case_b = fixture()
+
+        for a, b in itertools.combinations_with_replacement(
+            range(len(case_a)), 2
+        ):
+            self._compare_equal(case_a[a], case_b[b], compare_values)
+
+
+def insertmanyvalues_fixture(
+    connection, randomize_rows=False, warn_on_downgraded=False
+):
+    dialect = connection.dialect
+    orig_dialect = dialect._deliver_insertmanyvalues_batches
+    orig_conn = connection._exec_insertmany_context
+
+    class RandomCursor:
+        __slots__ = ("cursor",)
+
+        def __init__(self, cursor):
+            self.cursor = cursor
+
+        # only this method is called by the deliver method.
+        # by not having the other methods we assert that those aren't being
+        # used
+
+        @property
+        def description(self):
+            return self.cursor.description
+
+        def fetchall(self):
+            rows = self.cursor.fetchall()
+            rows = list(rows)
+            random.shuffle(rows)
+            return rows
+
+    def _deliver_insertmanyvalues_batches(
+        connection,
+        cursor,
+        statement,
+        parameters,
+        generic_setinputsizes,
+        context,
+    ):
+        if randomize_rows:
+            cursor = RandomCursor(cursor)
+        for batch in orig_dialect(
+            connection,
+            cursor,
+            statement,
+            parameters,
+            generic_setinputsizes,
+            context,
+        ):
+            if warn_on_downgraded and batch.is_downgraded:
+                util.warn("Batches were downgraded for sorted INSERT")
+
+            yield batch
+
+    def _exec_insertmany_context(dialect, context):
+        with mock.patch.object(
+            dialect,
+            "_deliver_insertmanyvalues_batches",
+            new=_deliver_insertmanyvalues_batches,
+        ):
+            return orig_conn(dialect, context)
+
+    connection._exec_insertmany_context = _exec_insertmany_context