You cannot select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
	
	
		
			912 lines
		
	
	
		
			31 KiB
		
	
	
	
		
			Python
		
	
			
		
		
	
	
			912 lines
		
	
	
		
			31 KiB
		
	
	
	
		
			Python
		
	
"""IPython extension to reload modules before executing user code.
 | 
						|
 | 
						|
``autoreload`` reloads modules automatically before entering the execution of
 | 
						|
code typed at the IPython prompt.
 | 
						|
 | 
						|
This makes for example the following workflow possible:
 | 
						|
 | 
						|
.. sourcecode:: ipython
 | 
						|
 | 
						|
   In [1]: %load_ext autoreload
 | 
						|
 | 
						|
   In [2]: %autoreload 2
 | 
						|
 | 
						|
   In [3]: from foo import some_function
 | 
						|
 | 
						|
   In [4]: some_function()
 | 
						|
   Out[4]: 42
 | 
						|
 | 
						|
   In [5]: # open foo.py in an editor and change some_function to return 43
 | 
						|
 | 
						|
   In [6]: some_function()
 | 
						|
   Out[6]: 43
 | 
						|
 | 
						|
The module was reloaded without reloading it explicitly, and the object
 | 
						|
imported with ``from foo import ...`` was also updated.
 | 
						|
 | 
						|
Usage
 | 
						|
=====
 | 
						|
 | 
						|
The following magic commands are provided:
 | 
						|
 | 
						|
``%autoreload``, ``%autoreload now``
 | 
						|
 | 
						|
    Reload all modules (except those excluded by ``%aimport``)
 | 
						|
    automatically now.
 | 
						|
 | 
						|
``%autoreload 0``, ``%autoreload off``
 | 
						|
 | 
						|
    Disable automatic reloading.
 | 
						|
 | 
						|
``%autoreload 1``, ``%autoreload explicit``
 | 
						|
 | 
						|
    Reload all modules imported with ``%aimport`` every time before
 | 
						|
    executing the Python code typed.
 | 
						|
 | 
						|
``%autoreload 2``, ``%autoreload all``
 | 
						|
 | 
						|
    Reload all modules (except those excluded by ``%aimport``) every
 | 
						|
    time before executing the Python code typed.
 | 
						|
 | 
						|
``%autoreload 3``, ``%autoreload complete``
 | 
						|
 | 
						|
    Same as 2/all, but also adds any new objects in the module. See
 | 
						|
    unit test at IPython/extensions/tests/test_autoreload.py::test_autoload_newly_added_objects
 | 
						|
 | 
						|
  Adding ``--print`` or ``-p`` to the ``%autoreload`` line will print autoreload activity to
 | 
						|
  standard out. ``--log`` or ``-l`` will do it to the log at INFO level; both can be used
 | 
						|
  simultaneously.
 | 
						|
 | 
						|
``%aimport``
 | 
						|
 | 
						|
    List modules which are to be automatically imported or not to be imported.
 | 
						|
 | 
						|
``%aimport foo``
 | 
						|
 | 
						|
    Import module 'foo' and mark it to be autoreloaded for ``%autoreload 1``
 | 
						|
 | 
						|
``%aimport foo, bar``
 | 
						|
 | 
						|
    Import modules 'foo', 'bar' and mark them to be autoreloaded for ``%autoreload 1``
 | 
						|
 | 
						|
``%aimport -foo``
 | 
						|
 | 
						|
    Mark module 'foo' to not be autoreloaded.
 | 
						|
 | 
						|
Import Conflict Resolution
 | 
						|
==========================
 | 
						|
 | 
						|
In ``%autoreload 3`` mode, the extension tracks ``from X import Y`` style imports
 | 
						|
and intelligently resolves conflicts when the same name is imported multiple ways.
 | 
						|
 | 
						|
Import tracking occurs after successful code execution, ensuring that only valid
 | 
						|
imports are tracked. This approach handles edge cases such as:
 | 
						|
 | 
						|
- Importing a name that doesn't initially exist in a module, then adding that name
 | 
						|
  to the module and importing it again
 | 
						|
- Conflicts between aliased imports (``from X import Y as Z``) and direct imports
 | 
						|
  (``from X import Z``)
 | 
						|
 | 
						|
When conflicts occur:
 | 
						|
 | 
						|
- If you first do ``from X import Y as Z`` then later ``from X import Z``,
 | 
						|
  the extension will switch to reloading ``Z`` instead of ``Y`` under the name ``Z``.
 | 
						|
 | 
						|
- Similarly, if you first do ``from X import Z`` then later ``from X import Y as Z``,
 | 
						|
  the extension will switch to reloading ``Y`` as ``Z`` instead of the original ``Z``.
 | 
						|
 | 
						|
- The most recent successful import always takes precedence in conflict resolution.
 | 
						|
 | 
						|
Caveats
 | 
						|
=======
 | 
						|
 | 
						|
Reloading Python modules in a reliable way is in general difficult,
 | 
						|
and unexpected things may occur. ``%autoreload`` tries to work around
 | 
						|
common pitfalls by replacing function code objects and parts of
 | 
						|
classes previously in the module with new versions. This makes the
 | 
						|
following things to work:
 | 
						|
 | 
						|
- Functions and classes imported via 'from xxx import foo' are upgraded
 | 
						|
  to new versions when 'xxx' is reloaded.
 | 
						|
 | 
						|
- Methods and properties of classes are upgraded on reload, so that
 | 
						|
  calling 'c.foo()' on an object 'c' created before the reload causes
 | 
						|
  the new code for 'foo' to be executed.
 | 
						|
 | 
						|
Some of the known remaining caveats are:
 | 
						|
 | 
						|
- Replacing code objects does not always succeed: changing a @property
 | 
						|
  in a class to an ordinary method or a method to a member variable
 | 
						|
  can cause problems (but in old objects only).
 | 
						|
 | 
						|
- Functions that are removed (eg. via monkey-patching) from a module
 | 
						|
  before it is reloaded are not upgraded.
 | 
						|
 | 
						|
- C extension modules cannot be reloaded, and so cannot be autoreloaded.
 | 
						|
 | 
						|
- While comparing Enum and Flag, the 'is' Identity Operator is used (even in the case '==' has been used (Similar to the 'None' keyword)).
 | 
						|
 | 
						|
- Reloading a module, or importing the same module by a different name, creates new Enums. These may look the same, but are not.
 | 
						|
"""
 | 
						|
 | 
						|
