You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
200 lines
6.6 KiB
Python
200 lines
6.6 KiB
Python
"""Tree matcher based on Lark grammar"""
|
|
|
|
import re
|
|
from typing import List, Dict
|
|
from collections import defaultdict
|
|
|
|
from . import Tree, Token, Lark
|
|
from .common import ParserConf
|
|
from .exceptions import ConfigurationError
|
|
from .parsers import earley
|
|
from .grammar import Rule, Terminal, NonTerminal
|
|
|
|
|
|
def is_discarded_terminal(t):
|
|
return t.is_term and t.filter_out
|
|
|
|
|
|
class _MakeTreeMatch:
|
|
def __init__(self, name, expansion):
|
|
self.name = name
|
|
self.expansion = expansion
|
|
|
|
def __call__(self, args):
|
|
t = Tree(self.name, args)
|
|
t.meta.match_tree = True
|
|
t.meta.orig_expansion = self.expansion
|
|
return t
|
|
|
|
|
|
def _best_from_group(seq, group_key, cmp_key):
|
|
d = {}
|
|
for item in seq:
|
|
key = group_key(item)
|
|
if key in d:
|
|
v1 = cmp_key(item)
|
|
v2 = cmp_key(d[key])
|
|
if v2 > v1:
|
|
d[key] = item
|
|
else:
|
|
d[key] = item
|
|
return list(d.values())
|
|
|
|
|
|
def _best_rules_from_group(rules: List[Rule]) -> List[Rule]:
|
|
rules = _best_from_group(rules, lambda r: r, lambda r: -len(r.expansion))
|
|
rules.sort(key=lambda r: len(r.expansion))
|
|
return rules
|
|
|
|
|
|
def _match(term, token):
|
|
if isinstance(token, Tree):
|
|
name, _args = parse_rulename(term.name)
|
|
return token.data == name
|
|
elif isinstance(token, Token):
|
|
return term == Terminal(token.type)
|
|
assert False, (term, token)
|
|
|
|
|
|
def make_recons_rule(origin, expansion, old_expansion):
|
|
return Rule(origin, expansion, alias=_MakeTreeMatch(origin.name, old_expansion))
|
|
|
|
|
|
def make_recons_rule_to_term(origin, term):
|
|
return make_recons_rule(origin, [Terminal(term.name)], [term])
|
|
|
|
|
|
def parse_rulename(s):
|
|
"Parse rule names that may contain a template syntax (like rule{a, b, ...})"
|
|
name, args_str = re.match(r'(\w+)(?:{(.+)})?', s).groups()
|
|
args = args_str and [a.strip() for a in args_str.split(',')]
|
|
return name, args
|
|
|
|
|
|
|
|
class ChildrenLexer:
|
|
def __init__(self, children):
|
|
self.children = children
|
|
|
|
def lex(self, parser_state):
|
|
return self.children
|
|
|
|
class TreeMatcher:
|
|
"""Match the elements of a tree node, based on an ontology
|
|
provided by a Lark grammar.
|
|
|
|
Supports templates and inlined rules (`rule{a, b,..}` and `_rule`)
|
|
|
|
Initialize with an instance of Lark.
|
|
"""
|
|
rules_for_root: Dict[str, List[Rule]]
|
|
rules: List[Rule]
|
|
parser: Lark
|
|
|
|
def __init__(self, parser: Lark):
|
|
# XXX TODO calling compile twice returns different results!
|
|
assert not parser.options.maybe_placeholders
|
|
|
|
if parser.options.postlex and parser.options.postlex.always_accept:
|
|
# If postlexer's always_accept is used, we need to recompile the grammar with empty terminals-to-keep
|
|
if not hasattr(parser, 'grammar'):
|
|
raise ConfigurationError('Source grammar not available from cached parser, use cache_grammar=True'
|
|
if parser.options.cache else "Source grammar not available!")
|
|
self.tokens, rules, _extra = parser.grammar.compile(parser.options.start, set())
|
|
else:
|
|
self.tokens = list(parser.terminals)
|
|
rules = list(parser.rules)
|
|
|
|
self.rules_for_root = defaultdict(list)
|
|
|
|
self.rules = list(self._build_recons_rules(rules))
|
|
self.rules.reverse()
|
|
|
|
# Choose the best rule from each group of {rule => [rule.alias]}, since we only really need one derivation.
|
|
self.rules = _best_rules_from_group(self.rules)
|
|
|
|
self.parser = parser
|
|
self._parser_cache: Dict[str, earley.Parser] = {}
|
|
|
|
def _build_recons_rules(self, rules: List[Rule]):
|
|
"Convert tree-parsing/construction rules to tree-matching rules"
|
|
expand1s = {r.origin for r in rules if r.options.expand1}
|
|
|
|
aliases = defaultdict(list)
|
|
for r in rules:
|
|
if r.alias:
|
|
aliases[r.origin].append(r.alias)
|
|
|
|
rule_names = {r.origin for r in rules}
|
|
nonterminals = {sym for sym in rule_names
|
|
if sym.name.startswith('_') or sym in expand1s or sym in aliases}
|
|
|
|
seen = set()
|
|
for r in rules:
|
|
recons_exp = [sym if sym in nonterminals else Terminal(sym.name)
|
|
for sym in r.expansion if not is_discarded_terminal(sym)]
|
|
|
|
# Skip self-recursive constructs
|
|
if recons_exp == [r.origin] and r.alias is None:
|
|
continue
|
|
|
|
sym = NonTerminal(r.alias) if r.alias else r.origin
|
|
rule = make_recons_rule(sym, recons_exp, r.expansion)
|
|
|
|
if sym in expand1s and len(recons_exp) != 1:
|
|
self.rules_for_root[sym.name].append(rule)
|
|
|
|
if sym.name not in seen:
|
|
yield make_recons_rule_to_term(sym, sym)
|
|
seen.add(sym.name)
|
|
else:
|
|
if sym.name.startswith('_') or sym in expand1s:
|
|
yield rule
|
|
else:
|
|
self.rules_for_root[sym.name].append(rule)
|
|
|
|
for origin, rule_aliases in aliases.items():
|
|
for alias in rule_aliases:
|
|
yield make_recons_rule_to_term(origin, NonTerminal(alias))
|
|
yield make_recons_rule_to_term(origin, origin)
|
|
|
|
def match_tree(self, tree: Tree, rulename: str) -> Tree:
|
|
"""Match the elements of `tree` to the symbols of rule `rulename`.
|
|
|
|
Parameters:
|
|
tree (Tree): the tree node to match
|
|
rulename (str): The expected full rule name (including template args)
|
|
|
|
Returns:
|
|
Tree: an unreduced tree that matches `rulename`
|
|
|
|
Raises:
|
|
UnexpectedToken: If no match was found.
|
|
|
|
Note:
|
|
It's the callers' responsibility to match the tree recursively.
|
|
"""
|
|
if rulename:
|
|
# validate
|
|
name, _args = parse_rulename(rulename)
|
|
assert tree.data == name
|
|
else:
|
|
rulename = tree.data
|
|
|
|
# TODO: ambiguity?
|
|
try:
|
|
parser = self._parser_cache[rulename]
|
|
except KeyError:
|
|
rules = self.rules + _best_rules_from_group(self.rules_for_root[rulename])
|
|
|
|
# TODO pass callbacks through dict, instead of alias?
|
|
callbacks = {rule: rule.alias for rule in rules}
|
|
conf = ParserConf(rules, callbacks, [rulename]) # type: ignore[arg-type]
|
|
parser = earley.Parser(self.parser.lexer_conf, conf, _match, resolve_ambiguity=True)
|
|
self._parser_cache[rulename] = parser
|
|
|
|
# find a full derivation
|
|
unreduced_tree: Tree = parser.parse(ChildrenLexer(tree.children), rulename)
|
|
assert unreduced_tree.data == rulename
|
|
return unreduced_tree
|