diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/numpy/f2py/tests/test_callback.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/numpy/f2py/tests/test_callback.py | 243 |
1 files changed, 243 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/numpy/f2py/tests/test_callback.py b/.venv/lib/python3.12/site-packages/numpy/f2py/tests/test_callback.py new file mode 100644 index 00000000..5b6c294d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/numpy/f2py/tests/test_callback.py @@ -0,0 +1,243 @@ +import math +import textwrap +import sys +import pytest +import threading +import traceback +import time + +import numpy as np +from numpy.testing import IS_PYPY +from . import util + + +class TestF77Callback(util.F2PyTest): + sources = [util.getpath("tests", "src", "callback", "foo.f")] + + @pytest.mark.parametrize("name", "t,t2".split(",")) + def test_all(self, name): + self.check_function(name) + + @pytest.mark.xfail(IS_PYPY, + reason="PyPy cannot modify tp_doc after PyType_Ready") + def test_docstring(self): + expected = textwrap.dedent("""\ + a = t(fun,[fun_extra_args]) + + Wrapper for ``t``. + + Parameters + ---------- + fun : call-back function + + Other Parameters + ---------------- + fun_extra_args : input tuple, optional + Default: () + + Returns + ------- + a : int + + Notes + ----- + Call-back functions:: + + def fun(): return a + Return objects: + a : int + """) + assert self.module.t.__doc__ == expected + + def check_function(self, name): + t = getattr(self.module, name) + r = t(lambda: 4) + assert r == 4 + r = t(lambda a: 5, fun_extra_args=(6, )) + assert r == 5 + r = t(lambda a: a, fun_extra_args=(6, )) + assert r == 6 + r = t(lambda a: 5 + a, fun_extra_args=(7, )) + assert r == 12 + r = t(lambda a: math.degrees(a), fun_extra_args=(math.pi, )) + assert r == 180 + r = t(math.degrees, fun_extra_args=(math.pi, )) + assert r == 180 + + r = t(self.module.func, fun_extra_args=(6, )) + assert r == 17 + r = t(self.module.func0) + assert r == 11 + r = t(self.module.func0._cpointer) + assert r == 11 + + class A: + def __call__(self): + return 7 + + def mth(self): + return 9 + + a = A() + r = t(a) + assert r == 7 + r = t(a.mth) + assert r == 9 + + @pytest.mark.skipif(sys.platform == 'win32', + reason='Fails with MinGW64 Gfortran (Issue #9673)') + def test_string_callback(self): + def callback(code): + if code == "r": + return 0 + else: + return 1 + + f = getattr(self.module, "string_callback") + r = f(callback) + assert r == 0 + + @pytest.mark.skipif(sys.platform == 'win32', + reason='Fails with MinGW64 Gfortran (Issue #9673)') + def test_string_callback_array(self): + # See gh-10027 + cu1 = np.zeros((1, ), "S8") + cu2 = np.zeros((1, 8), "c") + cu3 = np.array([""], "S8") + + def callback(cu, lencu): + if cu.shape != (lencu,): + return 1 + if cu.dtype != "S8": + return 2 + if not np.all(cu == b""): + return 3 + return 0 + + f = getattr(self.module, "string_callback_array") + for cu in [cu1, cu2, cu3]: + res = f(callback, cu, cu.size) + assert res == 0 + + def test_threadsafety(self): + # Segfaults if the callback handling is not threadsafe + + errors = [] + + def cb(): + # Sleep here to make it more likely for another thread + # to call their callback at the same time. + time.sleep(1e-3) + + # Check reentrancy + r = self.module.t(lambda: 123) + assert r == 123 + + return 42 + + def runner(name): + try: + for j in range(50): + r = self.module.t(cb) + assert r == 42 + self.check_function(name) + except Exception: + errors.append(traceback.format_exc()) + + threads = [ + threading.Thread(target=runner, args=(arg, )) + for arg in ("t", "t2") for n in range(20) + ] + + for t in threads: + t.start() + + for t in threads: + t.join() + + errors = "\n\n".join(errors) + if errors: + raise AssertionError(errors) + + def test_hidden_callback(self): + try: + self.module.hidden_callback(2) + except Exception as msg: + assert str(msg).startswith("Callback global_f not defined") + + try: + self.module.hidden_callback2(2) + except Exception as msg: + assert str(msg).startswith("cb: Callback global_f not defined") + + self.module.global_f = lambda x: x + 1 + r = self.module.hidden_callback(2) + assert r == 3 + + self.module.global_f = lambda x: x + 2 + r = self.module.hidden_callback(2) + assert r == 4 + + del self.module.global_f + try: + self.module.hidden_callback(2) + except Exception as msg: + assert str(msg).startswith("Callback global_f not defined") + + self.module.global_f = lambda x=0: x + 3 + r = self.module.hidden_callback(2) + assert r == 5 + + # reproducer of gh18341 + r = self.module.hidden_callback2(2) + assert r == 3 + + +class TestF77CallbackPythonTLS(TestF77Callback): + """ + Callback tests using Python thread-local storage instead of + compiler-provided + """ + + options = ["-DF2PY_USE_PYTHON_TLS"] + + +class TestF90Callback(util.F2PyTest): + sources = [util.getpath("tests", "src", "callback", "gh17797.f90")] + + def test_gh17797(self): + def incr(x): + return x + 123 + + y = np.array([1, 2, 3], dtype=np.int64) + r = self.module.gh17797(incr, y) + assert r == 123 + 1 + 2 + 3 + + +class TestGH18335(util.F2PyTest): + """The reproduction of the reported issue requires specific input that + extensions may break the issue conditions, so the reproducer is + implemented as a separate test class. Do not extend this test with + other tests! + """ + sources = [util.getpath("tests", "src", "callback", "gh18335.f90")] + + def test_gh18335(self): + def foo(x): + x[0] += 1 + + r = self.module.gh18335(foo) + assert r == 123 + 1 + + +class TestGH25211(util.F2PyTest): + sources = [util.getpath("tests", "src", "callback", "gh25211.f"), + util.getpath("tests", "src", "callback", "gh25211.pyf")] + module_name = "callback2" + + def test_gh18335(self): + def bar(x): + return x*x + + res = self.module.foo(bar) + assert res == 110 |