about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/dataclasses_json/mm.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/dataclasses_json/mm.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/dataclasses_json/mm.py')
-rw-r--r--.venv/lib/python3.12/site-packages/dataclasses_json/mm.py399
1 files changed, 399 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/dataclasses_json/mm.py b/.venv/lib/python3.12/site-packages/dataclasses_json/mm.py
new file mode 100644
index 00000000..9cfacf1d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/dataclasses_json/mm.py
@@ -0,0 +1,399 @@
+# flake8: noqa
+
+import typing
+import warnings
+import sys
+from copy import deepcopy
+
+from dataclasses import MISSING, is_dataclass, fields as dc_fields
+from datetime import datetime
+from decimal import Decimal
+from uuid import UUID
+from enum import Enum
+
+from typing_inspect import is_union_type  # type: ignore
+
+from marshmallow import fields, Schema, post_load  # type: ignore
+from marshmallow.exceptions import ValidationError  # type: ignore
+
+from dataclasses_json.core import (_is_supported_generic, _decode_dataclass,
+                                   _ExtendedEncoder, _user_overrides_or_exts)
+from dataclasses_json.utils import (_is_collection, _is_optional,
+                                    _issubclass_safe, _timestamp_to_dt_aware,
+                                    _is_new_type, _get_type_origin,
+                                    _handle_undefined_parameters_safe,
+                                    CatchAllVar)
+
+
+class _TimestampField(fields.Field):
+    def _serialize(self, value, attr, obj, **kwargs):
+        if value is not None:
+            return value.timestamp()
+        else:
+            if not self.required:
+                return None
+            else:
+                raise ValidationError(self.default_error_messages["required"])
+
+    def _deserialize(self, value, attr, data, **kwargs):
+        if value is not None:
+            return _timestamp_to_dt_aware(value)
+        else:
+            if not self.required:
+                return None
+            else:
+                raise ValidationError(self.default_error_messages["required"])
+
+
+class _IsoField(fields.Field):
+    def _serialize(self, value, attr, obj, **kwargs):
+        if value is not None:
+            return value.isoformat()
+        else:
+            if not self.required:
+                return None
+            else:
+                raise ValidationError(self.default_error_messages["required"])
+
+    def _deserialize(self, value, attr, data, **kwargs):
+        if value is not None:
+            return datetime.fromisoformat(value)
+        else:
+            if not self.required:
+                return None
+            else:
+                raise ValidationError(self.default_error_messages["required"])
+
+
+class _UnionField(fields.Field):
+    def __init__(self, desc, cls, field, *args, **kwargs):
+        self.desc = desc
+        self.cls = cls
+        self.field = field
+        super().__init__(*args, **kwargs)
+
+    def _serialize(self, value, attr, obj, **kwargs):
+        if self.allow_none and value is None:
+            return None
+        for type_, schema_ in self.desc.items():
+            if _issubclass_safe(type(value), type_):
+                if is_dataclass(value):
+                    res = schema_._serialize(value, attr, obj, **kwargs)
+                    res['__type'] = str(type_.__name__)
+                    return res
+                break
+            elif isinstance(value, _get_type_origin(type_)):
+                return schema_._serialize(value, attr, obj, **kwargs)
+        else:
+            warnings.warn(
+                f'The type "{type(value).__name__}" (value: "{value}") '
+                f'is not in the list of possible types of typing.Union '
+                f'(dataclass: {self.cls.__name__}, field: {self.field.name}). '
+                f'Value cannot be serialized properly.')
+        return super()._serialize(value, attr, obj, **kwargs)
+
+    def _deserialize(self, value, attr, data, **kwargs):
+        tmp_value = deepcopy(value)
+        if isinstance(tmp_value, dict) and '__type' in tmp_value:
+            dc_name = tmp_value['__type']
+            for type_, schema_ in self.desc.items():
+                if is_dataclass(type_) and type_.__name__ == dc_name:
+                    del tmp_value['__type']
+                    return schema_._deserialize(tmp_value, attr, data, **kwargs)
+        elif isinstance(tmp_value, dict):
+            warnings.warn(
+                f'Attempting to deserialize "dict" (value: "{tmp_value}) '
+                f'that does not have a "__type" type specifier field into'
+                f'(dataclass: {self.cls.__name__}, field: {self.field.name}).'
+                f'Deserialization may fail, or deserialization to wrong type may occur.'
+            )
+            return super()._deserialize(tmp_value, attr, data, **kwargs)
+        else:
+            for type_, schema_ in self.desc.items():
+                if isinstance(tmp_value, _get_type_origin(type_)):
+                    return schema_._deserialize(tmp_value, attr, data, **kwargs)
+            else:
+                warnings.warn(
+                    f'The type "{type(tmp_value).__name__}" (value: "{tmp_value}") '
+                    f'is not in the list of possible types of typing.Union '
+                    f'(dataclass: {self.cls.__name__}, field: {self.field.name}). '
+                    f'Value cannot be deserialized properly.')
+            return super()._deserialize(tmp_value, attr, data, **kwargs)
+
+
+class _TupleVarLen(fields.List):
+    """
+    variable-length homogeneous tuples
+    """
+    def _deserialize(self, value, attr, data, **kwargs):
+        optional_list = super()._deserialize(value, attr, data, **kwargs)
+        return None if optional_list is None else tuple(optional_list)
+
+
+TYPES = {
+    typing.Mapping: fields.Mapping,
+    typing.MutableMapping: fields.Mapping,
+    typing.List: fields.List,
+    typing.Dict: fields.Dict,
+    typing.Tuple: fields.Tuple,
+    typing.Callable: fields.Function,
+    typing.Any: fields.Raw,
+    dict: fields.Dict,
+    list: fields.List,
+    tuple: fields.Tuple,
+    str: fields.Str,
+    int: fields.Int,
+    float: fields.Float,
+    bool: fields.Bool,
+    datetime: _TimestampField,
+    UUID: fields.UUID,
+    Decimal: fields.Decimal,
+    CatchAllVar: fields.Dict,
+}
+
+A = typing.TypeVar('A')
+JsonData = typing.Union[str, bytes, bytearray]
+TEncoded = typing.Dict[str, typing.Any]
+TOneOrMulti = typing.Union[typing.List[A], A]
+TOneOrMultiEncoded = typing.Union[typing.List[TEncoded], TEncoded]
+
+if sys.version_info >= (3, 7) or typing.TYPE_CHECKING:
+    class SchemaF(Schema, typing.Generic[A]):
+        """Lift Schema into a type constructor"""
+
+        def __init__(self, *args, **kwargs):
+            """
+            Raises exception because this class should not be inherited.
+            This class is helper only.
+            """
+
+            super().__init__(*args, **kwargs)
+            raise NotImplementedError()
+
+        @typing.overload
+        def dump(self, obj: typing.List[A], many: typing.Optional[bool] = None) -> typing.List[TEncoded]:  # type: ignore
+            # mm has the wrong return type annotation (dict) so we can ignore the mypy error
+            pass
+
+        @typing.overload
+        def dump(self, obj: A, many: typing.Optional[bool] = None) -> TEncoded:
+            pass
+
+        def dump(self, obj: TOneOrMulti,    # type: ignore
+                 many: typing.Optional[bool] = None) -> TOneOrMultiEncoded:
+            pass
+
+        @typing.overload
+        def dumps(self, obj: typing.List[A], many: typing.Optional[bool] = None, *args,
+                  **kwargs) -> str:
+            pass
+
+        @typing.overload
+        def dumps(self, obj: A, many: typing.Optional[bool] = None, *args, **kwargs) -> str:
+            pass
+
+        def dumps(self, obj: TOneOrMulti, many: typing.Optional[bool] = None, *args,   # type: ignore
+                  **kwargs) -> str:
+            pass
+
+        @typing.overload  # type: ignore
+        def load(self, data: typing.List[TEncoded],
+                 many: bool = True, partial: typing.Optional[bool] = None,
+                 unknown: typing.Optional[str] = None) -> \
+                typing.List[A]:
+            # ignore the mypy error of the decorator because mm does not define lists as an allowed input type
+            pass
+
+        @typing.overload
+        def load(self, data: TEncoded,
+                 many: None = None, partial: typing.Optional[bool] = None,
+                 unknown: typing.Optional[str] = None) -> A:
+            pass
+
+        def load(self, data: TOneOrMultiEncoded,
+                 many: typing.Optional[bool] = None, partial: typing.Optional[bool] = None,
+                 unknown: typing.Optional[str] = None) -> TOneOrMulti:
+            pass
+
+        @typing.overload  # type: ignore
+        def loads(self, json_data: JsonData,  # type: ignore
+                  many: typing.Optional[bool] = True, partial: typing.Optional[bool] = None, unknown: typing.Optional[str] = None,
+                  **kwargs) -> typing.List[A]:
+            # ignore the mypy error of the decorator because mm does not define bytes as correct input data
+            # mm has the wrong return type annotation (dict) so we can ignore the mypy error
+            # for the return type overlap
+            pass
+
+        def loads(self, json_data: JsonData,
+                  many: typing.Optional[bool] = None, partial: typing.Optional[bool] = None, unknown: typing.Optional[str] = None,
+                  **kwargs) -> TOneOrMulti:
+            pass
+
+
+    SchemaType = SchemaF[A]
+else:
+    SchemaType = Schema
+
+
+def build_type(type_, options, mixin, field, cls):
+    def inner(type_, options):
+        while True:
+            if not _is_new_type(type_):
+                break
+
+            type_ = type_.__supertype__
+
+        if is_dataclass(type_):
+            if _issubclass_safe(type_, mixin):
+                options['field_many'] = bool(
+                    _is_supported_generic(field.type) and _is_collection(
+                        field.type))
+                return fields.Nested(type_.schema(), **options)
+            else:
+                warnings.warn(f"Nested dataclass field {field.name} of type "
+                              f"{field.type} detected in "
+                              f"{cls.__name__} that is not an instance of "
+                              f"dataclass_json. Did you mean to recursively "
+                              f"serialize this field? If so, make sure to "
+                              f"augment {type_} with either the "
+                              f"`dataclass_json` decorator or mixin.")
+                return fields.Field(**options)
+
+        origin = getattr(type_, '__origin__', type_)
+        args = [inner(a, {}) for a in getattr(type_, '__args__', []) if
+                a is not type(None)]
+
+        if type_ == Ellipsis:
+            return type_
+
+        if _is_optional(type_):
+            options["allow_none"] = True
+        if origin is tuple:
+            if len(args) == 2 and args[1] == Ellipsis:
+                return _TupleVarLen(args[0], **options)
+            else:
+                return fields.Tuple(args, **options)
+        if origin in TYPES:
+            return TYPES[origin](*args, **options)
+
+        if _issubclass_safe(origin, Enum):
+            return fields.Enum(enum=origin, by_value=True, *args, **options)
+
+        if is_union_type(type_):
+            union_types = [a for a in getattr(type_, '__args__', []) if
+                           a is not type(None)]
+            union_desc = dict(zip(union_types, args))
+            return _UnionField(union_desc, cls, field, **options)
+
+        warnings.warn(
+            f"Unknown type {type_} at {cls.__name__}.{field.name}: {field.type} "
+            f"It's advised to pass the correct marshmallow type to `mm_field`.")
+        return fields.Field(**options)
+
+    return inner(type_, options)
+
+
+def schema(cls, mixin, infer_missing):
+    schema = {}
+    overrides = _user_overrides_or_exts(cls)
+    # TODO check the undefined parameters and add the proper schema action
+    #  https://marshmallow.readthedocs.io/en/stable/quickstart.html
+    for field in dc_fields(cls):
+        metadata = overrides[field.name]
+        if metadata.mm_field is not None:
+            schema[field.name] = metadata.mm_field
+        else:
+            type_ = field.type
+            options: typing.Dict[str, typing.Any] = {}
+            missing_key = 'missing' if infer_missing else 'default'
+            if field.default is not MISSING:
+                options[missing_key] = field.default
+            elif field.default_factory is not MISSING:
+                options[missing_key] = field.default_factory()
+            else:
+                options['required'] = True
+
+            if options.get(missing_key, ...) is None:
+                options['allow_none'] = True
+
+            if _is_optional(type_):
+                options.setdefault(missing_key, None)
+                options['allow_none'] = True
+                if len(type_.__args__) == 2:
+                    # Union[str, int, None] is optional too, but it has more than 1 typed field.
+                    type_ = [tp for tp in type_.__args__ if tp is not type(None)][0]
+
+            if metadata.letter_case is not None:
+                options['data_key'] = metadata.letter_case(field.name)
+
+            t = build_type(type_, options, mixin, field, cls)
+            if field.metadata.get('dataclasses_json', {}).get('decoder'):
+                # If the field defines a custom decoder, it should completely replace the Marshmallow field's conversion
+                # logic.
+                # From Marshmallow's documentation for the _deserialize method:
+                # "Deserialize value. Concrete :class:`Field` classes should implement this method. "
+                # This is the method that Field implementations override to perform the actual deserialization logic.
+                # In this case we specifically override this method instead of `deserialize` to minimize potential
+                # side effects, and only cancel the actual value deserialization.
+                t._deserialize = lambda v, *_a, **_kw: v
+
+            # if type(t) is not fields.Field:  # If we use `isinstance` we would return nothing.
+            if field.type != typing.Optional[CatchAllVar]:
+                schema[field.name] = t
+
+    return schema
+
+
+def build_schema(cls: typing.Type[A],
+                 mixin,
+                 infer_missing,
+                 partial) -> typing.Type["SchemaType[A]"]:
+    Meta = type('Meta',
+                (),
+                {'fields': tuple(field.name for field in dc_fields(cls)  # type: ignore
+                                 if
+                                 field.name != 'dataclass_json_config' and field.type !=
+                                 typing.Optional[CatchAllVar]),
+                 # TODO #180
+                 # 'render_module': global_config.json_module
+                 })
+
+    @post_load
+    def make_instance(self, kvs, **kwargs):
+        return _decode_dataclass(cls, kvs, partial)
+
+    def dumps(self, *args, **kwargs):
+        if 'cls' not in kwargs:
+            kwargs['cls'] = _ExtendedEncoder
+
+        return Schema.dumps(self, *args, **kwargs)
+
+    def dump(self, obj, *, many=None):
+        many = self.many if many is None else bool(many)
+        dumped = Schema.dump(self, obj, many=many)
+        # TODO This is hacky, but the other option I can think of is to generate a different schema
+        #  depending on dump and load, which is even more hacky
+
+        # The only problem is the catch-all field, we can't statically create a schema for it,
+        # so we just update the dumped dict
+        if many:
+            for i, _obj in enumerate(obj):
+                dumped[i].update(
+                    _handle_undefined_parameters_safe(cls=_obj, kvs={},
+                                                      usage="dump"))
+        else:
+            dumped.update(_handle_undefined_parameters_safe(cls=obj, kvs={},
+                                                            usage="dump"))
+        return dumped
+
+    schema_ = schema(cls, mixin, infer_missing)
+    DataClassSchema: typing.Type["SchemaType[A]"] = type(
+        f'{cls.__name__.capitalize()}Schema',
+        (Schema,),
+        {'Meta': Meta,
+         f'make_{cls.__name__.lower()}': make_instance,
+         'dumps': dumps,
+         'dump': dump,
+         **schema_})
+
+    return DataClassSchema