diff options
| author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
|---|---|---|
| committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
| commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
| tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/dataclasses_json/undefined.py | |
| parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
| download | gn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz | |
Diffstat (limited to '.venv/lib/python3.12/site-packages/dataclasses_json/undefined.py')
| -rw-r--r-- | .venv/lib/python3.12/site-packages/dataclasses_json/undefined.py | 280 |
1 files changed, 280 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/dataclasses_json/undefined.py b/.venv/lib/python3.12/site-packages/dataclasses_json/undefined.py new file mode 100644 index 00000000..cb8b2cfc --- /dev/null +++ b/.venv/lib/python3.12/site-packages/dataclasses_json/undefined.py @@ -0,0 +1,280 @@ +import abc +import dataclasses +import functools +import inspect +import sys +from dataclasses import Field, fields +from typing import Any, Callable, Dict, Optional, Tuple, Union, Type, get_type_hints +from enum import Enum + +from marshmallow.exceptions import ValidationError # type: ignore + +from dataclasses_json.utils import CatchAllVar + +KnownParameters = Dict[str, Any] +UnknownParameters = Dict[str, Any] + + +class _UndefinedParameterAction(abc.ABC): + @staticmethod + @abc.abstractmethod + def handle_from_dict(cls, kvs: Dict[Any, Any]) -> Dict[str, Any]: + """ + Return the parameters to initialize the class with. + """ + pass + + @staticmethod + def handle_to_dict(obj, kvs: Dict[Any, Any]) -> Dict[Any, Any]: + """ + Return the parameters that will be written to the output dict + """ + return kvs + + @staticmethod + def handle_dump(obj) -> Dict[Any, Any]: + """ + Return the parameters that will be added to the schema dump. + """ + return {} + + @staticmethod + def create_init(obj) -> Callable: + return obj.__init__ + + @staticmethod + def _separate_defined_undefined_kvs(cls, kvs: Dict) -> \ + Tuple[KnownParameters, UnknownParameters]: + """ + Returns a 2 dictionaries: defined and undefined parameters + """ + class_fields = fields(cls) + field_names = [field.name for field in class_fields] + unknown_given_parameters = {k: v for k, v in kvs.items() if + k not in field_names} + known_given_parameters = {k: v for k, v in kvs.items() if + k in field_names} + return known_given_parameters, unknown_given_parameters + + +class _RaiseUndefinedParameters(_UndefinedParameterAction): + """ + This action raises UndefinedParameterError if it encounters an undefined + parameter during initialization. + """ + + @staticmethod + def handle_from_dict(cls, kvs: Dict) -> Dict[str, Any]: + known, unknown = \ + _UndefinedParameterAction._separate_defined_undefined_kvs( + cls=cls, kvs=kvs) + if len(unknown) > 0: + raise UndefinedParameterError( + f"Received undefined initialization arguments {unknown}") + return known + + +CatchAll = Optional[CatchAllVar] + + +class _IgnoreUndefinedParameters(_UndefinedParameterAction): + """ + This action does nothing when it encounters undefined parameters. + The undefined parameters can not be retrieved after the class has been + created. + """ + + @staticmethod + def handle_from_dict(cls, kvs: Dict) -> Dict[str, Any]: + known_given_parameters, _ = \ + _UndefinedParameterAction._separate_defined_undefined_kvs( + cls=cls, kvs=kvs) + return known_given_parameters + + @staticmethod + def create_init(obj) -> Callable: + original_init = obj.__init__ + init_signature = inspect.signature(original_init) + + @functools.wraps(obj.__init__) + def _ignore_init(self, *args, **kwargs): + known_kwargs, _ = \ + _CatchAllUndefinedParameters._separate_defined_undefined_kvs( + obj, kwargs) + num_params_takeable = len( + init_signature.parameters) - 1 # don't count self + num_args_takeable = num_params_takeable - len(known_kwargs) + + args = args[:num_args_takeable] + bound_parameters = init_signature.bind_partial(self, *args, + **known_kwargs) + bound_parameters.apply_defaults() + + arguments = bound_parameters.arguments + arguments.pop("self", None) + final_parameters = \ + _IgnoreUndefinedParameters.handle_from_dict(obj, arguments) + original_init(self, **final_parameters) + + return _ignore_init + + +class _CatchAllUndefinedParameters(_UndefinedParameterAction): + """ + This class allows to add a field of type utils.CatchAll which acts as a + dictionary into which all + undefined parameters will be written. + These parameters are not affected by LetterCase. + If no undefined parameters are given, this dictionary will be empty. + """ + + class _SentinelNoDefault: + pass + + @staticmethod + def handle_from_dict(cls, kvs: Dict) -> Dict[str, Any]: + known, unknown = _UndefinedParameterAction \ + ._separate_defined_undefined_kvs(cls=cls, kvs=kvs) + catch_all_field = _CatchAllUndefinedParameters._get_catch_all_field( + cls=cls) + + if catch_all_field.name in known: + + already_parsed = isinstance(known[catch_all_field.name], dict) + default_value = _CatchAllUndefinedParameters._get_default( + catch_all_field=catch_all_field) + received_default = default_value == known[catch_all_field.name] + + value_to_write: Any + if received_default and len(unknown) == 0: + value_to_write = default_value + elif received_default and len(unknown) > 0: + value_to_write = unknown + elif already_parsed: + # Did not receive default + value_to_write = known[catch_all_field.name] + if len(unknown) > 0: + value_to_write.update(unknown) + else: + error_message = f"Received input field with " \ + f"same name as catch-all field: " \ + f"'{catch_all_field.name}': " \ + f"'{known[catch_all_field.name]}'" + raise UndefinedParameterError(error_message) + else: + value_to_write = unknown + + known[catch_all_field.name] = value_to_write + return known + + @staticmethod + def _get_default(catch_all_field: Field) -> Any: + # access to the default factory currently causes + # a false-positive mypy error (16. Dec 2019): + # https://github.com/python/mypy/issues/6910 + + # noinspection PyProtectedMember + has_default = not isinstance(catch_all_field.default, + dataclasses._MISSING_TYPE) + # noinspection PyProtectedMember + has_default_factory = not isinstance(catch_all_field.default_factory, + # type: ignore + dataclasses._MISSING_TYPE) + # TODO: black this for proper formatting + default_value: Union[ + Type[_CatchAllUndefinedParameters._SentinelNoDefault], Any] = _CatchAllUndefinedParameters\ + ._SentinelNoDefault + + if has_default: + default_value = catch_all_field.default + elif has_default_factory: + # This might be unwanted if the default factory constructs + # something expensive, + # because we have to construct it again just for this test + default_value = catch_all_field.default_factory() # type: ignore + + return default_value + + @staticmethod + def handle_to_dict(obj, kvs: Dict[Any, Any]) -> Dict[Any, Any]: + catch_all_field = \ + _CatchAllUndefinedParameters._get_catch_all_field(obj.__class__) + undefined_parameters = kvs.pop(catch_all_field.name) + if isinstance(undefined_parameters, dict): + kvs.update( + undefined_parameters) # If desired handle letter case here + return kvs + + @staticmethod + def handle_dump(obj) -> Dict[Any, Any]: + catch_all_field = _CatchAllUndefinedParameters._get_catch_all_field( + cls=obj) + return getattr(obj, catch_all_field.name) + + @staticmethod + def create_init(obj) -> Callable: + original_init = obj.__init__ + init_signature = inspect.signature(original_init) + + @functools.wraps(obj.__init__) + def _catch_all_init(self, *args, **kwargs): + known_kwargs, unknown_kwargs = \ + _CatchAllUndefinedParameters._separate_defined_undefined_kvs( + obj, kwargs) + num_params_takeable = len( + init_signature.parameters) - 1 # don't count self + if _CatchAllUndefinedParameters._get_catch_all_field( + obj).name not in known_kwargs: + num_params_takeable -= 1 + num_args_takeable = num_params_takeable - len(known_kwargs) + + args, unknown_args = args[:num_args_takeable], args[ + num_args_takeable:] + bound_parameters = init_signature.bind_partial(self, *args, + **known_kwargs) + + unknown_args = {f"_UNKNOWN{i}": v for i, v in + enumerate(unknown_args)} + arguments = bound_parameters.arguments + arguments.update(unknown_args) + arguments.update(unknown_kwargs) + arguments.pop("self", None) + final_parameters = _CatchAllUndefinedParameters.handle_from_dict( + obj, arguments) + original_init(self, **final_parameters) + + return _catch_all_init + + @staticmethod + def _get_catch_all_field(cls) -> Field: + cls_globals = vars(sys.modules[cls.__module__]) + types = get_type_hints(cls, globalns=cls_globals) + catch_all_fields = list( + filter(lambda f: types[f.name] == Optional[CatchAllVar], fields(cls))) + number_of_catch_all_fields = len(catch_all_fields) + if number_of_catch_all_fields == 0: + raise UndefinedParameterError( + "No field of type dataclasses_json.CatchAll defined") + elif number_of_catch_all_fields > 1: + raise UndefinedParameterError( + f"Multiple catch-all fields supplied: " + f"{number_of_catch_all_fields}.") + else: + return catch_all_fields[0] + + +class Undefined(Enum): + """ + Choose the behavior what happens when an undefined parameter is encountered + during class initialization. + """ + INCLUDE = _CatchAllUndefinedParameters + RAISE = _RaiseUndefinedParameters + EXCLUDE = _IgnoreUndefinedParameters + + +class UndefinedParameterError(ValidationError): + """ + Raised when something has gone wrong handling undefined parameters. + """ + pass |
