diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/lark/parsers/cyk.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/lark/parsers/cyk.py | 345 |
1 files changed, 345 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/lark/parsers/cyk.py b/.venv/lib/python3.12/site-packages/lark/parsers/cyk.py new file mode 100644 index 00000000..ff0924f2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/lark/parsers/cyk.py @@ -0,0 +1,345 @@ +"""This module implements a CYK parser.""" + +# Author: https://github.com/ehudt (2018) +# +# Adapted by Erez + + +from collections import defaultdict +import itertools + +from ..exceptions import ParseError +from ..lexer import Token +from ..tree import Tree +from ..grammar import Terminal as T, NonTerminal as NT, Symbol + +try: + xrange +except NameError: + xrange = range + +def match(t, s): + assert isinstance(t, T) + return t.name == s.type + + +class Rule(object): + """Context-free grammar rule.""" + + def __init__(self, lhs, rhs, weight, alias): + super(Rule, self).__init__() + assert isinstance(lhs, NT), lhs + assert all(isinstance(x, NT) or isinstance(x, T) for x in rhs), rhs + self.lhs = lhs + self.rhs = rhs + self.weight = weight + self.alias = alias + + def __str__(self): + return '%s -> %s' % (str(self.lhs), ' '.join(str(x) for x in self.rhs)) + + def __repr__(self): + return str(self) + + def __hash__(self): + return hash((self.lhs, tuple(self.rhs))) + + def __eq__(self, other): + return self.lhs == other.lhs and self.rhs == other.rhs + + def __ne__(self, other): + return not (self == other) + + +class Grammar(object): + """Context-free grammar.""" + + def __init__(self, rules): + self.rules = frozenset(rules) + + def __eq__(self, other): + return self.rules == other.rules + + def __str__(self): + return '\n' + '\n'.join(sorted(repr(x) for x in self.rules)) + '\n' + + def __repr__(self): + return str(self) + + +# Parse tree data structures +class RuleNode(object): + """A node in the parse tree, which also contains the full rhs rule.""" + + def __init__(self, rule, children, weight=0): + self.rule = rule + self.children = children + self.weight = weight + + def __repr__(self): + return 'RuleNode(%s, [%s])' % (repr(self.rule.lhs), ', '.join(str(x) for x in self.children)) + + + +class Parser(object): + """Parser wrapper.""" + + def __init__(self, rules): + super(Parser, self).__init__() + self.orig_rules = {rule: rule for rule in rules} + rules = [self._to_rule(rule) for rule in rules] + self.grammar = to_cnf(Grammar(rules)) + + def _to_rule(self, lark_rule): + """Converts a lark rule, (lhs, rhs, callback, options), to a Rule.""" + assert isinstance(lark_rule.origin, NT) + assert all(isinstance(x, Symbol) for x in lark_rule.expansion) + return Rule( + lark_rule.origin, lark_rule.expansion, + weight=lark_rule.options.priority if lark_rule.options.priority else 0, + alias=lark_rule) + + def parse(self, tokenized, start): # pylint: disable=invalid-name + """Parses input, which is a list of tokens.""" + assert start + start = NT(start) + + table, trees = _parse(tokenized, self.grammar) + # Check if the parse succeeded. + if all(r.lhs != start for r in table[(0, len(tokenized) - 1)]): + raise ParseError('Parsing failed.') + parse = trees[(0, len(tokenized) - 1)][start] + return self._to_tree(revert_cnf(parse)) + + def _to_tree(self, rule_node): + """Converts a RuleNode parse tree to a lark Tree.""" + orig_rule = self.orig_rules[rule_node.rule.alias] + children = [] + for child in rule_node.children: + if isinstance(child, RuleNode): + children.append(self._to_tree(child)) + else: + assert isinstance(child.name, Token) + children.append(child.name) + t = Tree(orig_rule.origin, children) + t.rule=orig_rule + return t + + +def print_parse(node, indent=0): + if isinstance(node, RuleNode): + print(' ' * (indent * 2) + str(node.rule.lhs)) + for child in node.children: + print_parse(child, indent + 1) + else: + print(' ' * (indent * 2) + str(node.s)) + + +def _parse(s, g): + """Parses sentence 's' using CNF grammar 'g'.""" + # The CYK table. Indexed with a 2-tuple: (start pos, end pos) + table = defaultdict(set) + # Top-level structure is similar to the CYK table. Each cell is a dict from + # rule name to the best (lightest) tree for that rule. + trees = defaultdict(dict) + # Populate base case with existing terminal production rules + for i, w in enumerate(s): + for terminal, rules in g.terminal_rules.items(): + if match(terminal, w): + for rule in rules: + table[(i, i)].add(rule) + if (rule.lhs not in trees[(i, i)] or + rule.weight < trees[(i, i)][rule.lhs].weight): + trees[(i, i)][rule.lhs] = RuleNode(rule, [T(w)], weight=rule.weight) + + # Iterate over lengths of sub-sentences + for l in xrange(2, len(s) + 1): + # Iterate over sub-sentences with the given length + for i in xrange(len(s) - l + 1): + # Choose partition of the sub-sentence in [1, l) + for p in xrange(i + 1, i + l): + span1 = (i, p - 1) + span2 = (p, i + l - 1) + for r1, r2 in itertools.product(table[span1], table[span2]): + for rule in g.nonterminal_rules.get((r1.lhs, r2.lhs), []): + table[(i, i + l - 1)].add(rule) + r1_tree = trees[span1][r1.lhs] + r2_tree = trees[span2][r2.lhs] + rule_total_weight = rule.weight + r1_tree.weight + r2_tree.weight + if (rule.lhs not in trees[(i, i + l - 1)] + or rule_total_weight < trees[(i, i + l - 1)][rule.lhs].weight): + trees[(i, i + l - 1)][rule.lhs] = RuleNode(rule, [r1_tree, r2_tree], weight=rule_total_weight) + return table, trees + + +# This section implements context-free grammar converter to Chomsky normal form. +# It also implements a conversion of parse trees from its CNF to the original +# grammar. +# Overview: +# Applies the following operations in this order: +# * TERM: Eliminates non-solitary terminals from all rules +# * BIN: Eliminates rules with more than 2 symbols on their right-hand-side. +# * UNIT: Eliminates non-terminal unit rules +# +# The following grammar characteristics aren't featured: +# * Start symbol appears on RHS +# * Empty rules (epsilon rules) + + +class CnfWrapper(object): + """CNF wrapper for grammar. + + Validates that the input grammar is CNF and provides helper data structures. + """ + + def __init__(self, grammar): + super(CnfWrapper, self).__init__() + self.grammar = grammar + self.rules = grammar.rules + self.terminal_rules = defaultdict(list) + self.nonterminal_rules = defaultdict(list) + for r in self.rules: + # Validate that the grammar is CNF and populate auxiliary data structures. + assert isinstance(r.lhs, NT), r + if len(r.rhs) not in [1, 2]: + raise ParseError("CYK doesn't support empty rules") + if len(r.rhs) == 1 and isinstance(r.rhs[0], T): + self.terminal_rules[r.rhs[0]].append(r) + elif len(r.rhs) == 2 and all(isinstance(x, NT) for x in r.rhs): + self.nonterminal_rules[tuple(r.rhs)].append(r) + else: + assert False, r + + def __eq__(self, other): + return self.grammar == other.grammar + + def __repr__(self): + return repr(self.grammar) + + +class UnitSkipRule(Rule): + """A rule that records NTs that were skipped during transformation.""" + + def __init__(self, lhs, rhs, skipped_rules, weight, alias): + super(UnitSkipRule, self).__init__(lhs, rhs, weight, alias) + self.skipped_rules = skipped_rules + + def __eq__(self, other): + return isinstance(other, type(self)) and self.skipped_rules == other.skipped_rules + + __hash__ = Rule.__hash__ + + +def build_unit_skiprule(unit_rule, target_rule): + skipped_rules = [] + if isinstance(unit_rule, UnitSkipRule): + skipped_rules += unit_rule.skipped_rules + skipped_rules.append(target_rule) + if isinstance(target_rule, UnitSkipRule): + skipped_rules += target_rule.skipped_rules + return UnitSkipRule(unit_rule.lhs, target_rule.rhs, skipped_rules, + weight=unit_rule.weight + target_rule.weight, alias=unit_rule.alias) + + +def get_any_nt_unit_rule(g): + """Returns a non-terminal unit rule from 'g', or None if there is none.""" + for rule in g.rules: + if len(rule.rhs) == 1 and isinstance(rule.rhs[0], NT): + return rule + return None + + +def _remove_unit_rule(g, rule): + """Removes 'rule' from 'g' without changing the langugage produced by 'g'.""" + new_rules = [x for x in g.rules if x != rule] + refs = [x for x in g.rules if x.lhs == rule.rhs[0]] + new_rules += [build_unit_skiprule(rule, ref) for ref in refs] + return Grammar(new_rules) + + +def _split(rule): + """Splits a rule whose len(rhs) > 2 into shorter rules.""" + rule_str = str(rule.lhs) + '__' + '_'.join(str(x) for x in rule.rhs) + rule_name = '__SP_%s' % (rule_str) + '_%d' + yield Rule(rule.lhs, [rule.rhs[0], NT(rule_name % 1)], weight=rule.weight, alias=rule.alias) + for i in xrange(1, len(rule.rhs) - 2): + yield Rule(NT(rule_name % i), [rule.rhs[i], NT(rule_name % (i + 1))], weight=0, alias='Split') + yield Rule(NT(rule_name % (len(rule.rhs) - 2)), rule.rhs[-2:], weight=0, alias='Split') + + +def _term(g): + """Applies the TERM rule on 'g' (see top comment).""" + all_t = {x for rule in g.rules for x in rule.rhs if isinstance(x, T)} + t_rules = {t: Rule(NT('__T_%s' % str(t)), [t], weight=0, alias='Term') for t in all_t} + new_rules = [] + for rule in g.rules: + if len(rule.rhs) > 1 and any(isinstance(x, T) for x in rule.rhs): + new_rhs = [t_rules[x].lhs if isinstance(x, T) else x for x in rule.rhs] + new_rules.append(Rule(rule.lhs, new_rhs, weight=rule.weight, alias=rule.alias)) + new_rules.extend(v for k, v in t_rules.items() if k in rule.rhs) + else: + new_rules.append(rule) + return Grammar(new_rules) + + +def _bin(g): + """Applies the BIN rule to 'g' (see top comment).""" + new_rules = [] + for rule in g.rules: + if len(rule.rhs) > 2: + new_rules += _split(rule) + else: + new_rules.append(rule) + return Grammar(new_rules) + + +def _unit(g): + """Applies the UNIT rule to 'g' (see top comment).""" + nt_unit_rule = get_any_nt_unit_rule(g) + while nt_unit_rule: + g = _remove_unit_rule(g, nt_unit_rule) + nt_unit_rule = get_any_nt_unit_rule(g) + return g + + +def to_cnf(g): + """Creates a CNF grammar from a general context-free grammar 'g'.""" + g = _unit(_bin(_term(g))) + return CnfWrapper(g) + + +def unroll_unit_skiprule(lhs, orig_rhs, skipped_rules, children, weight, alias): + if not skipped_rules: + return RuleNode(Rule(lhs, orig_rhs, weight=weight, alias=alias), children, weight=weight) + else: + weight = weight - skipped_rules[0].weight + return RuleNode( + Rule(lhs, [skipped_rules[0].lhs], weight=weight, alias=alias), [ + unroll_unit_skiprule(skipped_rules[0].lhs, orig_rhs, + skipped_rules[1:], children, + skipped_rules[0].weight, skipped_rules[0].alias) + ], weight=weight) + + +def revert_cnf(node): + """Reverts a parse tree (RuleNode) to its original non-CNF form (Node).""" + if isinstance(node, T): + return node + # Reverts TERM rule. + if node.rule.lhs.name.startswith('__T_'): + return node.children[0] + else: + children = [] + for child in map(revert_cnf, node.children): + # Reverts BIN rule. + if isinstance(child, RuleNode) and child.rule.lhs.name.startswith('__SP_'): + children += child.children + else: + children.append(child) + # Reverts UNIT rule. + if isinstance(node.rule, UnitSkipRule): + return unroll_unit_skiprule(node.rule.lhs, node.rule.rhs, + node.rule.skipped_rules, children, + node.rule.weight, node.rule.alias) + else: + return RuleNode(node.rule, children) |