1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
|
"""A registry of :class:`Schema <marshmallow.Schema>` classes. This allows for string
lookup of schemas, which may be used with
class:`fields.Nested <marshmallow.fields.Nested>`.
.. warning::
This module is treated as private API.
Users should not need to use this module directly.
"""
# ruff: noqa: ERA001
from __future__ import annotations
import typing
from marshmallow.exceptions import RegistryError
if typing.TYPE_CHECKING:
from marshmallow import Schema
SchemaType = type[Schema]
# {
# <class_name>: <list of class objects>
# <module_path_to_class>: <list of class objects>
# }
_registry = {} # type: dict[str, list[SchemaType]]
def register(classname: str, cls: SchemaType) -> None:
"""Add a class to the registry of serializer classes. When a class is
registered, an entry for both its classname and its full, module-qualified
path are added to the registry.
Example: ::
class MyClass:
pass
register("MyClass", MyClass)
# Registry:
# {
# 'MyClass': [path.to.MyClass],
# 'path.to.MyClass': [path.to.MyClass],
# }
"""
# Module where the class is located
module = cls.__module__
# Full module path to the class
# e.g. user.schemas.UserSchema
fullpath = f"{module}.{classname}"
# If the class is already registered; need to check if the entries are
# in the same module as cls to avoid having multiple instances of the same
# class in the registry
if classname in _registry and not any(
each.__module__ == module for each in _registry[classname]
):
_registry[classname].append(cls)
elif classname not in _registry:
_registry[classname] = [cls]
# Also register the full path
if fullpath not in _registry:
_registry.setdefault(fullpath, []).append(cls)
else:
# If fullpath does exist, replace existing entry
_registry[fullpath] = [cls]
@typing.overload
def get_class(classname: str, *, all: typing.Literal[False] = ...) -> SchemaType: ...
@typing.overload
def get_class(
classname: str, *, all: typing.Literal[True] = ...
) -> list[SchemaType]: ...
def get_class(classname: str, *, all: bool = False) -> list[SchemaType] | SchemaType: # noqa: A002
"""Retrieve a class from the registry.
:raises: `marshmallow.exceptions.RegistryError` if the class cannot be found
or if there are multiple entries for the given class name.
"""
try:
classes = _registry[classname]
except KeyError as error:
raise RegistryError(
f"Class with name {classname!r} was not found. You may need "
"to import the class."
) from error
if len(classes) > 1:
if all:
return _registry[classname]
raise RegistryError(
f"Multiple classes with name {classname!r} "
"were found. Please use the full, "
"module-qualified path."
)
return _registry[classname][0]
|