about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/sqlalchemy/testing/assertsql.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/sqlalchemy/testing/assertsql.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/sqlalchemy/testing/assertsql.py')
-rw-r--r--.venv/lib/python3.12/site-packages/sqlalchemy/testing/assertsql.py516
1 files changed, 516 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/sqlalchemy/testing/assertsql.py b/.venv/lib/python3.12/site-packages/sqlalchemy/testing/assertsql.py
new file mode 100644
index 00000000..81c7138c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/sqlalchemy/testing/assertsql.py
@@ -0,0 +1,516 @@
+# testing/assertsql.py
+# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: ignore-errors
+
+
+from __future__ import annotations
+
+import collections
+import contextlib
+import itertools
+import re
+
+from .. import event
+from ..engine import url
+from ..engine.default import DefaultDialect
+from ..schema import BaseDDLElement
+
+
+class AssertRule:
+    is_consumed = False
+    errormessage = None
+    consume_statement = True
+
+    def process_statement(self, execute_observed):
+        pass
+
+    def no_more_statements(self):
+        assert False, (
+            "All statements are complete, but pending "
+            "assertion rules remain"
+        )
+
+
+class SQLMatchRule(AssertRule):
+    pass
+
+
+class CursorSQL(SQLMatchRule):
+    def __init__(self, statement, params=None, consume_statement=True):
+        self.statement = statement
+        self.params = params
+        self.consume_statement = consume_statement
+
+    def process_statement(self, execute_observed):
+        stmt = execute_observed.statements[0]
+        if self.statement != stmt.statement or (
+            self.params is not None and self.params != stmt.parameters
+        ):
+            self.consume_statement = True
+            self.errormessage = (
+                "Testing for exact SQL %s parameters %s received %s %s"
+                % (
+                    self.statement,
+                    self.params,
+                    stmt.statement,
+                    stmt.parameters,
+                )
+            )
+        else:
+            execute_observed.statements.pop(0)
+            self.is_consumed = True
+            if not execute_observed.statements:
+                self.consume_statement = True
+
+
+class CompiledSQL(SQLMatchRule):
+    def __init__(
+        self, statement, params=None, dialect="default", enable_returning=True
+    ):
+        self.statement = statement
+        self.params = params
+        self.dialect = dialect
+        self.enable_returning = enable_returning
+
+    def _compare_sql(self, execute_observed, received_statement):
+        stmt = re.sub(r"[\n\t]", "", self.statement)
+        return received_statement == stmt
+
+    def _compile_dialect(self, execute_observed):
+        if self.dialect == "default":
+            dialect = DefaultDialect()
+            # this is currently what tests are expecting
+            # dialect.supports_default_values = True
+            dialect.supports_default_metavalue = True
+
+            if self.enable_returning:
+                dialect.insert_returning = dialect.update_returning = (
+                    dialect.delete_returning
+                ) = True
+                dialect.use_insertmanyvalues = True
+                dialect.supports_multivalues_insert = True
+                dialect.update_returning_multifrom = True
+                dialect.delete_returning_multifrom = True
+                # dialect.favor_returning_over_lastrowid = True
+                # dialect.insert_null_pk_still_autoincrements = True
+
+                # this is calculated but we need it to be True for this
+                # to look like all the current RETURNING dialects
+                assert dialect.insert_executemany_returning
+
+            return dialect
+        else:
+            return url.URL.create(self.dialect).get_dialect()()
+
+    def _received_statement(self, execute_observed):
+        """reconstruct the statement and params in terms
+        of a target dialect, which for CompiledSQL is just DefaultDialect."""
+
+        context = execute_observed.context
+        compare_dialect = self._compile_dialect(execute_observed)
+
+        # received_statement runs a full compile().  we should not need to
+        # consider extracted_parameters; if we do this indicates some state
+        # is being sent from a previous cached query, which some misbehaviors
+        # in the ORM can cause, see #6881
+        cache_key = None  # execute_observed.context.compiled.cache_key
+        extracted_parameters = (
+            None  # execute_observed.context.extracted_parameters
+        )
+
+        if "schema_translate_map" in context.execution_options:
+            map_ = context.execution_options["schema_translate_map"]
+        else:
+            map_ = None
+
+        if isinstance(execute_observed.clauseelement, BaseDDLElement):
+            compiled = execute_observed.clauseelement.compile(
+                dialect=compare_dialect,
+                schema_translate_map=map_,
+            )
+        else:
+            compiled = execute_observed.clauseelement.compile(
+                cache_key=cache_key,
+                dialect=compare_dialect,
+                column_keys=context.compiled.column_keys,
+                for_executemany=context.compiled.for_executemany,
+                schema_translate_map=map_,
+            )
+        _received_statement = re.sub(r"[\n\t]", "", str(compiled))
+        parameters = execute_observed.parameters
+
+        if not parameters:
+            _received_parameters = [
+                compiled.construct_params(
+                    extracted_parameters=extracted_parameters
+                )
+            ]
+        else:
+            _received_parameters = [
+                compiled.construct_params(
+                    m, extracted_parameters=extracted_parameters
+                )
+                for m in parameters
+            ]
+
+        return _received_statement, _received_parameters
+
+    def process_statement(self, execute_observed):
+        context = execute_observed.context
+
+        _received_statement, _received_parameters = self._received_statement(
+            execute_observed
+        )
+        params = self._all_params(context)
+
+        equivalent = self._compare_sql(execute_observed, _received_statement)
+
+        if equivalent:
+            if params is not None:
+                all_params = list(params)
+                all_received = list(_received_parameters)
+                while all_params and all_received:
+                    param = dict(all_params.pop(0))
+
+                    for idx, received in enumerate(list(all_received)):
+                        # do a positive compare only
+                        for param_key in param:
+                            # a key in param did not match current
+                            # 'received'
+                            if (
+                                param_key not in received
+                                or received[param_key] != param[param_key]
+                            ):
+                                break
+                        else:
+                            # all keys in param matched 'received';
+                            # onto next param
+                            del all_received[idx]
+                            break
+                    else:
+                        # param did not match any entry
+                        # in all_received
+                        equivalent = False
+                        break
+                if all_params or all_received:
+                    equivalent = False
+
+        if equivalent:
+            self.is_consumed = True
+            self.errormessage = None
+        else:
+            self.errormessage = self._failure_message(
+                execute_observed, params
+            ) % {
+                "received_statement": _received_statement,
+                "received_parameters": _received_parameters,
+            }
+
+    def _all_params(self, context):
+        if self.params:
+            if callable(self.params):
+                params = self.params(context)
+            else:
+                params = self.params
+            if not isinstance(params, list):
+                params = [params]
+            return params
+        else:
+            return None
+
+    def _failure_message(self, execute_observed, expected_params):
+        return (
+            "Testing for compiled statement\n%r partial params %s, "
+            "received\n%%(received_statement)r with params "
+            "%%(received_parameters)r"
+            % (
+                self.statement.replace("%", "%%"),
+                repr(expected_params).replace("%", "%%"),
+            )
+        )
+
+
+class RegexSQL(CompiledSQL):
+    def __init__(
+        self, regex, params=None, dialect="default", enable_returning=False
+    ):
+        SQLMatchRule.__init__(self)
+        self.regex = re.compile(regex)
+        self.orig_regex = regex
+        self.params = params
+        self.dialect = dialect
+        self.enable_returning = enable_returning
+
+    def _failure_message(self, execute_observed, expected_params):
+        return (
+            "Testing for compiled statement ~%r partial params %s, "
+            "received %%(received_statement)r with params "
+            "%%(received_parameters)r"
+            % (
+                self.orig_regex.replace("%", "%%"),
+                repr(expected_params).replace("%", "%%"),
+            )
+        )
+
+    def _compare_sql(self, execute_observed, received_statement):
+        return bool(self.regex.match(received_statement))
+
+
+class DialectSQL(CompiledSQL):
+    def _compile_dialect(self, execute_observed):
+        return execute_observed.context.dialect
+
+    def _compare_no_space(self, real_stmt, received_stmt):
+        stmt = re.sub(r"[\n\t]", "", real_stmt)
+        return received_stmt == stmt
+
+    def _received_statement(self, execute_observed):
+        received_stmt, received_params = super()._received_statement(
+            execute_observed
+        )
+
+        # TODO: why do we need this part?
+        for real_stmt in execute_observed.statements:
+            if self._compare_no_space(real_stmt.statement, received_stmt):
+                break
+        else:
+            raise AssertionError(
+                "Can't locate compiled statement %r in list of "
+                "statements actually invoked" % received_stmt
+            )
+
+        return received_stmt, execute_observed.context.compiled_parameters
+
+    def _dialect_adjusted_statement(self, dialect):
+        paramstyle = dialect.paramstyle
+        stmt = re.sub(r"[\n\t]", "", self.statement)
+
+        # temporarily escape out PG double colons
+        stmt = stmt.replace("::", "!!")
+
+        if paramstyle == "pyformat":
+            stmt = re.sub(r":([\w_]+)", r"%(\1)s", stmt)
+        else:
+            # positional params
+            repl = None
+            if paramstyle == "qmark":
+                repl = "?"
+            elif paramstyle == "format":
+                repl = r"%s"
+            elif paramstyle.startswith("numeric"):
+                counter = itertools.count(1)
+
+                num_identifier = "$" if paramstyle == "numeric_dollar" else ":"
+
+                def repl(m):
+                    return f"{num_identifier}{next(counter)}"
+
+            stmt = re.sub(r":([\w_]+)", repl, stmt)
+
+        # put them back
+        stmt = stmt.replace("!!", "::")
+
+        return stmt
+
+    def _compare_sql(self, execute_observed, received_statement):
+        stmt = self._dialect_adjusted_statement(
+            execute_observed.context.dialect
+        )
+        return received_statement == stmt
+
+    def _failure_message(self, execute_observed, expected_params):
+        return (
+            "Testing for compiled statement\n%r partial params %s, "
+            "received\n%%(received_statement)r with params "
+            "%%(received_parameters)r"
+            % (
+                self._dialect_adjusted_statement(
+                    execute_observed.context.dialect
+                ).replace("%", "%%"),
+                repr(expected_params).replace("%", "%%"),
+            )
+        )
+
+
+class CountStatements(AssertRule):
+    def __init__(self, count):
+        self.count = count
+        self._statement_count = 0
+
+    def process_statement(self, execute_observed):
+        self._statement_count += 1
+
+    def no_more_statements(self):
+        if self.count != self._statement_count:
+            assert False, "desired statement count %d does not match %d" % (
+                self.count,
+                self._statement_count,
+            )
+
+
+class AllOf(AssertRule):
+    def __init__(self, *rules):
+        self.rules = set(rules)
+
+    def process_statement(self, execute_observed):
+        for rule in list(self.rules):
+            rule.errormessage = None
+            rule.process_statement(execute_observed)
+            if rule.is_consumed:
+                self.rules.discard(rule)
+                if not self.rules:
+                    self.is_consumed = True
+                break
+            elif not rule.errormessage:
+                # rule is not done yet
+                self.errormessage = None
+                break
+        else:
+            self.errormessage = list(self.rules)[0].errormessage
+
+
+class EachOf(AssertRule):
+    def __init__(self, *rules):
+        self.rules = list(rules)
+
+    def process_statement(self, execute_observed):
+        if not self.rules:
+            self.is_consumed = True
+            self.consume_statement = False
+
+        while self.rules:
+            rule = self.rules[0]
+            rule.process_statement(execute_observed)
+            if rule.is_consumed:
+                self.rules.pop(0)
+            elif rule.errormessage:
+                self.errormessage = rule.errormessage
+            if rule.consume_statement:
+                break
+
+        if not self.rules:
+            self.is_consumed = True
+
+    def no_more_statements(self):
+        if self.rules and not self.rules[0].is_consumed:
+            self.rules[0].no_more_statements()
+        elif self.rules:
+            super().no_more_statements()
+
+
+class Conditional(EachOf):
+    def __init__(self, condition, rules, else_rules):
+        if condition:
+            super().__init__(*rules)
+        else:
+            super().__init__(*else_rules)
+
+
+class Or(AllOf):
+    def process_statement(self, execute_observed):
+        for rule in self.rules:
+            rule.process_statement(execute_observed)
+            if rule.is_consumed:
+                self.is_consumed = True
+                break
+        else:
+            self.errormessage = list(self.rules)[0].errormessage
+
+
+class SQLExecuteObserved:
+    def __init__(self, context, clauseelement, multiparams, params):
+        self.context = context
+        self.clauseelement = clauseelement
+
+        if multiparams:
+            self.parameters = multiparams
+        elif params:
+            self.parameters = [params]
+        else:
+            self.parameters = []
+        self.statements = []
+
+    def __repr__(self):
+        return str(self.statements)
+
+
+class SQLCursorExecuteObserved(
+    collections.namedtuple(
+        "SQLCursorExecuteObserved",
+        ["statement", "parameters", "context", "executemany"],
+    )
+):
+    pass
+
+
+class SQLAsserter:
+    def __init__(self):
+        self.accumulated = []
+
+    def _close(self):
+        self._final = self.accumulated
+        del self.accumulated
+
+    def assert_(self, *rules):
+        rule = EachOf(*rules)
+
+        observed = list(self._final)
+        while observed:
+            statement = observed.pop(0)
+            rule.process_statement(statement)
+            if rule.is_consumed:
+                break
+            elif rule.errormessage:
+                assert False, rule.errormessage
+        if observed:
+            assert False, "Additional SQL statements remain:\n%s" % observed
+        elif not rule.is_consumed:
+            rule.no_more_statements()
+
+
+@contextlib.contextmanager
+def assert_engine(engine):
+    asserter = SQLAsserter()
+
+    orig = []
+
+    @event.listens_for(engine, "before_execute")
+    def connection_execute(
+        conn, clauseelement, multiparams, params, execution_options
+    ):
+        # grab the original statement + params before any cursor
+        # execution
+        orig[:] = clauseelement, multiparams, params
+
+    @event.listens_for(engine, "after_cursor_execute")
+    def cursor_execute(
+        conn, cursor, statement, parameters, context, executemany
+    ):
+        if not context:
+            return
+        # then grab real cursor statements and associate them all
+        # around a single context
+        if (
+            asserter.accumulated
+            and asserter.accumulated[-1].context is context
+        ):
+            obs = asserter.accumulated[-1]
+        else:
+            obs = SQLExecuteObserved(context, orig[0], orig[1], orig[2])
+            asserter.accumulated.append(obs)
+        obs.statements.append(
+            SQLCursorExecuteObserved(
+                statement, parameters, context, executemany
+            )
+        )
+
+    try:
+        yield asserter
+    finally:
+        event.remove(engine, "after_cursor_execute", cursor_execute)
+        event.remove(engine, "before_execute", connection_execute)
+        asserter._close()