about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/sqlalchemy/testing/suite/test_results.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/sqlalchemy/testing/suite/test_results.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/sqlalchemy/testing/suite/test_results.py')
-rw-r--r--.venv/lib/python3.12/site-packages/sqlalchemy/testing/suite/test_results.py502
1 files changed, 502 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/sqlalchemy/testing/suite/test_results.py b/.venv/lib/python3.12/site-packages/sqlalchemy/testing/suite/test_results.py
new file mode 100644
index 00000000..a6179d85
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/sqlalchemy/testing/suite/test_results.py
@@ -0,0 +1,502 @@
+# testing/suite/test_results.py
+# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: ignore-errors
+
+import datetime
+import re
+
+from .. import engines
+from .. import fixtures
+from ..assertions import eq_
+from ..config import requirements
+from ..schema import Column
+from ..schema import Table
+from ... import DateTime
+from ... import func
+from ... import Integer
+from ... import select
+from ... import sql
+from ... import String
+from ... import testing
+from ... import text
+
+
+class RowFetchTest(fixtures.TablesTest):
+    __backend__ = True
+
+    @classmethod
+    def define_tables(cls, metadata):
+        Table(
+            "plain_pk",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("data", String(50)),
+        )
+        Table(
+            "has_dates",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("today", DateTime),
+        )
+
+    @classmethod
+    def insert_data(cls, connection):
+        connection.execute(
+            cls.tables.plain_pk.insert(),
+            [
+                {"id": 1, "data": "d1"},
+                {"id": 2, "data": "d2"},
+                {"id": 3, "data": "d3"},
+            ],
+        )
+
+        connection.execute(
+            cls.tables.has_dates.insert(),
+            [{"id": 1, "today": datetime.datetime(2006, 5, 12, 12, 0, 0)}],
+        )
+
+    def test_via_attr(self, connection):
+        row = connection.execute(
+            self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
+        ).first()
+
+        eq_(row.id, 1)
+        eq_(row.data, "d1")
+
+    def test_via_string(self, connection):
+        row = connection.execute(
+            self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
+        ).first()
+
+        eq_(row._mapping["id"], 1)
+        eq_(row._mapping["data"], "d1")
+
+    def test_via_int(self, connection):
+        row = connection.execute(
+            self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
+        ).first()
+
+        eq_(row[0], 1)
+        eq_(row[1], "d1")
+
+    def test_via_col_object(self, connection):
+        row = connection.execute(
+            self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
+        ).first()
+
+        eq_(row._mapping[self.tables.plain_pk.c.id], 1)
+        eq_(row._mapping[self.tables.plain_pk.c.data], "d1")
+
+    @requirements.duplicate_names_in_cursor_description
+    def test_row_with_dupe_names(self, connection):
+        result = connection.execute(
+            select(
+                self.tables.plain_pk.c.data,
+                self.tables.plain_pk.c.data.label("data"),
+            ).order_by(self.tables.plain_pk.c.id)
+        )
+        row = result.first()
+        eq_(result.keys(), ["data", "data"])
+        eq_(row, ("d1", "d1"))
+
+    def test_row_w_scalar_select(self, connection):
+        """test that a scalar select as a column is returned as such
+        and that type conversion works OK.
+
+        (this is half a SQLAlchemy Core test and half to catch database
+        backends that may have unusual behavior with scalar selects.)
+
+        """
+        datetable = self.tables.has_dates
+        s = select(datetable.alias("x").c.today).scalar_subquery()
+        s2 = select(datetable.c.id, s.label("somelabel"))
+        row = connection.execute(s2).first()
+
+        eq_(row.somelabel, datetime.datetime(2006, 5, 12, 12, 0, 0))
+
+
+class PercentSchemaNamesTest(fixtures.TablesTest):
+    """tests using percent signs, spaces in table and column names.
+
+    This didn't work for PostgreSQL / MySQL drivers for a long time
+    but is now supported.
+
+    """
+
+    __requires__ = ("percent_schema_names",)
+
+    __backend__ = True
+
+    @classmethod
+    def define_tables(cls, metadata):
+        cls.tables.percent_table = Table(
+            "percent%table",
+            metadata,
+            Column("percent%", Integer),
+            Column("spaces % more spaces", Integer),
+        )
+        cls.tables.lightweight_percent_table = sql.table(
+            "percent%table",
+            sql.column("percent%"),
+            sql.column("spaces % more spaces"),
+        )
+
+    def test_single_roundtrip(self, connection):
+        percent_table = self.tables.percent_table
+        for params in [
+            {"percent%": 5, "spaces % more spaces": 12},
+            {"percent%": 7, "spaces % more spaces": 11},
+            {"percent%": 9, "spaces % more spaces": 10},
+            {"percent%": 11, "spaces % more spaces": 9},
+        ]:
+            connection.execute(percent_table.insert(), params)
+        self._assert_table(connection)
+
+    def test_executemany_roundtrip(self, connection):
+        percent_table = self.tables.percent_table
+        connection.execute(
+            percent_table.insert(), {"percent%": 5, "spaces % more spaces": 12}
+        )
+        connection.execute(
+            percent_table.insert(),
+            [
+                {"percent%": 7, "spaces % more spaces": 11},
+                {"percent%": 9, "spaces % more spaces": 10},
+                {"percent%": 11, "spaces % more spaces": 9},
+            ],
+        )
+        self._assert_table(connection)
+
+    @requirements.insert_executemany_returning
+    def test_executemany_returning_roundtrip(self, connection):
+        percent_table = self.tables.percent_table
+        connection.execute(
+            percent_table.insert(), {"percent%": 5, "spaces % more spaces": 12}
+        )
+        result = connection.execute(
+            percent_table.insert().returning(
+                percent_table.c["percent%"],
+                percent_table.c["spaces % more spaces"],
+            ),
+            [
+                {"percent%": 7, "spaces % more spaces": 11},
+                {"percent%": 9, "spaces % more spaces": 10},
+                {"percent%": 11, "spaces % more spaces": 9},
+            ],
+        )
+        eq_(result.all(), [(7, 11), (9, 10), (11, 9)])
+        self._assert_table(connection)
+
+    def _assert_table(self, conn):
+        percent_table = self.tables.percent_table
+        lightweight_percent_table = self.tables.lightweight_percent_table
+
+        for table in (
+            percent_table,
+            percent_table.alias(),
+            lightweight_percent_table,
+            lightweight_percent_table.alias(),
+        ):
+            eq_(
+                list(
+                    conn.execute(table.select().order_by(table.c["percent%"]))
+                ),
+                [(5, 12), (7, 11), (9, 10), (11, 9)],
+            )
+
+            eq_(
+                list(
+                    conn.execute(
+                        table.select()
+                        .where(table.c["spaces % more spaces"].in_([9, 10]))
+                        .order_by(table.c["percent%"])
+                    )
+                ),
+                [(9, 10), (11, 9)],
+            )
+
+            row = conn.execute(
+                table.select().order_by(table.c["percent%"])
+            ).first()
+            eq_(row._mapping["percent%"], 5)
+            eq_(row._mapping["spaces % more spaces"], 12)
+
+            eq_(row._mapping[table.c["percent%"]], 5)
+            eq_(row._mapping[table.c["spaces % more spaces"]], 12)
+
+        conn.execute(
+            percent_table.update().values(
+                {percent_table.c["spaces % more spaces"]: 15}
+            )
+        )
+
+        eq_(
+            list(
+                conn.execute(
+                    percent_table.select().order_by(
+                        percent_table.c["percent%"]
+                    )
+                )
+            ),
+            [(5, 15), (7, 15), (9, 15), (11, 15)],
+        )
+
+
+class ServerSideCursorsTest(
+    fixtures.TestBase, testing.AssertsExecutionResults
+):
+    __requires__ = ("server_side_cursors",)
+
+    __backend__ = True
+
+    def _is_server_side(self, cursor):
+        # TODO: this is a huge issue as it prevents these tests from being
+        # usable by third party dialects.
+        if self.engine.dialect.driver == "psycopg2":
+            return bool(cursor.name)
+        elif self.engine.dialect.driver == "pymysql":
+            sscursor = __import__("pymysql.cursors").cursors.SSCursor
+            return isinstance(cursor, sscursor)
+        elif self.engine.dialect.driver in ("aiomysql", "asyncmy", "aioodbc"):
+            return cursor.server_side
+        elif self.engine.dialect.driver == "mysqldb":
+            sscursor = __import__("MySQLdb.cursors").cursors.SSCursor
+            return isinstance(cursor, sscursor)
+        elif self.engine.dialect.driver == "mariadbconnector":
+            return not cursor.buffered
+        elif self.engine.dialect.driver in ("asyncpg", "aiosqlite"):
+            return cursor.server_side
+        elif self.engine.dialect.driver == "pg8000":
+            return getattr(cursor, "server_side", False)
+        elif self.engine.dialect.driver == "psycopg":
+            return bool(getattr(cursor, "name", False))
+        elif self.engine.dialect.driver == "oracledb":
+            return getattr(cursor, "server_side", False)
+        else:
+            return False
+
+    def _fixture(self, server_side_cursors):
+        if server_side_cursors:
+            with testing.expect_deprecated(
+                "The create_engine.server_side_cursors parameter is "
+                "deprecated and will be removed in a future release.  "
+                "Please use the Connection.execution_options.stream_results "
+                "parameter."
+            ):
+                self.engine = engines.testing_engine(
+                    options={"server_side_cursors": server_side_cursors}
+                )
+        else:
+            self.engine = engines.testing_engine(
+                options={"server_side_cursors": server_side_cursors}
+            )
+        return self.engine
+
+    def stringify(self, str_):
+        return re.compile(r"SELECT (\d+)", re.I).sub(
+            lambda m: str(select(int(m.group(1))).compile(testing.db)), str_
+        )
+
+    @testing.combinations(
+        ("global_string", True, lambda stringify: stringify("select 1"), True),
+        (
+            "global_text",
+            True,
+            lambda stringify: text(stringify("select 1")),
+            True,
+        ),
+        ("global_expr", True, select(1), True),
+        (
+            "global_off_explicit",
+            False,
+            lambda stringify: text(stringify("select 1")),
+            False,
+        ),
+        (
+            "stmt_option",
+            False,
+            select(1).execution_options(stream_results=True),
+            True,
+        ),
+        (
+            "stmt_option_disabled",
+            True,
+            select(1).execution_options(stream_results=False),
+            False,
+        ),
+        ("for_update_expr", True, select(1).with_for_update(), True),
+        # TODO: need a real requirement for this, or dont use this test
+        (
+            "for_update_string",
+            True,
+            lambda stringify: stringify("SELECT 1 FOR UPDATE"),
+            True,
+            testing.skip_if(["sqlite", "mssql"]),
+        ),
+        (
+            "text_no_ss",
+            False,
+            lambda stringify: text(stringify("select 42")),
+            False,
+        ),
+        (
+            "text_ss_option",
+            False,
+            lambda stringify: text(stringify("select 42")).execution_options(
+                stream_results=True
+            ),
+            True,
+        ),
+        id_="iaaa",
+        argnames="engine_ss_arg, statement, cursor_ss_status",
+    )
+    def test_ss_cursor_status(
+        self, engine_ss_arg, statement, cursor_ss_status
+    ):
+        engine = self._fixture(engine_ss_arg)
+        with engine.begin() as conn:
+            if callable(statement):
+                statement = testing.resolve_lambda(
+                    statement, stringify=self.stringify
+                )
+
+            if isinstance(statement, str):
+                result = conn.exec_driver_sql(statement)
+            else:
+                result = conn.execute(statement)
+            eq_(self._is_server_side(result.cursor), cursor_ss_status)
+            result.close()
+
+    def test_conn_option(self):
+        engine = self._fixture(False)
+
+        with engine.connect() as conn:
+            # should be enabled for this one
+            result = conn.execution_options(
+                stream_results=True
+            ).exec_driver_sql(self.stringify("select 1"))
+            assert self._is_server_side(result.cursor)
+
+            # the connection has autobegun, which means at the end of the
+            # block, we will roll back, which on MySQL at least will fail
+            # with "Commands out of sync" if the result set
+            # is not closed, so we close it first.
+            #
+            # fun fact!  why did we not have this result.close() in this test
+            # before 2.0? don't we roll back in the connection pool
+            # unconditionally? yes!  and in fact if you run this test in 1.4
+            # with stdout shown, there is in fact "Exception during reset or
+            # similar" with "Commands out sync" emitted a warning!  2.0's
+            # architecture finds and fixes what was previously an expensive
+            # silent error condition.
+            result.close()
+
+    def test_stmt_enabled_conn_option_disabled(self):
+        engine = self._fixture(False)
+
+        s = select(1).execution_options(stream_results=True)
+
+        with engine.connect() as conn:
+            # not this one
+            result = conn.execution_options(stream_results=False).execute(s)
+            assert not self._is_server_side(result.cursor)
+
+    def test_aliases_and_ss(self):
+        engine = self._fixture(False)
+        s1 = (
+            select(sql.literal_column("1").label("x"))
+            .execution_options(stream_results=True)
+            .subquery()
+        )
+
+        # options don't propagate out when subquery is used as a FROM clause
+        with engine.begin() as conn:
+            result = conn.execute(s1.select())
+            assert not self._is_server_side(result.cursor)
+            result.close()
+
+        s2 = select(1).select_from(s1)
+        with engine.begin() as conn:
+            result = conn.execute(s2)
+            assert not self._is_server_side(result.cursor)
+            result.close()
+
+    def test_roundtrip_fetchall(self, metadata):
+        md = self.metadata
+
+        engine = self._fixture(True)
+        test_table = Table(
+            "test_table",
+            md,
+            Column(
+                "id", Integer, primary_key=True, test_needs_autoincrement=True
+            ),
+            Column("data", String(50)),
+        )
+
+        with engine.begin() as connection:
+            test_table.create(connection, checkfirst=True)
+            connection.execute(test_table.insert(), dict(data="data1"))
+            connection.execute(test_table.insert(), dict(data="data2"))
+            eq_(
+                connection.execute(
+                    test_table.select().order_by(test_table.c.id)
+                ).fetchall(),
+                [(1, "data1"), (2, "data2")],
+            )
+            connection.execute(
+                test_table.update()
+                .where(test_table.c.id == 2)
+                .values(data=test_table.c.data + " updated")
+            )
+            eq_(
+                connection.execute(
+                    test_table.select().order_by(test_table.c.id)
+                ).fetchall(),
+                [(1, "data1"), (2, "data2 updated")],
+            )
+            connection.execute(test_table.delete())
+            eq_(
+                connection.scalar(
+                    select(func.count("*")).select_from(test_table)
+                ),
+                0,
+            )
+
+    def test_roundtrip_fetchmany(self, metadata):
+        md = self.metadata
+
+        engine = self._fixture(True)
+        test_table = Table(
+            "test_table",
+            md,
+            Column(
+                "id", Integer, primary_key=True, test_needs_autoincrement=True
+            ),
+            Column("data", String(50)),
+        )
+
+        with engine.begin() as connection:
+            test_table.create(connection, checkfirst=True)
+            connection.execute(
+                test_table.insert(),
+                [dict(data="data%d" % i) for i in range(1, 20)],
+            )
+
+            result = connection.execute(
+                test_table.select().order_by(test_table.c.id)
+            )
+
+            eq_(
+                result.fetchmany(5),
+                [(i, "data%d" % i) for i in range(1, 6)],
+            )
+            eq_(
+                result.fetchmany(10),
+                [(i, "data%d" % i) for i in range(6, 16)],
+            )
+            eq_(result.fetchall(), [(i, "data%d" % i) for i in range(16, 20)])