aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/alembic/testing/fixtures.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/alembic/testing/fixtures.py')
-rw-r--r--.venv/lib/python3.12/site-packages/alembic/testing/fixtures.py306
1 files changed, 306 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/alembic/testing/fixtures.py b/.venv/lib/python3.12/site-packages/alembic/testing/fixtures.py
new file mode 100644
index 00000000..17d732f8
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/alembic/testing/fixtures.py
@@ -0,0 +1,306 @@
+from __future__ import annotations
+
+import configparser
+from contextlib import contextmanager
+import io
+import re
+from typing import Any
+from typing import Dict
+
+from sqlalchemy import Column
+from sqlalchemy import create_mock_engine
+from sqlalchemy import inspect
+from sqlalchemy import MetaData
+from sqlalchemy import String
+from sqlalchemy import Table
+from sqlalchemy import testing
+from sqlalchemy import text
+from sqlalchemy.testing import config
+from sqlalchemy.testing import mock
+from sqlalchemy.testing.assertions import eq_
+from sqlalchemy.testing.fixtures import FutureEngineMixin
+from sqlalchemy.testing.fixtures import TablesTest as SQLAlchemyTablesTest
+from sqlalchemy.testing.fixtures import TestBase as SQLAlchemyTestBase
+
+import alembic
+from .assertions import _get_dialect
+from ..environment import EnvironmentContext
+from ..migration import MigrationContext
+from ..operations import Operations
+from ..util import sqla_compat
+from ..util.sqla_compat import sqla_2
+
+
+testing_config = configparser.ConfigParser()
+testing_config.read(["test.cfg"])
+
+
+class TestBase(SQLAlchemyTestBase):
+ is_sqlalchemy_future = sqla_2
+
+ @testing.fixture()
+ def ops_context(self, migration_context):
+ with migration_context.begin_transaction(_per_migration=True):
+ yield Operations(migration_context)
+
+ @testing.fixture
+ def migration_context(self, connection):
+ return MigrationContext.configure(
+ connection, opts=dict(transaction_per_migration=True)
+ )
+
+ @testing.fixture
+ def as_sql_migration_context(self, connection):
+ return MigrationContext.configure(
+ connection, opts=dict(transaction_per_migration=True, as_sql=True)
+ )
+
+ @testing.fixture
+ def connection(self):
+ with config.db.connect() as conn:
+ yield conn
+
+
+class TablesTest(TestBase, SQLAlchemyTablesTest):
+ pass
+
+
+FutureEngineMixin.is_sqlalchemy_future = True
+
+
+def capture_db(dialect="postgresql://"):
+ buf = []
+
+ def dump(sql, *multiparams, **params):
+ buf.append(str(sql.compile(dialect=engine.dialect)))
+
+ engine = create_mock_engine(dialect, dump)
+ return engine, buf
+
+
+_engs: Dict[Any, Any] = {}
+
+
+@contextmanager
+def capture_context_buffer(**kw):
+ if kw.pop("bytes_io", False):
+ buf = io.BytesIO()
+ else:
+ buf = io.StringIO()
+
+ kw.update({"dialect_name": "sqlite", "output_buffer": buf})
+ conf = EnvironmentContext.configure
+
+ def configure(*arg, **opt):
+ opt.update(**kw)
+ return conf(*arg, **opt)
+
+ with mock.patch.object(EnvironmentContext, "configure", configure):
+ yield buf
+
+
+@contextmanager
+def capture_engine_context_buffer(**kw):
+ from .env import _sqlite_file_db
+ from sqlalchemy import event
+
+ buf = io.StringIO()
+
+ eng = _sqlite_file_db()
+
+ conn = eng.connect()
+
+ @event.listens_for(conn, "before_cursor_execute")
+ def bce(conn, cursor, statement, parameters, context, executemany):
+ buf.write(statement + "\n")
+
+ kw.update({"connection": conn})
+ conf = EnvironmentContext.configure
+
+ def configure(*arg, **opt):
+ opt.update(**kw)
+ return conf(*arg, **opt)
+
+ with mock.patch.object(EnvironmentContext, "configure", configure):
+ yield buf
+
+
+def op_fixture(
+ dialect="default",
+ as_sql=False,
+ naming_convention=None,
+ literal_binds=False,
+ native_boolean=None,
+):
+ opts = {}
+ if naming_convention:
+ opts["target_metadata"] = MetaData(naming_convention=naming_convention)
+
+ class buffer_:
+ def __init__(self):
+ self.lines = []
+
+ def write(self, msg):
+ msg = msg.strip()
+ msg = re.sub(r"[\n\t]", "", msg)
+ if as_sql:
+ # the impl produces soft tabs,
+ # so search for blocks of 4 spaces
+ msg = re.sub(r" ", "", msg)
+ msg = re.sub(r"\;\n*$", "", msg)
+
+ self.lines.append(msg)
+
+ def flush(self):
+ pass
+
+ buf = buffer_()
+
+ class ctx(MigrationContext):
+ def get_buf(self):
+ return buf
+
+ def clear_assertions(self):
+ buf.lines[:] = []
+
+ def assert_(self, *sql):
+ # TODO: make this more flexible about
+ # whitespace and such
+ eq_(buf.lines, [re.sub(r"[\n\t]", "", s) for s in sql])
+
+ def assert_contains(self, sql):
+ for stmt in buf.lines:
+ if re.sub(r"[\n\t]", "", sql) in stmt:
+ return
+ else:
+ assert False, "Could not locate fragment %r in %r" % (
+ sql,
+ buf.lines,
+ )
+
+ if as_sql:
+ opts["as_sql"] = as_sql
+ if literal_binds:
+ opts["literal_binds"] = literal_binds
+
+ ctx_dialect = _get_dialect(dialect)
+ if native_boolean is not None:
+ ctx_dialect.supports_native_boolean = native_boolean
+ # this is new as of SQLAlchemy 1.2.7 and is used by SQL Server,
+ # which breaks assumptions in the alembic test suite
+ ctx_dialect.non_native_boolean_check_constraint = True
+ if not as_sql:
+
+ def execute(stmt, *multiparam, **param):
+ if isinstance(stmt, str):
+ stmt = text(stmt)
+ assert stmt.supports_execution
+ sql = str(stmt.compile(dialect=ctx_dialect))
+
+ buf.write(sql)
+
+ connection = mock.Mock(dialect=ctx_dialect, execute=execute)
+ else:
+ opts["output_buffer"] = buf
+ connection = None
+ context = ctx(ctx_dialect, connection, opts)
+
+ alembic.op._proxy = Operations(context)
+ return context
+
+
+class AlterColRoundTripFixture:
+ # since these tests are about syntax, use more recent SQLAlchemy as some of
+ # the type / server default compare logic might not work on older
+ # SQLAlchemy versions as seems to be the case for SQLAlchemy 1.1 on Oracle
+
+ __requires__ = ("alter_column",)
+
+ def setUp(self):
+ self.conn = config.db.connect()
+ self.ctx = MigrationContext.configure(self.conn)
+ self.op = Operations(self.ctx)
+ self.metadata = MetaData()
+
+ def _compare_type(self, t1, t2):
+ c1 = Column("q", t1)
+ c2 = Column("q", t2)
+ assert not self.ctx.impl.compare_type(
+ c1, c2
+ ), "Type objects %r and %r didn't compare as equivalent" % (t1, t2)
+
+ def _compare_server_default(self, t1, s1, t2, s2):
+ c1 = Column("q", t1, server_default=s1)
+ c2 = Column("q", t2, server_default=s2)
+ assert not self.ctx.impl.compare_server_default(
+ c1, c2, s2, s1
+ ), "server defaults %r and %r didn't compare as equivalent" % (s1, s2)
+
+ def tearDown(self):
+ sqla_compat._safe_rollback_connection_transaction(self.conn)
+ with self.conn.begin():
+ self.metadata.drop_all(self.conn)
+ self.conn.close()
+
+ def _run_alter_col(self, from_, to_, compare=None):
+ column = Column(
+ from_.get("name", "colname"),
+ from_.get("type", String(10)),
+ nullable=from_.get("nullable", True),
+ server_default=from_.get("server_default", None),
+ # comment=from_.get("comment", None)
+ )
+ t = Table("x", self.metadata, column)
+
+ with sqla_compat._ensure_scope_for_ddl(self.conn):
+ t.create(self.conn)
+ insp = inspect(self.conn)
+ old_col = insp.get_columns("x")[0]
+
+ # TODO: conditional comment support
+ self.op.alter_column(
+ "x",
+ column.name,
+ existing_type=column.type,
+ existing_server_default=(
+ column.server_default
+ if column.server_default is not None
+ else False
+ ),
+ existing_nullable=True if column.nullable else False,
+ # existing_comment=column.comment,
+ nullable=to_.get("nullable", None),
+ # modify_comment=False,
+ server_default=to_.get("server_default", False),
+ new_column_name=to_.get("name", None),
+ type_=to_.get("type", None),
+ )
+
+ insp = inspect(self.conn)
+ new_col = insp.get_columns("x")[0]
+
+ if compare is None:
+ compare = to_
+
+ eq_(
+ new_col["name"],
+ compare["name"] if "name" in compare else column.name,
+ )
+ self._compare_type(
+ new_col["type"], compare.get("type", old_col["type"])
+ )
+ eq_(new_col["nullable"], compare.get("nullable", column.nullable))
+ self._compare_server_default(
+ new_col["type"],
+ new_col.get("default", None),
+ compare.get("type", old_col["type"]),
+ (
+ compare["server_default"].text
+ if "server_default" in compare
+ else (
+ column.server_default.arg.text
+ if column.server_default is not None
+ else None
+ )
+ ),
+ )