From 4a52a71956a8d46fcb7294ac71734504bb09bcc2 Mon Sep 17 00:00:00 2001 From: S. Solomon Darnell Date: Fri, 28 Mar 2025 21:52:21 -0500 Subject: two version of R2R are here --- .../numpy/array_api/_searching_functions.py | 51 ++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 .venv/lib/python3.12/site-packages/numpy/array_api/_searching_functions.py (limited to '.venv/lib/python3.12/site-packages/numpy/array_api/_searching_functions.py') diff --git a/.venv/lib/python3.12/site-packages/numpy/array_api/_searching_functions.py b/.venv/lib/python3.12/site-packages/numpy/array_api/_searching_functions.py new file mode 100644 index 00000000..a1f4b0c9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/numpy/array_api/_searching_functions.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +from ._array_object import Array +from ._dtypes import _result_type, _real_numeric_dtypes + +from typing import Optional, Tuple + +import numpy as np + + +def argmax(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> Array: + """ + Array API compatible wrapper for :py:func:`np.argmax `. + + See its docstring for more information. + """ + if x.dtype not in _real_numeric_dtypes: + raise TypeError("Only real numeric dtypes are allowed in argmax") + return Array._new(np.asarray(np.argmax(x._array, axis=axis, keepdims=keepdims))) + + +def argmin(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> Array: + """ + Array API compatible wrapper for :py:func:`np.argmin `. + + See its docstring for more information. + """ + if x.dtype not in _real_numeric_dtypes: + raise TypeError("Only real numeric dtypes are allowed in argmin") + return Array._new(np.asarray(np.argmin(x._array, axis=axis, keepdims=keepdims))) + + +def nonzero(x: Array, /) -> Tuple[Array, ...]: + """ + Array API compatible wrapper for :py:func:`np.nonzero `. + + See its docstring for more information. + """ + return tuple(Array._new(i) for i in np.nonzero(x._array)) + + +def where(condition: Array, x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.where `. + + See its docstring for more information. + """ + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.where(condition._array, x1._array, x2._array)) -- cgit 1.4.1