aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/docx/opc/rel.py
blob: 47e8860d8833bc69567fee56917266dcbe70d940 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""Relationship-related objects."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Dict, cast

from docx.opc.oxml import CT_Relationships

if TYPE_CHECKING:
    from docx.opc.part import Part


class Relationships(Dict[str, "_Relationship"]):
    """Collection object for |_Relationship| instances, having list semantics."""

    def __init__(self, baseURI: str):
        super(Relationships, self).__init__()
        self._baseURI = baseURI
        self._target_parts_by_rId: dict[str, Any] = {}

    def add_relationship(
        self, reltype: str, target: Part | str, rId: str, is_external: bool = False
    ) -> "_Relationship":
        """Return a newly added |_Relationship| instance."""
        rel = _Relationship(rId, reltype, target, self._baseURI, is_external)
        self[rId] = rel
        if not is_external:
            self._target_parts_by_rId[rId] = target
        return rel

    def get_or_add(self, reltype: str, target_part: Part) -> _Relationship:
        """Return relationship of `reltype` to `target_part`, newly added if not already
        present in collection."""
        rel = self._get_matching(reltype, target_part)
        if rel is None:
            rId = self._next_rId
            rel = self.add_relationship(reltype, target_part, rId)
        return rel

    def get_or_add_ext_rel(self, reltype: str, target_ref: str) -> str:
        """Return rId of external relationship of `reltype` to `target_ref`, newly added
        if not already present in collection."""
        rel = self._get_matching(reltype, target_ref, is_external=True)
        if rel is None:
            rId = self._next_rId
            rel = self.add_relationship(reltype, target_ref, rId, is_external=True)
        return rel.rId

    def part_with_reltype(self, reltype: str) -> Part:
        """Return target part of rel with matching `reltype`, raising |KeyError| if not
        found and |ValueError| if more than one matching relationship is found."""
        rel = self._get_rel_of_type(reltype)
        return rel.target_part

    @property
    def related_parts(self):
        """Dict mapping rIds to target parts for all the internal relationships in the
        collection."""
        return self._target_parts_by_rId

    @property
    def xml(self) -> str:
        """Serialize this relationship collection into XML suitable for storage as a
        .rels file in an OPC package."""
        rels_elm = CT_Relationships.new()
        for rel in self.values():
            rels_elm.add_rel(rel.rId, rel.reltype, rel.target_ref, rel.is_external)
        return rels_elm.xml

    def _get_matching(
        self, reltype: str, target: Part | str, is_external: bool = False
    ) -> _Relationship | None:
        """Return relationship of matching `reltype`, `target`, and `is_external` from
        collection, or None if not found."""

        def matches(rel: _Relationship, reltype: str, target: Part | str, is_external: bool):
            if rel.reltype != reltype:
                return False
            if rel.is_external != is_external:
                return False
            rel_target = rel.target_ref if rel.is_external else rel.target_part
            if rel_target != target:
                return False
            return True

        for rel in self.values():
            if matches(rel, reltype, target, is_external):
                return rel
        return None

    def _get_rel_of_type(self, reltype: str):
        """Return single relationship of type `reltype` from the collection.

        Raises |KeyError| if no matching relationship is found. Raises |ValueError| if
        more than one matching relationship is found.
        """
        matching = [rel for rel in self.values() if rel.reltype == reltype]
        if len(matching) == 0:
            tmpl = "no relationship of type '%s' in collection"
            raise KeyError(tmpl % reltype)
        if len(matching) > 1:
            tmpl = "multiple relationships of type '%s' in collection"
            raise ValueError(tmpl % reltype)
        return matching[0]

    @property
    def _next_rId(self) -> str:  # pyright: ignore[reportReturnType]
        """Next available rId in collection, starting from 'rId1' and making use of any
        gaps in numbering, e.g. 'rId2' for rIds ['rId1', 'rId3']."""
        for n in range(1, len(self) + 2):
            rId_candidate = "rId%d" % n  # like 'rId19'
            if rId_candidate not in self:
                return rId_candidate


class _Relationship:
    """Value object for relationship to part."""

    def __init__(
        self, rId: str, reltype: str, target: Part | str, baseURI: str, external: bool = False
    ):
        super(_Relationship, self).__init__()
        self._rId = rId
        self._reltype = reltype
        self._target = target
        self._baseURI = baseURI
        self._is_external = bool(external)

    @property
    def is_external(self) -> bool:
        return self._is_external

    @property
    def reltype(self) -> str:
        return self._reltype

    @property
    def rId(self) -> str:
        return self._rId

    @property
    def target_part(self) -> Part:
        if self._is_external:
            raise ValueError(
                "target_part property on _Relationship is undef" "ined when target mode is External"
            )
        return cast("Part", self._target)

    @property
    def target_ref(self) -> str:
        if self._is_external:
            return cast(str, self._target)
        else:
            target = cast("Part", self._target)
            return target.partname.relative_ref(self._baseURI)