# pyright: reportImportCycles=false """Enabling declarative definition of lxml custom element classes.""" from __future__ import annotations import re from typing import ( TYPE_CHECKING, Any, Callable, Dict, List, Sequence, Tuple, Type, TypeVar, ) from lxml import etree from lxml.etree import ElementBase, _Element # pyright: ignore[reportPrivateUsage] from docx.oxml.exceptions import InvalidXmlError from docx.oxml.ns import NamespacePrefixedTag, nsmap, qn from docx.shared import lazyproperty if TYPE_CHECKING: from docx.enum.base import BaseXmlEnum from docx.oxml.simpletypes import BaseSimpleType def serialize_for_reading(element: ElementBase): """Serialize `element` to human-readable XML suitable for tests. No XML declaration. """ xml = etree.tostring(element, encoding="unicode", pretty_print=True) return XmlString(xml) class XmlString(str): """Provides string comparison override suitable for serialized XML that is useful for tests.""" # ' <w:xyz xmlns:a="http://ns/decl/a" attr_name="val">text</w:xyz>' # | | || | # +----------+------------------------------------------++-----------+ # front attrs | text # close _xml_elm_line_patt = re.compile(r"( *</?[\w:]+)(.*?)(/?>)([^<]*</[\w:]+>)?$") def __eq__(self, other: object) -> bool: if not isinstance(other, str): return False lines = self.splitlines() lines_other = other.splitlines() if len(lines) != len(lines_other): return False for line, line_other in zip(lines, lines_other): if not self._eq_elm_strs(line, line_other): return False return True def __ne__(self, other: object) -> bool: return not self.__eq__(other) def _attr_seq(self, attrs: str) -> List[str]: """Return a sequence of attribute strings parsed from `attrs`. Each attribute string is stripped of whitespace on both ends. """ attrs = attrs.strip() attr_lst = attrs.split() return sorted(attr_lst) def _eq_elm_strs(self, line: str, line_2: str): """Return True if the element in `line_2` is XML equivalent to the element in `line`.""" front, attrs, close, text = self._parse_line(line) front_2, attrs_2, close_2, text_2 = self._parse_line(line_2) if front != front_2: return False if self._attr_seq(attrs) != self._attr_seq(attrs_2): return False if close != close_2: return False if text != text_2: return False return True @classmethod def _parse_line(cls, line: str) -> Tuple[str, str, str, str]: """(front, attrs, close, text) 4-tuple result of parsing XML element `line`.""" match = cls._xml_elm_line_patt.match(line) if match is None: return "", "", "", "" front, attrs, close, text = [match.group(n) for n in range(1, 5)] return front, attrs, close, text _T = TypeVar("_T") class MetaOxmlElement(type): """Metaclass for BaseOxmlElement.""" def __init__(cls, clsname: str, bases: Tuple[type, ...], namespace: Dict[str, Any]): dispatchable = ( OneAndOnlyOne, OneOrMore, OptionalAttribute, RequiredAttribute, ZeroOrMore, ZeroOrOne, ZeroOrOneChoice, ) for key, value in namespace.items(): if isinstance(value, dispatchable): value.populate_class_members(cls, key) class BaseAttribute: """Base class for OptionalAttribute and RequiredAttribute. Provides common methods. """ def __init__(self, attr_name: str, simple_type: Type[BaseXmlEnum] | Type[BaseSimpleType]): super(BaseAttribute, self).__init__() self._attr_name = attr_name self._simple_type = simple_type def populate_class_members(self, element_cls: MetaOxmlElement, prop_name: str) -> None: """Add the appropriate methods to `element_cls`.""" self._element_cls = element_cls self._prop_name = prop_name self._add_attr_property() def _add_attr_property(self): """Add a read/write `.{prop_name}` property to the element class. The property returns the interpreted value of this attribute on access and changes the attribute value to its ST_* counterpart on assignment. """ property_ = property(self._getter, self._setter, None) # -- assign unconditionally to overwrite element name definition -- setattr(self._element_cls, self._prop_name, property_) @property def _clark_name(self): if ":" in self._attr_name: return qn(self._attr_name) return self._attr_name @property def _getter(self) -> Callable[[BaseOxmlElement], Any | None]: ... @property def _setter( self, ) -> Callable[[BaseOxmlElement, Any | None], None]: ... class OptionalAttribute(BaseAttribute): """Defines an optional attribute on a custom element class. An optional attribute returns a default value when not present for reading. When assigned |None|, the attribute is removed, but still returns the default value when one is specified. """ def __init__( self, attr_name: str, simple_type: Type[BaseXmlEnum] | Type[BaseSimpleType], default: BaseXmlEnum | BaseSimpleType | str | bool | None = None, ): super(OptionalAttribute, self).__init__(attr_name, simple_type) self._default = default @property def _docstring(self): """String to use as `__doc__` attribute of attribute property.""" return ( f"{self._simple_type.__name__} type-converted value of" f" ``{self._attr_name}`` attribute, or |None| (or specified default" f" value) if not present. Assigning the default value causes the" f" attribute to be removed from the element." ) @property def _getter( self, ) -> Callable[[BaseOxmlElement], Any | None]: """Function suitable for `__get__()` method on attribute property descriptor.""" def get_attr_value( obj: BaseOxmlElement, ) -> Any | None: attr_str_value = obj.get(self._clark_name) if attr_str_value is None: return self._default return self._simple_type.from_xml(attr_str_value) get_attr_value.__doc__ = self._docstring return get_attr_value @property def _setter(self) -> Callable[[BaseOxmlElement, Any], None]: """Function suitable for `__set__()` method on attribute property descriptor.""" def set_attr_value(obj: BaseOxmlElement, value: Any | None): if value is None or value == self._default: if self._clark_name in obj.attrib: del obj.attrib[self._clark_name] return str_value = self._simple_type.to_xml(value) if str_value is None: if self._clark_name in obj.attrib: del obj.attrib[self._clark_name] return obj.set(self._clark_name, str_value) return set_attr_value class RequiredAttribute(BaseAttribute): """Defines a required attribute on a custom element class. A required attribute is assumed to be present for reading, so does not have a default value; its actual value is always used. If missing on read, an |InvalidXmlError| is raised. It also does not remove the attribute if |None| is assigned. Assigning |None| raises |TypeError| or |ValueError|, depending on the simple type of the attribute. """ @property def _docstring(self): """Return the string to use as the ``__doc__`` attribute of the property for this attribute.""" return "%s type-converted value of ``%s`` attribute." % ( self._simple_type.__name__, self._attr_name, ) @property def _getter(self) -> Callable[[BaseOxmlElement], Any]: """function object suitable for "get" side of attr property descriptor.""" def get_attr_value(obj: BaseOxmlElement) -> Any | None: attr_str_value = obj.get(self._clark_name) if attr_str_value is None: raise InvalidXmlError( "required '%s' attribute not present on element %s" % (self._attr_name, obj.tag) ) return self._simple_type.from_xml(attr_str_value) get_attr_value.__doc__ = self._docstring return get_attr_value @property def _setter(self) -> Callable[[BaseOxmlElement, Any], None]: """function object suitable for "set" side of attribute property descriptor.""" def set_attr_value(obj: BaseOxmlElement, value: Any): str_value = self._simple_type.to_xml(value) if str_value is None: raise ValueError(f"cannot assign {value} to this required attribute") obj.set(self._clark_name, str_value) return set_attr_value class _BaseChildElement: """Base class for the child-element classes. The child-element sub-classes correspond to varying cardinalities, such as ZeroOrOne and ZeroOrMore. """ def __init__(self, nsptagname: str, successors: Tuple[str, ...] = ()): super(_BaseChildElement, self).__init__() self._nsptagname = nsptagname self._successors = successors def populate_class_members(self, element_cls: MetaOxmlElement, prop_name: str) -> None: """Baseline behavior for adding the appropriate methods to `element_cls`.""" self._element_cls = element_cls self._prop_name = prop_name def _add_adder(self): """Add an ``_add_x()`` method to the element class for this child element.""" def _add_child(obj: BaseOxmlElement, **attrs: Any): new_method = getattr(obj, self._new_method_name) child = new_method() for key, value in attrs.items(): setattr(child, key, value) insert_method = getattr(obj, self._insert_method_name) insert_method(child) return child _add_child.__doc__ = ( "Add a new ``<%s>`` child element unconditionally, inserted in t" "he correct sequence." % self._nsptagname ) self._add_to_class(self._add_method_name, _add_child) def _add_creator(self): """Add a ``_new_{prop_name}()`` method to the element class that creates a new, empty element of the correct type, having no attributes.""" creator = self._creator creator.__doc__ = ( 'Return a "loose", newly created ``<%s>`` element having no attri' "butes, text, or children." % self._nsptagname ) self._add_to_class(self._new_method_name, creator) def _add_getter(self): """Add a read-only ``{prop_name}`` property to the element class for this child element.""" property_ = property(self._getter, None, None) # -- assign unconditionally to overwrite element name definition -- setattr(self._element_cls, self._prop_name, property_) def _add_inserter(self): """Add an ``_insert_x()`` method to the element class for this child element.""" def _insert_child(obj: BaseOxmlElement, child: BaseOxmlElement): obj.insert_element_before(child, *self._successors) return child _insert_child.__doc__ = ( "Return the passed ``<%s>`` element after inserting it as a chil" "d in the correct sequence." % self._nsptagname ) self._add_to_class(self._insert_method_name, _insert_child) def _add_list_getter(self): """Add a read-only ``{prop_name}_lst`` property to the element class to retrieve a list of child elements matching this type.""" prop_name = "%s_lst" % self._prop_name property_ = property(self._list_getter, None, None) setattr(self._element_cls, prop_name, property_) @lazyproperty def _add_method_name(self): return "_add_%s" % self._prop_name def _add_public_adder(self): """Add a public ``add_x()`` method to the parent element class.""" def add_child(obj: BaseOxmlElement): private_add_method = getattr(obj, self._add_method_name) child = private_add_method() return child add_child.__doc__ = ( "Add a new ``<%s>`` child element unconditionally, inserted in t" "he correct sequence." % self._nsptagname ) self._add_to_class(self._public_add_method_name, add_child) def _add_to_class(self, name: str, method: Callable[..., Any]): """Add `method` to the target class as `name`, unless `name` is already defined on the class.""" if hasattr(self._element_cls, name): return setattr(self._element_cls, name, method) @property def _creator(self) -> Callable[[BaseOxmlElement], BaseOxmlElement]: """Callable that creates an empty element of the right type, with no attrs.""" from docx.oxml.parser import OxmlElement def new_child_element(obj: BaseOxmlElement): return OxmlElement(self._nsptagname) return new_child_element @property def _getter(self): """Return a function object suitable for the "get" side of the property descriptor. This default getter returns the child element with matching tag name or |None| if not present. """ def get_child_element(obj: BaseOxmlElement): return obj.find(qn(self._nsptagname)) get_child_element.__doc__ = ( "``<%s>`` child element or |None| if not present." % self._nsptagname ) return get_child_element @lazyproperty def _insert_method_name(self): return "_insert_%s" % self._prop_name @property def _list_getter(self): """Return a function object suitable for the "get" side of a list property descriptor.""" def get_child_element_list(obj: BaseOxmlElement): return obj.findall(qn(self._nsptagname)) get_child_element_list.__doc__ = ( "A list containing each of the ``<%s>`` child elements, in the o" "rder they appear." % self._nsptagname ) return get_child_element_list @lazyproperty def _public_add_method_name(self): """add_childElement() is public API for a repeating element, allowing new elements to be added to the sequence. May be overridden to provide a friendlier API to clients having domain appropriate parameter names for required attributes. """ return "add_%s" % self._prop_name @lazyproperty def _remove_method_name(self): return "_remove_%s" % self._prop_name @lazyproperty def _new_method_name(self): return "_new_%s" % self._prop_name class Choice(_BaseChildElement): """Defines a child element belonging to a group, only one of which may appear as a child.""" @property def nsptagname(self): return self._nsptagname def populate_class_members( # pyright: ignore[reportIncompatibleMethodOverride] self, element_cls: MetaOxmlElement, group_prop_name: str, successors: Tuple[str, ...], ) -> None: """Add the appropriate methods to `element_cls`.""" self._element_cls = element_cls self._group_prop_name = group_prop_name self._successors = successors self._add_getter() self._add_creator() self._add_inserter() self._add_adder() self._add_get_or_change_to_method() def _add_get_or_change_to_method(self): """Add a ``get_or_change_to_x()`` method to the element class for this child element.""" def get_or_change_to_child(obj: BaseOxmlElement): child = getattr(obj, self._prop_name) if child is not None: return child remove_group_method = getattr(obj, self._remove_group_method_name) remove_group_method() add_method = getattr(obj, self._add_method_name) child = add_method() return child get_or_change_to_child.__doc__ = ( "Return the ``<%s>`` child, replacing any other group element if" " found." ) % self._nsptagname self._add_to_class(self._get_or_change_to_method_name, get_or_change_to_child) @property def _prop_name(self): """Property name computed from tag name, e.g. a:schemeClr -> schemeClr.""" start = self._nsptagname.index(":") + 1 if ":" in self._nsptagname else 0 return self._nsptagname[start:] @lazyproperty def _get_or_change_to_method_name(self): return "get_or_change_to_%s" % self._prop_name @lazyproperty def _remove_group_method_name(self): return "_remove_%s" % self._group_prop_name class OneAndOnlyOne(_BaseChildElement): """Defines a required child element for MetaOxmlElement.""" def __init__(self, nsptagname: str): super(OneAndOnlyOne, self).__init__(nsptagname, ()) def populate_class_members(self, element_cls: MetaOxmlElement, prop_name: str) -> None: """Add the appropriate methods to `element_cls`.""" super(OneAndOnlyOne, self).populate_class_members(element_cls, prop_name) self._add_getter() @property def _getter(self): """Return a function object suitable for the "get" side of the property descriptor.""" def get_child_element(obj: BaseOxmlElement): child = obj.find(qn(self._nsptagname)) if child is None: raise InvalidXmlError( "required ``<%s>`` child element not present" % self._nsptagname ) return child get_child_element.__doc__ = "Required ``<%s>`` child element." % self._nsptagname return get_child_element class OneOrMore(_BaseChildElement): """Defines a repeating child element for MetaOxmlElement that must appear at least once.""" def populate_class_members(self, element_cls: MetaOxmlElement, prop_name: str) -> None: """Add the appropriate methods to `element_cls`.""" super(OneOrMore, self).populate_class_members(element_cls, prop_name) self._add_list_getter() self._add_creator() self._add_inserter() self._add_adder() self._add_public_adder() delattr(element_cls, prop_name) class ZeroOrMore(_BaseChildElement): """Defines an optional repeating child element for MetaOxmlElement.""" def populate_class_members(self, element_cls: MetaOxmlElement, prop_name: str) -> None: """Add the appropriate methods to `element_cls`.""" super(ZeroOrMore, self).populate_class_members(element_cls, prop_name) self._add_list_getter() self._add_creator() self._add_inserter() self._add_adder() self._add_public_adder() delattr(element_cls, prop_name) class ZeroOrOne(_BaseChildElement): """Defines an optional child element for MetaOxmlElement.""" def populate_class_members(self, element_cls: MetaOxmlElement, prop_name: str) -> None: """Add the appropriate methods to `element_cls`.""" super(ZeroOrOne, self).populate_class_members(element_cls, prop_name) self._add_getter() self._add_creator() self._add_inserter() self._add_adder() self._add_get_or_adder() self._add_remover() def _add_get_or_adder(self): """Add a ``get_or_add_x()`` method to the element class for this child element.""" def get_or_add_child(obj: BaseOxmlElement): child = getattr(obj, self._prop_name) if child is None: add_method = getattr(obj, self._add_method_name) child = add_method() return child get_or_add_child.__doc__ = ( "Return the ``<%s>`` child element, newly added if not present." ) % self._nsptagname self._add_to_class(self._get_or_add_method_name, get_or_add_child) def _add_remover(self): """Add a ``_remove_x()`` method to the element class for this child element.""" def _remove_child(obj: BaseOxmlElement): obj.remove_all(self._nsptagname) _remove_child.__doc__ = ("Remove all ``<%s>`` child elements.") % self._nsptagname self._add_to_class(self._remove_method_name, _remove_child) @lazyproperty def _get_or_add_method_name(self): return "get_or_add_%s" % self._prop_name class ZeroOrOneChoice(_BaseChildElement): """Correspondes to an ``EG_*`` element group where at most one of its members may appear as a child.""" def __init__(self, choices: Sequence[Choice], successors: Tuple[str, ...] = ()): self._choices = choices self._successors = successors def populate_class_members(self, element_cls: MetaOxmlElement, prop_name: str) -> None: """Add the appropriate methods to `element_cls`.""" super(ZeroOrOneChoice, self).populate_class_members(element_cls, prop_name) self._add_choice_getter() for choice in self._choices: choice.populate_class_members(element_cls, self._prop_name, self._successors) self._add_group_remover() def _add_choice_getter(self): """Add a read-only ``{prop_name}`` property to the element class that returns the present member of this group, or |None| if none are present.""" property_ = property(self._choice_getter, None, None) # assign unconditionally to overwrite element name definition setattr(self._element_cls, self._prop_name, property_) def _add_group_remover(self): """Add a ``_remove_eg_x()`` method to the element class for this choice group.""" def _remove_choice_group(obj: BaseOxmlElement): for tagname in self._member_nsptagnames: obj.remove_all(tagname) _remove_choice_group.__doc__ = "Remove the current choice group child element if present." self._add_to_class(self._remove_choice_group_method_name, _remove_choice_group) @property def _choice_getter(self): """Return a function object suitable for the "get" side of the property descriptor.""" def get_group_member_element(obj: BaseOxmlElement): return obj.first_child_found_in(*self._member_nsptagnames) get_group_member_element.__doc__ = ( "Return the child element belonging to this element group, or " "|None| if no member child is present." ) return get_group_member_element @lazyproperty def _member_nsptagnames(self): """Sequence of namespace-prefixed tagnames, one for each of the member elements of this choice group.""" return [choice.nsptagname for choice in self._choices] @lazyproperty def _remove_choice_group_method_name(self): return "_remove_%s" % self._prop_name # -- lxml typing isn't quite right here, just ignore this error on _Element -- class BaseOxmlElement(etree.ElementBase, metaclass=MetaOxmlElement): """Effective base class for all custom element classes. Adds standardized behavior to all classes in one place. """ def __repr__(self): return "<%s '<%s>' at 0x%0x>" % ( self.__class__.__name__, self._nsptag, id(self), ) def first_child_found_in(self, *tagnames: str) -> _Element | None: """First child with tag in `tagnames`, or None if not found.""" for tagname in tagnames: child = self.find(qn(tagname)) if child is not None: return child return None def insert_element_before(self, elm: ElementBase, *tagnames: str): successor = self.first_child_found_in(*tagnames) if successor is not None: successor.addprevious(elm) else: self.append(elm) return elm def remove_all(self, *tagnames: str) -> None: """Remove child elements with tagname (e.g. "a:p") in `tagnames`.""" for tagname in tagnames: matching = self.findall(qn(tagname)) for child in matching: self.remove(child) @property def xml(self) -> str: """XML string for this element, suitable for testing purposes. Pretty printed for readability and without an XML declaration at the top. """ return serialize_for_reading(self) def xpath(self, xpath_str: str) -> Any: # pyright: ignore[reportIncompatibleMethodOverride] """Override of `lxml` _Element.xpath() method. Provides standard Open XML namespace mapping (`nsmap`) in centralized location. """ return super().xpath(xpath_str, namespaces=nsmap) @property def _nsptag(self) -> str: return NamespacePrefixedTag.from_clark_name(self.tag)