about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/numpy/core/tests/test_extint128.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/numpy/core/tests/test_extint128.py')
-rw-r--r--.venv/lib/python3.12/site-packages/numpy/core/tests/test_extint128.py219
1 files changed, 219 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/numpy/core/tests/test_extint128.py b/.venv/lib/python3.12/site-packages/numpy/core/tests/test_extint128.py
new file mode 100644
index 00000000..3b64915f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/numpy/core/tests/test_extint128.py
@@ -0,0 +1,219 @@
+import itertools
+import contextlib
+import operator
+import pytest
+
+import numpy as np
+import numpy.core._multiarray_tests as mt
+
+from numpy.testing import assert_raises, assert_equal
+
+
+INT64_MAX = np.iinfo(np.int64).max
+INT64_MIN = np.iinfo(np.int64).min
+INT64_MID = 2**32
+
+# int128 is not two's complement, the sign bit is separate
+INT128_MAX = 2**128 - 1
+INT128_MIN = -INT128_MAX
+INT128_MID = 2**64
+
+INT64_VALUES = (
+    [INT64_MIN + j for j in range(20)] +
+    [INT64_MAX - j for j in range(20)] +
+    [INT64_MID + j for j in range(-20, 20)] +
+    [2*INT64_MID + j for j in range(-20, 20)] +
+    [INT64_MID//2 + j for j in range(-20, 20)] +
+    list(range(-70, 70))
+)
+
+INT128_VALUES = (
+    [INT128_MIN + j for j in range(20)] +
+    [INT128_MAX - j for j in range(20)] +
+    [INT128_MID + j for j in range(-20, 20)] +
+    [2*INT128_MID + j for j in range(-20, 20)] +
+    [INT128_MID//2 + j for j in range(-20, 20)] +
+    list(range(-70, 70)) +
+    [False]  # negative zero
+)
+
+INT64_POS_VALUES = [x for x in INT64_VALUES if x > 0]
+
+
+@contextlib.contextmanager
+def exc_iter(*args):
+    """
+    Iterate over Cartesian product of *args, and if an exception is raised,
+    add information of the current iterate.
+    """
+
+    value = [None]
+
+    def iterate():
+        for v in itertools.product(*args):
+            value[0] = v
+            yield v
+
+    try:
+        yield iterate()
+    except Exception:
+        import traceback
+        msg = "At: %r\n%s" % (repr(value[0]),
+                              traceback.format_exc())
+        raise AssertionError(msg)
+
+
+def test_safe_binop():
+    # Test checked arithmetic routines
+
+    ops = [
+        (operator.add, 1),
+        (operator.sub, 2),
+        (operator.mul, 3)
+    ]
+
+    with exc_iter(ops, INT64_VALUES, INT64_VALUES) as it:
+        for xop, a, b in it:
+            pyop, op = xop
+            c = pyop(a, b)
+
+            if not (INT64_MIN <= c <= INT64_MAX):
+                assert_raises(OverflowError, mt.extint_safe_binop, a, b, op)
+            else:
+                d = mt.extint_safe_binop(a, b, op)
+                if c != d:
+                    # assert_equal is slow
+                    assert_equal(d, c)
+
+
+def test_to_128():
+    with exc_iter(INT64_VALUES) as it:
+        for a, in it:
+            b = mt.extint_to_128(a)
+            if a != b:
+                assert_equal(b, a)
+
+
+def test_to_64():
+    with exc_iter(INT128_VALUES) as it:
+        for a, in it:
+            if not (INT64_MIN <= a <= INT64_MAX):
+                assert_raises(OverflowError, mt.extint_to_64, a)
+            else:
+                b = mt.extint_to_64(a)
+                if a != b:
+                    assert_equal(b, a)
+
+
+def test_mul_64_64():
+    with exc_iter(INT64_VALUES, INT64_VALUES) as it:
+        for a, b in it:
+            c = a * b
+            d = mt.extint_mul_64_64(a, b)
+            if c != d:
+                assert_equal(d, c)
+
+
+def test_add_128():
+    with exc_iter(INT128_VALUES, INT128_VALUES) as it:
+        for a, b in it:
+            c = a + b
+            if not (INT128_MIN <= c <= INT128_MAX):
+                assert_raises(OverflowError, mt.extint_add_128, a, b)
+            else:
+                d = mt.extint_add_128(a, b)
+                if c != d:
+                    assert_equal(d, c)
+
+
+def test_sub_128():
+    with exc_iter(INT128_VALUES, INT128_VALUES) as it:
+        for a, b in it:
+            c = a - b
+            if not (INT128_MIN <= c <= INT128_MAX):
+                assert_raises(OverflowError, mt.extint_sub_128, a, b)
+            else:
+                d = mt.extint_sub_128(a, b)
+                if c != d:
+                    assert_equal(d, c)
+
+
+def test_neg_128():
+    with exc_iter(INT128_VALUES) as it:
+        for a, in it:
+            b = -a
+            c = mt.extint_neg_128(a)
+            if b != c:
+                assert_equal(c, b)
+
+
+def test_shl_128():
+    with exc_iter(INT128_VALUES) as it:
+        for a, in it:
+            if a < 0:
+                b = -(((-a) << 1) & (2**128-1))
+            else:
+                b = (a << 1) & (2**128-1)
+            c = mt.extint_shl_128(a)
+            if b != c:
+                assert_equal(c, b)
+
+
+def test_shr_128():
+    with exc_iter(INT128_VALUES) as it:
+        for a, in it:
+            if a < 0:
+                b = -((-a) >> 1)
+            else:
+                b = a >> 1
+            c = mt.extint_shr_128(a)
+            if b != c:
+                assert_equal(c, b)
+
+
+def test_gt_128():
+    with exc_iter(INT128_VALUES, INT128_VALUES) as it:
+        for a, b in it:
+            c = a > b
+            d = mt.extint_gt_128(a, b)
+            if c != d:
+                assert_equal(d, c)
+
+
+@pytest.mark.slow
+def test_divmod_128_64():
+    with exc_iter(INT128_VALUES, INT64_POS_VALUES) as it:
+        for a, b in it:
+            if a >= 0:
+                c, cr = divmod(a, b)
+            else:
+                c, cr = divmod(-a, b)
+                c = -c
+                cr = -cr
+
+            d, dr = mt.extint_divmod_128_64(a, b)
+
+            if c != d or d != dr or b*d + dr != a:
+                assert_equal(d, c)
+                assert_equal(dr, cr)
+                assert_equal(b*d + dr, a)
+
+
+def test_floordiv_128_64():
+    with exc_iter(INT128_VALUES, INT64_POS_VALUES) as it:
+        for a, b in it:
+            c = a // b
+            d = mt.extint_floordiv_128_64(a, b)
+
+            if c != d:
+                assert_equal(d, c)
+
+
+def test_ceildiv_128_64():
+    with exc_iter(INT128_VALUES, INT64_POS_VALUES) as it:
+        for a, b in it:
+            c = (a + b - 1) // b
+            d = mt.extint_ceildiv_128_64(a, b)
+
+            if c != d:
+                assert_equal(d, c)