You cannot select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
	
	
		
			200 lines
		
	
	
		
			6.6 KiB
		
	
	
	
		
			Python
		
	
			
		
		
	
	
			200 lines
		
	
	
		
			6.6 KiB
		
	
	
	
		
			Python
		
	
"""Tree matcher based on Lark grammar"""
 | 
						|
 | 
						|
import re
 | 
						|
from typing import List, Dict
 | 
						|
from collections import defaultdict
 | 
						|
 | 
						|
from . import Tree, Token, Lark
 | 
						|
from .common import ParserConf
 | 
						|
from .exceptions import ConfigurationError
 | 
						|
from .parsers import earley
 | 
						|
from .grammar import Rule, Terminal, NonTerminal
 | 
						|
 | 
						|
 | 
						|
def is_discarded_terminal(t):
 | 
						|
    return t.is_term and t.filter_out
 | 
						|
 | 
						|
 | 
						|
class _MakeTreeMatch:
 | 
						|
    def __init__(self, name, expansion):
 | 
						|
        self.name = name
 | 
						|
        self.expansion = expansion
 | 
						|
 | 
						|
    def __call__(self, args):
 | 
						|
        t = Tree(self.name, args)
 | 
						|
        t.meta.match_tree = True
 | 
						|
        t.meta.orig_expansion = self.expansion
 | 
						|
        return t
 | 
						|
 | 
						|
 | 
						|
def _best_from_group(seq, group_key, cmp_key):
 | 
						|
    d = {}
 | 
						|
    for item in seq:
 | 
						|
        key = group_key(item)
 | 
						|
        if key in d:
 | 
						|
            v1 = cmp_key(item)
 | 
						|
            v2 = cmp_key(d[key])
 | 
						|
            if v2 > v1:
 | 
						|
                d[key] = item
 | 
						|
        else:
 | 
						|
            d[key] = item
 | 
						|
    return list(d.values())
 | 
						|
 | 
						|
 | 
						|
def _best_rules_from_group(rules: List[Rule]) -> List[Rule]:
 | 
						|
    rules = _best_from_group(rules, lambda r: r, lambda r: -len(r.expansion))
 | 
						|
    rules.sort(key=lambda r: len(r.expansion))
 | 
						|
    return rules
 | 
						|
 | 
						|
 | 
						|
def _match(term, token):
 | 
						|
    if isinstance(token, Tree):
 | 
						|
        name, _args = parse_rulename(term.name)
 | 
						|
        return token.data == name
 | 
						|
    elif isinstance(token, Token):
 | 
						|
        return term == Terminal(token.type)
 | 
						|
    assert False, (term, token)
 | 
						|
 | 
						|
 | 
						|
def make_recons_rule(origin, expansion, old_expansion):
 | 
						|
    return Rule(origin, expansion, alias=_MakeTreeMatch(origin.name, old_expansion))
 | 
						|
 | 
						|
 | 
						|
def make_recons_rule_to_term(origin, term):
 | 
						|
    return make_recons_rule(origin, [Terminal(term.name)], [term])
 | 
						|
 | 
						|
 | 
						|
def parse_rulename(s):
 | 
						|
    "Parse rule names that may contain a template syntax (like rule{a, b, ...})"
 | 
						|
    name, args_str = re.match(r'(\w+)(?:{(.+)})?', s).groups()
 | 
						|
    args = args_str and [a.strip() for a in args_str.split(',')]
 | 
						|
    return name, args
 | 
						|
 | 
						|
 | 
						|
 | 
						|
class ChildrenLexer:
 | 
						|
    def __init__(self, children):
 | 
						|
        self.children = children
 | 
						|
 | 
						|
    def lex(self, parser_state):
 | 
						|
        return self.children
 | 
						|
 | 
						|
