diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/numpy/array_api/_set_functions.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/numpy/array_api/_set_functions.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/numpy/array_api/_set_functions.py | 106 |
1 files changed, 106 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/numpy/array_api/_set_functions.py b/.venv/lib/python3.12/site-packages/numpy/array_api/_set_functions.py new file mode 100644 index 00000000..0b4132cf --- /dev/null +++ b/.venv/lib/python3.12/site-packages/numpy/array_api/_set_functions.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +from ._array_object import Array + +from typing import NamedTuple + +import numpy as np + +# Note: np.unique() is split into four functions in the array API: +# unique_all, unique_counts, unique_inverse, and unique_values (this is done +# to remove polymorphic return types). + +# Note: The various unique() functions are supposed to return multiple NaNs. +# This does not match the NumPy behavior, however, this is currently left as a +# TODO in this implementation as this behavior may be reverted in np.unique(). +# See https://github.com/numpy/numpy/issues/20326. + +# Note: The functions here return a namedtuple (np.unique() returns a normal +# tuple). + +class UniqueAllResult(NamedTuple): + values: Array + indices: Array + inverse_indices: Array + counts: Array + + +class UniqueCountsResult(NamedTuple): + values: Array + counts: Array + + +class UniqueInverseResult(NamedTuple): + values: Array + inverse_indices: Array + + +def unique_all(x: Array, /) -> UniqueAllResult: + """ + Array API compatible wrapper for :py:func:`np.unique <numpy.unique>`. + + See its docstring for more information. + """ + values, indices, inverse_indices, counts = np.unique( + x._array, + return_counts=True, + return_index=True, + return_inverse=True, + equal_nan=False, + ) + # np.unique() flattens inverse indices, but they need to share x's shape + # See https://github.com/numpy/numpy/issues/20638 + inverse_indices = inverse_indices.reshape(x.shape) + return UniqueAllResult( + Array._new(values), + Array._new(indices), + Array._new(inverse_indices), + Array._new(counts), + ) + + +def unique_counts(x: Array, /) -> UniqueCountsResult: + res = np.unique( + x._array, + return_counts=True, + return_index=False, + return_inverse=False, + equal_nan=False, + ) + + return UniqueCountsResult(*[Array._new(i) for i in res]) + + +def unique_inverse(x: Array, /) -> UniqueInverseResult: + """ + Array API compatible wrapper for :py:func:`np.unique <numpy.unique>`. + + See its docstring for more information. + """ + values, inverse_indices = np.unique( + x._array, + return_counts=False, + return_index=False, + return_inverse=True, + equal_nan=False, + ) + # np.unique() flattens inverse indices, but they need to share x's shape + # See https://github.com/numpy/numpy/issues/20638 + inverse_indices = inverse_indices.reshape(x.shape) + return UniqueInverseResult(Array._new(values), Array._new(inverse_indices)) + + +def unique_values(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.unique <numpy.unique>`. + + See its docstring for more information. + """ + res = np.unique( + x._array, + return_counts=False, + return_index=False, + return_inverse=False, + equal_nan=False, + ) + return Array._new(res) |