aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/lark/tree_matcher.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/lark/tree_matcher.py')
-rw-r--r--.venv/lib/python3.12/site-packages/lark/tree_matcher.py186
1 files changed, 186 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/lark/tree_matcher.py b/.venv/lib/python3.12/site-packages/lark/tree_matcher.py
new file mode 100644
index 00000000..02f95885
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/lark/tree_matcher.py
@@ -0,0 +1,186 @@
+"""Tree matcher based on Lark grammar"""
+
+import re
+from collections import defaultdict
+
+from . import Tree, Token
+from .common import ParserConf
+from .parsers import earley
+from .grammar import Rule, Terminal, NonTerminal
+
+
+def is_discarded_terminal(t):
+ return t.is_term and t.filter_out
+
+
+class _MakeTreeMatch:
+ def __init__(self, name, expansion):
+ self.name = name
+ self.expansion = expansion
+
+ def __call__(self, args):
+ t = Tree(self.name, args)
+ t.meta.match_tree = True
+ t.meta.orig_expansion = self.expansion
+ return t
+
+
+def _best_from_group(seq, group_key, cmp_key):
+ d = {}
+ for item in seq:
+ key = group_key(item)
+ if key in d:
+ v1 = cmp_key(item)
+ v2 = cmp_key(d[key])
+ if v2 > v1:
+ d[key] = item
+ else:
+ d[key] = item
+ return list(d.values())
+
+
+def _best_rules_from_group(rules):
+ rules = _best_from_group(rules, lambda r: r, lambda r: -len(r.expansion))
+ rules.sort(key=lambda r: len(r.expansion))
+ return rules
+
+
+def _match(term, token):
+ if isinstance(token, Tree):
+ name, _args = parse_rulename(term.name)
+ return token.data == name
+ elif isinstance(token, Token):
+ return term == Terminal(token.type)
+ assert False, (term, token)
+
+
+def make_recons_rule(origin, expansion, old_expansion):
+ return Rule(origin, expansion, alias=_MakeTreeMatch(origin.name, old_expansion))
+
+
+def make_recons_rule_to_term(origin, term):
+ return make_recons_rule(origin, [Terminal(term.name)], [term])
+
+
+def parse_rulename(s):
+ "Parse rule names that may contain a template syntax (like rule{a, b, ...})"
+ name, args_str = re.match(r'(\w+)(?:{(.+)})?', s).groups()
+ args = args_str and [a.strip() for a in args_str.split(',')]
+ return name, args
+
+
+
+class ChildrenLexer:
+ def __init__(self, children):
+ self.children = children
+
+ def lex(self, parser_state):
+ return self.children
+
+class TreeMatcher:
+ """Match the elements of a tree node, based on an ontology
+ provided by a Lark grammar.
+
+ Supports templates and inlined rules (`rule{a, b,..}` and `_rule`)
+
+ Initiialize with an instance of Lark.
+ """
+
+ def __init__(self, parser):
+ # XXX TODO calling compile twice returns different results!
+ assert parser.options.maybe_placeholders == False
+ # XXX TODO: we just ignore the potential existence of a postlexer
+ self.tokens, rules, _extra = parser.grammar.compile(parser.options.start, set())
+
+ self.rules_for_root = defaultdict(list)
+
+ self.rules = list(self._build_recons_rules(rules))
+ self.rules.reverse()
+
+ # Choose the best rule from each group of {rule => [rule.alias]}, since we only really need one derivation.
+ self.rules = _best_rules_from_group(self.rules)
+
+ self.parser = parser
+ self._parser_cache = {}
+
+ def _build_recons_rules(self, rules):
+ "Convert tree-parsing/construction rules to tree-matching rules"
+ expand1s = {r.origin for r in rules if r.options.expand1}
+
+ aliases = defaultdict(list)
+ for r in rules:
+ if r.alias:
+ aliases[r.origin].append(r.alias)
+
+ rule_names = {r.origin for r in rules}
+ nonterminals = {sym for sym in rule_names
+ if sym.name.startswith('_') or sym in expand1s or sym in aliases}
+
+ seen = set()
+ for r in rules:
+ recons_exp = [sym if sym in nonterminals else Terminal(sym.name)
+ for sym in r.expansion if not is_discarded_terminal(sym)]
+
+ # Skip self-recursive constructs
+ if recons_exp == [r.origin] and r.alias is None:
+ continue
+
+ sym = NonTerminal(r.alias) if r.alias else r.origin
+ rule = make_recons_rule(sym, recons_exp, r.expansion)
+
+ if sym in expand1s and len(recons_exp) != 1:
+ self.rules_for_root[sym.name].append(rule)
+
+ if sym.name not in seen:
+ yield make_recons_rule_to_term(sym, sym)
+ seen.add(sym.name)
+ else:
+ if sym.name.startswith('_') or sym in expand1s:
+ yield rule
+ else:
+ self.rules_for_root[sym.name].append(rule)
+
+ for origin, rule_aliases in aliases.items():
+ for alias in rule_aliases:
+ yield make_recons_rule_to_term(origin, NonTerminal(alias))
+ yield make_recons_rule_to_term(origin, origin)
+
+ def match_tree(self, tree, rulename):
+ """Match the elements of `tree` to the symbols of rule `rulename`.
+
+ Parameters:
+ tree (Tree): the tree node to match
+ rulename (str): The expected full rule name (including template args)
+
+ Returns:
+ Tree: an unreduced tree that matches `rulename`
+
+ Raises:
+ UnexpectedToken: If no match was found.
+
+ Note:
+ It's the callers' responsibility match the tree recursively.
+ """
+ if rulename:
+ # validate
+ name, _args = parse_rulename(rulename)
+ assert tree.data == name
+ else:
+ rulename = tree.data
+
+ # TODO: ambiguity?
+ try:
+ parser = self._parser_cache[rulename]
+ except KeyError:
+ rules = self.rules + _best_rules_from_group(self.rules_for_root[rulename])
+
+ # TODO pass callbacks through dict, instead of alias?
+ callbacks = {rule: rule.alias for rule in rules}
+ conf = ParserConf(rules, callbacks, [rulename])
+ parser = earley.Parser(conf, _match, resolve_ambiguity=True)
+ self._parser_cache[rulename] = parser
+
+ # find a full derivation
+ unreduced_tree = parser.parse(ChildrenLexer(tree.children), rulename)
+ assert unreduced_tree.data == rulename
+ return unreduced_tree