from IPython.core import magic_arguments
 | 
						|
from IPython.core.magic import Magics, magics_class, line_magic
 | 
						|
from IPython.extensions.deduperreload.deduperreload import DeduperReloader
 | 
						|
 | 
						|
__skip_doctest__ = True
 | 
						|
 | 
						|
# -----------------------------------------------------------------------------
 | 
						|
#  Copyright (C) 2000 Thomas Heller
 | 
						|
#  Copyright (C) 2008 Pauli Virtanen <pav@iki.fi>
 | 
						|
#  Copyright (C) 2012  The IPython Development Team
 | 
						|
#
 | 
						|
#  Distributed under the terms of the BSD License.  The full license is in
 | 
						|
#  the file COPYING, distributed as part of this software.
 | 
						|
# -----------------------------------------------------------------------------
 | 
						|
#
 | 
						|
# This IPython module is written by Pauli Virtanen, based on the autoreload
 | 
						|
# code by Thomas Heller.
 | 
						|
 | 
						|
# -----------------------------------------------------------------------------
 | 
						|
# Imports
 | 
						|
# -----------------------------------------------------------------------------
 | 
						|
 | 
						|
import ast
 | 
						|
import os
 | 
						|
import sys
 | 
						|
import traceback
 | 
						|
import types
 | 
						|
import weakref
 | 
						|
import gc
 | 
						|
import logging
 | 
						|
from importlib import import_module, reload
 | 
						|
from importlib.util import source_from_cache
 | 
						|
 | 
						|
# ------------------------------------------------------------------------------
 | 
						|
# Autoreload functionality
 | 
						|
# ------------------------------------------------------------------------------
 | 
						|
 | 
						|
 | 
						|
