aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/lark/parsers/cyk.py
blob: ff0924f2415dc5e93af4501de66246d2f5958a78 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
"""This module implements a CYK parser."""

# Author: https://github.com/ehudt (2018)
#
# Adapted by Erez


from collections import defaultdict
import itertools

from ..exceptions import ParseError
from ..lexer import Token
from ..tree import Tree
from ..grammar import Terminal as T, NonTerminal as NT, Symbol

try:
    xrange
except NameError:
    xrange = range

def match(t, s):
    assert isinstance(t, T)
    return t.name == s.type


class Rule(object):
    """Context-free grammar rule."""

    def __init__(self, lhs, rhs, weight, alias):
        super(Rule, self).__init__()
        assert isinstance(lhs, NT), lhs
        assert all(isinstance(x, NT) or isinstance(x, T) for x in rhs), rhs
        self.lhs = lhs
        self.rhs = rhs
        self.weight = weight
        self.alias = alias

    def __str__(self):
        return '%s -> %s' % (str(self.lhs), ' '.join(str(x) for x in self.rhs))

    def __repr__(self):
        return str(self)

    def __hash__(self):
        return hash((self.lhs, tuple(self.rhs)))

    def __eq__(self, other):
        return self.lhs == other.lhs and self.rhs == other.rhs

    def __ne__(self, other):
        return not (self == other)


class Grammar(object):
    """Context-free grammar."""

    def __init__(self, rules):
        self.rules = frozenset(rules)

    def __eq__(self, other):
        return self.rules == other.rules

    def __str__(self):
        return '\n' + '\n'.join(sorted(repr(x) for x in self.rules)) + '\n'

    def __repr__(self):
        return str(self)


# Parse tree data structures
class RuleNode(object):
    """A node in the parse tree, which also contains the full rhs rule."""

    def __init__(self, rule, children, weight=0):
        self.rule = rule
        self.children = children
        self.weight = weight

    def __repr__(self):
        return 'RuleNode(%s, [%s])' % (repr(self.rule.lhs), ', '.join(str(x) for x in self.children))



class Parser(object):
    """Parser wrapper."""

    def __init__(self, rules):
        super(Parser, self).__init__()
        self.orig_rules = {rule: rule for rule in rules}
        rules = [self._to_rule(rule) for rule in rules]
        self.grammar = to_cnf(Grammar(rules))

    def _to_rule(self, lark_rule):
        """Converts a lark rule, (lhs, rhs, callback, options), to a Rule."""
        assert isinstance(lark_rule.origin, NT)
        assert all(isinstance(x, Symbol) for x in lark_rule.expansion)
        return Rule(
            lark_rule.origin, lark_rule.expansion,
            weight=lark_rule.options.priority if lark_rule.options.priority else 0,
            alias=lark_rule)

    def parse(self, tokenized, start):  # pylint: disable=invalid-name
        """Parses input, which is a list of tokens."""
        assert start
        start = NT(start)

        table, trees = _parse(tokenized, self.grammar)
        # Check if the parse succeeded.
        if all(r.lhs != start for r in table[(0, len(tokenized) - 1)]):
            raise ParseError('Parsing failed.')
        parse = trees[(0, len(tokenized) - 1)][start]
        return self._to_tree(revert_cnf(parse))

    def _to_tree(self, rule_node):
        """Converts a RuleNode parse tree to a lark Tree."""
        orig_rule = self.orig_rules[rule_node.rule.alias]
        children = []
        for child in rule_node.children:
            if isinstance(child, RuleNode):
                children.append(self._to_tree(child))
            else:
                assert isinstance(child.name, Token)
                children.append(child.name)
        t = Tree(orig_rule.origin, children)
        t.rule=orig_rule
        return t


def print_parse(node, indent=0):
    if isinstance(node, RuleNode):
        print(' ' * (indent * 2) + str(node.rule.lhs))
        for child in node.children:
            print_parse(child, indent + 1)
    else:
        print(' ' * (indent * 2) + str(node.s))


