about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/dataclasses_json/undefined.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/dataclasses_json/undefined.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/dataclasses_json/undefined.py')
-rw-r--r--.venv/lib/python3.12/site-packages/dataclasses_json/undefined.py280
1 files changed, 280 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/dataclasses_json/undefined.py b/.venv/lib/python3.12/site-packages/dataclasses_json/undefined.py
new file mode 100644
index 00000000..cb8b2cfc
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/dataclasses_json/undefined.py
@@ -0,0 +1,280 @@
+import abc
+import dataclasses
+import functools
+import inspect
+import sys
+from dataclasses import Field, fields
+from typing import Any, Callable, Dict, Optional, Tuple, Union, Type, get_type_hints
+from enum import Enum
+
+from marshmallow.exceptions import ValidationError  # type: ignore
+
+from dataclasses_json.utils import CatchAllVar
+
+KnownParameters = Dict[str, Any]
+UnknownParameters = Dict[str, Any]
+
+
+class _UndefinedParameterAction(abc.ABC):
+    @staticmethod
+    @abc.abstractmethod
+    def handle_from_dict(cls, kvs: Dict[Any, Any]) -> Dict[str, Any]:
+        """
+        Return the parameters to initialize the class with.
+        """
+        pass
+
+    @staticmethod
+    def handle_to_dict(obj, kvs: Dict[Any, Any]) -> Dict[Any, Any]:
+        """
+        Return the parameters that will be written to the output dict
+        """
+        return kvs
+
+    @staticmethod
+    def handle_dump(obj) -> Dict[Any, Any]:
+        """
+        Return the parameters that will be added to the schema dump.
+        """
+        return {}
+
+    @staticmethod
+    def create_init(obj) -> Callable:
+        return obj.__init__
+
+    @staticmethod
+    def _separate_defined_undefined_kvs(cls, kvs: Dict) -> \
+            Tuple[KnownParameters, UnknownParameters]:
+        """
+        Returns a 2 dictionaries: defined and undefined parameters
+        """
+        class_fields = fields(cls)
+        field_names = [field.name for field in class_fields]
+        unknown_given_parameters = {k: v for k, v in kvs.items() if
+                                    k not in field_names}
+        known_given_parameters = {k: v for k, v in kvs.items() if
+                                  k in field_names}
+        return known_given_parameters, unknown_given_parameters
+
+
+class _RaiseUndefinedParameters(_UndefinedParameterAction):
+    """
+    This action raises UndefinedParameterError if it encounters an undefined
+    parameter during initialization.
+    """
+
+    @staticmethod
+    def handle_from_dict(cls, kvs: Dict) -> Dict[str, Any]:
+        known, unknown = \
+            _UndefinedParameterAction._separate_defined_undefined_kvs(
+                cls=cls, kvs=kvs)
+        if len(unknown) > 0:
+            raise UndefinedParameterError(
+                f"Received undefined initialization arguments {unknown}")
+        return known
+
+
+CatchAll = Optional[CatchAllVar]
+
+
+class _IgnoreUndefinedParameters(_UndefinedParameterAction):
+    """
+    This action does nothing when it encounters undefined parameters.
+    The undefined parameters can not be retrieved after the class has been
+    created.
+    """
+
+    @staticmethod
+    def handle_from_dict(cls, kvs: Dict) -> Dict[str, Any]:
+        known_given_parameters, _ = \
+            _UndefinedParameterAction._separate_defined_undefined_kvs(
+                cls=cls, kvs=kvs)
+        return known_given_parameters
+
+    @staticmethod
+    def create_init(obj) -> Callable:
+        original_init = obj.__init__
+        init_signature = inspect.signature(original_init)
+
+        @functools.wraps(obj.__init__)
+        def _ignore_init(self, *args, **kwargs):
+            known_kwargs, _ = \
+                _CatchAllUndefinedParameters._separate_defined_undefined_kvs(
+                    obj, kwargs)
+            num_params_takeable = len(
+                init_signature.parameters) - 1  # don't count self
+            num_args_takeable = num_params_takeable - len(known_kwargs)
+
+            args = args[:num_args_takeable]
+            bound_parameters = init_signature.bind_partial(self, *args,
+                                                           **known_kwargs)
+            bound_parameters.apply_defaults()
+
+            arguments = bound_parameters.arguments
+            arguments.pop("self", None)
+            final_parameters = \
+                _IgnoreUndefinedParameters.handle_from_dict(obj, arguments)
+            original_init(self, **final_parameters)
+
+        return _ignore_init
+
+
+class _CatchAllUndefinedParameters(_UndefinedParameterAction):
+    """
+    This class allows to add a field of type utils.CatchAll which acts as a
+    dictionary into which all
+    undefined parameters will be written.
+    These parameters are not affected by LetterCase.
+    If no undefined parameters are given, this dictionary will be empty.
+    """
+
+    class _SentinelNoDefault:
+        pass
+
+    @staticmethod
+    def handle_from_dict(cls, kvs: Dict) -> Dict[str, Any]:
+        known, unknown = _UndefinedParameterAction \
+            ._separate_defined_undefined_kvs(cls=cls, kvs=kvs)
+        catch_all_field = _CatchAllUndefinedParameters._get_catch_all_field(
+            cls=cls)
+
+        if catch_all_field.name in known:
+
+            already_parsed = isinstance(known[catch_all_field.name], dict)
+            default_value = _CatchAllUndefinedParameters._get_default(
+                catch_all_field=catch_all_field)
+            received_default = default_value == known[catch_all_field.name]
+
+            value_to_write: Any
+            if received_default and len(unknown) == 0:
+                value_to_write = default_value
+            elif received_default and len(unknown) > 0:
+                value_to_write = unknown
+            elif already_parsed:
+                # Did not receive default
+                value_to_write = known[catch_all_field.name]
+                if len(unknown) > 0:
+                    value_to_write.update(unknown)
+            else:
+                error_message = f"Received input field with " \
+                                f"same name as catch-all field: " \
+                                f"'{catch_all_field.name}': " \
+                                f"'{known[catch_all_field.name]}'"
+                raise UndefinedParameterError(error_message)
+        else:
+            value_to_write = unknown
+
+        known[catch_all_field.name] = value_to_write
+        return known
+
+    @staticmethod
+    def _get_default(catch_all_field: Field) -> Any:
+        # access to the default factory currently causes
+        # a false-positive mypy error (16. Dec 2019):
+        # https://github.com/python/mypy/issues/6910
+
+        # noinspection PyProtectedMember
+        has_default = not isinstance(catch_all_field.default,
+                                     dataclasses._MISSING_TYPE)
+        # noinspection PyProtectedMember
+        has_default_factory = not isinstance(catch_all_field.default_factory,
+                                             # type: ignore
+                                             dataclasses._MISSING_TYPE)
+        # TODO: black this for proper formatting
+        default_value: Union[
+            Type[_CatchAllUndefinedParameters._SentinelNoDefault], Any] = _CatchAllUndefinedParameters\
+            ._SentinelNoDefault
+
+        if has_default:
+            default_value = catch_all_field.default
+        elif has_default_factory:
+            # This might be unwanted if the default factory constructs
+            # something expensive,
+            # because we have to construct it again just for this test
+            default_value = catch_all_field.default_factory()  # type: ignore
+
+        return default_value
+
+    @staticmethod
+    def handle_to_dict(obj, kvs: Dict[Any, Any]) -> Dict[Any, Any]:
+        catch_all_field = \
+            _CatchAllUndefinedParameters._get_catch_all_field(obj.__class__)
+        undefined_parameters = kvs.pop(catch_all_field.name)
+        if isinstance(undefined_parameters, dict):
+            kvs.update(
+                undefined_parameters)  # If desired handle letter case here
+        return kvs
+
+    @staticmethod
+    def handle_dump(obj) -> Dict[Any, Any]:
+        catch_all_field = _CatchAllUndefinedParameters._get_catch_all_field(
+            cls=obj)
+        return getattr(obj, catch_all_field.name)
+
+    @staticmethod
+    def create_init(obj) -> Callable:
+        original_init = obj.__init__
+        init_signature = inspect.signature(original_init)
+
+        @functools.wraps(obj.__init__)
+        def _catch_all_init(self, *args, **kwargs):
+            known_kwargs, unknown_kwargs = \
+                _CatchAllUndefinedParameters._separate_defined_undefined_kvs(
+                    obj, kwargs)
+            num_params_takeable = len(
+                init_signature.parameters) - 1  # don't count self
+            if _CatchAllUndefinedParameters._get_catch_all_field(
+                    obj).name not in known_kwargs:
+                num_params_takeable -= 1
+            num_args_takeable = num_params_takeable - len(known_kwargs)
+
+            args, unknown_args = args[:num_args_takeable], args[
+                                                           num_args_takeable:]
+            bound_parameters = init_signature.bind_partial(self, *args,
+                                                           **known_kwargs)
+
+            unknown_args = {f"_UNKNOWN{i}": v for i, v in
+                            enumerate(unknown_args)}
+            arguments = bound_parameters.arguments
+            arguments.update(unknown_args)
+            arguments.update(unknown_kwargs)
+            arguments.pop("self", None)
+            final_parameters = _CatchAllUndefinedParameters.handle_from_dict(
+                obj, arguments)
+            original_init(self, **final_parameters)
+
+        return _catch_all_init
+
+    @staticmethod
+    def _get_catch_all_field(cls) -> Field:
+        cls_globals = vars(sys.modules[cls.__module__])
+        types = get_type_hints(cls, globalns=cls_globals)
+        catch_all_fields = list(
+            filter(lambda f: types[f.name] == Optional[CatchAllVar], fields(cls)))
+        number_of_catch_all_fields = len(catch_all_fields)
+        if number_of_catch_all_fields == 0:
+            raise UndefinedParameterError(
+                "No field of type dataclasses_json.CatchAll defined")
+        elif number_of_catch_all_fields > 1:
+            raise UndefinedParameterError(
+                f"Multiple catch-all fields supplied: "
+                f"{number_of_catch_all_fields}.")
+        else:
+            return catch_all_fields[0]
+
+
+class Undefined(Enum):
+    """
+    Choose the behavior what happens when an undefined parameter is encountered
+    during class initialization.
+    """
+    INCLUDE = _CatchAllUndefinedParameters
+    RAISE = _RaiseUndefinedParameters
+    EXCLUDE = _IgnoreUndefinedParameters
+
+
+class UndefinedParameterError(ValidationError):
+    """
+    Raised when something has gone wrong handling undefined parameters.
+    """
+    pass