about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/sqlalchemy/testing/suite/test_insert.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/sqlalchemy/testing/suite/test_insert.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/sqlalchemy/testing/suite/test_insert.py')
-rw-r--r--.venv/lib/python3.12/site-packages/sqlalchemy/testing/suite/test_insert.py630
1 files changed, 630 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/sqlalchemy/testing/suite/test_insert.py b/.venv/lib/python3.12/site-packages/sqlalchemy/testing/suite/test_insert.py
new file mode 100644
index 00000000..8467c351
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/sqlalchemy/testing/suite/test_insert.py
@@ -0,0 +1,630 @@
+# testing/suite/test_insert.py
+# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: ignore-errors
+
+from decimal import Decimal
+import uuid
+
+from . import testing
+from .. import fixtures
+from ..assertions import eq_
+from ..config import requirements
+from ..schema import Column
+from ..schema import Table
+from ... import Double
+from ... import Float
+from ... import Identity
+from ... import Integer
+from ... import literal
+from ... import literal_column
+from ... import Numeric
+from ... import select
+from ... import String
+from ...types import LargeBinary
+from ...types import UUID
+from ...types import Uuid
+
+
+class LastrowidTest(fixtures.TablesTest):
+    run_deletes = "each"
+
+    __backend__ = True
+
+    __requires__ = "implements_get_lastrowid", "autoincrement_insert"
+
+    @classmethod
+    def define_tables(cls, metadata):
+        Table(
+            "autoinc_pk",
+            metadata,
+            Column(
+                "id", Integer, primary_key=True, test_needs_autoincrement=True
+            ),
+            Column("data", String(50)),
+            implicit_returning=False,
+        )
+
+        Table(
+            "manual_pk",
+            metadata,
+            Column("id", Integer, primary_key=True, autoincrement=False),
+            Column("data", String(50)),
+            implicit_returning=False,
+        )
+
+    def _assert_round_trip(self, table, conn):
+        row = conn.execute(table.select()).first()
+        eq_(
+            row,
+            (
+                conn.dialect.default_sequence_base,
+                "some data",
+            ),
+        )
+
+    def test_autoincrement_on_insert(self, connection):
+        connection.execute(
+            self.tables.autoinc_pk.insert(), dict(data="some data")
+        )
+        self._assert_round_trip(self.tables.autoinc_pk, connection)
+
+    def test_last_inserted_id(self, connection):
+        r = connection.execute(
+            self.tables.autoinc_pk.insert(), dict(data="some data")
+        )
+        pk = connection.scalar(select(self.tables.autoinc_pk.c.id))
+        eq_(r.inserted_primary_key, (pk,))
+
+    @requirements.dbapi_lastrowid
+    def test_native_lastrowid_autoinc(self, connection):
+        r = connection.execute(
+            self.tables.autoinc_pk.insert(), dict(data="some data")
+        )
+        lastrowid = r.lastrowid
+        pk = connection.scalar(select(self.tables.autoinc_pk.c.id))
+        eq_(lastrowid, pk)
+
+
+class InsertBehaviorTest(fixtures.TablesTest):
+    run_deletes = "each"
+    __backend__ = True
+
+    @classmethod
+    def define_tables(cls, metadata):
+        Table(
+            "autoinc_pk",
+            metadata,
+            Column(
+                "id", Integer, primary_key=True, test_needs_autoincrement=True
+            ),
+            Column("data", String(50)),
+        )
+        Table(
+            "manual_pk",
+            metadata,
+            Column("id", Integer, primary_key=True, autoincrement=False),
+            Column("data", String(50)),
+        )
+        Table(
+            "no_implicit_returning",
+            metadata,
+            Column(
+                "id", Integer, primary_key=True, test_needs_autoincrement=True
+            ),
+            Column("data", String(50)),
+            implicit_returning=False,
+        )
+        Table(
+            "includes_defaults",
+            metadata,
+            Column(
+                "id", Integer, primary_key=True, test_needs_autoincrement=True
+            ),
+            Column("data", String(50)),
+            Column("x", Integer, default=5),
+            Column(
+                "y",
+                Integer,
+                default=literal_column("2", type_=Integer) + literal(2),
+            ),
+        )
+
+    @testing.variation("style", ["plain", "return_defaults"])
+    @testing.variation("executemany", [True, False])
+    def test_no_results_for_non_returning_insert(
+        self, connection, style, executemany
+    ):
+        """test another INSERT issue found during #10453"""
+
+        table = self.tables.no_implicit_returning
+
+        stmt = table.insert()
+        if style.return_defaults:
+            stmt = stmt.return_defaults()
+
+        if executemany:
+            data = [
+                {"data": "d1"},
+                {"data": "d2"},
+                {"data": "d3"},
+                {"data": "d4"},
+                {"data": "d5"},
+            ]
+        else:
+            data = {"data": "d1"}
+
+        r = connection.execute(stmt, data)
+        assert not r.returns_rows
+
+    @requirements.autoincrement_insert
+    def test_autoclose_on_insert(self, connection):
+        r = connection.execute(
+            self.tables.autoinc_pk.insert(), dict(data="some data")
+        )
+        assert r._soft_closed
+        assert not r.closed
+        assert r.is_insert
+
+        # new as of I8091919d45421e3f53029b8660427f844fee0228; for the moment
+        # an insert where the PK was taken from a row that the dialect
+        # selected, as is the case for mssql/pyodbc, will still report
+        # returns_rows as true because there's a cursor description.  in that
+        # case, the row had to have been consumed at least.
+        assert not r.returns_rows or r.fetchone() is None
+
+    @requirements.insert_returning
+    def test_autoclose_on_insert_implicit_returning(self, connection):
+        r = connection.execute(
+            # return_defaults() ensures RETURNING will be used,
+            # new in 2.0 as sqlite/mariadb offer both RETURNING and
+            # cursor.lastrowid
+            self.tables.autoinc_pk.insert().return_defaults(),
+            dict(data="some data"),
+        )
+        assert r._soft_closed
+        assert not r.closed
+        assert r.is_insert
+
+        # note we are experimenting with having this be True
+        # as of I8091919d45421e3f53029b8660427f844fee0228 .
+        # implicit returning has fetched the row, but it still is a
+        # "returns rows"
+        assert r.returns_rows
+
+        # and we should be able to fetchone() on it, we just get no row
+        eq_(r.fetchone(), None)
+
+        # and the keys, etc.
+        eq_(r.keys(), ["id"])
+
+        # but the dialect took in the row already.   not really sure
+        # what the best behavior is.
+
+    @requirements.empty_inserts
+    def test_empty_insert(self, connection):
+        r = connection.execute(self.tables.autoinc_pk.insert())
+        assert r._soft_closed
+        assert not r.closed
+
+        r = connection.execute(
+            self.tables.autoinc_pk.select().where(
+                self.tables.autoinc_pk.c.id != None
+            )
+        )
+        eq_(len(r.all()), 1)
+
+    @requirements.empty_inserts_executemany
+    def test_empty_insert_multiple(self, connection):
+        r = connection.execute(self.tables.autoinc_pk.insert(), [{}, {}, {}])
+        assert r._soft_closed
+        assert not r.closed
+
+        r = connection.execute(
+            self.tables.autoinc_pk.select().where(
+                self.tables.autoinc_pk.c.id != None
+            )
+        )
+
+        eq_(len(r.all()), 3)
+
+    @requirements.insert_from_select
+    def test_insert_from_select_autoinc(self, connection):
+        src_table = self.tables.manual_pk
+        dest_table = self.tables.autoinc_pk
+        connection.execute(
+            src_table.insert(),
+            [
+                dict(id=1, data="data1"),
+                dict(id=2, data="data2"),
+                dict(id=3, data="data3"),
+            ],
+        )
+
+        result = connection.execute(
+            dest_table.insert().from_select(
+                ("data",),
+                select(src_table.c.data).where(
+                    src_table.c.data.in_(["data2", "data3"])
+                ),
+            )
+        )
+
+        eq_(result.inserted_primary_key, (None,))
+
+        result = connection.execute(
+            select(dest_table.c.data).order_by(dest_table.c.data)
+        )
+        eq_(result.fetchall(), [("data2",), ("data3",)])
+
+    @requirements.insert_from_select
+    def test_insert_from_select_autoinc_no_rows(self, connection):
+        src_table = self.tables.manual_pk
+        dest_table = self.tables.autoinc_pk
+
+        result = connection.execute(
+            dest_table.insert().from_select(
+                ("data",),
+                select(src_table.c.data).where(
+                    src_table.c.data.in_(["data2", "data3"])
+                ),
+            )
+        )
+        eq_(result.inserted_primary_key, (None,))
+
+        result = connection.execute(
+            select(dest_table.c.data).order_by(dest_table.c.data)
+        )
+
+        eq_(result.fetchall(), [])
+
+    @requirements.insert_from_select
+    def test_insert_from_select(self, connection):
+        table = self.tables.manual_pk
+        connection.execute(
+            table.insert(),
+            [
+                dict(id=1, data="data1"),
+                dict(id=2, data="data2"),
+                dict(id=3, data="data3"),
+            ],
+        )
+
+        connection.execute(
+            table.insert()
+            .inline()
+            .from_select(
+                ("id", "data"),
+                select(table.c.id + 5, table.c.data).where(
+                    table.c.data.in_(["data2", "data3"])
+                ),
+            )
+        )
+
+        eq_(
+            connection.execute(
+                select(table.c.data).order_by(table.c.data)
+            ).fetchall(),
+            [("data1",), ("data2",), ("data2",), ("data3",), ("data3",)],
+        )
+
+    @requirements.insert_from_select
+    def test_insert_from_select_with_defaults(self, connection):
+        table = self.tables.includes_defaults
+        connection.execute(
+            table.insert(),
+            [
+                dict(id=1, data="data1"),
+                dict(id=2, data="data2"),
+                dict(id=3, data="data3"),
+            ],
+        )
+
+        connection.execute(
+            table.insert()
+            .inline()
+            .from_select(
+                ("id", "data"),
+                select(table.c.id + 5, table.c.data).where(
+                    table.c.data.in_(["data2", "data3"])
+                ),
+            )
+        )
+
+        eq_(
+            connection.execute(
+                select(table).order_by(table.c.data, table.c.id)
+            ).fetchall(),
+            [
+                (1, "data1", 5, 4),
+                (2, "data2", 5, 4),
+                (7, "data2", 5, 4),
+                (3, "data3", 5, 4),
+                (8, "data3", 5, 4),
+            ],
+        )
+
+
+class ReturningTest(fixtures.TablesTest):
+    run_create_tables = "each"
+    __requires__ = "insert_returning", "autoincrement_insert"
+    __backend__ = True
+
+    def _assert_round_trip(self, table, conn):
+        row = conn.execute(table.select()).first()
+        eq_(
+            row,
+            (
+                conn.dialect.default_sequence_base,
+                "some data",
+            ),
+        )
+
+    @classmethod
+    def define_tables(cls, metadata):
+        Table(
+            "autoinc_pk",
+            metadata,
+            Column(
+                "id", Integer, primary_key=True, test_needs_autoincrement=True
+            ),
+            Column("data", String(50)),
+        )
+
+    @requirements.fetch_rows_post_commit
+    def test_explicit_returning_pk_autocommit(self, connection):
+        table = self.tables.autoinc_pk
+        r = connection.execute(
+            table.insert().returning(table.c.id), dict(data="some data")
+        )
+        pk = r.first()[0]
+        fetched_pk = connection.scalar(select(table.c.id))
+        eq_(fetched_pk, pk)
+
+    def test_explicit_returning_pk_no_autocommit(self, connection):
+        table = self.tables.autoinc_pk
+        r = connection.execute(
+            table.insert().returning(table.c.id), dict(data="some data")
+        )
+
+        pk = r.first()[0]
+        fetched_pk = connection.scalar(select(table.c.id))
+        eq_(fetched_pk, pk)
+
+    def test_autoincrement_on_insert_implicit_returning(self, connection):
+        connection.execute(
+            self.tables.autoinc_pk.insert(), dict(data="some data")
+        )
+        self._assert_round_trip(self.tables.autoinc_pk, connection)
+
+    def test_last_inserted_id_implicit_returning(self, connection):
+        r = connection.execute(
+            self.tables.autoinc_pk.insert(), dict(data="some data")
+        )
+        pk = connection.scalar(select(self.tables.autoinc_pk.c.id))
+        eq_(r.inserted_primary_key, (pk,))
+
+    @requirements.insert_executemany_returning
+    def test_insertmanyvalues_returning(self, connection):
+        r = connection.execute(
+            self.tables.autoinc_pk.insert().returning(
+                self.tables.autoinc_pk.c.id
+            ),
+            [
+                {"data": "d1"},
+                {"data": "d2"},
+                {"data": "d3"},
+                {"data": "d4"},
+                {"data": "d5"},
+            ],
+        )
+        rall = r.all()
+
+        pks = connection.execute(select(self.tables.autoinc_pk.c.id))
+
+        eq_(rall, pks.all())
+
+    @testing.combinations(
+        (Double(), 8.5514716, True),
+        (
+            Double(53),
+            8.5514716,
+            True,
+            testing.requires.float_or_double_precision_behaves_generically,
+        ),
+        (Float(), 8.5514, True),
+        (
+            Float(8),
+            8.5514,
+            True,
+            testing.requires.float_or_double_precision_behaves_generically,
+        ),
+        (
+            Numeric(precision=15, scale=12, asdecimal=False),
+            8.5514716,
+            True,
+            testing.requires.literal_float_coercion,
+        ),
+        (
+            Numeric(precision=15, scale=12, asdecimal=True),
+            Decimal("8.5514716"),
+            False,
+        ),
+        argnames="type_,value,do_rounding",
+    )
+    @testing.variation("sort_by_parameter_order", [True, False])
+    @testing.variation("multiple_rows", [True, False])
+    def test_insert_w_floats(
+        self,
+        connection,
+        metadata,
+        sort_by_parameter_order,
+        type_,
+        value,
+        do_rounding,
+        multiple_rows,
+    ):
+        """test #9701.
+
+        this tests insertmanyvalues as well as decimal / floating point
+        RETURNING types
+
+        """
+
+        t = Table(
+            # Oracle backends seems to be getting confused if
+            # this table is named the same as the one
+            # in test_imv_returning_datatypes.  use a different name
+            "f_t",
+            metadata,
+            Column("id", Integer, Identity(), primary_key=True),
+            Column("value", type_),
+        )
+
+        t.create(connection)
+
+        result = connection.execute(
+            t.insert().returning(
+                t.c.id,
+                t.c.value,
+                sort_by_parameter_order=bool(sort_by_parameter_order),
+            ),
+            (
+                [{"value": value} for i in range(10)]
+                if multiple_rows
+                else {"value": value}
+            ),
+        )
+
+        if multiple_rows:
+            i_range = range(1, 11)
+        else:
+            i_range = range(1, 2)
+
+        # we want to test only that we are getting floating points back
+        # with some degree of the original value maintained, that it is not
+        # being truncated to an integer.  there's too much variation in how
+        # drivers return floats, which should not be relied upon to be
+        # exact, for us to just compare as is (works for PG drivers but not
+        # others) so we use rounding here.  There's precedent for this
+        # in suite/test_types.py::NumericTest as well
+
+        if do_rounding:
+            eq_(
+                {(id_, round(val_, 5)) for id_, val_ in result},
+                {(id_, round(value, 5)) for id_ in i_range},
+            )
+
+            eq_(
+                {
+                    round(val_, 5)
+                    for val_ in connection.scalars(select(t.c.value))
+                },
+                {round(value, 5)},
+            )
+        else:
+            eq_(
+                set(result),
+                {(id_, value) for id_ in i_range},
+            )
+
+            eq_(
+                set(connection.scalars(select(t.c.value))),
+                {value},
+            )
+
+    @testing.combinations(
+        (
+            "non_native_uuid",
+            Uuid(native_uuid=False),
+            uuid.uuid4(),
+        ),
+        (
+            "non_native_uuid_str",
+            Uuid(as_uuid=False, native_uuid=False),
+            str(uuid.uuid4()),
+        ),
+        (
+            "generic_native_uuid",
+            Uuid(native_uuid=True),
+            uuid.uuid4(),
+            testing.requires.uuid_data_type,
+        ),
+        (
+            "generic_native_uuid_str",
+            Uuid(as_uuid=False, native_uuid=True),
+            str(uuid.uuid4()),
+            testing.requires.uuid_data_type,
+        ),
+        ("UUID", UUID(), uuid.uuid4(), testing.requires.uuid_data_type),
+        (
+            "LargeBinary1",
+            LargeBinary(),
+            b"this is binary",
+        ),
+        ("LargeBinary2", LargeBinary(), b"7\xe7\x9f"),
+        argnames="type_,value",
+        id_="iaa",
+    )
+    @testing.variation("sort_by_parameter_order", [True, False])
+    @testing.variation("multiple_rows", [True, False])
+    @testing.requires.insert_returning
+    def test_imv_returning_datatypes(
+        self,
+        connection,
+        metadata,
+        sort_by_parameter_order,
+        type_,
+        value,
+        multiple_rows,
+    ):
+        """test #9739, #9808 (similar to #9701).
+
+        this tests insertmanyvalues in conjunction with various datatypes.
+
+        These tests are particularly for the asyncpg driver which needs
+        most types to be explicitly cast for the new IMV format
+
+        """
+        t = Table(
+            "d_t",
+            metadata,
+            Column("id", Integer, Identity(), primary_key=True),
+            Column("value", type_),
+        )
+
+        t.create(connection)
+
+        result = connection.execute(
+            t.insert().returning(
+                t.c.id,
+                t.c.value,
+                sort_by_parameter_order=bool(sort_by_parameter_order),
+            ),
+            (
+                [{"value": value} for i in range(10)]
+                if multiple_rows
+                else {"value": value}
+            ),
+        )
+
+        if multiple_rows:
+            i_range = range(1, 11)
+        else:
+            i_range = range(1, 2)
+
+        eq_(
+            set(result),
+            {(id_, value) for id_ in i_range},
+        )
+
+        eq_(
+            set(connection.scalars(select(t.c.value))),
+            {value},
+        )
+
+
+__all__ = ("LastrowidTest", "InsertBehaviorTest", "ReturningTest")