diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/numpy/typing/mypy_plugin.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/numpy/typing/mypy_plugin.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/numpy/typing/mypy_plugin.py | 196 |
1 files changed, 196 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/numpy/typing/mypy_plugin.py b/.venv/lib/python3.12/site-packages/numpy/typing/mypy_plugin.py new file mode 100644 index 00000000..8ec96370 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/numpy/typing/mypy_plugin.py @@ -0,0 +1,196 @@ +"""A mypy_ plugin for managing a number of platform-specific annotations. +Its functionality can be split into three distinct parts: + +* Assigning the (platform-dependent) precisions of certain `~numpy.number` + subclasses, including the likes of `~numpy.int_`, `~numpy.intp` and + `~numpy.longlong`. See the documentation on + :ref:`scalar types <arrays.scalars.built-in>` for a comprehensive overview + of the affected classes. Without the plugin the precision of all relevant + classes will be inferred as `~typing.Any`. +* Removing all extended-precision `~numpy.number` subclasses that are + unavailable for the platform in question. Most notably this includes the + likes of `~numpy.float128` and `~numpy.complex256`. Without the plugin *all* + extended-precision types will, as far as mypy is concerned, be available + to all platforms. +* Assigning the (platform-dependent) precision of `~numpy.ctypeslib.c_intp`. + Without the plugin the type will default to `ctypes.c_int64`. + + .. versionadded:: 1.22 + +Examples +-------- +To enable the plugin, one must add it to their mypy `configuration file`_: + +.. code-block:: ini + + [mypy] + plugins = numpy.typing.mypy_plugin + +.. _mypy: http://mypy-lang.org/ +.. _configuration file: https://mypy.readthedocs.io/en/stable/config_file.html + +""" + +from __future__ import annotations + +from collections.abc import Iterable +from typing import Final, TYPE_CHECKING, Callable + +import numpy as np + +try: + import mypy.types + from mypy.types import Type + from mypy.plugin import Plugin, AnalyzeTypeContext + from mypy.nodes import MypyFile, ImportFrom, Statement + from mypy.build import PRI_MED + + _HookFunc = Callable[[AnalyzeTypeContext], Type] + MYPY_EX: None | ModuleNotFoundError = None +except ModuleNotFoundError as ex: + MYPY_EX = ex + +__all__: list[str] = [] + + +def _get_precision_dict() -> dict[str, str]: + names = [ + ("_NBitByte", np.byte), + ("_NBitShort", np.short), + ("_NBitIntC", np.intc), + ("_NBitIntP", np.intp), + ("_NBitInt", np.int_), + ("_NBitLongLong", np.longlong), + + ("_NBitHalf", np.half), + ("_NBitSingle", np.single), + ("_NBitDouble", np.double), + ("_NBitLongDouble", np.longdouble), + ] + ret = {} + for name, typ in names: + n: int = 8 * typ().dtype.itemsize + ret[f'numpy._typing._nbit.{name}'] = f"numpy._{n}Bit" + return ret + + +def _get_extended_precision_list() -> list[str]: + extended_names = [ + "uint128", + "uint256", + "int128", + "int256", + "float80", + "float96", + "float128", + "float256", + "complex160", + "complex192", + "complex256", + "complex512", + ] + return [i for i in extended_names if hasattr(np, i)] + + +def _get_c_intp_name() -> str: + # Adapted from `np.core._internal._getintp_ctype` + char = np.dtype('p').char + if char == 'i': + return "c_int" + elif char == 'l': + return "c_long" + elif char == 'q': + return "c_longlong" + else: + return "c_long" + + +#: A dictionary mapping type-aliases in `numpy._typing._nbit` to +#: concrete `numpy.typing.NBitBase` subclasses. +_PRECISION_DICT: Final = _get_precision_dict() + +#: A list with the names of all extended precision `np.number` subclasses. +_EXTENDED_PRECISION_LIST: Final = _get_extended_precision_list() + +#: The name of the ctypes quivalent of `np.intp` +_C_INTP: Final = _get_c_intp_name() + + +def _hook(ctx: AnalyzeTypeContext) -> Type: + """Replace a type-alias with a concrete ``NBitBase`` subclass.""" + typ, _, api = ctx + name = typ.name.split(".")[-1] + name_new = _PRECISION_DICT[f"numpy._typing._nbit.{name}"] + return api.named_type(name_new) + + +if TYPE_CHECKING or MYPY_EX is None: + def _index(iterable: Iterable[Statement], id: str) -> int: + """Identify the first ``ImportFrom`` instance the specified `id`.""" + for i, value in enumerate(iterable): + if getattr(value, "id", None) == id: + return i + raise ValueError("Failed to identify a `ImportFrom` instance " + f"with the following id: {id!r}") + + def _override_imports( + file: MypyFile, + module: str, + imports: list[tuple[str, None | str]], + ) -> None: + """Override the first `module`-based import with new `imports`.""" + # Construct a new `from module import y` statement + import_obj = ImportFrom(module, 0, names=imports) + import_obj.is_top_level = True + + # Replace the first `module`-based import statement with `import_obj` + for lst in [file.defs, file.imports]: # type: list[Statement] + i = _index(lst, module) + lst[i] = import_obj + + class _NumpyPlugin(Plugin): + """A mypy plugin for handling versus numpy-specific typing tasks.""" + + def get_type_analyze_hook(self, fullname: str) -> None | _HookFunc: + """Set the precision of platform-specific `numpy.number` + subclasses. + + For example: `numpy.int_`, `numpy.longlong` and `numpy.longdouble`. + """ + if fullname in _PRECISION_DICT: + return _hook + return None + + def get_additional_deps( + self, file: MypyFile + ) -> list[tuple[int, str, int]]: + """Handle all import-based overrides. + + * Import platform-specific extended-precision `numpy.number` + subclasses (*e.g.* `numpy.float96`, `numpy.float128` and + `numpy.complex256`). + * Import the appropriate `ctypes` equivalent to `numpy.intp`. + + """ + ret = [(PRI_MED, file.fullname, -1)] + + if file.fullname == "numpy": + _override_imports( + file, "numpy._typing._extended_precision", + imports=[(v, v) for v in _EXTENDED_PRECISION_LIST], + ) + elif file.fullname == "numpy.ctypeslib": + _override_imports( + file, "ctypes", + imports=[(_C_INTP, "_c_intp")], + ) + return ret + + def plugin(version: str) -> type[_NumpyPlugin]: + """An entry-point for mypy.""" + return _NumpyPlugin + +else: + def plugin(version: str) -> type[_NumpyPlugin]: + """An entry-point for mypy.""" + raise MYPY_EX |