about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/sqlalchemy/testing/assertions.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/sqlalchemy/testing/assertions.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/sqlalchemy/testing/assertions.py')
-rw-r--r--.venv/lib/python3.12/site-packages/sqlalchemy/testing/assertions.py989
1 files changed, 989 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/sqlalchemy/testing/assertions.py b/.venv/lib/python3.12/site-packages/sqlalchemy/testing/assertions.py
new file mode 100644
index 00000000..8364c15f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/sqlalchemy/testing/assertions.py
@@ -0,0 +1,989 @@
+# testing/assertions.py
+# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: ignore-errors
+
+
+from __future__ import annotations
+
+from collections import defaultdict
+import contextlib
+from copy import copy
+from itertools import filterfalse
+import re
+import sys
+import warnings
+
+from . import assertsql
+from . import config
+from . import engines
+from . import mock
+from .exclusions import db_spec
+from .util import fail
+from .. import exc as sa_exc
+from .. import schema
+from .. import sql
+from .. import types as sqltypes
+from .. import util
+from ..engine import default
+from ..engine import url
+from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
+from ..util import decorator
+
+
+def expect_warnings(*messages, **kw):
+    """Context manager which expects one or more warnings.
+
+    With no arguments, squelches all SAWarning emitted via
+    sqlalchemy.util.warn and sqlalchemy.util.warn_limited.   Otherwise
+    pass string expressions that will match selected warnings via regex;
+    all non-matching warnings are sent through.
+
+    The expect version **asserts** that the warnings were in fact seen.
+
+    Note that the test suite sets SAWarning warnings to raise exceptions.
+
+    """  # noqa
+    return _expect_warnings_sqla_only(sa_exc.SAWarning, messages, **kw)
+
+
+@contextlib.contextmanager
+def expect_warnings_on(db, *messages, **kw):
+    """Context manager which expects one or more warnings on specific
+    dialects.
+
+    The expect version **asserts** that the warnings were in fact seen.
+
+    """
+    spec = db_spec(db)
+
+    if isinstance(db, str) and not spec(config._current):
+        yield
+    else:
+        with expect_warnings(*messages, **kw):
+            yield
+
+
+def emits_warning(*messages):
+    """Decorator form of expect_warnings().
+
+    Note that emits_warning does **not** assert that the warnings
+    were in fact seen.
+
+    """
+
+    @decorator
+    def decorate(fn, *args, **kw):
+        with expect_warnings(assert_=False, *messages):
+            return fn(*args, **kw)
+
+    return decorate
+
+
+def expect_deprecated(*messages, **kw):
+    return _expect_warnings_sqla_only(
+        sa_exc.SADeprecationWarning, messages, **kw
+    )
+
+
+def expect_deprecated_20(*messages, **kw):
+    return _expect_warnings_sqla_only(
+        sa_exc.Base20DeprecationWarning, messages, **kw
+    )
+
+
+def emits_warning_on(db, *messages):
+    """Mark a test as emitting a warning on a specific dialect.
+
+    With no arguments, squelches all SAWarning failures.  Or pass one or more
+    strings; these will be matched to the root of the warning description by
+    warnings.filterwarnings().
+
+    Note that emits_warning_on does **not** assert that the warnings
+    were in fact seen.
+
+    """
+
+    @decorator
+    def decorate(fn, *args, **kw):
+        with expect_warnings_on(db, assert_=False, *messages):
+            return fn(*args, **kw)
+
+    return decorate
+
+
+def uses_deprecated(*messages):
+    """Mark a test as immune from fatal deprecation warnings.
+
+    With no arguments, squelches all SADeprecationWarning failures.
+    Or pass one or more strings; these will be matched to the root
+    of the warning description by warnings.filterwarnings().
+
+    As a special case, you may pass a function name prefixed with //
+    and it will be re-written as needed to match the standard warning
+    verbiage emitted by the sqlalchemy.util.deprecated decorator.
+
+    Note that uses_deprecated does **not** assert that the warnings
+    were in fact seen.
+
+    """
+
+    @decorator
+    def decorate(fn, *args, **kw):
+        with expect_deprecated(*messages, assert_=False):
+            return fn(*args, **kw)
+
+    return decorate
+
+
+_FILTERS = None
+_SEEN = None
+_EXC_CLS = None
+
+
+def _expect_warnings_sqla_only(
+    exc_cls,
+    messages,
+    regex=True,
+    search_msg=False,
+    assert_=True,
+):
+    """SQLAlchemy internal use only _expect_warnings().
+
+    Alembic is using _expect_warnings() directly, and should be updated
+    to use this new interface.
+
+    """
+    return _expect_warnings(
+        exc_cls,
+        messages,
+        regex=regex,
+        search_msg=search_msg,
+        assert_=assert_,
+        raise_on_any_unexpected=True,
+    )
+
+
+@contextlib.contextmanager
+def _expect_warnings(
+    exc_cls,
+    messages,
+    regex=True,
+    search_msg=False,
+    assert_=True,
+    raise_on_any_unexpected=False,
+    squelch_other_warnings=False,
+):
+    global _FILTERS, _SEEN, _EXC_CLS
+
+    if regex or search_msg:
+        filters = [re.compile(msg, re.I | re.S) for msg in messages]
+    else:
+        filters = list(messages)
+
+    if _FILTERS is not None:
+        # nested call; update _FILTERS and _SEEN, return.  outer
+        # block will assert our messages
+        assert _SEEN is not None
+        assert _EXC_CLS is not None
+        _FILTERS.extend(filters)
+        _SEEN.update(filters)
+        _EXC_CLS += (exc_cls,)
+        yield
+    else:
+        seen = _SEEN = set(filters)
+        _FILTERS = filters
+        _EXC_CLS = (exc_cls,)
+
+        if raise_on_any_unexpected:
+
+            def real_warn(msg, *arg, **kw):
+                raise AssertionError("Got unexpected warning: %r" % msg)
+
+        else:
+            real_warn = warnings.warn
+
+        def our_warn(msg, *arg, **kw):
+            if isinstance(msg, _EXC_CLS):
+                exception = type(msg)
+                msg = str(msg)
+            elif arg:
+                exception = arg[0]
+            else:
+                exception = None
+
+            if not exception or not issubclass(exception, _EXC_CLS):
+                if not squelch_other_warnings:
+                    return real_warn(msg, *arg, **kw)
+                else:
+                    return
+
+            if not filters and not raise_on_any_unexpected:
+                return
+
+            for filter_ in filters:
+                if (
+                    (search_msg and filter_.search(msg))
+                    or (regex and filter_.match(msg))
+                    or (not regex and filter_ == msg)
+                ):
+                    seen.discard(filter_)
+                    break
+            else:
+                if not squelch_other_warnings:
+                    real_warn(msg, *arg, **kw)
+
+        with mock.patch("warnings.warn", our_warn):
+            try:
+                yield
+            finally:
+                _SEEN = _FILTERS = _EXC_CLS = None
+
+                if assert_:
+                    assert not seen, "Warnings were not seen: %s" % ", ".join(
+                        "%r" % (s.pattern if regex else s) for s in seen
+                    )
+
+
+def global_cleanup_assertions():
+    """Check things that have to be finalized at the end of a test suite.
+
+    Hardcoded at the moment, a modular system can be built here
+    to support things like PG prepared transactions, tables all
+    dropped, etc.
+
+    """
+    _assert_no_stray_pool_connections()
+
+
+def _assert_no_stray_pool_connections():
+    engines.testing_reaper.assert_all_closed()
+
+
+def int_within_variance(expected, received, variance):
+    deviance = int(expected * variance)
+    assert (
+        abs(received - expected) < deviance
+    ), "Given int value %s is not within %d%% of expected value %s" % (
+        received,
+        variance * 100,
+        expected,
+    )
+
+
+def eq_regex(a, b, msg=None):
+    assert re.match(b, a), msg or "%r !~ %r" % (a, b)
+
+
+def eq_(a, b, msg=None):
+    """Assert a == b, with repr messaging on failure."""
+    assert a == b, msg or "%r != %r" % (a, b)
+
+
+def ne_(a, b, msg=None):
+    """Assert a != b, with repr messaging on failure."""
+    assert a != b, msg or "%r == %r" % (a, b)
+
+
+def le_(a, b, msg=None):
+    """Assert a <= b, with repr messaging on failure."""
+    assert a <= b, msg or "%r != %r" % (a, b)
+
+
+def is_instance_of(a, b, msg=None):
+    assert isinstance(a, b), msg or "%r is not an instance of %r" % (a, b)
+
+
+def is_none(a, msg=None):
+    is_(a, None, msg=msg)
+
+
+def is_not_none(a, msg=None):
+    is_not(a, None, msg=msg)
+
+
+def is_true(a, msg=None):
+    is_(bool(a), True, msg=msg)
+
+
+def is_false(a, msg=None):
+    is_(bool(a), False, msg=msg)
+
+
+def is_(a, b, msg=None):
+    """Assert a is b, with repr messaging on failure."""
+    assert a is b, msg or "%r is not %r" % (a, b)
+
+
+def is_not(a, b, msg=None):
+    """Assert a is not b, with repr messaging on failure."""
+    assert a is not b, msg or "%r is %r" % (a, b)
+
+
+# deprecated.  See #5429
+is_not_ = is_not
+
+
+def in_(a, b, msg=None):
+    """Assert a in b, with repr messaging on failure."""
+    assert a in b, msg or "%r not in %r" % (a, b)
+
+
+def not_in(a, b, msg=None):
+    """Assert a in not b, with repr messaging on failure."""
+    assert a not in b, msg or "%r is in %r" % (a, b)
+
+
+# deprecated.  See #5429
+not_in_ = not_in
+
+
+def startswith_(a, fragment, msg=None):
+    """Assert a.startswith(fragment), with repr messaging on failure."""
+    assert a.startswith(fragment), msg or "%r does not start with %r" % (
+        a,
+        fragment,
+    )
+
+
+def eq_ignore_whitespace(a, b, msg=None):
+    a = re.sub(r"^\s+?|\n", "", a)
+    a = re.sub(r" {2,}", " ", a)
+    a = re.sub(r"\t", "", a)
+    b = re.sub(r"^\s+?|\n", "", b)
+    b = re.sub(r" {2,}", " ", b)
+    b = re.sub(r"\t", "", b)
+
+    assert a == b, msg or "%r != %r" % (a, b)
+
+
+def _assert_proper_exception_context(exception):
+    """assert that any exception we're catching does not have a __context__
+    without a __cause__, and that __suppress_context__ is never set.
+
+    Python 3 will report nested as exceptions as "during the handling of
+    error X, error Y occurred". That's not what we want to do.  we want
+    these exceptions in a cause chain.
+
+    """
+
+    if (
+        exception.__context__ is not exception.__cause__
+        and not exception.__suppress_context__
+    ):
+        assert False, (
+            "Exception %r was correctly raised but did not set a cause, "
+            "within context %r as its cause."
+            % (exception, exception.__context__)
+        )
+
+
+def assert_raises(except_cls, callable_, *args, **kw):
+    return _assert_raises(except_cls, callable_, args, kw, check_context=True)
+
+
+def assert_raises_context_ok(except_cls, callable_, *args, **kw):
+    return _assert_raises(except_cls, callable_, args, kw)
+
+
+def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
+    return _assert_raises(
+        except_cls, callable_, args, kwargs, msg=msg, check_context=True
+    )
+
+
+def assert_warns(except_cls, callable_, *args, **kwargs):
+    """legacy adapter function for functions that were previously using
+    assert_raises with SAWarning or similar.
+
+    has some workarounds to accommodate the fact that the callable completes
+    with this approach rather than stopping at the exception raise.
+
+
+    """
+    with _expect_warnings_sqla_only(except_cls, [".*"]):
+        return callable_(*args, **kwargs)
+
+
+def assert_warns_message(except_cls, msg, callable_, *args, **kwargs):
+    """legacy adapter function for functions that were previously using
+    assert_raises with SAWarning or similar.
+
+    has some workarounds to accommodate the fact that the callable completes
+    with this approach rather than stopping at the exception raise.
+
+    Also uses regex.search() to match the given message to the error string
+    rather than regex.match().
+
+    """
+    with _expect_warnings_sqla_only(
+        except_cls,
+        [msg],
+        search_msg=True,
+        regex=False,
+    ):
+        return callable_(*args, **kwargs)
+
+
+def assert_raises_message_context_ok(
+    except_cls, msg, callable_, *args, **kwargs
+):
+    return _assert_raises(except_cls, callable_, args, kwargs, msg=msg)
+
+
+def _assert_raises(
+    except_cls, callable_, args, kwargs, msg=None, check_context=False
+):
+    with _expect_raises(except_cls, msg, check_context) as ec:
+        callable_(*args, **kwargs)
+    return ec.error
+
+
+class _ErrorContainer:
+    error = None
+
+
+@contextlib.contextmanager
+def _expect_raises(except_cls, msg=None, check_context=False):
+    if (
+        isinstance(except_cls, type)
+        and issubclass(except_cls, Warning)
+        or isinstance(except_cls, Warning)
+    ):
+        raise TypeError(
+            "Use expect_warnings for warnings, not "
+            "expect_raises / assert_raises"
+        )
+    ec = _ErrorContainer()
+    if check_context:
+        are_we_already_in_a_traceback = sys.exc_info()[0]
+    try:
+        yield ec
+        success = False
+    except except_cls as err:
+        ec.error = err
+        success = True
+        if msg is not None:
+            # I'm often pdbing here, and "err" above isn't
+            # in scope, so assign the string explicitly
+            error_as_string = str(err)
+            assert re.search(msg, error_as_string, re.UNICODE), "%r !~ %s" % (
+                msg,
+                error_as_string,
+            )
+        if check_context and not are_we_already_in_a_traceback:
+            _assert_proper_exception_context(err)
+        print(str(err).encode("utf-8"))
+
+    # it's generally a good idea to not carry traceback objects outside
+    # of the except: block, but in this case especially we seem to have
+    # hit some bug in either python 3.10.0b2 or greenlet or both which
+    # this seems to fix:
+    # https://github.com/python-greenlet/greenlet/issues/242
+    del ec
+
+    # assert outside the block so it works for AssertionError too !
+    assert success, "Callable did not raise an exception"
+
+
+def expect_raises(except_cls, check_context=True):
+    return _expect_raises(except_cls, check_context=check_context)
+
+
+def expect_raises_message(except_cls, msg, check_context=True):
+    return _expect_raises(except_cls, msg=msg, check_context=check_context)
+
+
+class AssertsCompiledSQL:
+    def assert_compile(
+        self,
+        clause,
+        result,
+        params=None,
+        checkparams=None,
+        for_executemany=False,
+        check_literal_execute=None,
+        check_post_param=None,
+        dialect=None,
+        checkpositional=None,
+        check_prefetch=None,
+        use_default_dialect=False,
+        allow_dialect_select=False,
+        supports_default_values=True,
+        supports_default_metavalue=True,
+        literal_binds=False,
+        render_postcompile=False,
+        schema_translate_map=None,
+        render_schema_translate=False,
+        default_schema_name=None,
+        from_linting=False,
+        check_param_order=True,
+        use_literal_execute_for_simple_int=False,
+    ):
+        if use_default_dialect:
+            dialect = default.DefaultDialect()
+            dialect.supports_default_values = supports_default_values
+            dialect.supports_default_metavalue = supports_default_metavalue
+        elif allow_dialect_select:
+            dialect = None
+        else:
+            if dialect is None:
+                dialect = getattr(self, "__dialect__", None)
+
+            if dialect is None:
+                dialect = config.db.dialect
+            elif dialect == "default" or dialect == "default_qmark":
+                if dialect == "default":
+                    dialect = default.DefaultDialect()
+                else:
+                    dialect = default.DefaultDialect("qmark")
+                dialect.supports_default_values = supports_default_values
+                dialect.supports_default_metavalue = supports_default_metavalue
+            elif dialect == "default_enhanced":
+                dialect = default.StrCompileDialect()
+            elif isinstance(dialect, str):
+                dialect = url.URL.create(dialect).get_dialect()()
+
+        if default_schema_name:
+            dialect.default_schema_name = default_schema_name
+
+        kw = {}
+        compile_kwargs = {}
+
+        if schema_translate_map:
+            kw["schema_translate_map"] = schema_translate_map
+
+        if params is not None:
+            kw["column_keys"] = list(params)
+
+        if literal_binds:
+            compile_kwargs["literal_binds"] = True
+
+        if render_postcompile:
+            compile_kwargs["render_postcompile"] = True
+
+        if use_literal_execute_for_simple_int:
+            compile_kwargs["use_literal_execute_for_simple_int"] = True
+
+        if for_executemany:
+            kw["for_executemany"] = True
+
+        if render_schema_translate:
+            kw["render_schema_translate"] = True
+
+        if from_linting or getattr(self, "assert_from_linting", False):
+            kw["linting"] = sql.FROM_LINTING
+
+        from sqlalchemy import orm
+
+        if isinstance(clause, orm.Query):
+            stmt = clause._statement_20()
+            stmt._label_style = LABEL_STYLE_TABLENAME_PLUS_COL
+            clause = stmt
+
+        if compile_kwargs:
+            kw["compile_kwargs"] = compile_kwargs
+
+        class DontAccess:
+            def __getattribute__(self, key):
+                raise NotImplementedError(
+                    "compiler accessed .statement; use "
+                    "compiler.current_executable"
+                )
+
+        class CheckCompilerAccess:
+            def __init__(self, test_statement):
+                self.test_statement = test_statement
+                self._annotations = {}
+                self.supports_execution = getattr(
+                    test_statement, "supports_execution", False
+                )
+
+                if self.supports_execution:
+                    self._execution_options = test_statement._execution_options
+
+                    if hasattr(test_statement, "_returning"):
+                        self._returning = test_statement._returning
+                    if hasattr(test_statement, "_inline"):
+                        self._inline = test_statement._inline
+                    if hasattr(test_statement, "_return_defaults"):
+                        self._return_defaults = test_statement._return_defaults
+
+            @property
+            def _variant_mapping(self):
+                return self.test_statement._variant_mapping
+
+            def _default_dialect(self):
+                return self.test_statement._default_dialect()
+
+            def compile(self, dialect, **kw):
+                return self.test_statement.compile.__func__(
+                    self, dialect=dialect, **kw
+                )
+
+            def _compiler(self, dialect, **kw):
+                return self.test_statement._compiler.__func__(
+                    self, dialect, **kw
+                )
+
+            def _compiler_dispatch(self, compiler, **kwargs):
+                if hasattr(compiler, "statement"):
+                    with mock.patch.object(
+                        compiler, "statement", DontAccess()
+                    ):
+                        return self.test_statement._compiler_dispatch(
+                            compiler, **kwargs
+                        )
+                else:
+                    return self.test_statement._compiler_dispatch(
+                        compiler, **kwargs
+                    )
+
+        # no construct can assume it's the "top level" construct in all cases
+        # as anything can be nested.  ensure constructs don't assume they
+        # are the "self.statement" element
+        c = CheckCompilerAccess(clause).compile(dialect=dialect, **kw)
+
+        if isinstance(clause, sqltypes.TypeEngine):
+            cache_key_no_warnings = clause._static_cache_key
+            if cache_key_no_warnings:
+                hash(cache_key_no_warnings)
+        else:
+            cache_key_no_warnings = clause._generate_cache_key()
+            if cache_key_no_warnings:
+                hash(cache_key_no_warnings[0])
+
+        param_str = repr(getattr(c, "params", {}))
+        param_str = param_str.encode("utf-8").decode("ascii", "ignore")
+        print(("\nSQL String:\n" + str(c) + param_str).encode("utf-8"))
+
+        cc = re.sub(r"[\n\t]", "", str(c))
+
+        eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect))
+
+        if checkparams is not None:
+            if render_postcompile:
+                expanded_state = c.construct_expanded_state(
+                    params, escape_names=False
+                )
+                eq_(expanded_state.parameters, checkparams)
+            else:
+                eq_(c.construct_params(params), checkparams)
+        if checkpositional is not None:
+            if render_postcompile:
+                expanded_state = c.construct_expanded_state(
+                    params, escape_names=False
+                )
+                eq_(
+                    tuple(
+                        [
+                            expanded_state.parameters[x]
+                            for x in expanded_state.positiontup
+                        ]
+                    ),
+                    checkpositional,
+                )
+            else:
+                p = c.construct_params(params, escape_names=False)
+                eq_(tuple([p[x] for x in c.positiontup]), checkpositional)
+        if check_prefetch is not None:
+            eq_(c.prefetch, check_prefetch)
+        if check_literal_execute is not None:
+            eq_(
+                {
+                    c.bind_names[b]: b.effective_value
+                    for b in c.literal_execute_params
+                },
+                check_literal_execute,
+            )
+        if check_post_param is not None:
+            eq_(
+                {
+                    c.bind_names[b]: b.effective_value
+                    for b in c.post_compile_params
+                },
+                check_post_param,
+            )
+        if check_param_order and getattr(c, "params", None):
+
+            def get_dialect(paramstyle, positional):
+                cp = copy(dialect)
+                cp.paramstyle = paramstyle
+                cp.positional = positional
+                return cp
+
+            pyformat_dialect = get_dialect("pyformat", False)
+            pyformat_c = clause.compile(dialect=pyformat_dialect, **kw)
+            stmt = re.sub(r"[\n\t]", "", str(pyformat_c))
+
+            qmark_dialect = get_dialect("qmark", True)
+            qmark_c = clause.compile(dialect=qmark_dialect, **kw)
+            values = list(qmark_c.positiontup)
+            escaped = qmark_c.escaped_bind_names
+
+            for post_param in (
+                qmark_c.post_compile_params | qmark_c.literal_execute_params
+            ):
+                name = qmark_c.bind_names[post_param]
+                if name in values:
+                    values = [v for v in values if v != name]
+            positions = []
+            pos_by_value = defaultdict(list)
+            for v in values:
+                try:
+                    if v in pos_by_value:
+                        start = pos_by_value[v][-1]
+                    else:
+                        start = 0
+                    esc = escaped.get(v, v)
+                    pos = stmt.index("%%(%s)s" % (esc,), start) + 2
+                    positions.append(pos)
+                    pos_by_value[v].append(pos)
+                except ValueError:
+                    msg = "Expected to find bindparam %r in %r" % (v, stmt)
+                    assert False, msg
+
+            ordered = all(
+                positions[i - 1] < positions[i]
+                for i in range(1, len(positions))
+            )
+
+            expected = [v for _, v in sorted(zip(positions, values))]
+
+            msg = (
+                "Order of parameters %s does not match the order "
+                "in the statement %s. Statement %r" % (values, expected, stmt)
+            )
+
+            is_true(ordered, msg)
+
+
+class ComparesTables:
+    def assert_tables_equal(
+        self,
+        table,
+        reflected_table,
+        strict_types=False,
+        strict_constraints=True,
+    ):
+        assert len(table.c) == len(reflected_table.c)
+        for c, reflected_c in zip(table.c, reflected_table.c):
+            eq_(c.name, reflected_c.name)
+            assert reflected_c is reflected_table.c[c.name]
+
+            if strict_constraints:
+                eq_(c.primary_key, reflected_c.primary_key)
+                eq_(c.nullable, reflected_c.nullable)
+
+            if strict_types:
+                msg = "Type '%s' doesn't correspond to type '%s'"
+                assert isinstance(reflected_c.type, type(c.type)), msg % (
+                    reflected_c.type,
+                    c.type,
+                )
+            else:
+                self.assert_types_base(reflected_c, c)
+
+            if isinstance(c.type, sqltypes.String):
+                eq_(c.type.length, reflected_c.type.length)
+
+            if strict_constraints:
+                eq_(
+                    {f.column.name for f in c.foreign_keys},
+                    {f.column.name for f in reflected_c.foreign_keys},
+                )
+            if c.server_default:
+                assert isinstance(
+                    reflected_c.server_default, schema.FetchedValue
+                )
+
+        if strict_constraints:
+            assert len(table.primary_key) == len(reflected_table.primary_key)
+            for c in table.primary_key:
+                assert reflected_table.primary_key.columns[c.name] is not None
+
+    def assert_types_base(self, c1, c2):
+        assert c1.type._compare_type_affinity(
+            c2.type
+        ), "On column %r, type '%s' doesn't correspond to type '%s'" % (
+            c1.name,
+            c1.type,
+            c2.type,
+        )
+
+
+class AssertsExecutionResults:
+    def assert_result(self, result, class_, *objects):
+        result = list(result)
+        print(repr(result))
+        self.assert_list(result, class_, objects)
+
+    def assert_list(self, result, class_, list_):
+        self.assert_(
+            len(result) == len(list_),
+            "result list is not the same size as test list, "
+            + "for class "
+            + class_.__name__,
+        )
+        for i in range(0, len(list_)):
+            self.assert_row(class_, result[i], list_[i])
+
+    def assert_row(self, class_, rowobj, desc):
+        self.assert_(
+            rowobj.__class__ is class_, "item class is not " + repr(class_)
+        )
+        for key, value in desc.items():
+            if isinstance(value, tuple):
+                if isinstance(value[1], list):
+                    self.assert_list(getattr(rowobj, key), value[0], value[1])
+                else:
+                    self.assert_row(value[0], getattr(rowobj, key), value[1])
+            else:
+                self.assert_(
+                    getattr(rowobj, key) == value,
+                    "attribute %s value %s does not match %s"
+                    % (key, getattr(rowobj, key), value),
+                )
+
+    def assert_unordered_result(self, result, cls, *expected):
+        """As assert_result, but the order of objects is not considered.
+
+        The algorithm is very expensive but not a big deal for the small
+        numbers of rows that the test suite manipulates.
+        """
+
+        class immutabledict(dict):
+            def __hash__(self):
+                return id(self)
+
+        found = util.IdentitySet(result)
+        expected = {immutabledict(e) for e in expected}
+
+        for wrong in filterfalse(lambda o: isinstance(o, cls), found):
+            fail(
+                'Unexpected type "%s", expected "%s"'
+                % (type(wrong).__name__, cls.__name__)
+            )
+
+        if len(found) != len(expected):
+            fail(
+                'Unexpected object count "%s", expected "%s"'
+                % (len(found), len(expected))
+            )
+
+        NOVALUE = object()
+
+        def _compare_item(obj, spec):
+            for key, value in spec.items():
+                if isinstance(value, tuple):
+                    try:
+                        self.assert_unordered_result(
+                            getattr(obj, key), value[0], *value[1]
+                        )
+                    except AssertionError:
+                        return False
+                else:
+                    if getattr(obj, key, NOVALUE) != value:
+                        return False
+            return True
+
+        for expected_item in expected:
+            for found_item in found:
+                if _compare_item(found_item, expected_item):
+                    found.remove(found_item)
+                    break
+            else:
+                fail(
+                    "Expected %s instance with attributes %s not found."
+                    % (cls.__name__, repr(expected_item))
+                )
+        return True
+
+    def sql_execution_asserter(self, db=None):
+        if db is None:
+            from . import db as db
+
+        return assertsql.assert_engine(db)
+
+    def assert_sql_execution(self, db, callable_, *rules):
+        with self.sql_execution_asserter(db) as asserter:
+            result = callable_()
+        asserter.assert_(*rules)
+        return result
+
+    def assert_sql(self, db, callable_, rules):
+        newrules = []
+        for rule in rules:
+            if isinstance(rule, dict):
+                newrule = assertsql.AllOf(
+                    *[assertsql.CompiledSQL(k, v) for k, v in rule.items()]
+                )
+            else:
+                newrule = assertsql.CompiledSQL(*rule)
+            newrules.append(newrule)
+
+        return self.assert_sql_execution(db, callable_, *newrules)
+
+    def assert_sql_count(self, db, callable_, count):
+        return self.assert_sql_execution(
+            db, callable_, assertsql.CountStatements(count)
+        )
+
+    @contextlib.contextmanager
+    def assert_execution(self, db, *rules):
+        with self.sql_execution_asserter(db) as asserter:
+            yield
+        asserter.assert_(*rules)
+
+    def assert_statement_count(self, db, count):
+        return self.assert_execution(db, assertsql.CountStatements(count))
+
+    @contextlib.contextmanager
+    def assert_statement_count_multi_db(self, dbs, counts):
+        recs = [
+            (self.sql_execution_asserter(db), db, count)
+            for (db, count) in zip(dbs, counts)
+        ]
+        asserters = []
+        for ctx, db, count in recs:
+            asserters.append(ctx.__enter__())
+        try:
+            yield
+        finally:
+            for asserter, (ctx, db, count) in zip(asserters, recs):
+                ctx.__exit__(None, None, None)
+                asserter.assert_(assertsql.CountStatements(count))
+
+
+class ComparesIndexes:
+    def compare_table_index_with_expected(
+        self, table: schema.Table, expected: list, dialect_name: str
+    ):
+        eq_(len(table.indexes), len(expected))
+        idx_dict = {idx.name: idx for idx in table.indexes}
+        for exp in expected:
+            idx = idx_dict[exp["name"]]
+            eq_(idx.unique, exp["unique"])
+            cols = [c for c in exp["column_names"] if c is not None]
+            eq_(len(idx.columns), len(cols))
+            for c in cols:
+                is_true(c in idx.columns)
+            exprs = exp.get("expressions")
+            if exprs:
+                eq_(len(idx.expressions), len(exprs))
+                for idx_exp, expr, col in zip(
+                    idx.expressions, exprs, exp["column_names"]
+                ):
+                    if col is None:
+                        eq_(idx_exp.text, expr)
+            if (
+                exp.get("dialect_options")
+                and f"{dialect_name}_include" in exp["dialect_options"]
+            ):
+                eq_(
+                    idx.dialect_options[dialect_name]["include"],
+                    exp["dialect_options"][f"{dialect_name}_include"],
+                )