class ModuleReloader:
 | 
						|
    enabled = False
 | 
						|
    """Whether this reloader is enabled"""
 | 
						|
 | 
						|
    check_all = True
 | 
						|
    """Autoreload all modules, not just those listed in 'modules'"""
 | 
						|
 | 
						|
    autoload_obj = False
 | 
						|
    """Autoreload all modules AND autoload all new objects"""
 | 
						|
 | 
						|
    def __init__(self, shell=None):
 | 
						|
        # Modules that failed to reload: {module: mtime-on-failed-reload, ...}
 | 
						|
        self.failed = {}
 | 
						|
        # Modules specially marked as autoreloadable.
 | 
						|
        self.modules = {}
 | 
						|
        # Modules specially marked as not autoreloadable.
 | 
						|
        self.skip_modules = {}
 | 
						|
        # (module-name, name) -> weakref, for replacing old code objects
 | 
						|
        self.old_objects = {}
 | 
						|
        # Module modification timestamps
 | 
						|
        self.modules_mtimes = {}
 | 
						|
        self.shell = shell
 | 
						|
 | 
						|
        # Reporting callable for verbosity
 | 
						|
        self._report = lambda msg: None  # by default, be quiet.
 | 
						|
 | 
						|
        # Deduper reloader
 | 
						|
        self.deduper_reloader = DeduperReloader()
 | 
						|
 | 
						|
        # Persistent import tracker for from-imports
 | 
						|
        self.import_from_tracker = ImportFromTracker({}, {})
 | 
						|
 | 
						|
        # Cache module modification times
 | 
						|
        self.check(check_all=True, do_reload=False)
 | 
						|
 | 
						|
        # To hide autoreload errors
 | 
						|
        self.hide_errors = False
 | 
						|
 | 
						|
    def mark_module_skipped(self, module_name):
 | 
						|
        """Skip reloading the named module in the future"""
 | 
						|
        try:
 | 
						|
            del self.modules[module_name]
 | 
						|
        except KeyError:
 | 
						|
            pass
 | 
						|
        self.skip_modules[module_name] = True
 | 
						|
 | 
						|
    def mark_module_reloadable(self, module_name):
 | 
						|
        """Reload the named module in the future (if it is imported)"""
 | 
						|
        try:
 | 
						|
            del self.skip_modules[module_name]
 | 
						|
        except KeyError:
 | 
						|
            pass
 | 
						|
        self.modules[module_name] = True
 | 
						|
 | 
						|
    def clear_import_tracker(self):
 | 
						|
        """Clear the persistent import tracker state"""
 | 
						|
        self.import_from_tracker = ImportFromTracker({}, {})
 | 
						|
 | 
						|
    def aimport_module(self, module_name):
 | 
						|
        """Import a module, and mark it reloadable
 | 
						|
 | 
						|
        Returns
 | 
						|
        -------
 | 
						|
        top_module : module
 | 
						|
            The imported module if it is top-level, or the top-level
 | 
						|
        top_name : module
 | 
						|
            Name of top_module
 | 
						|
 | 
						|
        """
 | 
						|
        self.mark_module_reloadable(module_name)
 | 
						|
 | 
						|
        import_module(module_name)
 | 
						|
        top_name = module_name.split(".")[0]
 | 
						|
        top_module = sys.modules[top_name]
 | 
						|
        return top_module, top_name
 | 
						|
 | 
						|
    def filename_and_mtime(self, module):
 | 
						|
        if not hasattr(module, "__file__") or module.__file__ is None:
 | 
						|
            return None, None
 | 
						|
 | 
						|
        if getattr(module, "__name__", None) in [None, "__mp_main__", "__main__"]:
 | 
						|
            # we cannot reload(__main__) or reload(__mp_main__)
 | 
						|
            return None, None
 | 
						|
 | 
						|
        filename = module.__file__
 | 
						|
        path, ext = os.path.splitext(filename)
 | 
						|
 | 
						|
        if ext.lower() == ".py":
 | 
						|
            py_filename = filename
 | 
						|
        else:
 | 
						|
            try:
 | 
						|
                py_filename = source_from_cache(filename)
 | 
						|
            except ValueError:
 | 
						|
                return None, None
 | 
						|
 | 
						|
        try:
 | 
						|
            pymtime = os.stat(py_filename).st_mtime
 | 
						|
        except OSError:
 | 
						|
            return None, None
 | 
						|
 | 
						|
        return py_filename, pymtime
 | 
						|
 | 
						|
    def check(self, check_all=False, do_reload=True, execution_info=None):
 | 
						|
        """Check whether some modules need to be reloaded."""
 | 
						|
 | 
						|
        if not self.enabled and not check_all:
 | 
						|
            return
 | 
						|
 | 
						|
        if check_all or self.check_all:
 | 
						|
            modules = list(sys.modules.keys())
 | 
						|
        else:
 | 
						|
            modules = list(self.modules.keys())
 | 
						|
 | 
						|
        # Use the persistent import_from_tracker
 | 
						|
        import_from_tracker = (
 | 
						|
            self.import_from_tracker if self.import_from_tracker.imports_froms else None
 | 
						|
        )
 | 
						|
        for modname in modules:
 | 
						|
            m = sys.modules.get(modname, None)
 | 
						|
 | 
						|
            if modname in self.skip_modules:
 | 
						|
                continue
 | 
						|
 | 
						|
            py_filename, pymtime = self.filename_and_mtime(m)
 | 
						|
            if py_filename is None:
 | 
						|
                continue
 | 
						|
 | 
						|
            try:
 | 
						|
                if pymtime <= self.modules_mtimes[modname]:
 | 
						|
                    continue
 | 
						|
            except KeyError:
 | 
						|
                self.modules_mtimes[modname] = pymtime
 | 
						|
                continue
 | 
						|
            else:
 | 
						|
                if self.failed.get(py_filename, None) == pymtime:
 | 
						|
                    continue
 | 
						|
 | 
						|
            self.modules_mtimes[modname] = pymtime
 | 
						|
 | 
						|
            # If we've reached this point, we should try to reload the module
 | 
						|
            if do_reload:
 | 
						|
                self._report(f"Reloading '{modname}'.")
 | 
						|
                try:
 | 
						|
                    if self.autoload_obj:
 | 
						|
                        superreload(
 | 
						|
                            m,
 | 
						|
                            reload,
 | 
						|
                            self.old_objects,
 | 
						|
                            self.shell,
 | 
						|
                            import_from_tracker=import_from_tracker,
 | 
						|
                        )
 | 
						|
                    # if not using autoload, check if deduperreload is viable for this module
 | 
						|
                    elif self.deduper_reloader.maybe_reload_module(m):
 | 
						|
                        pass
 | 
						|
                    else:
 | 
						|
                        superreload(m, reload, self.old_objects)
 | 
						|
                    if py_filename in self.failed:
 | 
						|
                        del self.failed[py_filename]
 | 
						|
                except:
 | 
						|
                    if not self.hide_errors:
 | 
						|
                        print(
 | 
						|
                            "[autoreload of {} failed: {}]".format(
 | 
						|
                                modname, traceback.format_exc(10)
 | 
						|
                            ),
 | 
						|
                            file=sys.stderr,
 | 
						|
                        )
 | 
						|
                    self.failed[py_filename] = pymtime
 | 
						|
        self.deduper_reloader.update_sources()
 | 
						|
 | 
						|
 | 
						|
# ------------------------------------------------------------------------------
 | 
						|
# superreload
 | 
						|
# ------------------------------------------------------------------------------
 | 
						|
 | 
						|
 | 
						|
