You cannot select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
	
	
		
			621 lines
		
	
	
		
			25 KiB
		
	
	
	
		
			Python
		
	
			
		
		
	
	
			621 lines
		
	
	
		
			25 KiB
		
	
	
	
		
			Python
		
	
from __future__ import annotations
 | 
						|
import ast
 | 
						|
import builtins
 | 
						|
import contextlib
 | 
						|
import itertools
 | 
						|
import os
 | 
						|
import pickle
 | 
						|
import platform
 | 
						|
import sys
 | 
						|
import textwrap
 | 
						|
from types import ModuleType
 | 
						|
from typing import TYPE_CHECKING, Any, Generator, Iterable, NamedTuple, cast
 | 
						|
 | 
						|
from IPython.extensions.deduperreload.deduperreload_patching import (
 | 
						|
    DeduperReloaderPatchingMixin,
 | 
						|
)
 | 
						|
 | 
						|
if TYPE_CHECKING:
 | 
						|
    TDefinitionAst = (
 | 
						|
        ast.FunctionDef
 | 
						|
        | ast.AsyncFunctionDef
 | 
						|
        | ast.Import
 | 
						|
        | ast.ImportFrom
 | 
						|
        | ast.Assign
 | 
						|
        | ast.AnnAssign
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
def get_module_file_name(module: ModuleType | str) -> str | None:
 | 
						|
    """Returns the module's file path, or the empty string if it's inaccessible"""
 | 
						|
    if (mod := sys.modules.get(module) if isinstance(module, str) else module) is None:
 | 
						|
        return ""
 | 
						|
    return getattr(mod, "__file__", "") or ""
 | 
						|
 | 
						|
 | 
						|
def compare_ast(node1: ast.AST | list[ast.AST], node2: ast.AST | list[ast.AST]) -> bool:
 | 
						|
    """Checks if node1 and node2 have identical AST structure/values, apart from some attributes"""
 | 
						|
    if type(node1) is not type(node2):
 | 
						|
        return False
 | 
						|
 | 
						|
    if isinstance(node1, ast.AST):
 | 
						|
        for k, v in node1.__dict__.items():
 | 
						|
            if k in (
 | 
						|
                "lineno",
 | 
						|
                "end_lineno",
 | 
						|
                "col_offset",
 | 
						|
                "end_col_offset",
 | 
						|
                "ctx",
 | 
						|
                "parent",
 | 
						|
            ):
 | 
						|
                continue
 | 
						|
            if not hasattr(node2, k) or not compare_ast(v, getattr(node2, k)):
 | 
						|
                return False
 | 
						|
        return True
 | 
						|
 | 
						|
    elif isinstance(node1, list) and isinstance(  # type:ignore [redundant-expr]
 | 
						|
        node2, list
 | 
						|
    ):
 | 
						|
        return len(node1) == len(node2) and all(
 | 
						|
            compare_ast(n1, n2) for n1, n2 in zip(node1, node2)
 | 
						|
        )
 | 
						|
    else:
 | 
						|
        return node1 == node2
 | 
						|
 | 
						|
 | 
						|
class DependencyNode(NamedTuple):
 | 
						|
    """
 | 
						|
    Each node represents a function.
 | 
						|
    qualified_name: string which represents the namespace/name of the function
 | 
						|
    abstract_syntax_tree: subtree of the overall module which corresponds to this function
 | 
						|
 | 
						|
    qualified_name is of the structure: (namespace1, namespace2, ..., name)
 | 
						|
 | 
						|
    For example, foo() in the following would be represented as (A, B, foo):
 | 
						|
 | 
						|
    class A:
 | 
						|
        class B:
 | 
						|
            def foo():
 | 
						|
                pass
 | 
						|
    """
 | 
						|
 | 
						|
    qualified_name: tuple[str, ...]
 | 
						|
    abstract_syntax_tree: ast.AST
 | 
						|
 | 
						|
 | 
						|
class GatherResult(NamedTuple):
 | 
						|
    import_defs: list[tuple[tuple[str, ...], ast.Import | ast.ImportFrom]] = []
 | 
						|
    assign_defs: list[tuple[tuple[str, ...], ast.Assign | ast.AnnAssign]] = []
 | 
						|
    function_defs: list[
 | 
						|
        tuple[tuple[str, ...], ast.FunctionDef | ast.AsyncFunctionDef]
 | 
						|
    ] = []
 | 
						|
    classes: dict[str, ast.ClassDef] = {}
 | 
						|
    unfixable: list[ast.AST] = []
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def create(cls) -> GatherResult:
 | 
						|
        return cls([], [], [], {}, [])
 | 
						|
 | 
						|
    def all_defs(self) -> Iterable[tuple[tuple[str, ...], TDefinitionAst]]:
 | 
						|
        return itertools.chain(self.import_defs, self.assign_defs, self.function_defs)
 | 
						|
 | 
						|
    def inplace_merge(self, other: GatherResult) -> None:
 | 
						|
        self.import_defs.extend(other.import_defs)
 | 
						|
        self.assign_defs.extend(other.assign_defs)
 | 
						|
        self.function_defs.extend(other.function_defs)
 | 
						|
        self.classes.update(other.classes)
 | 
						|
        self.unfixable.extend(other.unfixable)
 | 
						|
 | 
						|
 | 
						|
class ConstexprDetector(ast.NodeVisitor):
 | 
						|
    def __init__(self) -> None:
 | 
						|
        self.is_constexpr = True
 | 
						|
        self._allow_builtins_exceptions = True
 | 
						|
 | 
						|
    @contextlib.contextmanager
 | 
						|
    def disallow_builtins_exceptions(self) -> Generator[None, None, None]:
 | 
						|
        prev_allow = self._allow_builtins_exceptions
 | 
						|
        self._allow_builtins_exceptions = False
 | 
						|
        try:
 | 
						|
            yield
 | 
						|
        finally:
 | 
						|
            self._allow_builtins_exceptions = prev_allow
 | 
						|
 | 
						|
    def visit_Attribute(self, node: ast.Attribute) -> None:
 | 
						|
        with self.disallow_builtins_exceptions():
 | 
						|
            self.visit(node.value)
 | 
						|
 | 
						|
    def visit_Name(self, node: ast.Name) -> None:
 | 
						|
        if self._allow_builtins_exceptions and hasattr(builtins, node.id):
 | 
						|
            return
 | 
						|
        self.is_constexpr = False
 | 
						|
 | 
						|
    def visit(self, node: ast.AST) -> None:
 | 
						|
        if not self.is_constexpr:
 | 
						|
            # can short-circuit if we've already detected that it's not a constexpr
 | 
						|
            return
 | 
						|
        super().visit(node)
 | 
						|
 | 
						|
    def __call__(self, node: ast.AST) -> bool:
 | 
						|
        self.is_constexpr = True
 | 
						|
        self.visit(node)
 | 
						|
        return self.is_constexpr
 | 
						|
 | 
						|
 | 
						|
class AutoreloadTree:
 | 
						|
    """
 | 
						|
    Recursive data structure to keep track of reloadable functions/methods. Each object corresponds to a specific scope level.
 | 
						|
    children: classes inside given scope, maps class name to autoreload tree for that class's scope
 | 
						|
    funcs_to_autoreload: list of function names that can be autoreloaded in given scope.
 | 
						|
    new_nested_classes: Classes getting added in new autoreload cycle
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(self) -> None:
 | 
						|
        self.children: dict[str, AutoreloadTree] = {}
 | 
						|
        self.defs_to_reload: list[tuple[tuple[str, ...], ast.AST]] = []
 | 
						|
        self.defs_to_delete: set[str] = set()
 | 
						|
        self.new_nested_classes: dict[str, ast.AST] = {}
 | 
						|
 | 
						|
    def traverse_prefixes(self, prefixes: list[str]) -> AutoreloadTree:
 | 
						|
        """
 | 
						|
        Return ref to the AutoreloadTree at the namespace specified by prefixes
 | 
						|
        """
 | 
						|
        cur = self
 | 
						|
        for prefix in prefixes:
 | 
						|
            if prefix not in cur.children:
 | 
						|
                cur.children[prefix] = AutoreloadTree()
 | 
						|
            cur = cur.children[prefix]
 | 
						|
        return cur
 | 
						|
 | 
						|
 | 
						|
class DeduperReloader(DeduperReloaderPatchingMixin):
 | 
						|
    """
 | 
						|
    This version of autoreload detects when we can leverage targeted recompilation of a subset of a module and patching
 | 
						|
    existing function/method objects to reflect these changes.
 | 
						|
 | 
						|
    Detects what functions/methods can be reloaded by recursively comparing the old/new AST of module-level classes,
 | 
						|
    module-level classes' methods, recursing through nested classes' methods. If other changes are made, original
 | 
						|
    autoreload algorithm is called directly.
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(self) -> None:
 | 
						|
        self._to_autoreload: AutoreloadTree = AutoreloadTree()
 | 
						|
        self.source_by_modname: dict[str, str] = {}
 | 
						|
        self.dependency_graph: dict[tuple[str, ...], list[DependencyNode]] = {}
 | 
						|
        self._enabled = True
 | 
						|
 | 
						|
    @property
 | 
						|
    def enabled(self) -> bool:
 | 
						|
        return self._enabled and platform.python_implementation() == "CPython"
 | 
						|
 | 
						|
    @enabled.setter
 | 
						|
    def enabled(self, value: bool) -> None:
 | 
						|
        self._enabled = value
 | 
						|
 | 
						|
    def update_sources(self) -> None:
 | 
						|
        """
 | 
						|
        Update dictionary source_by_modname with current modules' source codes.
 | 
						|
        """
 | 
						|
        if not self.enabled:
 | 
						|
            return
 | 
						|
        for new_modname in sys.modules.keys() - self.source_by_modname.keys():
 | 
						|
            new_module = sys.modules[new_modname]
 | 
						|
            if (
 | 
						|
                (fname := get_module_file_name(new_module)) is None
 | 
						|
                or "site-packages" in fname
 | 
						|
                or "dist-packages" in fname
 | 
						|
                or not os.access(fname, os.R_OK)
 | 
						|
            ):
 | 
						|
                self.source_by_modname[new_modname] = ""
 | 
						|
                continue
 | 
						|
            with open(fname, "r") as f:
 | 
						|
                try:
 | 
						|
                    self.source_by_modname[new_modname] = f.read()
 | 
						|
                except Exception:
 | 
						|
                    self.source_by_modname[new_modname] = ""
 | 
						|
 | 
						|
    constexpr_detector = ConstexprDetector()
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def is_enum_subclass(node: ast.Module | ast.ClassDef) -> bool:
 | 
						|
        if isinstance(node, ast.Module):
 | 
						|
            return False
 | 
						|
        for base in node.bases:
 | 
						|
            if isinstance(base, ast.Name) and base.id == "Enum":
 | 
						|
                return True
 | 
						|
            elif (
 | 
						|
                isinstance(base, ast.Attribute)
 | 
						|
                and base.attr == "Enum"
 | 
						|
                and isinstance(base.value, ast.Name)
 | 
						|
                and base.value.id == "enum"
 | 
						|
            ):
 | 
						|
                return True
 | 
						|
        return False
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def is_constexpr_assign(
 | 
						|
        cls, node: ast.AST, parent_node: ast.Module | ast.ClassDef
 | 
						|
    ) -> bool:
 | 
						|
        if not isinstance(node, (ast.Assign, ast.AnnAssign)) or node.value is None:
 | 
						|
            return False
 | 
						|
        if cls.is_enum_subclass(parent_node):
 | 
						|
            return False
 | 
						|
        for target in node.targets if isinstance(node, ast.Assign) else [node.target]:
 | 
						|
            if not isinstance(target, ast.Name):
 | 
						|
                return False
 | 
						|
        return cls.constexpr_detector(node.value)
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def _gather_children(
 | 
						|
        cls, body: list[ast.stmt], parent_node: ast.Module | ast.ClassDef
 | 
						|
    ) -> GatherResult:
 | 
						|
        """
 | 
						|
        Given list of ast elements, return:
 | 
						|
        1. dict mapping function names to their ASTs.
 | 
						|
        2. dict mapping class names to their ASTs.
 | 
						|
        3. list of any other ASTs.
 | 
						|
        """
 | 
						|
        result = GatherResult.create()
 | 
						|
        for ast_node in body:
 | 
						|
            ast_elt: ast.expr | ast.stmt = ast_node
 | 
						|
            while isinstance(ast_elt, ast.Expr):
 | 
						|
                ast_elt = ast_elt.value
 | 
						|
            if isinstance(ast_elt, (ast.FunctionDef, ast.AsyncFunctionDef)):
 | 
						|
                result.function_defs.append(((ast_elt.name,), ast_elt))
 | 
						|
            elif isinstance(ast_elt, (ast.Import, ast.ImportFrom)):
 | 
						|
                result.import_defs.append(
 | 
						|
                    (tuple(name.asname or name.name for name in ast_elt.names), ast_elt)
 | 
						|
                )
 | 
						|
            elif isinstance(ast_elt, ast.ClassDef):
 | 
						|
                result.classes[ast_elt.name] = ast_elt
 | 
						|
            elif isinstance(ast_elt, ast.If):
 | 
						|
                result.unfixable.append(ast_elt.test)
 | 
						|
                result.inplace_merge(cls._gather_children(ast_elt.body, parent_node))
 | 
						|
                result.inplace_merge(cls._gather_children(ast_elt.orelse, parent_node))
 | 
						|
            elif isinstance(ast_elt, (ast.AsyncWith, ast.With)):
 | 
						|
                result.unfixable.extend(ast_elt.items)
 | 
						|
                result.inplace_merge(cls._gather_children(ast_elt.body, parent_node))
 | 
						|
            elif isinstance(ast_elt, ast.Try):
 | 
						|
                result.inplace_merge(cls._gather_children(ast_elt.body, parent_node))
 | 
						|
                result.inplace_merge(cls._gather_children(ast_elt.orelse, parent_node))
 | 
						|
                result.inplace_merge(
 | 
						|
                    cls._gather_children(ast_elt.finalbody, parent_node)
 | 
						|
                )
 | 
						|
                for handler in ast_elt.handlers:
 | 
						|
                    if handler.type is not None:
 | 
						|
                        result.unfixable.append(handler.type)
 | 
						|
                    result.inplace_merge(
 | 
						|
                        cls._gather_children(handler.body, parent_node)
 | 
						|
                    )
 | 
						|
            elif not isinstance(ast_elt, (ast.Constant, ast.Pass)):
 | 
						|
                if cls.is_constexpr_assign(ast_elt, parent_node):
 | 
						|
                    assert isinstance(ast_elt, (ast.Assign, ast.AnnAssign))
 | 
						|
                    targets = (
 | 
						|
                        ast_elt.targets
 | 
						|
                        if isinstance(ast_elt, ast.Assign)
 | 
						|
                        else [ast_elt.target]
 | 
						|
                    )
 | 
						|
                    result.assign_defs.append(
 | 
						|
                        (
 | 
						|
                            tuple(cast(ast.Name, target).id for target in targets),
 | 
						|
                            ast_elt,
 | 
						|
                        )
 | 
						|
                    )
 | 
						|
                else:
 | 
						|
                    result.unfixable.append(ast_elt)
 | 
						|
        return result
 | 
						|
 | 
						|
    def detect_autoreload(
 | 
						|
        self,
 | 
						|
        old_node: ast.Module | ast.ClassDef,
 | 
						|
        new_node: ast.Module | ast.ClassDef,
 | 
						|
        prefixes: list[str] | None = None,
 | 
						|
    ) -> bool:
 | 
						|
        """
 | 
						|
        Returns
 | 
						|
        -------
 | 
						|
        `True` if we can run our targeted autoreload algorithm safely.
 | 
						|
        `False` if we should instead use IPython's original autoreload implementation.
 | 
						|
        """
 | 
						|
        if not self.enabled:
 | 
						|
            return False
 | 
						|
        prefixes = prefixes or []
 | 
						|
 | 
						|
        old_result = self._gather_children(old_node.body, old_node)
 | 
						|
        new_result = self._gather_children(new_node.body, new_node)
 | 
						|
        old_defs_by_name: dict[str, ast.AST] = {
 | 
						|
            name: ast_def for names, ast_def in old_result.all_defs() for name in names
 | 
						|
        }
 | 
						|
        new_defs_by_name: dict[str, ast.AST] = {
 | 
						|
            name: ast_def for names, ast_def in new_result.all_defs() for name in names
 | 
						|
        }
 | 
						|
 | 
						|
        if not compare_ast(old_result.unfixable, new_result.unfixable):
 | 
						|
            return False
 | 
						|
 | 
						|
        cur = self._to_autoreload.traverse_prefixes(prefixes)
 | 
						|
        for names, new_ast_def in new_result.all_defs():
 | 
						|
            names_to_reload = []
 | 
						|
            for name in names:
 | 
						|
                if new_defs_by_name[name] is not new_ast_def:
 | 
						|
                    continue
 | 
						|
                if name not in old_defs_by_name or not compare_ast(
 | 
						|
                    new_ast_def, old_defs_by_name[name]
 | 
						|
                ):
 | 
						|
                    names_to_reload.append(name)
 | 
						|
            if names_to_reload:
 | 
						|
                cur.defs_to_reload.append((tuple(names), new_ast_def))
 | 
						|
        cur.defs_to_delete |= set(old_defs_by_name.keys()) - set(
 | 
						|
            new_defs_by_name.keys()
 | 
						|
        )
 | 
						|
        for name, new_ast_def_class in new_result.classes.items():
 | 
						|
            if name not in old_result.classes:
 | 
						|
                cur.new_nested_classes[name] = new_ast_def_class
 | 
						|
            elif not compare_ast(
 | 
						|
                new_ast_def_class, old_result.classes[name]
 | 
						|
            ) and not self.detect_autoreload(
 | 
						|
                old_result.classes[name], new_ast_def_class, prefixes + [name]
 | 
						|
            ):
 | 
						|
                return False
 | 
						|
        return True
 | 
						|
 | 
						|
    def _check_dependents(self) -> bool:
 | 
						|
        """
 | 
						|
        If a decorator function is modified, we should similarly reload the functions which are decorated by this
 | 
						|
        decorator. Iterate through the Dependency Graph to find such cases in the given AutoreloadTree.
 | 
						|
        """
 | 
						|
        for node in self._check_dependents_inner():
 | 
						|
            self._add_node_to_autoreload_tree(node)
 | 
						|
        return True
 | 
						|
 | 
						|
    def _add_node_to_autoreload_tree(self, node: DependencyNode) -> None:
 | 
						|
        """
 | 
						|
        Given a node of the dependency graph, add decorator dependencies to the autoreload tree.
 | 
						|
        """
 | 
						|
        if len(node.qualified_name) == 0:
 | 
						|
            return
 | 
						|
        cur = self._to_autoreload.traverse_prefixes(list(node.qualified_name[:-1]))
 | 
						|
        if node.abstract_syntax_tree is not None:
 | 
						|
            cur.defs_to_reload.append(
 | 
						|
                ((node.qualified_name[-1],), node.abstract_syntax_tree)
 | 
						|
            )
 | 
						|
 | 
						|
    def _check_dependents_inner(
 | 
						|
        self, prefixes: list[str] | None = None
 | 
						|
    ) -> list[DependencyNode]:
 | 
						|
        prefixes = prefixes or []
 | 
						|
        cur = self._to_autoreload.traverse_prefixes(prefixes)
 | 
						|
        ans = []
 | 
						|
        for (func_name, *_), _ in cur.defs_to_reload:
 | 
						|
            node = tuple(prefixes + [func_name])
 | 
						|
            ans.extend(self._gen_dependents(node))
 | 
						|
        for class_name in cur.new_nested_classes:
 | 
						|
            ans.extend(self._check_dependents_inner(prefixes + [class_name]))
 | 
						|
        return ans
 | 
						|
 | 
						|
    def _gen_dependents(self, qualname: tuple[str, ...]) -> list[DependencyNode]:
 | 
						|
        ans = []
 | 
						|
        if qualname not in self.dependency_graph:
 | 
						|
            return []
 | 
						|
        for elt in self.dependency_graph[qualname]:
 | 
						|
            ans.extend(self._gen_dependents(elt.qualified_name))
 | 
						|
            ans.append(elt)
 | 
						|
        return ans
 | 
						|
 | 
						|
    def _patch_namespace_inner(
 | 
						|
        self, ns: ModuleType | type, prefixes: list[str] | None = None
 | 
						|
    ) -> bool:
 | 
						|
        """
 | 
						|
        This function patches module functions and methods. Specifically, only objects with their name in
 | 
						|
        self.to_autoreload will be considered for patching. If an object has been marked to be autoreloaded,
 | 
						|
        new_source_code gets executed in the old version's global environment. Then, replace the old function's
 | 
						|
        attributes with the new function's attributes.
 | 
						|
        """
 | 
						|
        prefixes = prefixes or []
 | 
						|
        cur = self._to_autoreload.traverse_prefixes(prefixes)
 | 
						|
        namespace_to_check = ns
 | 
						|
        for prefix in prefixes:
 | 
						|
            namespace_to_check = namespace_to_check.__dict__[prefix]
 | 
						|
        seen_names: set[str] = set()
 | 
						|
        for names, new_ast_def in cur.defs_to_reload:
 | 
						|
            if len(names) == 1 and names[0] in seen_names:
 | 
						|
                continue
 | 
						|
            seen_names.update(names)
 | 
						|
            local_env: dict[str, Any] = {}
 | 
						|
            if (
 | 
						|
                isinstance(new_ast_def, (ast.FunctionDef, ast.AsyncFunctionDef))
 | 
						|
                and (name := names[0]) in namespace_to_check.__dict__
 | 
						|
            ):
 | 
						|
                assert len(names) == 1
 | 
						|
                to_patch_to = namespace_to_check.__dict__[name]
 | 
						|
                if isinstance(to_patch_to, (staticmethod, classmethod)):
 | 
						|
                    to_patch_to = to_patch_to.__func__
 | 
						|
                # exec new source code using old function's (obj) globals environment.
 | 
						|
                func_code = textwrap.dedent(ast.unparse(new_ast_def))
 | 
						|
                if is_method := (len(prefixes) > 0):
 | 
						|
                    func_code = "class __autoreload_class__:\n" + textwrap.indent(
 | 
						|
                        func_code, "    "
 | 
						|
                    )
 | 
						|
                global_env = ns.__dict__
 | 
						|
                if not isinstance(global_env, dict):
 | 
						|
                    global_env = dict(global_env)
 | 
						|
                # Compile with correct filename to preserve in traceback
 | 
						|
                filename = (
 | 
						|
                    getattr(to_patch_to, "__code__", None)
 | 
						|
                    and to_patch_to.__code__.co_filename
 | 
						|
                    or "<string>"
 | 
						|
                )
 | 
						|
                func_asts = [ast.parse(func_code)]
 | 
						|
                if len(cast(ast.FunctionDef, func_asts[0].body[0]).decorator_list) > 0:
 | 
						|
                    without_decorator_list = pickle.loads(pickle.dumps(func_asts[0]))
 | 
						|
                    cast(
 | 
						|
                        ast.FunctionDef, without_decorator_list.body[0]
 | 
						|
                    ).decorator_list = []
 | 
						|
                    func_asts.insert(0, without_decorator_list)
 | 
						|
                for func_ast in func_asts:
 | 
						|
                    compiled_code = compile(
 | 
						|
                        func_ast, filename, mode="exec", dont_inherit=True
 | 
						|
                    )
 | 
						|
                    exec(compiled_code, global_env, local_env)  # type: ignore[arg-type]
 | 
						|
                    # local_env contains the function exec'd from  new version of function
 | 
						|
                    if is_method:
 | 
						|
                        to_patch_from = getattr(local_env["__autoreload_class__"], name)
 | 
						|
                    else:
 | 
						|
                        to_patch_from = local_env[name]
 | 
						|
                    if isinstance(to_patch_from, (staticmethod, classmethod)):
 | 
						|
                        to_patch_from = to_patch_from.__func__
 | 
						|
                    if isinstance(to_patch_to, property) and isinstance(
 | 
						|
                        to_patch_from, property
 | 
						|
                    ):
 | 
						|
                        for attr in ("fget", "fset", "fdel"):
 | 
						|
                            if (
 | 
						|
                                getattr(to_patch_to, attr) is None
 | 
						|
                                or getattr(to_patch_from, attr) is None
 | 
						|
                            ):
 | 
						|
                                self.try_patch_attr(to_patch_to, to_patch_from, attr)
 | 
						|
                            else:
 | 
						|
                                self.patch_function(
 | 
						|
                                    getattr(to_patch_to, attr),
 | 
						|
                                    getattr(to_patch_from, attr),
 | 
						|
                                    is_method,
 | 
						|
                                )
 | 
						|
                    elif not isinstance(to_patch_to, property) and not isinstance(
 | 
						|
                        to_patch_from, property
 | 
						|
                    ):
 | 
						|
                        self.patch_function(to_patch_to, to_patch_from, is_method)
 | 
						|
                    else:
 | 
						|
                        raise ValueError(
 | 
						|
                            "adding or removing property decorations not supported"
 | 
						|
                        )
 | 
						|
            else:
 | 
						|
                exec(
 | 
						|
                    ast.unparse(new_ast_def),
 | 
						|
                    ns.__dict__ | namespace_to_check.__dict__,
 | 
						|
                    local_env,
 | 
						|
                )
 | 
						|
                for name in names:
 | 
						|
                    setattr(namespace_to_check, name, local_env[name])
 | 
						|
        cur.defs_to_reload.clear()
 | 
						|
        for name in cur.defs_to_delete:
 | 
						|
            try:
 | 
						|
                delattr(namespace_to_check, name)
 | 
						|
            except (AttributeError, TypeError, ValueError):
 | 
						|
                # give up on deleting the attribute, let the stale one dangle
 | 
						|
                pass
 | 
						|
        cur.defs_to_delete.clear()
 | 
						|
        for class_name, class_ast_node in cur.new_nested_classes.items():
 | 
						|
            local_env_class: dict[str, Any] = {}
 | 
						|
            exec(
 | 
						|
                ast.unparse(class_ast_node),
 | 
						|
                ns.__dict__ | namespace_to_check.__dict__,
 | 
						|
                local_env_class,
 | 
						|
            )
 | 
						|
            setattr(namespace_to_check, class_name, local_env_class[class_name])
 | 
						|
        cur.new_nested_classes.clear()
 | 
						|
        for class_name in cur.children.keys():
 | 
						|
            if not self._patch_namespace(ns, prefixes + [class_name]):
 | 
						|
                return False
 | 
						|
        cur.children.clear()
 | 
						|
        return True
 | 
						|
 | 
						|
    def _patch_namespace(
 | 
						|
        self, ns: ModuleType | type, prefixes: list[str] | None = None
 | 
						|
    ) -> bool:
 | 
						|
        """
 | 
						|
        Wrapper for patching all elements in a namespace as specified by the to_autoreload member variable.
 | 
						|
        Returns `true` if patching was successful, and `false` if unsuccessful.
 | 
						|
        """
 | 
						|
        try:
 | 
						|
            return self._patch_namespace_inner(ns, prefixes=prefixes)
 | 
						|
        except Exception:
 | 
						|
            return False
 | 
						|
 | 
						|
    def maybe_reload_module(self, module: ModuleType) -> bool:
 | 
						|
        """
 | 
						|
        Uses Deduperreload to try to update a module.
 | 
						|
        Returns `true` on success and `false` on failure.
 | 
						|
        """
 | 
						|
        if not self.enabled:
 | 
						|
            return False
 | 
						|
        if not (modname := getattr(module, "__name__", None)):
 | 
						|
            return False
 | 
						|
        if (fname := get_module_file_name(module)) is None:
 | 
						|
            return False
 | 
						|
        with open(fname, "r") as f:
 | 
						|
            new_source_code = f.read()
 | 
						|
        patched_flag = False
 | 
						|
        if old_source_code := self.source_by_modname.get(modname):
 | 
						|
            # get old/new module ast
 | 
						|
            try:
 | 
						|
                old_module_ast = ast.parse(old_source_code)
 | 
						|
                new_module_ast = ast.parse(new_source_code)
 | 
						|
            except Exception:
 | 
						|
                return False
 | 
						|
            # detect if we are able to use our autoreload algorithm
 | 
						|
            ctx = contextlib.suppress()
 | 
						|
            with ctx:
 | 
						|
                self._build_dependency_graph(new_module_ast)
 | 
						|
                if (
 | 
						|
                    self.detect_autoreload(old_module_ast, new_module_ast)
 | 
						|
                    and self._check_dependents()
 | 
						|
                    and self._patch_namespace(module)
 | 
						|
                ):
 | 
						|
                    patched_flag = True
 | 
						|
 | 
						|
        self.source_by_modname[modname] = new_source_code
 | 
						|
        self._to_autoreload = AutoreloadTree()
 | 
						|
        return patched_flag
 | 
						|
 | 
						|
    def _separate_name(
 | 
						|
        self,
 | 
						|
        decorator: ast.Attribute | ast.Name | ast.Call | ast.expr,
 | 
						|
        accept_calls: bool,
 | 
						|
    ) -> list[str] | None:
 | 
						|
        """
 | 
						|
        Generates a qualified name for a given decorator by finding its relative namespace.
 | 
						|
        """
 | 
						|
        if isinstance(decorator, ast.Name):
 | 
						|
            return [decorator.id]
 | 
						|
        elif isinstance(decorator, ast.Call):
 | 
						|
            if accept_calls:
 | 
						|
                return self._separate_name(decorator.func, False)
 | 
						|
            else:
 | 
						|
                return None
 | 
						|
        if not isinstance(decorator, ast.Attribute):
 | 
						|
            return None
 | 
						|
        if pref := self._separate_name(decorator.value, False):
 | 
						|
            return pref + [decorator.attr]
 | 
						|
        else:
 | 
						|
            return None
 | 
						|
 | 
						|
    def _gather_dependents(
 | 
						|
        self, body: list[ast.stmt], body_prefixes: list[str] | None = None
 | 
						|
    ) -> bool:
 | 
						|
        body_prefixes = body_prefixes or []
 | 
						|
        for ast_node in body:
 | 
						|
            ast_elt: ast.expr | ast.stmt = ast_node
 | 
						|
            if isinstance(ast_elt, ast.ClassDef):
 | 
						|
                self._gather_dependents(ast_elt.body, body_prefixes + [ast_elt.name])
 | 
						|
                continue
 | 
						|
            if not isinstance(ast_elt, (ast.FunctionDef, ast.AsyncFunctionDef)):
 | 
						|
                continue
 | 
						|
            qualified_name = tuple(body_prefixes + [ast_elt.name])
 | 
						|
            cur_dependency_node = DependencyNode(qualified_name, ast_elt)
 | 
						|
            for decorator in ast_elt.decorator_list:
 | 
						|
                decorator_path = self._separate_name(decorator, True)
 | 
						|
                if not decorator_path:
 | 
						|
                    continue
 | 
						|
                decorator_path_tuple = tuple(decorator_path)
 | 
						|
                self.dependency_graph.setdefault(decorator_path_tuple, []).append(
 | 
						|
                    cur_dependency_node
 | 
						|
                )
 | 
						|
        return True
 | 
						|
 | 
						|
    def _build_dependency_graph(self, new_ast: ast.Module | ast.ClassDef) -> bool:
 | 
						|
        """
 | 
						|
        Wrapper function for generating dependency graph given some AST.
 | 
						|
        Returns `true` on success. Returns `false` on failure.
 | 
						|
        Currently, only returns `true` as we do not block on failure to build this graph.
 | 
						|
        """
 | 
						|
        return self._gather_dependents(new_ast.body)
 |