aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/pydash/helpers.py
blob: ee9accb15a68e01a1ccdf9655c979ac2bc319c6c (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
"""Generic utility methods not part of main API."""

from __future__ import annotations

import builtins
from collections.abc import Hashable, Iterable, Mapping, Sequence
from decimal import Decimal
from functools import wraps
import inspect
import operator
import warnings

import pydash as pyd


#: Singleton object that differentiates between an explicit ``None`` value and an unset value.
#: As a class so it has its own type
class Unset: ...


UNSET = Unset()

#: Tuple of number types.
NUMBER_TYPES = (int, float, Decimal)

#: Dictionary of builtins with keys as the builtin function and values as the string name.
BUILTINS = {value: key for key, value in builtins.__dict__.items() if isinstance(value, Hashable)}

#: Object keys that are restricted from access via path access.
RESTRICTED_KEYS = ("__globals__", "__builtins__")

#: Inspect signature parameter kinds that correspond to positional arguments.
POSITIONAL_PARAMETERS = (
    inspect.Parameter.VAR_POSITIONAL,
    inspect.Parameter.POSITIONAL_ONLY,
    inspect.Parameter.POSITIONAL_OR_KEYWORD,
)


def callit(iteratee, *args, **kwargs):
    """Inspect argspec of `iteratee` function and only pass the supported arguments when calling
    it."""
    maxargs = len(args)
    argcount = kwargs["argcount"] if "argcount" in kwargs else getargcount(iteratee, maxargs)
    argstop = min([maxargs, argcount])

    return iteratee(*args[:argstop])


def getargcount(iteratee, maxargs):
    """Return argument count of iteratee function."""
    if hasattr(iteratee, "_argcount"):
        # Optimization feature where argcount of iteratee is known and properly
        # set by initiator.
        # It should always be right, but it can be `None` for the function wrappers
        # in `pydash.function` as the wrapped functions are out of our control and
        # can support an unknown number of arguments.
        argcount = iteratee._argcount
        return argcount if argcount is not None else maxargs

    if isinstance(iteratee, type) or pyd.is_builtin(iteratee):
        # Only pass single argument to type iteratees or builtins.
        argcount = 1
    else:
        argcount = 1

        try:
            argcount = _getargcount(iteratee, maxargs)
        except TypeError:  # pragma: no cover
            pass

    return argcount


def _getargcount(iteratee, maxargs):
    argcount = None

    try:
        # PY2: inspect.signature was added in Python 3.
        # Try to use inspect.signature when possible since it works better for our purpose of
        # getting the iteratee argcount since it takes into account the "self" argument in callable
        # classes.
        sig = inspect.signature(iteratee)
    except (TypeError, ValueError, AttributeError):  # pragma: no cover
        pass
    else:
        # VAR_POSITIONAL corresponds to *args so we only want to count parameters if there isn't a
        # catch-all for positional args.
        params = list(sig.parameters.values())
        if not any(param.kind == inspect.Parameter.VAR_POSITIONAL for param in params):
            positional_params = [p for p in params if p.kind in POSITIONAL_PARAMETERS]
            argcount = len(positional_params)

    if argcount is None:
        # Signatures were added these operator methods in Python 3.12.3 and 3.11.9 but their
        # instance objects are incorrectly reported as accepting varargs when they only accept a
        # single argument.
        if isinstance(iteratee, (operator.itemgetter, operator.attrgetter, operator.methodcaller)):
            argcount = 1
        else:
            argspec = inspect.getfullargspec(iteratee)
            if argspec and not argspec.varargs:  # pragma: no cover
                # Use inspected arg count.
                argcount = len(argspec.args)

    if argcount is None:
        # Assume all args are handleable.
        argcount = maxargs

    return argcount


def iteriteratee(obj, iteratee=None, reverse=False):
    """Return iterative iteratee based on collection type."""
    if iteratee is None:
        cbk = pyd.identity
        argcount = 1
    else:
        cbk = pyd.iteratee(iteratee)
        argcount = getargcount(cbk, maxargs=3)

    items = iterator(obj)

    if reverse:
        items = reversed(tuple(items))

    for key, item in items:
        yield callit(cbk, item, key, obj, argcount=argcount), item, key, obj


def iterator(obj):
    """Return iterative based on object type."""
    if isinstance(obj, Mapping):
        return obj.items()
    elif hasattr(obj, "iteritems"):
        return obj.iteritems()  # noqa: B301
    elif hasattr(obj, "items"):
        return iter(obj.items())
    elif isinstance(obj, Iterable):
        return enumerate(obj)
    else:
        return getattr(obj, "__dict__", {}).items()


def base_get(obj, key, default=UNSET):
    """
    Safely get an item by `key` from a sequence or mapping object when `default` provided.

    Args:
        obj: Sequence or mapping to retrieve item from.
        key: Key or index identifying which item to retrieve.
        default: Default value to return if `key` not found in `obj`.

    Returns:
        `obj[key]`, `obj.key`, or `default`.

    Raises:
        KeyError: If `obj` is missing key, index, or attribute and no default value provided.
    """
    if isinstance(obj, dict):
        value = _base_get_dict(obj, key, default=default)
    elif not isinstance(obj, (Mapping, Sequence)) or (
        isinstance(obj, tuple) and hasattr(obj, "_fields")
    ):
        # Don't use getattr for dict/list objects since we don't want class methods/attributes
        # returned for them but do allow getattr for namedtuple.
        value = _base_get_object(obj, key, default=default)
    else:
        value = _base_get_item(obj, key, default=default)

    if value is UNSET:
        # Raise if there's no default provided.
        raise KeyError(f'Object "{repr(obj)}" does not have key "{key}"')

    return value


def _base_get_dict(obj, key, default=UNSET):
    value = obj.get(key, UNSET)
    if value is UNSET:
        value = default
        if not isinstance(key, int):
            # Try integer key fallback.
            try:
                value = obj.get(int(key), default)
            except Exception:
                pass
    return value


def _base_get_item(obj, key, default=UNSET):
    try:
        return obj[key]
    except Exception:
        pass

    if not isinstance(key, int):
        try:
            return obj[int(key)]
        except Exception:
            pass

    return default


def _base_get_object(obj, key, default=UNSET):
    value = _base_get_item(obj, key, default=UNSET)
    if value is UNSET:
        _raise_if_restricted_key(key)
        value = default
        try:
            value = getattr(obj, key)
        except Exception:
            pass
    return value


def _raise_if_restricted_key(key):
    # Prevent access to restricted keys for security reasons.
    if key in RESTRICTED_KEYS:
        raise KeyError(f"access to restricted key {key!r} is not allowed")


def base_set(obj, key, value, allow_override=True):
    """
    Set an object's `key` to `value`. If `obj` is a ``list`` and the `key` is the next available
    index position, append to list; otherwise, pad the list of ``None`` and then append to the list.

    Args:
        obj: Object to assign value to.
        key: Key or index to assign to.
        value: Value to assign.
        allow_override: Whether to allow overriding a previously set key.
    """
    if isinstance(obj, dict):
        if allow_override or key not in obj:
            obj[key] = value
    elif isinstance(obj, list):
        key = int(key)

        if key < len(obj):
            if allow_override:
                obj[key] = value
        else:
            if key > len(obj):
                # Pad list object with None values up to the index key, so we can append the value
                # into the key index.
                obj[:] = (obj + [None] * key)[:key]
            obj.append(value)
    elif (allow_override or not hasattr(obj, key)) and obj is not None:
        _raise_if_restricted_key(key)
        setattr(obj, key, value)

    return obj


def cmp(a, b):  # pragma: no cover
    """
    Replacement for built-in function ``cmp`` that was removed in Python 3.

    Note: Mainly used for comparison during sorting.
    """
    if a is None and b is None:
        return 0
    elif a is None:
        return -1
    elif b is None:
        return 1
    return (a > b) - (a < b)


def parse_iteratee(iteratee_keyword, *args, **kwargs):
    """Try to find iteratee function passed in either as a keyword argument or as the last
    positional argument in `args`."""
    iteratee = kwargs.get(iteratee_keyword)
    last_arg = args[-1]

    if iteratee is None and (
        callable(last_arg) or isinstance(last_arg, (dict, str)) or last_arg is None
    ):
        iteratee = last_arg
        args = args[:-1]

    return iteratee, args


class iterator_with_default(object):
    """A wrapper around an iterator object that provides a default."""

    def __init__(self, collection, default):
        self.iter = iter(collection)
        self.default = default

    def __iter__(self):
        return self

    def next_default(self):
        ret = self.default
        self.default = UNSET
        return ret

    def __next__(self):
        ret = next(self.iter, self.next_default())
        if ret is UNSET:
            raise StopIteration
        return ret

    next = __next__


def deprecated(func):  # pragma: no cover
    """
    This is a decorator which can be used to mark functions as deprecated.

    It will result in a warning being emitted when the function is used.
    """

    @wraps(func)
    def wrapper(*args, **kwargs):
        warnings.warn(
            f"Call to deprecated function {func.__name__}.",
            category=DeprecationWarning,
            stacklevel=3,
        )
        return func(*args, **kwargs)

    return wrapper