func_attrs = [
 | 
						|
    "__code__",
 | 
						|
    "__defaults__",
 | 
						|
    "__doc__",
 | 
						|
    "__closure__",
 | 
						|
    "__globals__",
 | 
						|
    "__dict__",
 | 
						|
]
 | 
						|
 | 
						|
 | 
						|
def update_function(old, new):
 | 
						|
    """Upgrade the code object of a function"""
 | 
						|
    for name in func_attrs:
 | 
						|
        try:
 | 
						|
            setattr(old, name, getattr(new, name))
 | 
						|
        except (AttributeError, TypeError):
 | 
						|
            pass
 | 
						|
 | 
						|
 | 
						|
def update_instances(old, new):
 | 
						|
    """Use garbage collector to find all instances that refer to the old
 | 
						|
    class definition and update their __class__ to point to the new class
 | 
						|
    definition"""
 | 
						|
 | 
						|
    refs = gc.get_referrers(old)
 | 
						|
 | 
						|
    for ref in refs:
 | 
						|
        if type(ref) is old:
 | 
						|
            object.__setattr__(ref, "__class__", new)
 | 
						|
 | 
						|
 | 
						|
def update_class(old, new):
 | 
						|
    """Replace stuff in the __dict__ of a class, and upgrade
 | 
						|
    method code objects, and add new methods, if any"""
 | 
						|
    for key in list(old.__dict__.keys()):
 | 
						|
        old_obj = getattr(old, key)
 | 
						|
        try:
 | 
						|
            new_obj = getattr(new, key)
 | 
						|
            # explicitly checking that comparison returns True to handle
 | 
						|
            # cases where `==` doesn't return a boolean.
 | 
						|
            if (old_obj == new_obj) is True:
 | 
						|
                continue
 | 
						|
        except AttributeError:
 | 
						|
            # obsolete attribute: remove it
 | 
						|
            try:
 | 
						|
                delattr(old, key)
 | 
						|
            except (AttributeError, TypeError):
 | 
						|
                pass
 | 
						|
            continue
 | 
						|
        except ValueError:
 | 
						|
            # can't compare nested structures containing
 | 
						|
            # numpy arrays using `==`
 | 
						|
            pass
 | 
						|
 | 
						|
        if update_generic(old_obj, new_obj):
 | 
						|
            continue
 | 
						|
 | 
						|
        try:
 | 
						|
            setattr(old, key, getattr(new, key))
 | 
						|
        except (AttributeError, TypeError):
 | 
						|
            pass  # skip non-writable attributes
 | 
						|
 | 
						|
    for key in list(new.__dict__.keys()):
 | 
						|
        if key not in list(old.__dict__.keys()):
 | 
						|
            try:
 | 
						|
                setattr(old, key, getattr(new, key))
 | 
						|
            except (AttributeError, TypeError):
 | 
						|
                pass  # skip non-writable attributes
 | 
						|
 | 
						|
    # update all instances of class
 | 
						|
    update_instances(old, new)
 | 
						|
 | 
						|
 | 
						|
def update_property(old, new):
 | 
						|
    """Replace get/set/del functions of a property"""
 | 
						|
    update_generic(old.fdel, new.fdel)
 | 
						|
    update_generic(old.fget, new.fget)
 | 
						|
    update_generic(old.fset, new.fset)
 | 
						|
 | 
						|
 | 
						|
def isinstance2(a, b, typ):
 | 
						|
    return isinstance(a, typ) and isinstance(b, typ)
 | 
						|
 | 
						|
 | 
						|
UPDATE_RULES = [
 | 
						|
    (lambda a, b: isinstance2(a, b, type), update_class),
 | 
						|
    (lambda a, b: isinstance2(a, b, types.FunctionType), update_function),
 | 
						|
    (lambda a, b: isinstance2(a, b, property), update_property),
 | 
						|
]
 | 
						|
UPDATE_RULES.extend(
 | 
						|
    [
 | 
						|
        (
 | 
						|
            lambda a, b: isinstance2(a, b, types.MethodType),
 | 
						|
            lambda a, b: update_function(a.__func__, b.__func__),
 | 
						|
        ),
 | 
						|
    ]
 | 
						|
)
 | 
						|
 | 
						|
 | 
						|
def update_generic(a, b):
 | 
						|
    for type_check, update in UPDATE_RULES:
 | 
						|
        if type_check(a, b):
 | 
						|
            update(a, b)
 | 
						|
            return True
 | 
						|
    return False
 | 
						|
 | 
						|
 | 
						|
class StrongRef:
 | 
						|
    def __init__(self, obj):
 | 
						|
        self.obj = obj
 | 
						|
 | 
						|
    def __call__(self):
 | 
						|
        return self.obj
 | 
						|
 | 
						|
 | 
						|
mod_attrs = [
 | 
						|
    "__name__",
 | 
						|
    "__doc__",
 | 
						|
    "__package__",
 | 
						|
    "__loader__",
 | 
						|
    "__spec__",
 | 
						|
    "__file__",
 | 
						|
    "__cached__",
 | 
						|
    "__builtins__",
 | 
						|
]
 | 
						|
 | 
						|
 | 
						|
