about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/numpy/array_api/_creation_functions.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/numpy/array_api/_creation_functions.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/numpy/array_api/_creation_functions.py')
-rw-r--r--.venv/lib/python3.12/site-packages/numpy/array_api/_creation_functions.py351
1 files changed, 351 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/numpy/array_api/_creation_functions.py b/.venv/lib/python3.12/site-packages/numpy/array_api/_creation_functions.py
new file mode 100644
index 00000000..3b014d37
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/numpy/array_api/_creation_functions.py
@@ -0,0 +1,351 @@
+from __future__ import annotations
+
+
+from typing import TYPE_CHECKING, List, Optional, Tuple, Union
+
+if TYPE_CHECKING:
+    from ._typing import (
+        Array,
+        Device,
+        Dtype,
+        NestedSequence,
+        SupportsBufferProtocol,
+    )
+    from collections.abc import Sequence
+from ._dtypes import _all_dtypes
+
+import numpy as np
+
+
+def _check_valid_dtype(dtype):
+    # Note: Only spelling dtypes as the dtype objects is supported.
+
+    # We use this instead of "dtype in _all_dtypes" because the dtype objects
+    # define equality with the sorts of things we want to disallow.
+    for d in (None,) + _all_dtypes:
+        if dtype is d:
+            return
+    raise ValueError("dtype must be one of the supported dtypes")
+
+
+def asarray(
+    obj: Union[
+        Array,
+        bool,
+        int,
+        float,
+        NestedSequence[bool | int | float],
+        SupportsBufferProtocol,
+    ],
+    /,
+    *,
+    dtype: Optional[Dtype] = None,
+    device: Optional[Device] = None,
+    copy: Optional[Union[bool, np._CopyMode]] = None,
+) -> Array:
+    """
+    Array API compatible wrapper for :py:func:`np.asarray <numpy.asarray>`.
+
+    See its docstring for more information.
+    """
+    # _array_object imports in this file are inside the functions to avoid
+    # circular imports
+    from ._array_object import Array
+
+    _check_valid_dtype(dtype)
+    if device not in ["cpu", None]:
+        raise ValueError(f"Unsupported device {device!r}")
+    if copy in (False, np._CopyMode.IF_NEEDED):
+        # Note: copy=False is not yet implemented in np.asarray
+        raise NotImplementedError("copy=False is not yet implemented")
+    if isinstance(obj, Array):
+        if dtype is not None and obj.dtype != dtype:
+            copy = True
+        if copy in (True, np._CopyMode.ALWAYS):
+            return Array._new(np.array(obj._array, copy=True, dtype=dtype))
+        return obj
+    if dtype is None and isinstance(obj, int) and (obj > 2 ** 64 or obj < -(2 ** 63)):
+        # Give a better error message in this case. NumPy would convert this
+        # to an object array. TODO: This won't handle large integers in lists.
+        raise OverflowError("Integer out of bounds for array dtypes")
+    res = np.asarray(obj, dtype=dtype)
+    return Array._new(res)
+
+
+def arange(
+    start: Union[int, float],
+    /,
+    stop: Optional[Union[int, float]] = None,
+    step: Union[int, float] = 1,
+    *,
+    dtype: Optional[Dtype] = None,
+    device: Optional[Device] = None,
+) -> Array:
+    """
+    Array API compatible wrapper for :py:func:`np.arange <numpy.arange>`.
+
+    See its docstring for more information.
+    """
+    from ._array_object import Array
+
+    _check_valid_dtype(dtype)
+    if device not in ["cpu", None]:
+        raise ValueError(f"Unsupported device {device!r}")
+    return Array._new(np.arange(start, stop=stop, step=step, dtype=dtype))
+
+
+def empty(
+    shape: Union[int, Tuple[int, ...]],
+    *,
+    dtype: Optional[Dtype] = None,
+    device: Optional[Device] = None,
+) -> Array:
+    """
+    Array API compatible wrapper for :py:func:`np.empty <numpy.empty>`.
+
+    See its docstring for more information.
+    """
+    from ._array_object import Array
+
+    _check_valid_dtype(dtype)
+    if device not in ["cpu", None]:
+        raise ValueError(f"Unsupported device {device!r}")
+    return Array._new(np.empty(shape, dtype=dtype))
+
+
+def empty_like(
+    x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
+) -> Array:
+    """
+    Array API compatible wrapper for :py:func:`np.empty_like <numpy.empty_like>`.
+
+    See its docstring for more information.
+    """
+    from ._array_object import Array
+
+    _check_valid_dtype(dtype)
+    if device not in ["cpu", None]:
+        raise ValueError(f"Unsupported device {device!r}")
+    return Array._new(np.empty_like(x._array, dtype=dtype))
+
+
+def eye(
+    n_rows: int,
+    n_cols: Optional[int] = None,
+    /,
+    *,
+    k: int = 0,
+    dtype: Optional[Dtype] = None,
+    device: Optional[Device] = None,
+) -> Array:
+    """
+    Array API compatible wrapper for :py:func:`np.eye <numpy.eye>`.
+
+    See its docstring for more information.
+    """
+    from ._array_object import Array
+
+    _check_valid_dtype(dtype)
+    if device not in ["cpu", None]:
+        raise ValueError(f"Unsupported device {device!r}")
+    return Array._new(np.eye(n_rows, M=n_cols, k=k, dtype=dtype))
+
+
+def from_dlpack(x: object, /) -> Array:
+    from ._array_object import Array
+
+    return Array._new(np.from_dlpack(x))
+
+
+def full(
+    shape: Union[int, Tuple[int, ...]],
+    fill_value: Union[int, float],
+    *,
+    dtype: Optional[Dtype] = None,
+    device: Optional[Device] = None,
+) -> Array:
+    """
+    Array API compatible wrapper for :py:func:`np.full <numpy.full>`.
+
+    See its docstring for more information.
+    """
+    from ._array_object import Array
+
+    _check_valid_dtype(dtype)
+    if device not in ["cpu", None]:
+        raise ValueError(f"Unsupported device {device!r}")
+    if isinstance(fill_value, Array) and fill_value.ndim == 0:
+        fill_value = fill_value._array
+    res = np.full(shape, fill_value, dtype=dtype)
+    if res.dtype not in _all_dtypes:
+        # This will happen if the fill value is not something that NumPy
+        # coerces to one of the acceptable dtypes.
+        raise TypeError("Invalid input to full")
+    return Array._new(res)
+
+
+def full_like(
+    x: Array,
+    /,
+    fill_value: Union[int, float],
+    *,
+    dtype: Optional[Dtype] = None,
+    device: Optional[Device] = None,
+) -> Array:
+    """
+    Array API compatible wrapper for :py:func:`np.full_like <numpy.full_like>`.
+
+    See its docstring for more information.
+    """
+    from ._array_object import Array
+
+    _check_valid_dtype(dtype)
+    if device not in ["cpu", None]:
+        raise ValueError(f"Unsupported device {device!r}")
+    res = np.full_like(x._array, fill_value, dtype=dtype)
+    if res.dtype not in _all_dtypes:
+        # This will happen if the fill value is not something that NumPy
+        # coerces to one of the acceptable dtypes.
+        raise TypeError("Invalid input to full_like")
+    return Array._new(res)
+
+
+def linspace(
+    start: Union[int, float],
+    stop: Union[int, float],
+    /,
+    num: int,
+    *,
+    dtype: Optional[Dtype] = None,
+    device: Optional[Device] = None,
+    endpoint: bool = True,
+) -> Array:
+    """
+    Array API compatible wrapper for :py:func:`np.linspace <numpy.linspace>`.
+
+    See its docstring for more information.
+    """
+    from ._array_object import Array
+
+    _check_valid_dtype(dtype)
+    if device not in ["cpu", None]:
+        raise ValueError(f"Unsupported device {device!r}")
+    return Array._new(np.linspace(start, stop, num, dtype=dtype, endpoint=endpoint))
+
+
+def meshgrid(*arrays: Array, indexing: str = "xy") -> List[Array]:
+    """
+    Array API compatible wrapper for :py:func:`np.meshgrid <numpy.meshgrid>`.
+
+    See its docstring for more information.
+    """
+    from ._array_object import Array
+
+    # Note: unlike np.meshgrid, only inputs with all the same dtype are
+    # allowed
+
+    if len({a.dtype for a in arrays}) > 1:
+        raise ValueError("meshgrid inputs must all have the same dtype")
+
+    return [
+        Array._new(array)
+        for array in np.meshgrid(*[a._array for a in arrays], indexing=indexing)
+    ]
+
+
+def ones(
+    shape: Union[int, Tuple[int, ...]],
+    *,
+    dtype: Optional[Dtype] = None,
+    device: Optional[Device] = None,
+) -> Array:
+    """
+    Array API compatible wrapper for :py:func:`np.ones <numpy.ones>`.
+
+    See its docstring for more information.
+    """
+    from ._array_object import Array
+
+    _check_valid_dtype(dtype)
+    if device not in ["cpu", None]:
+        raise ValueError(f"Unsupported device {device!r}")
+    return Array._new(np.ones(shape, dtype=dtype))
+
+
+def ones_like(
+    x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
+) -> Array:
+    """
+    Array API compatible wrapper for :py:func:`np.ones_like <numpy.ones_like>`.
+
+    See its docstring for more information.
+    """
+    from ._array_object import Array
+
+    _check_valid_dtype(dtype)
+    if device not in ["cpu", None]:
+        raise ValueError(f"Unsupported device {device!r}")
+    return Array._new(np.ones_like(x._array, dtype=dtype))
+
+
+def tril(x: Array, /, *, k: int = 0) -> Array:
+    """
+    Array API compatible wrapper for :py:func:`np.tril <numpy.tril>`.
+
+    See its docstring for more information.
+    """
+    from ._array_object import Array
+
+    if x.ndim < 2:
+        # Note: Unlike np.tril, x must be at least 2-D
+        raise ValueError("x must be at least 2-dimensional for tril")
+    return Array._new(np.tril(x._array, k=k))
+
+
+def triu(x: Array, /, *, k: int = 0) -> Array:
+    """
+    Array API compatible wrapper for :py:func:`np.triu <numpy.triu>`.
+
+    See its docstring for more information.
+    """
+    from ._array_object import Array
+
+    if x.ndim < 2:
+        # Note: Unlike np.triu, x must be at least 2-D
+        raise ValueError("x must be at least 2-dimensional for triu")
+    return Array._new(np.triu(x._array, k=k))
+
+
+def zeros(
+    shape: Union[int, Tuple[int, ...]],
+    *,
+    dtype: Optional[Dtype] = None,
+    device: Optional[Device] = None,
+) -> Array:
+    """
+    Array API compatible wrapper for :py:func:`np.zeros <numpy.zeros>`.
+
+    See its docstring for more information.
+    """
+    from ._array_object import Array
+
+    _check_valid_dtype(dtype)
+    if device not in ["cpu", None]:
+        raise ValueError(f"Unsupported device {device!r}")
+    return Array._new(np.zeros(shape, dtype=dtype))
+
+
+def zeros_like(
+    x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
+) -> Array:
+    """
+    Array API compatible wrapper for :py:func:`np.zeros_like <numpy.zeros_like>`.
+
+    See its docstring for more information.
+    """
+    from ._array_object import Array
+
+    _check_valid_dtype(dtype)
+    if device not in ["cpu", None]:
+        raise ValueError(f"Unsupported device {device!r}")
+    return Array._new(np.zeros_like(x._array, dtype=dtype))