diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/lark/tree_matcher.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/lark/tree_matcher.py | 186 |
1 files changed, 186 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/lark/tree_matcher.py b/.venv/lib/python3.12/site-packages/lark/tree_matcher.py new file mode 100644 index 00000000..02f95885 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/lark/tree_matcher.py @@ -0,0 +1,186 @@ +"""Tree matcher based on Lark grammar""" + +import re +from collections import defaultdict + +from . import Tree, Token +from .common import ParserConf +from .parsers import earley +from .grammar import Rule, Terminal, NonTerminal + + +def is_discarded_terminal(t): + return t.is_term and t.filter_out + + +class _MakeTreeMatch: + def __init__(self, name, expansion): + self.name = name + self.expansion = expansion + + def __call__(self, args): + t = Tree(self.name, args) + t.meta.match_tree = True + t.meta.orig_expansion = self.expansion + return t + + +def _best_from_group(seq, group_key, cmp_key): + d = {} + for item in seq: + key = group_key(item) + if key in d: + v1 = cmp_key(item) + v2 = cmp_key(d[key]) + if v2 > v1: + d[key] = item + else: + d[key] = item + return list(d.values()) + + +def _best_rules_from_group(rules): + rules = _best_from_group(rules, lambda r: r, lambda r: -len(r.expansion)) + rules.sort(key=lambda r: len(r.expansion)) + return rules + + +def _match(term, token): + if isinstance(token, Tree): + name, _args = parse_rulename(term.name) + return token.data == name + elif isinstance(token, Token): + return term == Terminal(token.type) + assert False, (term, token) + + +def make_recons_rule(origin, expansion, old_expansion): + return Rule(origin, expansion, alias=_MakeTreeMatch(origin.name, old_expansion)) + + +def make_recons_rule_to_term(origin, term): + return make_recons_rule(origin, [Terminal(term.name)], [term]) + + +def parse_rulename(s): + "Parse rule names that may contain a template syntax (like rule{a, b, ...})" + name, args_str = re.match(r'(\w+)(?:{(.+)})?', s).groups() + args = args_str and [a.strip() for a in args_str.split(',')] + return name, args + + + +class ChildrenLexer: + def __init__(self, children): + self.children = children + + def lex(self, parser_state): + return self.children + +class TreeMatcher: + """Match the elements of a tree node, based on an ontology + provided by a Lark grammar. + + Supports templates and inlined rules (`rule{a, b,..}` and `_rule`) + + Initiialize with an instance of Lark. + """ + + def __init__(self, parser): + # XXX TODO calling compile twice returns different results! + assert parser.options.maybe_placeholders == False + # XXX TODO: we just ignore the potential existence of a postlexer + self.tokens, rules, _extra = parser.grammar.compile(parser.options.start, set()) + + self.rules_for_root = defaultdict(list) + + self.rules = list(self._build_recons_rules(rules)) + self.rules.reverse() + + # Choose the best rule from each group of {rule => [rule.alias]}, since we only really need one derivation. + self.rules = _best_rules_from_group(self.rules) + + self.parser = parser + self._parser_cache = {} + + def _build_recons_rules(self, rules): + "Convert tree-parsing/construction rules to tree-matching rules" + expand1s = {r.origin for r in rules if r.options.expand1} + + aliases = defaultdict(list) + for r in rules: + if r.alias: + aliases[r.origin].append(r.alias) + + rule_names = {r.origin for r in rules} + nonterminals = {sym for sym in rule_names + if sym.name.startswith('_') or sym in expand1s or sym in aliases} + + seen = set() + for r in rules: + recons_exp = [sym if sym in nonterminals else Terminal(sym.name) + for sym in r.expansion if not is_discarded_terminal(sym)] + + # Skip self-recursive constructs + if recons_exp == [r.origin] and r.alias is None: + continue + + sym = NonTerminal(r.alias) if r.alias else r.origin + rule = make_recons_rule(sym, recons_exp, r.expansion) + + if sym in expand1s and len(recons_exp) != 1: + self.rules_for_root[sym.name].append(rule) + + if sym.name not in seen: + yield make_recons_rule_to_term(sym, sym) + seen.add(sym.name) + else: + if sym.name.startswith('_') or sym in expand1s: + yield rule + else: + self.rules_for_root[sym.name].append(rule) + + for origin, rule_aliases in aliases.items(): + for alias in rule_aliases: + yield make_recons_rule_to_term(origin, NonTerminal(alias)) + yield make_recons_rule_to_term(origin, origin) + + def match_tree(self, tree, rulename): + """Match the elements of `tree` to the symbols of rule `rulename`. + + Parameters: + tree (Tree): the tree node to match + rulename (str): The expected full rule name (including template args) + + Returns: + Tree: an unreduced tree that matches `rulename` + + Raises: + UnexpectedToken: If no match was found. + + Note: + It's the callers' responsibility match the tree recursively. + """ + if rulename: + # validate + name, _args = parse_rulename(rulename) + assert tree.data == name + else: + rulename = tree.data + + # TODO: ambiguity? + try: + parser = self._parser_cache[rulename] + except KeyError: + rules = self.rules + _best_rules_from_group(self.rules_for_root[rulename]) + + # TODO pass callbacks through dict, instead of alias? + callbacks = {rule: rule.alias for rule in rules} + conf = ParserConf(rules, callbacks, [rulename]) + parser = earley.Parser(conf, _match, resolve_ambiguity=True) + self._parser_cache[rulename] = parser + + # find a full derivation + unreduced_tree = parser.parse(ChildrenLexer(tree.children), rulename) + assert unreduced_tree.data == rulename + return unreduced_tree |