aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/lark/tree_matcher.py
blob: 02f95885524a9b4ae6cc9b61e7d248574c9c0eb1 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
"""Tree matcher based on Lark grammar"""

import re
from collections import defaultdict

from . import Tree, Token
from .common import ParserConf
from .parsers import earley
from .grammar import Rule, Terminal, NonTerminal


def is_discarded_terminal(t):
    return t.is_term and t.filter_out


class _MakeTreeMatch:
    def __init__(self, name, expansion):
        self.name = name
        self.expansion = expansion

    def __call__(self, args):
        t = Tree(self.name, args)
        t.meta.match_tree = True
        t.meta.orig_expansion = self.expansion
        return t


def _best_from_group(seq, group_key, cmp_key):
    d = {}
    for item in seq:
        key = group_key(item)
        if key in d:
            v1 = cmp_key(item)
            v2 = cmp_key(d[key])
            if v2 > v1:
                d[key] = item
        else:
            d[key] = item
    return list(d.values())


def _best_rules_from_group(rules):
    rules = _best_from_group(rules, lambda r: r, lambda r: -len(r.expansion))
    rules.sort(key=lambda r: len(r.expansion))
    return rules


def _match(term, token):
    if isinstance(token, Tree):
        name, _args = parse_rulename(term.name)
        return token.data == name
    elif isinstance(token, Token):
        return term == Terminal(token.type)
    assert False, (term, token)


def make_recons_rule(origin, expansion, old_expansion):
    return Rule(origin, expansion, alias=_MakeTreeMatch(origin.name, old_expansion))


def make_recons_rule_to_term(origin, term):
    return make_recons_rule(origin, [Terminal(term.name)], [term])


def parse_rulename(s):
    "Parse rule names that may contain a template syntax (like rule{a, b, ...})"
    name, args_str = re.match(r'(\w+)(?:{(.+)})?', s).groups()
    args = args_str and [a.strip() for a in args_str.split(',')]
    return name, args



class ChildrenLexer:
    def __init__(self, children):
        self.children = children

    def lex(self, parser_state):
        return self.children

class TreeMatcher:
    """Match the elements of a tree node, based on an ontology
    provided by a Lark grammar.

    Supports templates and inlined rules (`rule{a, b,..}` and `_rule`)

    Initiialize with an instance of Lark.
    """

    def __init__(self, parser):
        # XXX TODO calling compile twice returns different results!
        assert parser.options.maybe_placeholders == False
        # XXX TODO: we just ignore the potential existence of a postlexer
        self.tokens, rules, _extra = parser.grammar.compile(parser.options.start, set())

        self.rules_for_root = defaultdict(list)

        self.rules = list(self._build_recons_rules(rules))
        self.rules.reverse()

        # Choose the best rule from each group of {rule => [rule.alias]}, since we only really need one derivation.
        self.rules = _best_rules_from_group(self.rules)

        self.parser = parser
        self._parser_cache = {}

    def _build_recons_rules(self, rules):
        "Convert tree-parsing/construction rules to tree-matching rules"
        expand1s = {r.origin for r in rules if r.options.expand1}

        aliases = defaultdict(list)
        for r in rules:
            if r.alias:
                aliases[r.origin].append(r.alias)

        rule_names = {r.origin for r in rules}
        nonterminals = {sym for sym in rule_names
                        if sym.name.startswith('_') or sym in expand1s or sym in aliases}

        seen = set()
        for r in rules:
            recons_exp = [sym if sym in nonterminals else Terminal(sym.name)
                          for sym in r.expansion if not is_discarded_terminal(sym)]

            # Skip self-recursive constructs
            if recons_exp == [r.origin] and r.alias is None:
                continue

            sym = NonTerminal(r.alias) if r.alias else r.origin
            rule = make_recons_rule(sym, recons_exp, r.expansion)

            if sym in expand1s and len(recons_exp) != 1:
                self.rules_for_root[sym.name].append(rule)

                if sym.name not in seen:
                    yield make_recons_rule_to_term(sym, sym)
                    seen.add(sym.name)
            else:
                if sym.name.startswith('_') or sym in expand1s:
                    yield rule
                else:
                    self.rules_for_root[sym.name].append(rule)

        for origin, rule_aliases in aliases.items():
            for alias in rule_aliases:
                yield make_recons_rule_to_term(origin, NonTerminal(alias))
            yield make_recons_rule_to_term(origin, origin)

    def match_tree(self, tree, rulename):
        """Match the elements of `tree` to the symbols of rule `rulename`.

        Parameters:
            tree (Tree): the tree node to match
            rulename (str): The expected full rule name (including template args)

        Returns:
            Tree: an unreduced tree that matches `rulename`

        Raises:
            UnexpectedToken: If no match was found.

        Note:
            It's the callers' responsibility match the tree recursively.
        """
        if rulename:
            # validate
            name, _args = parse_rulename(rulename)
            assert tree.data == name
        else:
            rulename = tree.data

        # TODO: ambiguity?
        try:
            parser = self._parser_cache[rulename]
        except KeyError:
            rules = self.rules + _best_rules_from_group(self.rules_for_root[rulename])

            # TODO pass callbacks through dict, instead of alias?
            callbacks = {rule: rule.alias for rule in rules}
            conf = ParserConf(rules, callbacks, [rulename])
            parser = earley.Parser(conf, _match, resolve_ambiguity=True)
            self._parser_cache[rulename] = parser

        # find a full derivation
        unreduced_tree = parser.parse(ChildrenLexer(tree.children), rulename)
        assert unreduced_tree.data == rulename
        return unreduced_tree