about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/numpy/array_api/_manipulation_functions.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/numpy/array_api/_manipulation_functions.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/numpy/array_api/_manipulation_functions.py')
-rw-r--r--.venv/lib/python3.12/site-packages/numpy/array_api/_manipulation_functions.py112
1 files changed, 112 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/numpy/array_api/_manipulation_functions.py b/.venv/lib/python3.12/site-packages/numpy/array_api/_manipulation_functions.py
new file mode 100644
index 00000000..556bde7d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/numpy/array_api/_manipulation_functions.py
@@ -0,0 +1,112 @@
+from __future__ import annotations
+
+from ._array_object import Array
+from ._data_type_functions import result_type
+
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+
+# Note: the function name is different here
+def concat(
+    arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[int] = 0
+) -> Array:
+    """
+    Array API compatible wrapper for :py:func:`np.concatenate <numpy.concatenate>`.
+
+    See its docstring for more information.
+    """
+    # Note: Casting rules here are different from the np.concatenate default
+    # (no for scalars with axis=None, no cross-kind casting)
+    dtype = result_type(*arrays)
+    arrays = tuple(a._array for a in arrays)
+    return Array._new(np.concatenate(arrays, axis=axis, dtype=dtype))
+
+
+def expand_dims(x: Array, /, *, axis: int) -> Array:
+    """
+    Array API compatible wrapper for :py:func:`np.expand_dims <numpy.expand_dims>`.
+
+    See its docstring for more information.
+    """
+    return Array._new(np.expand_dims(x._array, axis))
+
+
+def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array:
+    """
+    Array API compatible wrapper for :py:func:`np.flip <numpy.flip>`.
+
+    See its docstring for more information.
+    """
+    return Array._new(np.flip(x._array, axis=axis))
+
+
+# Note: The function name is different here (see also matrix_transpose).
+# Unlike transpose(), the axes argument is required.
+def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array:
+    """
+    Array API compatible wrapper for :py:func:`np.transpose <numpy.transpose>`.
+
+    See its docstring for more information.
+    """
+    return Array._new(np.transpose(x._array, axes))
+
+
+# Note: the optional argument is called 'shape', not 'newshape'
+def reshape(x: Array, 
+            /, 
+            shape: Tuple[int, ...],
+            *,
+            copy: Optional[Bool] = None) -> Array:
+    """
+    Array API compatible wrapper for :py:func:`np.reshape <numpy.reshape>`.
+
+    See its docstring for more information.
+    """
+
+    data = x._array
+    if copy:
+        data = np.copy(data)
+
+    reshaped = np.reshape(data, shape)
+
+    if copy is False and not np.shares_memory(data, reshaped):
+        raise AttributeError("Incompatible shape for in-place modification.")
+
+    return Array._new(reshaped)
+
+
+def roll(
+    x: Array,
+    /,
+    shift: Union[int, Tuple[int, ...]],
+    *,
+    axis: Optional[Union[int, Tuple[int, ...]]] = None,
+) -> Array:
+    """
+    Array API compatible wrapper for :py:func:`np.roll <numpy.roll>`.
+
+    See its docstring for more information.
+    """
+    return Array._new(np.roll(x._array, shift, axis=axis))
+
+
+def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array:
+    """
+    Array API compatible wrapper for :py:func:`np.squeeze <numpy.squeeze>`.
+
+    See its docstring for more information.
+    """
+    return Array._new(np.squeeze(x._array, axis=axis))
+
+
+def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) -> Array:
+    """
+    Array API compatible wrapper for :py:func:`np.stack <numpy.stack>`.
+
+    See its docstring for more information.
+    """
+    # Call result type here just to raise on disallowed type combinations
+    result_type(*arrays)
+    arrays = tuple(a._array for a in arrays)
+    return Array._new(np.stack(arrays, axis=axis))