class ImportFromTracker:
 | 
						|
    def __init__(self, imports_froms: dict, symbol_map: dict):
 | 
						|
        self.imports_froms = imports_froms
 | 
						|
        # symbol_map maps original_name -> list of resolved_names
 | 
						|
        self.symbol_map = {}
 | 
						|
        if symbol_map:
 | 
						|
            for module_name, mappings in symbol_map.items():
 | 
						|
                self.symbol_map[module_name] = {}
 | 
						|
                for original_name, resolved_names in mappings.items():
 | 
						|
                    if isinstance(resolved_names, list):
 | 
						|
                        self.symbol_map[module_name][original_name] = resolved_names[:]
 | 
						|
                    else:
 | 
						|
                        self.symbol_map[module_name][original_name] = [resolved_names]
 | 
						|
        else:
 | 
						|
            self.symbol_map = symbol_map or {}
 | 
						|
 | 
						|
    def add_import(
 | 
						|
        self, module_name: str, original_name: str, resolved_name: str
 | 
						|
    ) -> None:
 | 
						|
        """Add an import, handling conflicts with existing imports.
 | 
						|
 | 
						|
        This method is called after successful code execution, so we know the import is valid.
 | 
						|
        """
 | 
						|
        if module_name not in self.imports_froms:
 | 
						|
            self.imports_froms[module_name] = []
 | 
						|
        if module_name not in self.symbol_map:
 | 
						|
            self.symbol_map[module_name] = {}
 | 
						|
 | 
						|
        # Check if there's already a different mapping for the same resolved_name from a different original_name
 | 
						|
        # We need to remove any conflicting mappings
 | 
						|
        for orig_name, res_names in list(self.symbol_map[module_name].items()):
 | 
						|
            if resolved_name in res_names and orig_name != original_name:
 | 
						|
                # Remove the conflicting resolved_name from the other original_name's list
 | 
						|
                res_names.remove(resolved_name)
 | 
						|
                if (
 | 
						|
                    not res_names
 | 
						|
                ):  # If the list is now empty, remove the original_name entirely
 | 
						|
                    if orig_name in self.imports_froms[module_name]:
 | 
						|
                        self.imports_froms[module_name].remove(orig_name)
 | 
						|
                    del self.symbol_map[module_name][orig_name]
 | 
						|
 | 
						|
        # Add the new mapping
 | 
						|
        if original_name not in self.imports_froms[module_name]:
 | 
						|
            self.imports_froms[module_name].append(original_name)
 | 
						|
 | 
						|
        if original_name not in self.symbol_map[module_name]:
 | 
						|
            self.symbol_map[module_name][original_name] = []
 | 
						|
 | 
						|
        # Add the resolved_name if it's not already in the list
 | 
						|
        if resolved_name not in self.symbol_map[module_name][original_name]:
 | 
						|
            self.symbol_map[module_name][original_name].append(resolved_name)
 | 
						|
 | 
						|
 | 
						|
def append_obj(module, d, name, obj, autoload=False):
 | 
						|
    in_module = hasattr(obj, "__module__") and obj.__module__ == module.__name__
 | 
						|
    if autoload:
 | 
						|
        # check needed for module global built-ins
 | 
						|
        if not in_module and name in mod_attrs:
 | 
						|
            return False
 | 
						|
    else:
 | 
						|
        if not in_module:
 | 
						|
            return False
 | 
						|
 | 
						|
    key = (module.__name__, name)
 | 
						|
    try:
 | 
						|
        d.setdefault(key, []).append(weakref.ref(obj))
 | 
						|
    except TypeError:
 | 
						|
        pass
 | 
						|
    return True
 | 
						|
 | 
						|
 | 
						|
