about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/alembic/testing/fixtures.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/alembic/testing/fixtures.py')
-rw-r--r--.venv/lib/python3.12/site-packages/alembic/testing/fixtures.py306
1 files changed, 306 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/alembic/testing/fixtures.py b/.venv/lib/python3.12/site-packages/alembic/testing/fixtures.py
new file mode 100644
index 00000000..17d732f8
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/alembic/testing/fixtures.py
@@ -0,0 +1,306 @@
+from __future__ import annotations
+
+import configparser
+from contextlib import contextmanager
+import io
+import re
+from typing import Any
+from typing import Dict
+
+from sqlalchemy import Column
+from sqlalchemy import create_mock_engine
+from sqlalchemy import inspect
+from sqlalchemy import MetaData
+from sqlalchemy import String
+from sqlalchemy import Table
+from sqlalchemy import testing
+from sqlalchemy import text
+from sqlalchemy.testing import config
+from sqlalchemy.testing import mock
+from sqlalchemy.testing.assertions import eq_
+from sqlalchemy.testing.fixtures import FutureEngineMixin
+from sqlalchemy.testing.fixtures import TablesTest as SQLAlchemyTablesTest
+from sqlalchemy.testing.fixtures import TestBase as SQLAlchemyTestBase
+
+import alembic
+from .assertions import _get_dialect
+from ..environment import EnvironmentContext
+from ..migration import MigrationContext
+from ..operations import Operations
+from ..util import sqla_compat
+from ..util.sqla_compat import sqla_2
+
+
+testing_config = configparser.ConfigParser()
+testing_config.read(["test.cfg"])
+
+
+class TestBase(SQLAlchemyTestBase):
+    is_sqlalchemy_future = sqla_2
+
+    @testing.fixture()
+    def ops_context(self, migration_context):
+        with migration_context.begin_transaction(_per_migration=True):
+            yield Operations(migration_context)
+
+    @testing.fixture
+    def migration_context(self, connection):
+        return MigrationContext.configure(
+            connection, opts=dict(transaction_per_migration=True)
+        )
+
+    @testing.fixture
+    def as_sql_migration_context(self, connection):
+        return MigrationContext.configure(
+            connection, opts=dict(transaction_per_migration=True, as_sql=True)
+        )
+
+    @testing.fixture
+    def connection(self):
+        with config.db.connect() as conn:
+            yield conn
+
+
+class TablesTest(TestBase, SQLAlchemyTablesTest):
+    pass
+
+
+FutureEngineMixin.is_sqlalchemy_future = True
+
+
+def capture_db(dialect="postgresql://"):
+    buf = []
+
+    def dump(sql, *multiparams, **params):
+        buf.append(str(sql.compile(dialect=engine.dialect)))
+
+    engine = create_mock_engine(dialect, dump)
+    return engine, buf
+
+
+_engs: Dict[Any, Any] = {}
+
+
+@contextmanager
+def capture_context_buffer(**kw):
+    if kw.pop("bytes_io", False):
+        buf = io.BytesIO()
+    else:
+        buf = io.StringIO()
+
+    kw.update({"dialect_name": "sqlite", "output_buffer": buf})
+    conf = EnvironmentContext.configure
+
+    def configure(*arg, **opt):
+        opt.update(**kw)
+        return conf(*arg, **opt)
+
+    with mock.patch.object(EnvironmentContext, "configure", configure):
+        yield buf
+
+
+@contextmanager
+def capture_engine_context_buffer(**kw):
+    from .env import _sqlite_file_db
+    from sqlalchemy import event
+
+    buf = io.StringIO()
+
+    eng = _sqlite_file_db()
+
+    conn = eng.connect()
+
+    @event.listens_for(conn, "before_cursor_execute")
+    def bce(conn, cursor, statement, parameters, context, executemany):
+        buf.write(statement + "\n")
+
+    kw.update({"connection": conn})
+    conf = EnvironmentContext.configure
+
+    def configure(*arg, **opt):
+        opt.update(**kw)
+        return conf(*arg, **opt)
+
+    with mock.patch.object(EnvironmentContext, "configure", configure):
+        yield buf
+
+
+def op_fixture(
+    dialect="default",
+    as_sql=False,
+    naming_convention=None,
+    literal_binds=False,
+    native_boolean=None,
+):
+    opts = {}
+    if naming_convention:
+        opts["target_metadata"] = MetaData(naming_convention=naming_convention)
+
+    class buffer_:
+        def __init__(self):
+            self.lines = []
+
+        def write(self, msg):
+            msg = msg.strip()
+            msg = re.sub(r"[\n\t]", "", msg)
+            if as_sql:
+                # the impl produces soft tabs,
+                # so search for blocks of 4 spaces
+                msg = re.sub(r"    ", "", msg)
+                msg = re.sub(r"\;\n*$", "", msg)
+
+            self.lines.append(msg)
+
+        def flush(self):
+            pass
+
+    buf = buffer_()
+
+    class ctx(MigrationContext):
+        def get_buf(self):
+            return buf
+
+        def clear_assertions(self):
+            buf.lines[:] = []
+
+        def assert_(self, *sql):
+            # TODO: make this more flexible about
+            # whitespace and such
+            eq_(buf.lines, [re.sub(r"[\n\t]", "", s) for s in sql])
+
+        def assert_contains(self, sql):
+            for stmt in buf.lines:
+                if re.sub(r"[\n\t]", "", sql) in stmt:
+                    return
+            else:
+                assert False, "Could not locate fragment %r in %r" % (
+                    sql,
+                    buf.lines,
+                )
+
+    if as_sql:
+        opts["as_sql"] = as_sql
+    if literal_binds:
+        opts["literal_binds"] = literal_binds
+
+    ctx_dialect = _get_dialect(dialect)
+    if native_boolean is not None:
+        ctx_dialect.supports_native_boolean = native_boolean
+        # this is new as of SQLAlchemy 1.2.7 and is used by SQL Server,
+        # which breaks assumptions in the alembic test suite
+        ctx_dialect.non_native_boolean_check_constraint = True
+    if not as_sql:
+
+        def execute(stmt, *multiparam, **param):
+            if isinstance(stmt, str):
+                stmt = text(stmt)
+            assert stmt.supports_execution
+            sql = str(stmt.compile(dialect=ctx_dialect))
+
+            buf.write(sql)
+
+        connection = mock.Mock(dialect=ctx_dialect, execute=execute)
+    else:
+        opts["output_buffer"] = buf
+        connection = None
+    context = ctx(ctx_dialect, connection, opts)
+
+    alembic.op._proxy = Operations(context)
+    return context
+
+
+class AlterColRoundTripFixture:
+    # since these tests are about syntax, use more recent SQLAlchemy as some of
+    # the type / server default compare logic might not work on older
+    # SQLAlchemy versions as seems to be the case for SQLAlchemy 1.1 on Oracle
+
+    __requires__ = ("alter_column",)
+
+    def setUp(self):
+        self.conn = config.db.connect()
+        self.ctx = MigrationContext.configure(self.conn)
+        self.op = Operations(self.ctx)
+        self.metadata = MetaData()
+
+    def _compare_type(self, t1, t2):
+        c1 = Column("q", t1)
+        c2 = Column("q", t2)
+        assert not self.ctx.impl.compare_type(
+            c1, c2
+        ), "Type objects %r and %r didn't compare as equivalent" % (t1, t2)
+
+    def _compare_server_default(self, t1, s1, t2, s2):
+        c1 = Column("q", t1, server_default=s1)
+        c2 = Column("q", t2, server_default=s2)
+        assert not self.ctx.impl.compare_server_default(
+            c1, c2, s2, s1
+        ), "server defaults %r and %r didn't compare as equivalent" % (s1, s2)
+
+    def tearDown(self):
+        sqla_compat._safe_rollback_connection_transaction(self.conn)
+        with self.conn.begin():
+            self.metadata.drop_all(self.conn)
+        self.conn.close()
+
+    def _run_alter_col(self, from_, to_, compare=None):
+        column = Column(
+            from_.get("name", "colname"),
+            from_.get("type", String(10)),
+            nullable=from_.get("nullable", True),
+            server_default=from_.get("server_default", None),
+            # comment=from_.get("comment", None)
+        )
+        t = Table("x", self.metadata, column)
+
+        with sqla_compat._ensure_scope_for_ddl(self.conn):
+            t.create(self.conn)
+            insp = inspect(self.conn)
+            old_col = insp.get_columns("x")[0]
+
+            # TODO: conditional comment support
+            self.op.alter_column(
+                "x",
+                column.name,
+                existing_type=column.type,
+                existing_server_default=(
+                    column.server_default
+                    if column.server_default is not None
+                    else False
+                ),
+                existing_nullable=True if column.nullable else False,
+                # existing_comment=column.comment,
+                nullable=to_.get("nullable", None),
+                # modify_comment=False,
+                server_default=to_.get("server_default", False),
+                new_column_name=to_.get("name", None),
+                type_=to_.get("type", None),
+            )
+
+        insp = inspect(self.conn)
+        new_col = insp.get_columns("x")[0]
+
+        if compare is None:
+            compare = to_
+
+        eq_(
+            new_col["name"],
+            compare["name"] if "name" in compare else column.name,
+        )
+        self._compare_type(
+            new_col["type"], compare.get("type", old_col["type"])
+        )
+        eq_(new_col["nullable"], compare.get("nullable", column.nullable))
+        self._compare_server_default(
+            new_col["type"],
+            new_col.get("default", None),
+            compare.get("type", old_col["type"]),
+            (
+                compare["server_default"].text
+                if "server_default" in compare
+                else (
+                    column.server_default.arg.text
+                    if column.server_default is not None
+                    else None
+                )
+            ),
+        )