def _parse(s, g):
    """Parses sentence 's' using CNF grammar 'g'."""
    # The CYK table. Indexed with a 2-tuple: (start pos, end pos)
    table = defaultdict(set)
    # Top-level structure is similar to the CYK table. Each cell is a dict from
    # rule name to the best (lightest) tree for that rule.
    trees = defaultdict(dict)
    # Populate base case with existing terminal production rules
    for i, w in enumerate(s):
        for terminal, rules in g.terminal_rules.items():
            if match(terminal, w):
                for rule in rules:
                    table[(i, i)].add(rule)
                    if (rule.lhs not in trees[(i, i)] or
                        rule.weight < trees[(i, i)][rule.lhs].weight):
                        trees[(i, i)][rule.lhs] = RuleNode(rule, [T(w)], weight=rule.weight)

    # Iterate over lengths of sub-sentences
    for l in xrange(2, len(s) + 1):
        # Iterate over sub-sentences with the given length
        for i in xrange(len(s) - l + 1):
            # Choose partition of the sub-sentence in [1, l)
            for p in xrange(i + 1, i + l):
                span1 = (i, p - 1)
                span2 = (p, i + l - 1)
                for r1, r2 in itertools.product(table[span1], table[span2]):
                    for rule in g.nonterminal_rules.get((r1.lhs, r2.lhs), []):
                        table[(i, i + l - 1)].add(rule)
                        r1_tree = trees[span1][r1.lhs]
                        r2_tree = trees[span2][r2.lhs]
                        rule_total_weight = rule.weight + r1_tree.weight + r2_tree.weight
                        if (rule.lhs not in trees[(i, i + l - 1)]
                            or rule_total_weight < trees[(i, i + l - 1)][rule.lhs].weight):
                            trees[(i, i + l - 1)][rule.lhs] = RuleNode(rule, [r1_tree, r2_tree], weight=rule_total_weight)
    return table, trees


# This section implements context-free grammar converter to Chomsky normal form.
# It also implements a conversion of parse trees from its CNF to the original
# grammar.
# Overview:
# Applies the following operations in this order:
# * TERM: Eliminates non-solitary terminals from all rules
# * BIN: Eliminates rules with more than 2 symbols on their right-hand-side.
# * UNIT: Eliminates non-terminal unit rules
#
# The following grammar characteristics aren't featured:
# * Start symbol appears on RHS
# * Empty rules (epsilon rules)


class CnfWrapper(object):
    """CNF wrapper for grammar.

  Validates that the input grammar is CNF and provides helper data structures.
  """

    def __init__(self, grammar):
        super(CnfWrapper, self).__init__()
        self.grammar = grammar
        self.rules = grammar.rules
        self.terminal_rules = defaultdict(list)
        self.nonterminal_rules = defaultdict(list)
        for r in self.rules:
            # Validate that the grammar is CNF and populate auxiliary data structures.
            assert isinstance(r.lhs, NT), r
            if len(r.rhs) not in [1, 2]:
                raise ParseError("CYK doesn't support empty rules")
            if len(r.rhs) == 1 and isinstance(r.rhs[0], T):
                self.terminal_rules[r.rhs[0]].append(r)
            elif len(r.rhs) == 2 and all(isinstance(x, NT) for x in r.rhs):
                self.nonterminal_rules[tuple(r.rhs)].append(r)
            else:
                assert False, r

    def __eq__(self, other):
        return self.grammar == other.grammar

    def __repr__(self):
        return repr(self.grammar)


class UnitSkipRule(Rule):
    """A rule that records NTs that were skipped during transformation."""

    def __init__(self, lhs, rhs, skipped_rules, weight, alias):
        super(UnitSkipRule, self).__init__(lhs, rhs, weight, alias)
        self.skipped_rules = skipped_rules

    def __eq__(self, other):
        return isinstance(other, type(self)) and self.skipped_rules == other.skipped_rules

    __hash__ = Rule.__hash__


def build_unit_skiprule(unit_rule, target_rule):
    skipped_rules = []
    if isinstance(unit_rule, UnitSkipRule):
        skipped_rules += unit_rule.skipped_rules
    skipped_rules.append(target_rule)
    if isinstance(target_rule, UnitSkipRule):
        skipped_rules += target_rule.skipped_rules
    return UnitSkipRule(unit_rule.lhs, target_rule.rhs, skipped_rules,
                      weight=unit_rule.weight + target_rule.weight, alias=unit_rule.alias)