def superreload(
 | 
						|
    module, reload=reload, old_objects=None, shell=None, import_from_tracker=None
 | 
						|
):
 | 
						|
    """Enhanced version of the builtin reload function.
 | 
						|
 | 
						|
    superreload remembers objects previously in the module, and
 | 
						|
 | 
						|
    - upgrades the class dictionary of every old class in the module
 | 
						|
    - upgrades the code object of every old function and method
 | 
						|
    - clears the module's namespace before reloading
 | 
						|
 | 
						|
    """
 | 
						|
    if old_objects is None:
 | 
						|
        old_objects = {}
 | 
						|
 | 
						|
    # collect old objects in the module
 | 
						|
    for name, obj in list(module.__dict__.items()):
 | 
						|
        if not append_obj(module, old_objects, name, obj):
 | 
						|
            continue
 | 
						|
        key = (module.__name__, name)
 | 
						|
        try:
 | 
						|
            old_objects.setdefault(key, []).append(weakref.ref(obj))
 | 
						|
        except TypeError:
 | 
						|
            pass
 | 
						|
 | 
						|
    # reload module
 | 
						|
    try:
 | 
						|
        # clear namespace first from old cruft
 | 
						|
        old_dict = module.__dict__.copy()
 | 
						|
        old_name = module.__name__
 | 
						|
        module.__dict__.clear()
 | 
						|
        module.__dict__["__name__"] = old_name
 | 
						|
        module.__dict__["__loader__"] = old_dict["__loader__"]
 | 
						|
    except (TypeError, AttributeError, KeyError):
 | 
						|
        pass
 | 
						|
 | 
						|
    try:
 | 
						|
        module = reload(module)
 | 
						|
    except:
 | 
						|
        # restore module dictionary on failed reload
 | 
						|
        module.__dict__.update(old_dict)
 | 
						|
        raise
 | 
						|
 | 
						|
    for name, new_obj in list(module.__dict__.items()):
 | 
						|
        key = (module.__name__, name)
 | 
						|
        if key not in old_objects:
 | 
						|
            # here 'shell' acts both as a flag and as an output var
 | 
						|
            imports_froms = (
 | 
						|
                import_from_tracker.imports_froms if import_from_tracker else None
 | 
						|
            )
 | 
						|
            symbol_map = import_from_tracker.symbol_map if import_from_tracker else None
 | 
						|
            if (
 | 
						|
                shell is None
 | 
						|
                or name == "Enum"
 | 
						|
                or not append_obj(module, old_objects, name, new_obj, True)
 | 
						|
                or (
 | 
						|
                    imports_froms
 | 
						|
                    and module.__name__ in imports_froms
 | 
						|
                    and "*" not in imports_froms[module.__name__]
 | 
						|
                    and name not in imports_froms[module.__name__]
 | 
						|
                )
 | 
						|
            ):
 | 
						|
                continue
 | 
						|
 | 
						|
            # Handle symbol mapping - now supporting multiple resolved names per original name
 | 
						|
            if symbol_map and name in symbol_map.get(module.__name__, {}):
 | 
						|
                resolved_names = symbol_map.get(module.__name__, {})[name]
 | 
						|
                for resolved_name in resolved_names:
 | 
						|
                    shell.user_ns[resolved_name] = new_obj
 | 
						|
            else:
 | 
						|
                shell.user_ns[name] = new_obj
 | 
						|
 | 
						|
        new_refs = []
 | 
						|
        for old_ref in old_objects[key]:
 | 
						|
            old_obj = old_ref()
 | 
						|
            if old_obj is None:
 | 
						|
                continue
 | 
						|
            new_refs.append(old_ref)
 | 
						|
            update_generic(old_obj, new_obj)
 | 
						|
 | 
						|
        if new_refs:
 | 
						|
            old_objects[key] = new_refs
 | 
						|
        else:
 | 
						|
            del old_objects[key]
 | 
						|
 | 
						|
    return module
 | 
						|
 | 
						|
 | 
						|
# ------------------------------------------------------------------------------
 | 
						|
# IPython connectivity
 | 
						|
# ------------------------------------------------------------------------------
 | 
						|
 | 
						|
 | 
						|
@magics_class
 | 
						|
