about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/et_xmlfile/xmlfile.py
blob: 9b8ce82fe5eb98902379360205c03d0067782b61 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from __future__ import absolute_import
# Copyright (c) 2010-2015 openpyxl

"""Implements the lxml.etree.xmlfile API using the standard library xml.etree"""


from contextlib import contextmanager

from xml.etree.ElementTree import (
    Element,
    _escape_cdata,
)

from . import incremental_tree


class LxmlSyntaxError(Exception):
    pass


class _IncrementalFileWriter(object):
    """Replacement for _IncrementalFileWriter of lxml"""
    def __init__(self, output_file):
        self._element_stack = []
        self._file = output_file
        self._have_root = False
        self.global_nsmap = incremental_tree.current_global_nsmap()
        self.is_html = False

    @contextmanager
    def element(self, tag, attrib=None, nsmap=None, **_extra):
        """Create a new xml element using a context manager."""
        if nsmap and None in nsmap:
            # Normalise None prefix (lxml's default namespace prefix) -> "", as
            # required for incremental_tree
            if "" in nsmap and nsmap[""] != nsmap[None]:
                raise ValueError(
                    'Found None and "" as default nsmap prefixes with different URIs'
                )
            nsmap = nsmap.copy()
            nsmap[""] = nsmap.pop(None)

        # __enter__ part
        self._have_root = True
        if attrib is None:
            attrib = {}
        elem = Element(tag, attrib=attrib, **_extra)
        elem.text = ''
        elem.tail = ''
        if self._element_stack:
            is_root = False
            (
                nsmap_scope,
                default_ns_attr_prefix,
                uri_to_prefix,
            ) = self._element_stack[-1]
        else:
            is_root = True
            nsmap_scope = {}
            default_ns_attr_prefix = None
            uri_to_prefix = {}
        (
            tag,
            nsmap_scope,
            default_ns_attr_prefix,
            uri_to_prefix,
            next_remains_root,
        ) = incremental_tree.write_elem_start(
            self._file,
            elem,
            nsmap_scope=nsmap_scope,
            global_nsmap=self.global_nsmap,
            short_empty_elements=False,
            is_html=self.is_html,
            is_root=is_root,
            uri_to_prefix=uri_to_prefix,
            default_ns_attr_prefix=default_ns_attr_prefix,
            new_nsmap=nsmap,
        )
        self._element_stack.append(
            (
                nsmap_scope,
                default_ns_attr_prefix,
                uri_to_prefix,
            )
        )
        yield

        # __exit__ part
        self._element_stack.pop()
        self._file(f"</{tag}>")
        if elem.tail:
            self._file(_escape_cdata(elem.tail))

    def write(self, arg):
        """Write a string or subelement."""

        if isinstance(arg, str):
            # it is not allowed to write a string outside of an element
            if not self._element_stack:
                raise LxmlSyntaxError()
            self._file(_escape_cdata(arg))

        else:
            if not self._element_stack and self._have_root:
                raise LxmlSyntaxError()

            if self._element_stack:
                is_root = False
                (
                    nsmap_scope,
                    default_ns_attr_prefix,
                    uri_to_prefix,
                ) = self._element_stack[-1]
            else:
                is_root = True
                nsmap_scope = {}
                default_ns_attr_prefix = None
                uri_to_prefix = {}
            incremental_tree._serialize_ns_xml(
                self._file,
                arg,
                nsmap_scope=nsmap_scope,
                global_nsmap=self.global_nsmap,
                short_empty_elements=True,
                is_html=self.is_html,
                is_root=is_root,
                uri_to_prefix=uri_to_prefix,
                default_ns_attr_prefix=default_ns_attr_prefix,
            )

    def __enter__(self):
        pass

    def __exit__(self, type, value, traceback):
        # without root the xml document is incomplete
        if not self._have_root:
            raise LxmlSyntaxError()


class xmlfile(object):
    """Context manager that can replace lxml.etree.xmlfile."""
    def __init__(self, output_file, buffered=False, encoding="utf-8", close=False):
        self._file = output_file
        self._close = close
        self.encoding = encoding
        self.writer_cm = None

    def __enter__(self):
        self.writer_cm = incremental_tree._get_writer(self._file, encoding=self.encoding)
        writer, declared_encoding = self.writer_cm.__enter__()
        return _IncrementalFileWriter(writer)

    def __exit__(self, type, value, traceback):
        if self.writer_cm:
            self.writer_cm.__exit__(type, value, traceback)
        if self._close:
            self._file.close()