You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
341 lines
12 KiB
Python
341 lines
12 KiB
Python
"""This module implements a CYK parser."""
|
|
|
|
# Author: https://github.com/ehudt (2018)
|
|
#
|
|
# Adapted by Erez
|
|
|
|
|
|
from collections import defaultdict
|
|
import itertools
|
|
|
|
from ..exceptions import ParseError
|
|
from ..lexer import Token
|
|
from ..tree import Tree
|
|
from ..grammar import Terminal as T, NonTerminal as NT, Symbol
|
|
|
|
def match(t, s):
|
|
assert isinstance(t, T)
|
|
return t.name == s.type
|
|
|
|
|
|
class Rule:
|
|
"""Context-free grammar rule."""
|
|
|
|
def __init__(self, lhs, rhs, weight, alias):
|
|
super(Rule, self).__init__()
|
|
assert isinstance(lhs, NT), lhs
|
|
assert all(isinstance(x, NT) or isinstance(x, T) for x in rhs), rhs
|
|
self.lhs = lhs
|
|
self.rhs = rhs
|
|
self.weight = weight
|
|
self.alias = alias
|
|
|
|
def __str__(self):
|
|
return '%s -> %s' % (str(self.lhs), ' '.join(str(x) for x in self.rhs))
|
|
|
|
def __repr__(self):
|
|
return str(self)
|
|
|
|
def __hash__(self):
|
|
return hash((self.lhs, tuple(self.rhs)))
|
|
|
|
def __eq__(self, other):
|
|
return self.lhs == other.lhs and self.rhs == other.rhs
|
|
|
|
def __ne__(self, other):
|
|
return not (self == other)
|
|
|
|
|
|
class Grammar:
|
|
"""Context-free grammar."""
|
|
|
|
def __init__(self, rules):
|
|
self.rules = frozenset(rules)
|
|
|
|
def __eq__(self, other):
|
|
return self.rules == other.rules
|
|
|
|
def __str__(self):
|
|
return '\n' + '\n'.join(sorted(repr(x) for x in self.rules)) + '\n'
|
|
|
|
def __repr__(self):
|
|
return str(self)
|
|
|
|
|
|
# Parse tree data structures
|
|
class RuleNode:
|
|
"""A node in the parse tree, which also contains the full rhs rule."""
|
|
|
|
def __init__(self, rule, children, weight=0):
|
|
self.rule = rule
|
|
self.children = children
|
|
self.weight = weight
|
|
|
|
def __repr__(self):
|
|
return 'RuleNode(%s, [%s])' % (repr(self.rule.lhs), ', '.join(str(x) for x in self.children))
|
|
|
|
|
|
|
|
class Parser:
|
|
"""Parser wrapper."""
|
|
|
|
def __init__(self, rules):
|
|
super(Parser, self).__init__()
|
|
self.orig_rules = {rule: rule for rule in rules}
|
|
rules = [self._to_rule(rule) for rule in rules]
|
|
self.grammar = to_cnf(Grammar(rules))
|
|
|
|
def _to_rule(self, lark_rule):
|
|
"""Converts a lark rule, (lhs, rhs, callback, options), to a Rule."""
|
|
assert isinstance(lark_rule.origin, NT)
|
|
assert all(isinstance(x, Symbol) for x in lark_rule.expansion)
|
|
return Rule(
|
|
lark_rule.origin, lark_rule.expansion,
|
|
weight=lark_rule.options.priority if lark_rule.options.priority else 0,
|
|
alias=lark_rule)
|
|
|
|
def parse(self, tokenized, start): # pylint: disable=invalid-name
|
|
"""Parses input, which is a list of tokens."""
|
|
assert start
|
|
start = NT(start)
|
|
|
|
table, trees = _parse(tokenized, self.grammar)
|
|
# Check if the parse succeeded.
|
|
if all(r.lhs != start for r in table[(0, len(tokenized) - 1)]):
|
|
raise ParseError('Parsing failed.')
|
|
parse = trees[(0, len(tokenized) - 1)][start]
|
|
return self._to_tree(revert_cnf(parse))
|
|
|
|
def _to_tree(self, rule_node):
|
|
"""Converts a RuleNode parse tree to a lark Tree."""
|
|
orig_rule = self.orig_rules[rule_node.rule.alias]
|
|
children = []
|
|
for child in rule_node.children:
|
|
if isinstance(child, RuleNode):
|
|
children.append(self._to_tree(child))
|
|
else:
|
|
assert isinstance(child.name, Token)
|
|
children.append(child.name)
|
|
t = Tree(orig_rule.origin, children)
|
|
t.rule=orig_rule
|
|
return t
|
|
|
|
|
|
def print_parse(node, indent=0):
|
|
if isinstance(node, RuleNode):
|
|
print(' ' * (indent * 2) + str(node.rule.lhs))
|
|
for child in node.children:
|
|
print_parse(child, indent + 1)
|
|
else:
|
|
print(' ' * (indent * 2) + str(node.s))
|
|
|
|
|
|
def _parse(s, g):
|
|
"""Parses sentence 's' using CNF grammar 'g'."""
|
|
# The CYK table. Indexed with a 2-tuple: (start pos, end pos)
|
|
table = defaultdict(set)
|
|
# Top-level structure is similar to the CYK table. Each cell is a dict from
|
|
# rule name to the best (lightest) tree for that rule.
|
|
trees = defaultdict(dict)
|
|
# Populate base case with existing terminal production rules
|
|
for i, w in enumerate(s):
|
|
for terminal, rules in g.terminal_rules.items():
|
|
if match(terminal, w):
|
|
for rule in rules:
|
|
table[(i, i)].add(rule)
|
|
if (rule.lhs not in trees[(i, i)] or
|
|
rule.weight < trees[(i, i)][rule.lhs].weight):
|
|
trees[(i, i)][rule.lhs] = RuleNode(rule, [T(w)], weight=rule.weight)
|
|
|
|
# Iterate over lengths of sub-sentences
|
|
for l in range(2, len(s) + 1):
|
|
# Iterate over sub-sentences with the given length
|
|
for i in range(len(s) - l + 1):
|
|
# Choose partition of the sub-sentence in [1, l)
|
|
for p in range(i + 1, i + l):
|
|
span1 = (i, p - 1)
|
|
span2 = (p, i + l - 1)
|
|
for r1, r2 in itertools.product(table[span1], table[span2]):
|
|
for rule in g.nonterminal_rules.get((r1.lhs, r2.lhs), []):
|
|
table[(i, i + l - 1)].add(rule)
|
|
r1_tree = trees[span1][r1.lhs]
|
|
r2_tree = trees[span2][r2.lhs]
|
|
rule_total_weight = rule.weight + r1_tree.weight + r2_tree.weight
|
|
if (rule.lhs not in trees[(i, i + l - 1)]
|
|
or rule_total_weight < trees[(i, i + l - 1)][rule.lhs].weight):
|
|
trees[(i, i + l - 1)][rule.lhs] = RuleNode(rule, [r1_tree, r2_tree], weight=rule_total_weight)
|
|
return table, trees
|
|
|
|
|
|
# This section implements context-free grammar converter to Chomsky normal form.
|
|
# It also implements a conversion of parse trees from its CNF to the original
|
|
# grammar.
|
|
# Overview:
|
|
# Applies the following operations in this order:
|
|
# * TERM: Eliminates non-solitary terminals from all rules
|
|
# * BIN: Eliminates rules with more than 2 symbols on their right-hand-side.
|
|
# * UNIT: Eliminates non-terminal unit rules
|
|
#
|
|
# The following grammar characteristics aren't featured:
|
|
# * Start symbol appears on RHS
|
|
# * Empty rules (epsilon rules)
|
|
|
|
|
|
class CnfWrapper:
|
|
"""CNF wrapper for grammar.
|
|
|
|
Validates that the input grammar is CNF and provides helper data structures.
|
|
"""
|
|
|
|
def __init__(self, grammar):
|
|
super(CnfWrapper, self).__init__()
|
|
self.grammar = grammar
|
|
self.rules = grammar.rules
|
|
self.terminal_rules = defaultdict(list)
|
|
self.nonterminal_rules = defaultdict(list)
|
|
for r in self.rules:
|
|
# Validate that the grammar is CNF and populate auxiliary data structures.
|
|
assert isinstance(r.lhs, NT), r
|
|
if len(r.rhs) not in [1, 2]:
|
|
raise ParseError("CYK doesn't support empty rules")
|
|
if len(r.rhs) == 1 and isinstance(r.rhs[0], T):
|
|
self.terminal_rules[r.rhs[0]].append(r)
|
|
elif len(r.rhs) == 2 and all(isinstance(x, NT) for x in r.rhs):
|
|
self.nonterminal_rules[tuple(r.rhs)].append(r)
|
|
else:
|
|
assert False, r
|
|
|
|
def __eq__(self, other):
|
|
return self.grammar == other.grammar
|
|
|
|
def __repr__(self):
|
|
return repr(self.grammar)
|
|
|
|
|
|
class UnitSkipRule(Rule):
|
|
"""A rule that records NTs that were skipped during transformation."""
|
|
|
|
def __init__(self, lhs, rhs, skipped_rules, weight, alias):
|
|
super(UnitSkipRule, self).__init__(lhs, rhs, weight, alias)
|
|
self.skipped_rules = skipped_rules
|
|
|
|
def __eq__(self, other):
|
|
return isinstance(other, type(self)) and self.skipped_rules == other.skipped_rules
|
|
|
|
__hash__ = Rule.__hash__
|
|
|
|
|
|
def build_unit_skiprule(unit_rule, target_rule):
|
|
skipped_rules = []
|
|
if isinstance(unit_rule, UnitSkipRule):
|
|
skipped_rules += unit_rule.skipped_rules
|
|
skipped_rules.append(target_rule)
|
|
if isinstance(target_rule, UnitSkipRule):
|
|
skipped_rules += target_rule.skipped_rules
|
|
return UnitSkipRule(unit_rule.lhs, target_rule.rhs, skipped_rules,
|
|
weight=unit_rule.weight + target_rule.weight, alias=unit_rule.alias)
|
|
|
|
|
|
def get_any_nt_unit_rule(g):
|
|
"""Returns a non-terminal unit rule from 'g', or None if there is none."""
|
|
for rule in g.rules:
|
|
if len(rule.rhs) == 1 and isinstance(rule.rhs[0], NT):
|
|
return rule
|
|
return None
|
|
|
|
|
|
def _remove_unit_rule(g, rule):
|
|
"""Removes 'rule' from 'g' without changing the language produced by 'g'."""
|
|
new_rules = [x for x in g.rules if x != rule]
|
|
refs = [x for x in g.rules if x.lhs == rule.rhs[0]]
|
|
new_rules += [build_unit_skiprule(rule, ref) for ref in refs]
|
|
return Grammar(new_rules)
|
|
|
|
|
|
def _split(rule):
|
|
"""Splits a rule whose len(rhs) > 2 into shorter rules."""
|
|
rule_str = str(rule.lhs) + '__' + '_'.join(str(x) for x in rule.rhs)
|
|
rule_name = '__SP_%s' % (rule_str) + '_%d'
|
|
yield Rule(rule.lhs, [rule.rhs[0], NT(rule_name % 1)], weight=rule.weight, alias=rule.alias)
|
|
for i in range(1, len(rule.rhs) - 2):
|
|
yield Rule(NT(rule_name % i), [rule.rhs[i], NT(rule_name % (i + 1))], weight=0, alias='Split')
|
|
yield Rule(NT(rule_name % (len(rule.rhs) - 2)), rule.rhs[-2:], weight=0, alias='Split')
|
|
|
|
|
|
def _term(g):
|
|
"""Applies the TERM rule on 'g' (see top comment)."""
|
|
all_t = {x for rule in g.rules for x in rule.rhs if isinstance(x, T)}
|
|
t_rules = {t: Rule(NT('__T_%s' % str(t)), [t], weight=0, alias='Term') for t in all_t}
|
|
new_rules = []
|
|
for rule in g.rules:
|
|
if len(rule.rhs) > 1 and any(isinstance(x, T) for x in rule.rhs):
|
|
new_rhs = [t_rules[x].lhs if isinstance(x, T) else x for x in rule.rhs]
|
|
new_rules.append(Rule(rule.lhs, new_rhs, weight=rule.weight, alias=rule.alias))
|
|
new_rules.extend(v for k, v in t_rules.items() if k in rule.rhs)
|
|
else:
|
|
new_rules.append(rule)
|
|
return Grammar(new_rules)
|
|
|
|
|
|
def _bin(g):
|
|
"""Applies the BIN rule to 'g' (see top comment)."""
|
|
new_rules = []
|
|
for rule in g.rules:
|
|
if len(rule.rhs) > 2:
|
|
new_rules += _split(rule)
|
|
else:
|
|
new_rules.append(rule)
|
|
return Grammar(new_rules)
|
|
|
|
|
|
def _unit(g):
|
|
"""Applies the UNIT rule to 'g' (see top comment)."""
|
|
nt_unit_rule = get_any_nt_unit_rule(g)
|
|
while nt_unit_rule:
|
|
g = _remove_unit_rule(g, nt_unit_rule)
|
|
nt_unit_rule = get_any_nt_unit_rule(g)
|
|
return g
|
|
|
|
|
|
def to_cnf(g):
|
|
"""Creates a CNF grammar from a general context-free grammar 'g'."""
|
|
g = _unit(_bin(_term(g)))
|
|
return CnfWrapper(g)
|
|
|
|
|
|
def unroll_unit_skiprule(lhs, orig_rhs, skipped_rules, children, weight, alias):
|
|
if not skipped_rules:
|
|
return RuleNode(Rule(lhs, orig_rhs, weight=weight, alias=alias), children, weight=weight)
|
|
else:
|
|
weight = weight - skipped_rules[0].weight
|
|
return RuleNode(
|
|
Rule(lhs, [skipped_rules[0].lhs], weight=weight, alias=alias), [
|
|
unroll_unit_skiprule(skipped_rules[0].lhs, orig_rhs,
|
|
skipped_rules[1:], children,
|
|
skipped_rules[0].weight, skipped_rules[0].alias)
|
|
], weight=weight)
|
|
|
|
|
|
def revert_cnf(node):
|
|
"""Reverts a parse tree (RuleNode) to its original non-CNF form (Node)."""
|
|
if isinstance(node, T):
|
|
return node
|
|
# Reverts TERM rule.
|
|
if node.rule.lhs.name.startswith('__T_'):
|
|
return node.children[0]
|
|
else:
|
|
children = []
|
|
for child in map(revert_cnf, node.children):
|
|
# Reverts BIN rule.
|
|
if isinstance(child, RuleNode) and child.rule.lhs.name.startswith('__SP_'):
|
|
children += child.children
|
|
else:
|
|
children.append(child)
|
|
# Reverts UNIT rule.
|
|
if isinstance(node.rule, UnitSkipRule):
|
|
return unroll_unit_skiprule(node.rule.lhs, node.rule.rhs,
|
|
node.rule.skipped_rules, children,
|
|
node.rule.weight, node.rule.alias)
|
|
else:
|
|
return RuleNode(node.rule, children)
|