class AutoreloadMagics(Magics):
 | 
						|
    def __init__(self, *a, **kw):
 | 
						|
        super().__init__(*a, **kw)
 | 
						|
        self._reloader = ModuleReloader(self.shell)
 | 
						|
        self._reloader.check_all = False
 | 
						|
        self._reloader.autoload_obj = False
 | 
						|
        self.loaded_modules = set(sys.modules)
 | 
						|
 | 
						|
    @line_magic
 | 
						|
    @magic_arguments.magic_arguments()
 | 
						|
    @magic_arguments.argument(
 | 
						|
        "mode",
 | 
						|
        type=str,
 | 
						|
        default="now",
 | 
						|
        nargs="?",
 | 
						|
        help="""blank or 'now' - Reload all modules (except those excluded by %%aimport)
 | 
						|
             automatically now.
 | 
						|
 | 
						|
             '0' or 'off' - Disable automatic reloading.
 | 
						|
 | 
						|
             '1' or 'explicit' - Reload only modules imported with %%aimport every
 | 
						|
             time before executing the Python code typed.
 | 
						|
 | 
						|
             '2' or 'all' - Reload all modules (except those excluded by %%aimport)
 | 
						|
             every time before executing the Python code typed.
 | 
						|
 | 
						|
             '3' or 'complete' - Same as 2/all, but also adds any new
 | 
						|
             objects in the module.
 | 
						|
             
 | 
						|
             By default, a newer autoreload algorithm that diffs the module's source code
 | 
						|
             with the previous version and only reloads changed parts is applied for modes
 | 
						|
             2 and below. To use the original algorithm, add the `-` suffix to the mode,
 | 
						|
             e.g. '%autoreload 2-', or pass in --full.
 | 
						|
             """,
 | 
						|
    )
 | 
						|
    @magic_arguments.argument(
 | 
						|
        "-p",
 | 
						|
        "--print",
 | 
						|
        action="store_true",
 | 
						|
        default=False,
 | 
						|
        help="Show autoreload activity using `print` statements",
 | 
						|
    )
 | 
						|
    @magic_arguments.argument(
 | 
						|
        "-l",
 | 
						|
        "--log",
 | 
						|
        action="store_true",
 | 
						|
        default=False,
 | 
						|
        help="Show autoreload activity using the logger",
 | 
						|
    )
 | 
						|
    @magic_arguments.argument(
 | 
						|
        "--hide-errors",
 | 
						|
        action="store_true",
 | 
						|
        default=False,
 | 
						|
        help="Hide autoreload errors",
 | 
						|
    )
 | 
						|
    @magic_arguments.argument(
 | 
						|
        "--full",
 | 
						|
        action="store_true",
 | 
						|
        default=False,
 | 
						|
        help="Don't ever use new diffing algorithm",
 | 
						|
    )
 | 
						|
    def autoreload(self, line=""):
 | 
						|
        r"""%autoreload => Reload modules automatically
 | 
						|
 | 
						|
        %autoreload or %autoreload now
 | 
						|
        Reload all modules (except those excluded by %aimport) automatically
 | 
						|
        now.
 | 
						|
 | 
						|
        %autoreload 0 or %autoreload off
 | 
						|
        Disable automatic reloading.
 | 
						|
 | 
						|
        %autoreload 1 or %autoreload explicit
 | 
						|
        Reload only modules imported with %aimport every time before executing
 | 
						|
        the Python code typed.
 | 
						|
 | 
						|
        %autoreload 2 or %autoreload all
 | 
						|
        Reload all modules (except those excluded by %aimport) every time
 | 
						|
        before executing the Python code typed.
 | 
						|
 | 
						|
        %autoreload 3 or %autoreload complete
 | 
						|
        Same as 2/all, but also but also adds any new objects in the module. See
 | 
						|
        unit test at IPython/extensions/tests/test_autoreload.py::test_autoload_newly_added_objects
 | 
						|
 | 
						|
        The optional arguments --print and --log control display of autoreload activity. The default
 | 
						|
        is to act silently; --print (or -p) will print out the names of modules that are being
 | 
						|
        reloaded, and --log (or -l) outputs them to the log at INFO level.
 | 
						|
 | 
						|
        The optional argument --hide-errors hides any errors that can happen when trying to
 | 
						|
        reload code.
 | 
						|
 | 
						|
        Reloading Python modules in a reliable way is in general
 | 
						|
        difficult, and unexpected things may occur. %autoreload tries to
 | 
						|
        work around common pitfalls by replacing function code objects and
 | 
						|
        parts of classes previously in the module with new versions. This
 | 
						|
        makes the following things to work:
 | 
						|
 | 
						|
        - Functions and classes imported via 'from xxx import foo' are upgraded
 | 
						|
          to new versions when 'xxx' is reloaded.
 | 
						|
 | 
						|
        - Methods and properties of classes are upgraded on reload, so that
 | 
						|
          calling 'c.foo()' on an object 'c' created before the reload causes
 | 
						|
          the new code for 'foo' to be executed.
 | 
						|
 | 
						|
        Some of the known remaining caveats are:
 | 
						|
 | 
						|
        - Replacing code objects does not always succeed: changing a @property
 | 
						|
          in a class to an ordinary method or a method to a member variable
 | 
						|
          can cause problems (but in old objects only).
 | 
						|
 | 
						|
        - Functions that are removed (eg. via monkey-patching) from a module
 | 
						|
          before it is reloaded are not upgraded.
 | 
						|
 | 
						|
        - C extension modules cannot be reloaded, and so cannot be
 | 
						|
          autoreloaded.
 | 
						|
 | 
						|
        """
 | 
						|
        args = magic_arguments.parse_argstring(self.autoreload, line)
 | 
						|
        mode = args.mode.lower()
 | 
						|
 | 
						|
        enable_deduperreload = not args.full
 | 
						|
        if mode.endswith("-"):
 | 
						|
            enable_deduperreload = False
 | 
						|
            mode = mode[:-1]
 | 
						|
        self._reloader.deduper_reloader.enabled = enable_deduperreload
 | 
						|
 | 
						|
        p = print
 | 
						|
 | 
						|
        logger = logging.getLogger("autoreload")
 | 
						|
 | 
						|
        l = logger.info
 | 
						|
 | 
						|
        def pl(msg):
 | 
						|
            p(msg)
 | 
						|
            l(msg)
 | 
						|
 | 
						|
        if args.print is False and args.log is False:
 | 
						|
            self._reloader._report = lambda msg: None
 | 
						|
        elif args.print is True:
 | 
						|
            if args.log is True:
 | 
						|
                self._reloader._report = pl
 | 
						|
            else:
 | 
						|
                self._reloader._report = p
 | 
						|
        elif args.log is True:
 | 
						|
            self._reloader._report = l
 | 
						|
 | 
						|
        self._reloader.hide_errors = args.hide_errors
 | 
						|
 | 
						|
        if mode == "" or mode == "now":
 | 
						|
            self._reloader.check(True)
 | 
						|
        elif mode == "0" or mode == "off":
 | 
						|
            self._reloader.enabled = False
 | 
						|
        elif mode == "1" or mode == "explicit":
 | 
						|
            self._reloader.enabled = True
 | 
						|
            self._reloader.check_all = False
 | 
						|
            self._reloader.autoload_obj = False
 | 
						|
        elif mode == "2" or mode == "all":
 | 
						|
            self._reloader.enabled = True
 | 
						|
            self._reloader.check_all = True
 | 
						|
            self._reloader.autoload_obj = False
 | 
						|
        elif mode == "3" or mode == "complete":
 | 
						|
            self._reloader.enabled = True
 | 
						|
            self._reloader.check_all = True
 | 
						|
            self._reloader.autoload_obj = True
 | 
						|
        else:
 | 
						|
            raise ValueError(f'Unrecognized autoreload mode "{mode}".')
 | 
						|
 | 
						|
    @line_magic
 | 
						|
    def aimport(self, parameter_s="", stream=None):
 | 
						|
        """%aimport => Import modules for automatic reloading.
 | 
						|
 | 
						|
        %aimport
 | 
						|
        List modules to automatically import and not to import.
 | 
						|
 | 
						|
        %aimport foo
 | 
						|
        Import module 'foo' and mark it to be autoreloaded for %autoreload explicit
 | 
						|
 | 
						|
        %aimport foo, bar
 | 
						|
        Import modules 'foo', 'bar' and mark them to be autoreloaded for %autoreload explicit
 | 
						|
 | 
						|
        %aimport -foo, bar
 | 
						|
        Mark module 'foo' to not be autoreloaded for %autoreload explicit, all, or complete, and 'bar'
 | 
						|
        to be autoreloaded for mode explicit.
 | 
						|
        """
 | 
						|
        modname = parameter_s
 | 
						|
        if not modname:
 | 
						|
            to_reload = sorted(self._reloader.modules.keys())
 | 
						|
            to_skip = sorted(self._reloader.skip_modules.keys())
 | 
						|
            if stream is None:
 | 
						|
                stream = sys.stdout
 | 
						|
            if self._reloader.check_all:
 | 
						|
                stream.write("Modules to reload:\nall-except-skipped\n")
 | 
						|
            else:
 | 
						|
                stream.write("Modules to reload:\n%s\n" % " ".join(to_reload))
 | 
						|
            stream.write("\nModules to skip:\n%s\n" % " ".join(to_skip))
 | 
						|
        else:
 | 
						|
            for _module in [_.strip() for _ in modname.split(",")]:
 | 
						|
                if _module.startswith("-"):
 | 
						|
                    _module = _module[1:].strip()
 | 
						|
                    self._reloader.mark_module_skipped(_module)
 | 
						|
                else:
 | 
						|
                    top_module, top_name = self._reloader.aimport_module(_module)
 | 
						|
 | 
						|
                    # Inject module to user namespace
 | 
						|
                    self.shell.push({top_name: top_module})
 | 
						|
 | 
						|
    def pre_run_cell(self, info):
 | 
						|
        # Store the execution info for later use in post_execute_hook
 | 
						|
        self._last_execution_info = info
 | 
						|
 | 
						|
        if self._reloader.enabled:
 | 
						|
            try:
 | 
						|
                self._reloader.check()
 | 
						|
            except:
 | 
						|
                pass
 | 
						|
 | 
						|
    def post_execute_hook(self):
 | 
						|
        """Cache the modification times of any modules imported in this execution and track imports"""
 | 
						|
 | 
						|
        # Track imports from the recently executed code if autoreload 3 is enabled
 | 
						|
        if self._reloader.enabled and self._reloader.autoload_obj:
 | 
						|
            # Use the stored execution info
 | 
						|
            if (
 | 
						|
                hasattr(self, "_last_execution_info")
 | 
						|
                and self._last_execution_info
 | 
						|
                and self._last_execution_info.transformed_cell
 | 
						|
            ):
 | 
						|
                self._track_imports_from_code(
 | 
						|
                    self._last_execution_info.transformed_cell
 | 
						|
                )
 | 
						|
 | 
						|
        newly_loaded_modules = set(sys.modules) - self.loaded_modules
 | 
						|
        for modname in newly_loaded_modules:
 | 
						|
            _, pymtime = self._reloader.filename_and_mtime(sys.modules[modname])
 | 
						|
            if pymtime is not None:
 | 
						|
                self._reloader.modules_mtimes[modname] = pymtime
 | 
						|
 | 
						|
        self.loaded_modules.update(newly_loaded_modules)
 | 
						|
 | 
						|
    def _track_imports_from_code(self, code: str) -> None:
 | 
						|
        """Track import statements from executed code"""
 | 
						|
        try:
 | 
						|
            tree = ast.parse(code)
 | 
						|
 | 
						|
            for node in ast.walk(tree):
 | 
						|
                # Handle "from X import Y" style imports
 | 
						|
                if isinstance(node, ast.ImportFrom):
 | 
						|
                    mod = node.module
 | 
						|
 | 
						|
                    # Skip relative imports that don't have a module name
 | 
						|
                    if mod is None:
 | 
						|
                        continue
 | 
						|
 | 
						|
                    for name in node.names:
 | 
						|
                        # name.name is going to be actual name that we want to import from module
 | 
						|
                        # name.asname is Z in the case of from X import Y as Z
 | 
						|
                        # we should update Z in the shell in this situation, so track it too.
 | 
						|
                        original_name = name.name
 | 
						|
                        resolved_name = name.asname if name.asname else name.name
 | 
						|
 | 
						|
                        # Since the code executed successfully, we know this import is valid
 | 
						|
                        self._reloader.import_from_tracker.add_import(
 | 
						|
                            mod, original_name, resolved_name
 | 
						|
                        )
 | 
						|
        except (SyntaxError, ValueError):
 | 
						|
            # If there's a syntax error, skip import tracking
 | 
						|
            # (though this shouldn't happen since the code already executed successfully)
 | 
						|
            pass
 | 
						|
 | 
						|
 | 
						|
def load_ipython_extension(ip):
 | 
						|
    """Load the extension in IPython."""
 | 
						|
    auto_reload = AutoreloadMagics(ip)
 | 
						|
    ip.register_magics(auto_reload)
 | 
						|
    ip.events.register("pre_run_cell", auto_reload.pre_run_cell)
 | 
						|
    ip.events.register("post_execute", auto_reload.post_execute_hook)
 |