about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/lark/tree_matcher.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/lark/tree_matcher.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/lark/tree_matcher.py')
-rw-r--r--.venv/lib/python3.12/site-packages/lark/tree_matcher.py186
1 files changed, 186 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/lark/tree_matcher.py b/.venv/lib/python3.12/site-packages/lark/tree_matcher.py
new file mode 100644
index 00000000..02f95885
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/lark/tree_matcher.py
@@ -0,0 +1,186 @@
+"""Tree matcher based on Lark grammar"""
+
+import re
+from collections import defaultdict
+
+from . import Tree, Token
+from .common import ParserConf
+from .parsers import earley
+from .grammar import Rule, Terminal, NonTerminal
+
+
+def is_discarded_terminal(t):
+    return t.is_term and t.filter_out
+
+
+class _MakeTreeMatch:
+    def __init__(self, name, expansion):
+        self.name = name
+        self.expansion = expansion
+
+    def __call__(self, args):
+        t = Tree(self.name, args)
+        t.meta.match_tree = True
+        t.meta.orig_expansion = self.expansion
+        return t
+
+
+def _best_from_group(seq, group_key, cmp_key):
+    d = {}
+    for item in seq:
+        key = group_key(item)
+        if key in d:
+            v1 = cmp_key(item)
+            v2 = cmp_key(d[key])
+            if v2 > v1:
+                d[key] = item
+        else:
+            d[key] = item
+    return list(d.values())
+
+
+def _best_rules_from_group(rules):
+    rules = _best_from_group(rules, lambda r: r, lambda r: -len(r.expansion))
+    rules.sort(key=lambda r: len(r.expansion))
+    return rules
+
+
+def _match(term, token):
+    if isinstance(token, Tree):
+        name, _args = parse_rulename(term.name)
+        return token.data == name
+    elif isinstance(token, Token):
+        return term == Terminal(token.type)
+    assert False, (term, token)
+
+
+def make_recons_rule(origin, expansion, old_expansion):
+    return Rule(origin, expansion, alias=_MakeTreeMatch(origin.name, old_expansion))
+
+
+def make_recons_rule_to_term(origin, term):
+    return make_recons_rule(origin, [Terminal(term.name)], [term])
+
+
+def parse_rulename(s):
+    "Parse rule names that may contain a template syntax (like rule{a, b, ...})"
+    name, args_str = re.match(r'(\w+)(?:{(.+)})?', s).groups()
+    args = args_str and [a.strip() for a in args_str.split(',')]
+    return name, args
+
+
+
+class ChildrenLexer:
+    def __init__(self, children):
+        self.children = children
+
+    def lex(self, parser_state):
+        return self.children
+
+class TreeMatcher:
+    """Match the elements of a tree node, based on an ontology
+    provided by a Lark grammar.
+
+    Supports templates and inlined rules (`rule{a, b,..}` and `_rule`)
+
+    Initiialize with an instance of Lark.
+    """
+
+    def __init__(self, parser):
+        # XXX TODO calling compile twice returns different results!
+        assert parser.options.maybe_placeholders == False
+        # XXX TODO: we just ignore the potential existence of a postlexer
+        self.tokens, rules, _extra = parser.grammar.compile(parser.options.start, set())
+
+        self.rules_for_root = defaultdict(list)
+
+        self.rules = list(self._build_recons_rules(rules))
+        self.rules.reverse()
+
+        # Choose the best rule from each group of {rule => [rule.alias]}, since we only really need one derivation.
+        self.rules = _best_rules_from_group(self.rules)
+
+        self.parser = parser
+        self._parser_cache = {}
+
+    def _build_recons_rules(self, rules):
+        "Convert tree-parsing/construction rules to tree-matching rules"
+        expand1s = {r.origin for r in rules if r.options.expand1}
+
+        aliases = defaultdict(list)
+        for r in rules:
+            if r.alias:
+                aliases[r.origin].append(r.alias)
+
+        rule_names = {r.origin for r in rules}
+        nonterminals = {sym for sym in rule_names
+                        if sym.name.startswith('_') or sym in expand1s or sym in aliases}
+
+        seen = set()
+        for r in rules:
+            recons_exp = [sym if sym in nonterminals else Terminal(sym.name)
+                          for sym in r.expansion if not is_discarded_terminal(sym)]
+
+            # Skip self-recursive constructs
+            if recons_exp == [r.origin] and r.alias is None:
+                continue
+
+            sym = NonTerminal(r.alias) if r.alias else r.origin
+            rule = make_recons_rule(sym, recons_exp, r.expansion)
+
+            if sym in expand1s and len(recons_exp) != 1:
+                self.rules_for_root[sym.name].append(rule)
+
+                if sym.name not in seen:
+                    yield make_recons_rule_to_term(sym, sym)
+                    seen.add(sym.name)
+            else:
+                if sym.name.startswith('_') or sym in expand1s:
+                    yield rule
+                else:
+                    self.rules_for_root[sym.name].append(rule)
+
+        for origin, rule_aliases in aliases.items():
+            for alias in rule_aliases:
+                yield make_recons_rule_to_term(origin, NonTerminal(alias))
+            yield make_recons_rule_to_term(origin, origin)
+
+    def match_tree(self, tree, rulename):
+        """Match the elements of `tree` to the symbols of rule `rulename`.
+
+        Parameters:
+            tree (Tree): the tree node to match
+            rulename (str): The expected full rule name (including template args)
+
+        Returns:
+            Tree: an unreduced tree that matches `rulename`
+
+        Raises:
+            UnexpectedToken: If no match was found.
+
+        Note:
+            It's the callers' responsibility match the tree recursively.
+        """
+        if rulename:
+            # validate
+            name, _args = parse_rulename(rulename)
+            assert tree.data == name
+        else:
+            rulename = tree.data
+
+        # TODO: ambiguity?
+        try:
+            parser = self._parser_cache[rulename]
+        except KeyError:
+            rules = self.rules + _best_rules_from_group(self.rules_for_root[rulename])
+
+            # TODO pass callbacks through dict, instead of alias?
+            callbacks = {rule: rule.alias for rule in rules}
+            conf = ParserConf(rules, callbacks, [rulename])
+            parser = earley.Parser(conf, _match, resolve_ambiguity=True)
+            self._parser_cache[rulename] = parser
+
+        # find a full derivation
+        unreduced_tree = parser.parse(ChildrenLexer(tree.children), rulename)
+        assert unreduced_tree.data == rulename
+        return unreduced_tree