"""Generic utility methods not part of main API.""" from __future__ import annotations import builtins from collections.abc import Hashable, Iterable, Mapping, Sequence from decimal import Decimal from functools import wraps import inspect import operator import warnings import pydash as pyd #: Singleton object that differentiates between an explicit ``None`` value and an unset value. #: As a class so it has its own type class Unset: ... UNSET = Unset() #: Tuple of number types. NUMBER_TYPES = (int, float, Decimal) #: Dictionary of builtins with keys as the builtin function and values as the string name. BUILTINS = {value: key for key, value in builtins.__dict__.items() if isinstance(value, Hashable)} #: Object keys that are restricted from access via path access. RESTRICTED_KEYS = ("__globals__", "__builtins__") #: Inspect signature parameter kinds that correspond to positional arguments. POSITIONAL_PARAMETERS = ( inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD, ) def callit(iteratee, *args, **kwargs): """Inspect argspec of `iteratee` function and only pass the supported arguments when calling it.""" maxargs = len(args) argcount = kwargs["argcount"] if "argcount" in kwargs else getargcount(iteratee, maxargs) argstop = min([maxargs, argcount]) return iteratee(*args[:argstop]) def getargcount(iteratee, maxargs): """Return argument count of iteratee function.""" if hasattr(iteratee, "_argcount"): # Optimization feature where argcount of iteratee is known and properly # set by initiator. # It should always be right, but it can be `None` for the function wrappers # in `pydash.function` as the wrapped functions are out of our control and # can support an unknown number of arguments. argcount = iteratee._argcount return argcount if argcount is not None else maxargs if isinstance(iteratee, type) or pyd.is_builtin(iteratee): # Only pass single argument to type iteratees or builtins. argcount = 1 else: argcount = 1 try: argcount = _getargcount(iteratee, maxargs) except TypeError: # pragma: no cover pass return argcount def _getargcount(iteratee, maxargs): argcount = None try: # PY2: inspect.signature was added in Python 3. # Try to use inspect.signature when possible since it works better for our purpose of # getting the iteratee argcount since it takes into account the "self" argument in callable # classes. sig = inspect.signature(iteratee) except (TypeError, ValueError, AttributeError): # pragma: no cover pass else: # VAR_POSITIONAL corresponds to *args so we only want to count parameters if there isn't a # catch-all for positional args. params = list(sig.parameters.values()) if not any(param.kind == inspect.Parameter.VAR_POSITIONAL for param in params): positional_params = [p for p in params if p.kind in POSITIONAL_PARAMETERS] argcount = len(positional_params) if argcount is None: # Signatures were added these operator methods in Python 3.12.3 and 3.11.9 but their # instance objects are incorrectly reported as accepting varargs when they only accept a # single argument. if isinstance(iteratee, (operator.itemgetter, operator.attrgetter, operator.methodcaller)): argcount = 1 else: argspec = inspect.getfullargspec(iteratee) if argspec and not argspec.varargs: # pragma: no cover # Use inspected arg count. argcount = len(argspec.args) if argcount is None: # Assume all args are handleable. argcount = maxargs return argcount def iteriteratee(obj, iteratee=None, reverse=False): """Return iterative iteratee based on collection type.""" if iteratee is None: cbk = pyd.identity argcount = 1 else: cbk = pyd.iteratee(iteratee) argcount = getargcount(cbk, maxargs=3) items = iterator(obj) if reverse: items = reversed(tuple(items)) for key, item in items: yield callit(cbk, item, key, obj, argcount=argcount), item, key, obj def iterator(obj): """Return iterative based on object type.""" if isinstance(obj, Mapping): return obj.items() elif hasattr(obj, "iteritems"): return obj.iteritems() # noqa: B301 elif hasattr(obj, "items"): return iter(obj.items()) elif isinstance(obj, Iterable): return enumerate(obj) else: return getattr(obj, "__dict__", {}).items() def base_get(obj, key, default=UNSET): """ Safely get an item by `key` from a sequence or mapping object when `default` provided. Args: obj: Sequence or mapping to retrieve item from. key: Key or index identifying which item to retrieve. default: Default value to return if `key` not found in `obj`. Returns: `obj[key]`, `obj.key`, or `default`. Raises: KeyError: If `obj` is missing key, index, or attribute and no default value provided. """ if isinstance(obj, dict): value = _base_get_dict(obj, key, default=default) elif not isinstance(obj, (Mapping, Sequence)) or ( isinstance(obj, tuple) and hasattr(obj, "_fields") ): # Don't use getattr for dict/list objects since we don't want class methods/attributes # returned for them but do allow getattr for namedtuple. value = _base_get_object(obj, key, default=default) else: value = _base_get_item(obj, key, default=default) if value is UNSET: # Raise if there's no default provided. raise KeyError(f'Object "{repr(obj)}" does not have key "{key}"') return value def _base_get_dict(obj, key, default=UNSET): value = obj.get(key, UNSET) if value is UNSET: value = default if not isinstance(key, int): # Try integer key fallback. try: value = obj.get(int(key), default) except Exception: pass return value def _base_get_item(obj, key, default=UNSET): try: return obj[key] except Exception: pass if not isinstance(key, int): try: return obj[int(key)] except Exception: pass return default def _base_get_object(obj, key, default=UNSET): value = _base_get_item(obj, key, default=UNSET) if value is UNSET: _raise_if_restricted_key(key) value = default try: value = getattr(obj, key) except Exception: pass return value def _raise_if_restricted_key(key): # Prevent access to restricted keys for security reasons. if key in RESTRICTED_KEYS: raise KeyError(f"access to restricted key {key!r} is not allowed") def base_set(obj, key, value, allow_override=True): """ Set an object's `key` to `value`. If `obj` is a ``list`` and the `key` is the next available index position, append to list; otherwise, pad the list of ``None`` and then append to the list. Args: obj: Object to assign value to. key: Key or index to assign to. value: Value to assign. allow_override: Whether to allow overriding a previously set key. """ if isinstance(obj, dict): if allow_override or key not in obj: obj[key] = value elif isinstance(obj, list): key = int(key) if key < len(obj): if allow_override: obj[key] = value else: if key > len(obj): # Pad list object with None values up to the index key, so we can append the value # into the key index. obj[:] = (obj + [None] * key)[:key] obj.append(value) elif (allow_override or not hasattr(obj, key)) and obj is not None: _raise_if_restricted_key(key) setattr(obj, key, value) return obj def cmp(a, b): # pragma: no cover """ Replacement for built-in function ``cmp`` that was removed in Python 3. Note: Mainly used for comparison during sorting. """ if a is None and b is None: return 0 elif a is None: return -1 elif b is None: return 1 return (a > b) - (a < b) def parse_iteratee(iteratee_keyword, *args, **kwargs): """Try to find iteratee function passed in either as a keyword argument or as the last positional argument in `args`.""" iteratee = kwargs.get(iteratee_keyword) last_arg = args[-1] if iteratee is None and ( callable(last_arg) or isinstance(last_arg, (dict, str)) or last_arg is None ): iteratee = last_arg args = args[:-1] return iteratee, args class iterator_with_default(object): """A wrapper around an iterator object that provides a default.""" def __init__(self, collection, default): self.iter = iter(collection) self.default = default def __iter__(self): return self def next_default(self): ret = self.default self.default = UNSET return ret def __next__(self): ret = next(self.iter, self.next_default()) if ret is UNSET: raise StopIteration return ret next = __next__ def deprecated(func): # pragma: no cover """ This is a decorator which can be used to mark functions as deprecated. It will result in a warning being emitted when the function is used. """ @wraps(func) def wrapper(*args, **kwargs): warnings.warn( f"Call to deprecated function {func.__name__}.", category=DeprecationWarning, stacklevel=3, ) return func(*args, **kwargs) return wrapper