def get_any_nt_unit_rule(g):
    """Returns a non-terminal unit rule from 'g', or None if there is none."""
    for rule in g.rules:
        if len(rule.rhs) == 1 and isinstance(rule.rhs[0], NT):
            return rule
    return None


def _remove_unit_rule(g, rule):
    """Removes 'rule' from 'g' without changing the langugage produced by 'g'."""
    new_rules = [x for x in g.rules if x != rule]
    refs = [x for x in g.rules if x.lhs == rule.rhs[0]]
    new_rules += [build_unit_skiprule(rule, ref) for ref in refs]
    return Grammar(new_rules)


def _split(rule):
    """Splits a rule whose len(rhs) > 2 into shorter rules."""
    rule_str = str(rule.lhs) + '__' + '_'.join(str(x) for x in rule.rhs)
    rule_name = '__SP_%s' % (rule_str) + '_%d'
    yield Rule(rule.lhs, [rule.rhs[0], NT(rule_name % 1)], weight=rule.weight, alias=rule.alias)
    for i in xrange(1, len(rule.rhs) - 2):
        yield Rule(NT(rule_name % i), [rule.rhs[i], NT(rule_name % (i + 1))], weight=0, alias='Split')
    yield Rule(NT(rule_name % (len(rule.rhs) - 2)), rule.rhs[-2:], weight=0, alias='Split')


def _term(g):
    """Applies the TERM rule on 'g' (see top comment)."""
    all_t = {x for rule in g.rules for x in rule.rhs if isinstance(x, T)}
    t_rules = {t: Rule(NT('__T_%s' % str(t)), [t], weight=0, alias='Term') for t in all_t}
    new_rules = []
    for rule in g.rules:
        if len(rule.rhs) > 1 and any(isinstance(x, T) for x in rule.rhs):
            new_rhs = [t_rules[x].lhs if isinstance(x, T) else x for x in rule.rhs]
            new_rules.append(Rule(rule.lhs, new_rhs, weight=rule.weight, alias=rule.alias))
            new_rules.extend(v for k, v in t_rules.items() if k in rule.rhs)
        else:
            new_rules.append(rule)
    return Grammar(new_rules)


def _bin(g):
    """Applies the BIN rule to 'g' (see top comment)."""
    new_rules = []
    for rule in g.rules:
        if len(rule.rhs) > 2:
            new_rules += _split(rule)
        else:
            new_rules.append(rule)
    return Grammar(new_rules)


def _unit(g):
    """Applies the UNIT rule to 'g' (see top comment)."""
    nt_unit_rule = get_any_nt_unit_rule(g)
    while nt_unit_rule:
        g = _remove_unit_rule(g, nt_unit_rule)
        nt_unit_rule = get_any_nt_unit_rule(g)
    return g


def to_cnf(g):
    """Creates a CNF grammar from a general context-free grammar 'g'."""
    g = _unit(_bin(_term(g)))
    return CnfWrapper(g)


def unroll_unit_skiprule(lhs, orig_rhs, skipped_rules, children, weight, alias):
    if not skipped_rules:
        return RuleNode(Rule(lhs, orig_rhs, weight=weight, alias=alias), children, weight=weight)
    else:
        weight = weight - skipped_rules[0].weight
        return RuleNode(
            Rule(lhs, [skipped_rules[0].lhs], weight=weight, alias=alias), [
                unroll_unit_skiprule(skipped_rules[0].lhs, orig_rhs,
                                skipped_rules[1:], children,
                                skipped_rules[0].weight, skipped_rules[0].alias)
            ], weight=weight)


def revert_cnf(node):
    """Reverts a parse tree (RuleNode) to its original non-CNF form (Node)."""
    if isinstance(node, T):
        return node
    # Reverts TERM rule.
    if node.rule.lhs.name.startswith('__T_'):
        return node.children[0]
    else:
        children = []
        for child in map(revert_cnf, node.children):
            # Reverts BIN rule.
            if isinstance(child, RuleNode) and child.rule.lhs.name.startswith('__SP_'):
                children += child.children
            else:
                children.append(child)
        # Reverts UNIT rule.
        if isinstance(node.rule, UnitSkipRule):
            return unroll_unit_skiprule(node.rule.lhs, node.rule.rhs,
                                    node.rule.skipped_rules, children,
                                    node.rule.weight, node.rule.alias)
        else:
            return RuleNode(node.rule, children)