You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
416 lines
13 KiB
Python
416 lines
13 KiB
Python
import unicodedata
|
|
import os
|
|
from itertools import product
|
|
from collections import deque
|
|
from typing import Callable, Iterator, List, Optional, Tuple, Type, TypeVar, Union, Dict, Any, Sequence, Iterable, AbstractSet
|
|
|
|
###{standalone
|
|
import sys, re
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from typing import Generic, AnyStr
|
|
|
|
logger: logging.Logger = logging.getLogger("lark")
|
|
logger.addHandler(logging.StreamHandler())
|
|
# Set to highest level, since we have some warnings amongst the code
|
|
# By default, we should not output any log messages
|
|
logger.setLevel(logging.CRITICAL)
|
|
|
|
|
|
NO_VALUE = object()
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
def classify(seq: Iterable, key: Optional[Callable] = None, value: Optional[Callable] = None) -> Dict:
|
|
d: Dict[Any, Any] = {}
|
|
for item in seq:
|
|
k = key(item) if (key is not None) else item
|
|
v = value(item) if (value is not None) else item
|
|
try:
|
|
d[k].append(v)
|
|
except KeyError:
|
|
d[k] = [v]
|
|
return d
|
|
|
|
|
|
def _deserialize(data: Any, namespace: Dict[str, Any], memo: Dict) -> Any:
|
|
if isinstance(data, dict):
|
|
if '__type__' in data: # Object
|
|
class_ = namespace[data['__type__']]
|
|
return class_.deserialize(data, memo)
|
|
elif '@' in data:
|
|
return memo[data['@']]
|
|
return {key:_deserialize(value, namespace, memo) for key, value in data.items()}
|
|
elif isinstance(data, list):
|
|
return [_deserialize(value, namespace, memo) for value in data]
|
|
return data
|
|
|
|
|
|
_T = TypeVar("_T", bound="Serialize")
|
|
|
|
class Serialize:
|
|
"""Safe-ish serialization interface that doesn't rely on Pickle
|
|
|
|
Attributes:
|
|
__serialize_fields__ (List[str]): Fields (aka attributes) to serialize.
|
|
__serialize_namespace__ (list): List of classes that deserialization is allowed to instantiate.
|
|
Should include all field types that aren't builtin types.
|
|
"""
|
|
|
|
def memo_serialize(self, types_to_memoize: List) -> Any:
|
|
memo = SerializeMemoizer(types_to_memoize)
|
|
return self.serialize(memo), memo.serialize()
|
|
|
|
def serialize(self, memo = None) -> Dict[str, Any]:
|
|
if memo and memo.in_types(self):
|
|
return {'@': memo.memoized.get(self)}
|
|
|
|
fields = getattr(self, '__serialize_fields__')
|
|
res = {f: _serialize(getattr(self, f), memo) for f in fields}
|
|
res['__type__'] = type(self).__name__
|
|
if hasattr(self, '_serialize'):
|
|
self._serialize(res, memo)
|
|
return res
|
|
|
|
@classmethod
|
|
def deserialize(cls: Type[_T], data: Dict[str, Any], memo: Dict[int, Any]) -> _T:
|
|
namespace = getattr(cls, '__serialize_namespace__', [])
|
|
namespace = {c.__name__:c for c in namespace}
|
|
|
|
fields = getattr(cls, '__serialize_fields__')
|
|
|
|
if '@' in data:
|
|
return memo[data['@']]
|
|
|
|
inst = cls.__new__(cls)
|
|
for f in fields:
|
|
try:
|
|
setattr(inst, f, _deserialize(data[f], namespace, memo))
|
|
except KeyError as e:
|
|
raise KeyError("Cannot find key for class", cls, e)
|
|
|
|
if hasattr(inst, '_deserialize'):
|
|
inst._deserialize()
|
|
|
|
return inst
|
|
|
|
|
|
class SerializeMemoizer(Serialize):
|
|
"A version of serialize that memoizes objects to reduce space"
|
|
|
|
__serialize_fields__ = 'memoized',
|
|
|
|
def __init__(self, types_to_memoize: List) -> None:
|
|
self.types_to_memoize = tuple(types_to_memoize)
|
|
self.memoized = Enumerator()
|
|
|
|
def in_types(self, value: Serialize) -> bool:
|
|
return isinstance(value, self.types_to_memoize)
|
|
|
|
def serialize(self) -> Dict[int, Any]: # type: ignore[override]
|
|
return _serialize(self.memoized.reversed(), None)
|
|
|
|
@classmethod
|
|
def deserialize(cls, data: Dict[int, Any], namespace: Dict[str, Any], memo: Dict[Any, Any]) -> Dict[int, Any]: # type: ignore[override]
|
|
return _deserialize(data, namespace, memo)
|
|
|
|
|
|
try:
|
|
import regex
|
|
_has_regex = True
|
|
except ImportError:
|
|
_has_regex = False
|
|
|
|
if sys.version_info >= (3, 11):
|
|
import re._parser as sre_parse
|
|
import re._constants as sre_constants
|
|
else:
|
|
import sre_parse
|
|
import sre_constants
|
|
|
|
categ_pattern = re.compile(r'\\p{[A-Za-z_]+}')
|
|
|
|
def get_regexp_width(expr: str) -> Union[Tuple[int, int], List[int]]:
|
|
if _has_regex:
|
|
# Since `sre_parse` cannot deal with Unicode categories of the form `\p{Mn}`, we replace these with
|
|
# a simple letter, which makes no difference as we are only trying to get the possible lengths of the regex
|
|
# match here below.
|
|
regexp_final = re.sub(categ_pattern, 'A', expr)
|
|
else:
|
|
if re.search(categ_pattern, expr):
|
|
raise ImportError('`regex` module must be installed in order to use Unicode categories.', expr)
|
|
regexp_final = expr
|
|
try:
|
|
# Fixed in next version (past 0.960) of typeshed
|
|
return [int(x) for x in sre_parse.parse(regexp_final).getwidth()]
|
|
except sre_constants.error:
|
|
if not _has_regex:
|
|
raise ValueError(expr)
|
|
else:
|
|
# sre_parse does not support the new features in regex. To not completely fail in that case,
|
|
# we manually test for the most important info (whether the empty string is matched)
|
|
c = regex.compile(regexp_final)
|
|
# Python 3.11.7 introducded sre_parse.MAXWIDTH that is used instead of MAXREPEAT
|
|
# See lark-parser/lark#1376 and python/cpython#109859
|
|
MAXWIDTH = getattr(sre_parse, "MAXWIDTH", sre_constants.MAXREPEAT)
|
|
if c.match('') is None:
|
|
# MAXREPEAT is a none pickable subclass of int, therefore needs to be converted to enable caching
|
|
return 1, int(MAXWIDTH)
|
|
else:
|
|
return 0, int(MAXWIDTH)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class TextSlice(Generic[AnyStr]):
|
|
"""A view of a string or bytes object, between the start and end indices.
|
|
|
|
Never creates a copy.
|
|
|
|
Lark accepts instances of TextSlice as input (instead of a string),
|
|
when the lexer is 'basic' or 'contextual'.
|
|
|
|
Args:
|
|
text (str or bytes): The text to slice.
|
|
start (int): The start index. Negative indices are supported.
|
|
end (int): The end index. Negative indices are supported.
|
|
|
|
Raises:
|
|
TypeError: If `text` is not a `str` or `bytes`.
|
|
AssertionError: If `start` or `end` are out of bounds.
|
|
|
|
Examples:
|
|
>>> TextSlice("Hello, World!", 7, -1)
|
|
TextSlice(text='Hello, World!', start=7, end=12)
|
|
|
|
>>> TextSlice("Hello, World!", 7, None).count("o")
|
|
1
|
|
|
|
"""
|
|
text: AnyStr
|
|
start: int
|
|
end: int
|
|
|
|
def __post_init__(self):
|
|
if not isinstance(self.text, (str, bytes)):
|
|
raise TypeError("text must be str or bytes")
|
|
|
|
if self.start < 0:
|
|
object.__setattr__(self, 'start', self.start + len(self.text))
|
|
assert self.start >=0
|
|
|
|
if self.end is None:
|
|
object.__setattr__(self, 'end', len(self.text))
|
|
elif self.end < 0:
|
|
object.__setattr__(self, 'end', self.end + len(self.text))
|
|
assert self.end <= len(self.text)
|
|
|
|
@classmethod
|
|
def cast_from(cls, text: 'TextOrSlice') -> 'TextSlice[AnyStr]':
|
|
if isinstance(text, TextSlice):
|
|
return text
|
|
|
|
return cls(text, 0, len(text))
|
|
|
|
def is_complete_text(self):
|
|
return self.start == 0 and self.end == len(self.text)
|
|
|
|
def __len__(self):
|
|
return self.end - self.start
|
|
|
|
def count(self, substr: AnyStr):
|
|
return self.text.count(substr, self.start, self.end)
|
|
|
|
def rindex(self, substr: AnyStr):
|
|
return self.text.rindex(substr, self.start, self.end)
|
|
|
|
|
|
TextOrSlice = Union[AnyStr, 'TextSlice[AnyStr]']
|
|
|
|
###}
|
|
|
|
|
|
_ID_START = 'Lu', 'Ll', 'Lt', 'Lm', 'Lo', 'Mn', 'Mc', 'Pc'
|
|
_ID_CONTINUE = _ID_START + ('Nd', 'Nl',)
|
|
|
|
def _test_unicode_category(s: str, categories: Sequence[str]) -> bool:
|
|
if len(s) != 1:
|
|
return all(_test_unicode_category(char, categories) for char in s)
|
|
return s == '_' or unicodedata.category(s) in categories
|
|
|
|
def is_id_continue(s: str) -> bool:
|
|
"""
|
|
Checks if all characters in `s` are alphanumeric characters (Unicode standard, so diacritics, indian vowels, non-latin
|
|
numbers, etc. all pass). Synonymous with a Python `ID_CONTINUE` identifier. See PEP 3131 for details.
|
|
"""
|
|
return _test_unicode_category(s, _ID_CONTINUE)
|
|
|
|
def is_id_start(s: str) -> bool:
|
|
"""
|
|
Checks if all characters in `s` are alphabetic characters (Unicode standard, so diacritics, indian vowels, non-latin
|
|
numbers, etc. all pass). Synonymous with a Python `ID_START` identifier. See PEP 3131 for details.
|
|
"""
|
|
return _test_unicode_category(s, _ID_START)
|
|
|
|
|
|
def dedup_list(l: Iterable[T]) -> List[T]:
|
|
"""Given a list (l) will removing duplicates from the list,
|
|
preserving the original order of the list. Assumes that
|
|
the list entries are hashable."""
|
|
return list(dict.fromkeys(l))
|
|
|
|
|
|
class Enumerator(Serialize):
|
|
def __init__(self) -> None:
|
|
self.enums: Dict[Any, int] = {}
|
|
|
|
def get(self, item) -> int:
|
|
if item not in self.enums:
|
|
self.enums[item] = len(self.enums)
|
|
return self.enums[item]
|
|
|
|
def __len__(self):
|
|
return len(self.enums)
|
|
|
|
def reversed(self) -> Dict[int, Any]:
|
|
r = {v: k for k, v in self.enums.items()}
|
|
assert len(r) == len(self.enums)
|
|
return r
|
|
|
|
|
|
|
|
def combine_alternatives(lists):
|
|
"""
|
|
Accepts a list of alternatives, and enumerates all their possible concatenations.
|
|
|
|
Examples:
|
|
>>> combine_alternatives([range(2), [4,5]])
|
|
[[0, 4], [0, 5], [1, 4], [1, 5]]
|
|
|
|
>>> combine_alternatives(["abc", "xy", '$'])
|
|
[['a', 'x', '$'], ['a', 'y', '$'], ['b', 'x', '$'], ['b', 'y', '$'], ['c', 'x', '$'], ['c', 'y', '$']]
|
|
|
|
>>> combine_alternatives([])
|
|
[[]]
|
|
"""
|
|
if not lists:
|
|
return [[]]
|
|
assert all(l for l in lists), lists
|
|
return list(product(*lists))
|
|
|
|
try:
|
|
import atomicwrites
|
|
_has_atomicwrites = True
|
|
except ImportError:
|
|
_has_atomicwrites = False
|
|
|
|
class FS:
|
|
exists = staticmethod(os.path.exists)
|
|
|
|
@staticmethod
|
|
def open(name, mode="r", **kwargs):
|
|
if _has_atomicwrites and "w" in mode:
|
|
return atomicwrites.atomic_write(name, mode=mode, overwrite=True, **kwargs)
|
|
else:
|
|
return open(name, mode, **kwargs)
|
|
|
|
|
|
class fzset(frozenset):
|
|
def __repr__(self):
|
|
return '{%s}' % ', '.join(map(repr, self))
|
|
|
|
|
|
def classify_bool(seq: Iterable, pred: Callable) -> Any:
|
|
false_elems = []
|
|
true_elems = [elem for elem in seq if pred(elem) or false_elems.append(elem)] # type: ignore[func-returns-value]
|
|
return true_elems, false_elems
|
|
|
|
|
|
def bfs(initial: Iterable, expand: Callable) -> Iterator:
|
|
open_q = deque(list(initial))
|
|
visited = set(open_q)
|
|
while open_q:
|
|
node = open_q.popleft()
|
|
yield node
|
|
for next_node in expand(node):
|
|
if next_node not in visited:
|
|
visited.add(next_node)
|
|
open_q.append(next_node)
|
|
|
|
def bfs_all_unique(initial, expand):
|
|
"bfs, but doesn't keep track of visited (aka seen), because there can be no repetitions"
|
|
open_q = deque(list(initial))
|
|
while open_q:
|
|
node = open_q.popleft()
|
|
yield node
|
|
open_q += expand(node)
|
|
|
|
|
|
def _serialize(value: Any, memo: Optional[SerializeMemoizer]) -> Any:
|
|
if isinstance(value, Serialize):
|
|
return value.serialize(memo)
|
|
elif isinstance(value, list):
|
|
return [_serialize(elem, memo) for elem in value]
|
|
elif isinstance(value, frozenset):
|
|
return list(value) # TODO reversible?
|
|
elif isinstance(value, dict):
|
|
return {key:_serialize(elem, memo) for key, elem in value.items()}
|
|
# assert value is None or isinstance(value, (int, float, str, tuple)), value
|
|
return value
|
|
|
|
|
|
|
|
|
|
def small_factors(n: int, max_factor: int) -> List[Tuple[int, int]]:
|
|
"""
|
|
Splits n up into smaller factors and summands <= max_factor.
|
|
Returns a list of [(a, b), ...]
|
|
so that the following code returns n:
|
|
|
|
n = 1
|
|
for a, b in values:
|
|
n = n * a + b
|
|
|
|
Currently, we also keep a + b <= max_factor, but that might change
|
|
"""
|
|
assert n >= 0
|
|
assert max_factor > 2
|
|
if n <= max_factor:
|
|
return [(n, 0)]
|
|
|
|
for a in range(max_factor, 1, -1):
|
|
r, b = divmod(n, a)
|
|
if a + b <= max_factor:
|
|
return small_factors(r, max_factor) + [(a, b)]
|
|
assert False, "Failed to factorize %s" % n
|
|
|
|
|
|
class OrderedSet(AbstractSet[T]):
|
|
"""A minimal OrderedSet implementation, using a dictionary.
|
|
|
|
(relies on the dictionary being ordered)
|
|
"""
|
|
def __init__(self, items: Iterable[T] =()):
|
|
self.d = dict.fromkeys(items)
|
|
|
|
def __contains__(self, item: Any) -> bool:
|
|
return item in self.d
|
|
|
|
def add(self, item: T):
|
|
self.d[item] = None
|
|
|
|
def __iter__(self) -> Iterator[T]:
|
|
return iter(self.d)
|
|
|
|
def remove(self, item: T):
|
|
del self.d[item]
|
|
|
|
def __bool__(self):
|
|
return bool(self.d)
|
|
|
|
def __len__(self) -> int:
|
|
return len(self.d)
|
|
|
|
def __repr__(self):
|
|
return f"{type(self).__name__}({', '.join(map(repr,self))})"
|