class TreeMatcher:
 | 
						|
    """Match the elements of a tree node, based on an ontology
 | 
						|
    provided by a Lark grammar.
 | 
						|
 | 
						|
    Supports templates and inlined rules (`rule{a, b,..}` and `_rule`)
 | 
						|
 | 
						|
    Initialize with an instance of Lark.
 | 
						|
    """
 | 
						|
    rules_for_root: Dict[str, List[Rule]]
 | 
						|
    rules: List[Rule]
 | 
						|
    parser: Lark
 | 
						|
 | 
						|
    def __init__(self, parser: Lark):
 | 
						|
        # XXX TODO calling compile twice returns different results!
 | 
						|
        assert not parser.options.maybe_placeholders
 | 
						|
 | 
						|
        if parser.options.postlex and parser.options.postlex.always_accept:
 | 
						|
            # If postlexer's always_accept is used, we need to recompile the grammar with empty terminals-to-keep
 | 
						|
            if not hasattr(parser, 'grammar'):
 | 
						|
                raise ConfigurationError('Source grammar not available from cached parser, use cache_grammar=True'
 | 
						|
                                         if parser.options.cache else "Source grammar not available!")
 | 
						|
            self.tokens, rules, _extra = parser.grammar.compile(parser.options.start, set())
 | 
						|
        else:
 | 
						|
            self.tokens = list(parser.terminals)
 | 
						|
            rules = list(parser.rules)
 | 
						|
 | 
						|
        self.rules_for_root = defaultdict(list)
 | 
						|
 | 
						|
        self.rules = list(self._build_recons_rules(rules))
 | 
						|
        self.rules.reverse()
 | 
						|
 | 
						|
        # Choose the best rule from each group of {rule => [rule.alias]}, since we only really need one derivation.
 | 
						|
        self.rules = _best_rules_from_group(self.rules)
 | 
						|
 | 
						|
        self.parser = parser
 | 
						|
        self._parser_cache: Dict[str, earley.Parser] = {}
 | 
						|
 | 
						|
    def _build_recons_rules(self, rules: List[Rule]):
 | 
						|
        "Convert tree-parsing/construction rules to tree-matching rules"
 | 
						|
        expand1s = {r.origin for r in rules if r.options.expand1}
 | 
						|
 | 
						|
        aliases = defaultdict(list)
 | 
						|
        for r in rules:
 | 
						|
            if r.alias:
 | 
						|
                aliases[r.origin].append(r.alias)
 | 
						|
 | 
						|
        rule_names = {r.origin for r in rules}
 | 
						|
        nonterminals = {sym for sym in rule_names
 | 
						|
                        if sym.name.startswith('_') or sym in expand1s or sym in aliases}
 | 
						|
 | 
						|
        seen = set()
 | 
						|
        for r in rules:
 | 
						|
            recons_exp = [sym if sym in nonterminals else Terminal(sym.name)
 | 
						|
                          for sym in r.expansion if not is_discarded_terminal(sym)]
 | 
						|
 | 
						|
            # Skip self-recursive constructs
 | 
						|
            if recons_exp == [r.origin] and r.alias is None:
 | 
						|
                continue
 | 
						|
 | 
						|
            sym = NonTerminal(r.alias) if r.alias else r.origin
 | 
						|
            rule = make_recons_rule(sym, recons_exp, r.expansion)
 | 
						|
 | 
						|
            if sym in expand1s and len(recons_exp) != 1:
 | 
						|
                self.rules_for_root[sym.name].append(rule)
 | 
						|
 | 
						|
                if sym.name not in seen:
 | 
						|
                    yield make_recons_rule_to_term(sym, sym)
 | 
						|
                    seen.add(sym.name)
 | 
						|
            else:
 | 
						|
                if sym.name.startswith('_') or sym in expand1s:
 | 
						|
                    yield rule
 | 
						|
                else:
 | 
						|
                    self.rules_for_root[sym.name].append(rule)
 | 
						|
 | 
						|
        for origin, rule_aliases in aliases.items():
 | 
						|
            for alias in rule_aliases:
 | 
						|
                yield make_recons_rule_to_term(origin, NonTerminal(alias))
 | 
						|
            yield make_recons_rule_to_term(origin, origin)
 | 
						|
 | 
						|
    def match_tree(self, tree: Tree, rulename: str) -> Tree:
 | 
						|
        """Match the elements of `tree` to the symbols of rule `rulename`.
 | 
						|
 | 
						|
        Parameters:
 | 
						|
            tree (Tree): the tree node to match
 | 
						|
            rulename (str): The expected full rule name (including template args)
 | 
						|
 | 
						|
        Returns:
 | 
						|
            Tree: an unreduced tree that matches `rulename`
 | 
						|
 | 
						|
        Raises:
 | 
						|
            UnexpectedToken: If no match was found.
 | 
						|
 | 
						|
        Note:
 | 
						|
            It's the callers' responsibility to match the tree recursively.
 | 
						|
        """
 | 
						|
        if rulename:
 | 
						|
            # validate
 | 
						|
            name, _args = parse_rulename(rulename)
 | 
						|
            assert tree.data == name
 | 
						|
        else:
 | 
						|
            rulename = tree.data
 | 
						|
 | 
						|
        # TODO: ambiguity?
 | 
						|
        try:
 | 
						|
            parser = self._parser_cache[rulename]
 | 
						|
        except KeyError:
 | 
						|
            rules = self.rules + _best_rules_from_group(self.rules_for_root[rulename])
 | 
						|
 | 
						|
            # TODO pass callbacks through dict, instead of alias?
 | 
						|
            callbacks = {rule: rule.alias for rule in rules}
 | 
						|
            conf = ParserConf(rules, callbacks, [rulename]) # type: ignore[arg-type]
 | 
						|
            parser = earley.Parser(self.parser.lexer_conf, conf, _match, resolve_ambiguity=True)
 | 
						|
            self._parser_cache[rulename] = parser
 | 
						|
 | 
						|
        # find a full derivation
 | 
						|
        unreduced_tree: Tree = parser.parse(ChildrenLexer(tree.children), rulename)
 | 
						|
        assert unreduced_tree.data == rulename
 | 
						|
        return unreduced_tree
 |