about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/sqlalchemy/orm/clsregistry.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/sqlalchemy/orm/clsregistry.py')
-rw-r--r--.venv/lib/python3.12/site-packages/sqlalchemy/orm/clsregistry.py571
1 files changed, 571 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/sqlalchemy/orm/clsregistry.py b/.venv/lib/python3.12/site-packages/sqlalchemy/orm/clsregistry.py
new file mode 100644
index 00000000..70307ec7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/sqlalchemy/orm/clsregistry.py
@@ -0,0 +1,571 @@
+# orm/clsregistry.py
+# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Routines to handle the string class registry used by declarative.
+
+This system allows specification of classes and expressions used in
+:func:`_orm.relationship` using strings.
+
+"""
+
+from __future__ import annotations
+
+import re
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Dict
+from typing import Generator
+from typing import Iterable
+from typing import List
+from typing import Mapping
+from typing import MutableMapping
+from typing import NoReturn
+from typing import Optional
+from typing import Set
+from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import TypeVar
+from typing import Union
+import weakref
+
+from . import attributes
+from . import interfaces
+from .descriptor_props import SynonymProperty
+from .properties import ColumnProperty
+from .util import class_mapper
+from .. import exc
+from .. import inspection
+from .. import util
+from ..sql.schema import _get_table_key
+from ..util.typing import CallableReference
+
+if TYPE_CHECKING:
+    from .relationships import RelationshipProperty
+    from ..sql.schema import MetaData
+    from ..sql.schema import Table
+
+_T = TypeVar("_T", bound=Any)
+
+_ClsRegistryType = MutableMapping[str, Union[type, "ClsRegistryToken"]]
+
+# strong references to registries which we place in
+# the _decl_class_registry, which is usually weak referencing.
+# the internal registries here link to classes with weakrefs and remove
+# themselves when all references to contained classes are removed.
+_registries: Set[ClsRegistryToken] = set()
+
+
+def add_class(
+    classname: str, cls: Type[_T], decl_class_registry: _ClsRegistryType
+) -> None:
+    """Add a class to the _decl_class_registry associated with the
+    given declarative class.
+
+    """
+    if classname in decl_class_registry:
+        # class already exists.
+        existing = decl_class_registry[classname]
+        if not isinstance(existing, _MultipleClassMarker):
+            existing = decl_class_registry[classname] = _MultipleClassMarker(
+                [cls, cast("Type[Any]", existing)]
+            )
+    else:
+        decl_class_registry[classname] = cls
+
+    try:
+        root_module = cast(
+            _ModuleMarker, decl_class_registry["_sa_module_registry"]
+        )
+    except KeyError:
+        decl_class_registry["_sa_module_registry"] = root_module = (
+            _ModuleMarker("_sa_module_registry", None)
+        )
+
+    tokens = cls.__module__.split(".")
+
+    # build up a tree like this:
+    # modulename:  myapp.snacks.nuts
+    #
+    # myapp->snack->nuts->(classes)
+    # snack->nuts->(classes)
+    # nuts->(classes)
+    #
+    # this allows partial token paths to be used.
+    while tokens:
+        token = tokens.pop(0)
+        module = root_module.get_module(token)
+        for token in tokens:
+            module = module.get_module(token)
+
+        try:
+            module.add_class(classname, cls)
+        except AttributeError as ae:
+            if not isinstance(module, _ModuleMarker):
+                raise exc.InvalidRequestError(
+                    f'name "{classname}" matches both a '
+                    "class name and a module name"
+                ) from ae
+            else:
+                raise
+
+
+def remove_class(
+    classname: str, cls: Type[Any], decl_class_registry: _ClsRegistryType
+) -> None:
+    if classname in decl_class_registry:
+        existing = decl_class_registry[classname]
+        if isinstance(existing, _MultipleClassMarker):
+            existing.remove_item(cls)
+        else:
+            del decl_class_registry[classname]
+
+    try:
+        root_module = cast(
+            _ModuleMarker, decl_class_registry["_sa_module_registry"]
+        )
+    except KeyError:
+        return
+
+    tokens = cls.__module__.split(".")
+
+    while tokens:
+        token = tokens.pop(0)
+        module = root_module.get_module(token)
+        for token in tokens:
+            module = module.get_module(token)
+        try:
+            module.remove_class(classname, cls)
+        except AttributeError:
+            if not isinstance(module, _ModuleMarker):
+                pass
+            else:
+                raise
+
+
+def _key_is_empty(
+    key: str,
+    decl_class_registry: _ClsRegistryType,
+    test: Callable[[Any], bool],
+) -> bool:
+    """test if a key is empty of a certain object.
+
+    used for unit tests against the registry to see if garbage collection
+    is working.
+
+    "test" is a callable that will be passed an object should return True
+    if the given object is the one we were looking for.
+
+    We can't pass the actual object itself b.c. this is for testing garbage
+    collection; the caller will have to have removed references to the
+    object itself.
+
+    """
+    if key not in decl_class_registry:
+        return True
+
+    thing = decl_class_registry[key]
+    if isinstance(thing, _MultipleClassMarker):
+        for sub_thing in thing.contents:
+            if test(sub_thing):
+                return False
+        else:
+            raise NotImplementedError("unknown codepath")
+    else:
+        return not test(thing)
+
+
+class ClsRegistryToken:
+    """an object that can be in the registry._class_registry as a value."""
+
+    __slots__ = ()
+
+
+class _MultipleClassMarker(ClsRegistryToken):
+    """refers to multiple classes of the same name
+    within _decl_class_registry.
+
+    """
+
+    __slots__ = "on_remove", "contents", "__weakref__"
+
+    contents: Set[weakref.ref[Type[Any]]]
+    on_remove: CallableReference[Optional[Callable[[], None]]]
+
+    def __init__(
+        self,
+        classes: Iterable[Type[Any]],
+        on_remove: Optional[Callable[[], None]] = None,
+    ):
+        self.on_remove = on_remove
+        self.contents = {
+            weakref.ref(item, self._remove_item) for item in classes
+        }
+        _registries.add(self)
+
+    def remove_item(self, cls: Type[Any]) -> None:
+        self._remove_item(weakref.ref(cls))
+
+    def __iter__(self) -> Generator[Optional[Type[Any]], None, None]:
+        return (ref() for ref in self.contents)
+
+    def attempt_get(self, path: List[str], key: str) -> Type[Any]:
+        if len(self.contents) > 1:
+            raise exc.InvalidRequestError(
+                'Multiple classes found for path "%s" '
+                "in the registry of this declarative "
+                "base. Please use a fully module-qualified path."
+                % (".".join(path + [key]))
+            )
+        else:
+            ref = list(self.contents)[0]
+            cls = ref()
+            if cls is None:
+                raise NameError(key)
+            return cls
+
+    def _remove_item(self, ref: weakref.ref[Type[Any]]) -> None:
+        self.contents.discard(ref)
+        if not self.contents:
+            _registries.discard(self)
+            if self.on_remove:
+                self.on_remove()
+
+    def add_item(self, item: Type[Any]) -> None:
+        # protect against class registration race condition against
+        # asynchronous garbage collection calling _remove_item,
+        # [ticket:3208] and [ticket:10782]
+        modules = {
+            cls.__module__
+            for cls in [ref() for ref in list(self.contents)]
+            if cls is not None
+        }
+        if item.__module__ in modules:
+            util.warn(
+                "This declarative base already contains a class with the "
+                "same class name and module name as %s.%s, and will "
+                "be replaced in the string-lookup table."
+                % (item.__module__, item.__name__)
+            )
+        self.contents.add(weakref.ref(item, self._remove_item))
+
+
+class _ModuleMarker(ClsRegistryToken):
+    """Refers to a module name within
+    _decl_class_registry.
+
+    """
+
+    __slots__ = "parent", "name", "contents", "mod_ns", "path", "__weakref__"
+
+    parent: Optional[_ModuleMarker]
+    contents: Dict[str, Union[_ModuleMarker, _MultipleClassMarker]]
+    mod_ns: _ModNS
+    path: List[str]
+
+    def __init__(self, name: str, parent: Optional[_ModuleMarker]):
+        self.parent = parent
+        self.name = name
+        self.contents = {}
+        self.mod_ns = _ModNS(self)
+        if self.parent:
+            self.path = self.parent.path + [self.name]
+        else:
+            self.path = []
+        _registries.add(self)
+
+    def __contains__(self, name: str) -> bool:
+        return name in self.contents
+
+    def __getitem__(self, name: str) -> ClsRegistryToken:
+        return self.contents[name]
+
+    def _remove_item(self, name: str) -> None:
+        self.contents.pop(name, None)
+        if not self.contents:
+            if self.parent is not None:
+                self.parent._remove_item(self.name)
+            _registries.discard(self)
+
+    def resolve_attr(self, key: str) -> Union[_ModNS, Type[Any]]:
+        return self.mod_ns.__getattr__(key)
+
+    def get_module(self, name: str) -> _ModuleMarker:
+        if name not in self.contents:
+            marker = _ModuleMarker(name, self)
+            self.contents[name] = marker
+        else:
+            marker = cast(_ModuleMarker, self.contents[name])
+        return marker
+
+    def add_class(self, name: str, cls: Type[Any]) -> None:
+        if name in self.contents:
+            existing = cast(_MultipleClassMarker, self.contents[name])
+            try:
+                existing.add_item(cls)
+            except AttributeError as ae:
+                if not isinstance(existing, _MultipleClassMarker):
+                    raise exc.InvalidRequestError(
+                        f'name "{name}" matches both a '
+                        "class name and a module name"
+                    ) from ae
+                else:
+                    raise
+        else:
+            existing = self.contents[name] = _MultipleClassMarker(
+                [cls], on_remove=lambda: self._remove_item(name)
+            )
+
+    def remove_class(self, name: str, cls: Type[Any]) -> None:
+        if name in self.contents:
+            existing = cast(_MultipleClassMarker, self.contents[name])
+            existing.remove_item(cls)
+
+
+class _ModNS:
+    __slots__ = ("__parent",)
+
+    __parent: _ModuleMarker
+
+    def __init__(self, parent: _ModuleMarker):
+        self.__parent = parent
+
+    def __getattr__(self, key: str) -> Union[_ModNS, Type[Any]]:
+        try:
+            value = self.__parent.contents[key]
+        except KeyError:
+            pass
+        else:
+            if value is not None:
+                if isinstance(value, _ModuleMarker):
+                    return value.mod_ns
+                else:
+                    assert isinstance(value, _MultipleClassMarker)
+                    return value.attempt_get(self.__parent.path, key)
+        raise NameError(
+            "Module %r has no mapped classes "
+            "registered under the name %r" % (self.__parent.name, key)
+        )
+
+
+class _GetColumns:
+    __slots__ = ("cls",)
+
+    cls: Type[Any]
+
+    def __init__(self, cls: Type[Any]):
+        self.cls = cls
+
+    def __getattr__(self, key: str) -> Any:
+        mp = class_mapper(self.cls, configure=False)
+        if mp:
+            if key not in mp.all_orm_descriptors:
+                raise AttributeError(
+                    "Class %r does not have a mapped column named %r"
+                    % (self.cls, key)
+                )
+
+            desc = mp.all_orm_descriptors[key]
+            if desc.extension_type is interfaces.NotExtension.NOT_EXTENSION:
+                assert isinstance(desc, attributes.QueryableAttribute)
+                prop = desc.property
+                if isinstance(prop, SynonymProperty):
+                    key = prop.name
+                elif not isinstance(prop, ColumnProperty):
+                    raise exc.InvalidRequestError(
+                        "Property %r is not an instance of"
+                        " ColumnProperty (i.e. does not correspond"
+                        " directly to a Column)." % key
+                    )
+        return getattr(self.cls, key)
+
+
+inspection._inspects(_GetColumns)(
+    lambda target: inspection.inspect(target.cls)
+)
+
+
+class _GetTable:
+    __slots__ = "key", "metadata"
+
+    key: str
+    metadata: MetaData
+
+    def __init__(self, key: str, metadata: MetaData):
+        self.key = key
+        self.metadata = metadata
+
+    def __getattr__(self, key: str) -> Table:
+        return self.metadata.tables[_get_table_key(key, self.key)]
+
+
+def _determine_container(key: str, value: Any) -> _GetColumns:
+    if isinstance(value, _MultipleClassMarker):
+        value = value.attempt_get([], key)
+    return _GetColumns(value)
+
+
+class _class_resolver:
+    __slots__ = (
+        "cls",
+        "prop",
+        "arg",
+        "fallback",
+        "_dict",
+        "_resolvers",
+        "favor_tables",
+    )
+
+    cls: Type[Any]
+    prop: RelationshipProperty[Any]
+    fallback: Mapping[str, Any]
+    arg: str
+    favor_tables: bool
+    _resolvers: Tuple[Callable[[str], Any], ...]
+
+    def __init__(
+        self,
+        cls: Type[Any],
+        prop: RelationshipProperty[Any],
+        fallback: Mapping[str, Any],
+        arg: str,
+        favor_tables: bool = False,
+    ):
+        self.cls = cls
+        self.prop = prop
+        self.arg = arg
+        self.fallback = fallback
+        self._dict = util.PopulateDict(self._access_cls)
+        self._resolvers = ()
+        self.favor_tables = favor_tables
+
+    def _access_cls(self, key: str) -> Any:
+        cls = self.cls
+
+        manager = attributes.manager_of_class(cls)
+        decl_base = manager.registry
+        assert decl_base is not None
+        decl_class_registry = decl_base._class_registry
+        metadata = decl_base.metadata
+
+        if self.favor_tables:
+            if key in metadata.tables:
+                return metadata.tables[key]
+            elif key in metadata._schemas:
+                return _GetTable(key, getattr(cls, "metadata", metadata))
+
+        if key in decl_class_registry:
+            return _determine_container(key, decl_class_registry[key])
+
+        if not self.favor_tables:
+            if key in metadata.tables:
+                return metadata.tables[key]
+            elif key in metadata._schemas:
+                return _GetTable(key, getattr(cls, "metadata", metadata))
+
+        if "_sa_module_registry" in decl_class_registry and key in cast(
+            _ModuleMarker, decl_class_registry["_sa_module_registry"]
+        ):
+            registry = cast(
+                _ModuleMarker, decl_class_registry["_sa_module_registry"]
+            )
+            return registry.resolve_attr(key)
+        elif self._resolvers:
+            for resolv in self._resolvers:
+                value = resolv(key)
+                if value is not None:
+                    return value
+
+        return self.fallback[key]
+
+    def _raise_for_name(self, name: str, err: Exception) -> NoReturn:
+        generic_match = re.match(r"(.+)\[(.+)\]", name)
+
+        if generic_match:
+            clsarg = generic_match.group(2).strip("'")
+            raise exc.InvalidRequestError(
+                f"When initializing mapper {self.prop.parent}, "
+                f'expression "relationship({self.arg!r})" seems to be '
+                "using a generic class as the argument to relationship(); "
+                "please state the generic argument "
+                "using an annotation, e.g. "
+                f'"{self.prop.key}: Mapped[{generic_match.group(1)}'
+                f"['{clsarg}']] = relationship()\""
+            ) from err
+        else:
+            raise exc.InvalidRequestError(
+                "When initializing mapper %s, expression %r failed to "
+                "locate a name (%r). If this is a class name, consider "
+                "adding this relationship() to the %r class after "
+                "both dependent classes have been defined."
+                % (self.prop.parent, self.arg, name, self.cls)
+            ) from err
+
+    def _resolve_name(self) -> Union[Table, Type[Any], _ModNS]:
+        name = self.arg
+        d = self._dict
+        rval = None
+        try:
+            for token in name.split("."):
+                if rval is None:
+                    rval = d[token]
+                else:
+                    rval = getattr(rval, token)
+        except KeyError as err:
+            self._raise_for_name(name, err)
+        except NameError as n:
+            self._raise_for_name(n.args[0], n)
+        else:
+            if isinstance(rval, _GetColumns):
+                return rval.cls
+            else:
+                if TYPE_CHECKING:
+                    assert isinstance(rval, (type, Table, _ModNS))
+                return rval
+
+    def __call__(self) -> Any:
+        try:
+            x = eval(self.arg, globals(), self._dict)
+
+            if isinstance(x, _GetColumns):
+                return x.cls
+            else:
+                return x
+        except NameError as n:
+            self._raise_for_name(n.args[0], n)
+
+
+_fallback_dict: Mapping[str, Any] = None  # type: ignore
+
+
+def _resolver(cls: Type[Any], prop: RelationshipProperty[Any]) -> Tuple[
+    Callable[[str], Callable[[], Union[Type[Any], Table, _ModNS]]],
+    Callable[[str, bool], _class_resolver],
+]:
+    global _fallback_dict
+
+    if _fallback_dict is None:
+        import sqlalchemy
+        from . import foreign
+        from . import remote
+
+        _fallback_dict = util.immutabledict(sqlalchemy.__dict__).union(
+            {"foreign": foreign, "remote": remote}
+        )
+
+    def resolve_arg(arg: str, favor_tables: bool = False) -> _class_resolver:
+        return _class_resolver(
+            cls, prop, _fallback_dict, arg, favor_tables=favor_tables
+        )
+
+    def resolve_name(
+        arg: str,
+    ) -> Callable[[], Union[Type[Any], Table, _ModNS]]:
+        return _class_resolver(cls, prop, _fallback_dict, arg)._resolve_name
+
+    return resolve_name, resolve_arg