You cannot select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
	
	
		
			207 lines
		
	
	
		
			4.5 KiB
		
	
	
	
		
			Python
		
	
			
		
		
	
	
			207 lines
		
	
	
		
			4.5 KiB
		
	
	
	
		
			Python
		
	
from collections import OrderedDict, deque
 | 
						|
from datetime import date, time, datetime
 | 
						|
from decimal import Decimal
 | 
						|
from fractions import Fraction
 | 
						|
import ast
 | 
						|
import enum
 | 
						|
import typing
 | 
						|
 | 
						|
 | 
						|
class CannotEval(Exception):
 | 
						|
    def __repr__(self):
 | 
						|
        return self.__class__.__name__
 | 
						|
 | 
						|
    __str__ = __repr__
 | 
						|
 | 
						|
 | 
						|
def is_any(x, *args):
 | 
						|
    return any(
 | 
						|
        x is arg
 | 
						|
        for arg in args
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
def of_type(x, *types):
 | 
						|
    if is_any(type(x), *types):
 | 
						|
        return x
 | 
						|
    else:
 | 
						|
        raise CannotEval
 | 
						|
 | 
						|
 | 
						|
def of_standard_types(x, *, check_dict_values: bool, deep: bool):
 | 
						|
    if is_standard_types(x, check_dict_values=check_dict_values, deep=deep):
 | 
						|
        return x
 | 
						|
    else:
 | 
						|
        raise CannotEval
 | 
						|
 | 
						|
 | 
						|
def is_standard_types(x, *, check_dict_values: bool, deep: bool):
 | 
						|
    try:
 | 
						|
        return _is_standard_types_deep(x, check_dict_values, deep)[0]
 | 
						|
    except RecursionError:
 | 
						|
        return False
 | 
						|
 | 
						|
 | 
						|
def _is_standard_types_deep(x, check_dict_values: bool, deep: bool):
 | 
						|
    typ = type(x)
 | 
						|
    if is_any(
 | 
						|
        typ,
 | 
						|
        str,
 | 
						|
        int,
 | 
						|
        bool,
 | 
						|
        float,
 | 
						|
        bytes,
 | 
						|
        complex,
 | 
						|
        date,
 | 
						|
        time,
 | 
						|
        datetime,
 | 
						|
        Fraction,
 | 
						|
        Decimal,
 | 
						|
        type(None),
 | 
						|
        object,
 | 
						|
    ):
 | 
						|
        return True, 0
 | 
						|
 | 
						|
    if is_any(typ, tuple, frozenset, list, set, dict, OrderedDict, deque, slice):
 | 
						|
        if typ in [slice]:
 | 
						|
            length = 0
 | 
						|
        else:
 | 
						|
            length = len(x)
 | 
						|
        assert isinstance(deep, bool)
 | 
						|
        if not deep:
 | 
						|
            return True, length
 | 
						|
 | 
						|
        if check_dict_values and typ in (dict, OrderedDict):
 | 
						|
            items = (v for pair in x.items() for v in pair)
 | 
						|
        elif typ is slice:
 | 
						|
            items = [x.start, x.stop, x.step]
 | 
						|
        else:
 | 
						|
            items = x
 | 
						|
        for item in items:
 | 
						|
            if length > 100000:
 | 
						|
                return False, length
 | 
						|
            is_standard, item_length = _is_standard_types_deep(
 | 
						|
                item, check_dict_values, deep
 | 
						|
            )
 | 
						|
            if not is_standard:
 | 
						|
                return False, length
 | 
						|
            length += item_length
 | 
						|
        return True, length
 | 
						|
 | 
						|
    return False, 0
 | 
						|
 | 
						|
 | 
						|
class _E(enum.Enum):
 | 
						|
    pass
 | 
						|
 | 
						|
 | 
						|
class _C:
 | 
						|
    def foo(self): pass  # pragma: nocover
 | 
						|
 | 
						|
    def bar(self): pass  # pragma: nocover
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def cm(cls): pass  # pragma: nocover
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def sm(): pass  # pragma: nocover
 | 
						|
 | 
						|
 | 
						|
safe_name_samples = {
 | 
						|
    "len": len,
 | 
						|
    "append": list.append,
 | 
						|
    "__add__": list.__add__,
 | 
						|
    "insert": [].insert,
 | 
						|
    "__mul__": [].__mul__,
 | 
						|
    "fromkeys": dict.__dict__['fromkeys'],
 | 
						|
    "is_any": is_any,
 | 
						|
    "__repr__": CannotEval.__repr__,
 | 
						|
    "foo": _C().foo,
 | 
						|
    "bar": _C.bar,
 | 
						|
    "cm": _C.cm,
 | 
						|
    "sm": _C.sm,
 | 
						|
    "ast": ast,
 | 
						|
    "CannotEval": CannotEval,
 | 
						|
    "_E": _E,
 | 
						|
}
 | 
						|
 | 
						|
typing_annotation_samples = {
 | 
						|
    name: getattr(typing, name)
 | 
						|
    for name in "List Dict Tuple Set Callable Mapping".split()
 | 
						|
}
 | 
						|
 | 
						|
safe_name_types = tuple({
 | 
						|
    type(f)
 | 
						|
    for f in safe_name_samples.values()
 | 
						|
})
 | 
						|
 | 
						|
 | 
						|
typing_annotation_types = tuple({
 | 
						|
    type(f)
 | 
						|
    for f in typing_annotation_samples.values()
 | 
						|
})
 | 
						|
 | 
						|
 | 
						|
def eq_checking_types(a, b):
 | 
						|
    return type(a) is type(b) and a == b
 | 
						|
 | 
						|
 | 
						|
def ast_name(node):
 | 
						|
    if isinstance(node, ast.Name):
 | 
						|
        return node.id
 | 
						|
    elif isinstance(node, ast.Attribute):
 | 
						|
        return node.attr
 | 
						|
    else:
 | 
						|
        return None
 | 
						|
 | 
						|
 | 
						|
def safe_name(value):
 | 
						|
    typ = type(value)
 | 
						|
    if is_any(typ, *safe_name_types):
 | 
						|
        return value.__name__
 | 
						|
    elif value is typing.Optional:
 | 
						|
        return "Optional"
 | 
						|
    elif value is typing.Union:
 | 
						|
        return "Union"
 | 
						|
    elif is_any(typ, *typing_annotation_types):
 | 
						|
        return getattr(value, "__name__", None) or getattr(value, "_name", None)
 | 
						|
    else:
 | 
						|
        return None
 | 
						|
 | 
						|
 | 
						|
def has_ast_name(value, node):
 | 
						|
    value_name = safe_name(value)
 | 
						|
    if type(value_name) is not str:
 | 
						|
        return False
 | 
						|
    return eq_checking_types(ast_name(node), value_name)
 | 
						|
 | 
						|
 | 
						|
def copy_ast_without_context(x):
 | 
						|
    if isinstance(x, ast.AST):
 | 
						|
        kwargs = {
 | 
						|
            field: copy_ast_without_context(getattr(x, field))
 | 
						|
            for field in x._fields
 | 
						|
            if field != 'ctx'
 | 
						|
            if hasattr(x, field)
 | 
						|
        }
 | 
						|
        a = type(x)(**kwargs)
 | 
						|
        if hasattr(a, 'ctx'):
 | 
						|
            # Python 3.13.0b2+ defaults to Load when we don't pass ctx
 | 
						|
            # https://github.com/python/cpython/pull/118871
 | 
						|
            del a.ctx
 | 
						|
        return a
 | 
						|
    elif isinstance(x, list):
 | 
						|
        return list(map(copy_ast_without_context, x))
 | 
						|
    else:
 | 
						|
        return x
 | 
						|
 | 
						|
 | 
						|
def ensure_dict(x):
 | 
						|
    """
 | 
						|
    Handles invalid non-dict inputs
 | 
						|
    """
 | 
						|
    try:
 | 
						|
        return dict(x)
 | 
						|
    except Exception:
 | 
						|
        return {}
 |