You cannot select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
	
	
		
			469 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
			
		
		
	
	
			469 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
"""The extension manager."""
 | 
						|
 | 
						|
from __future__ import annotations
 | 
						|
 | 
						|
import importlib
 | 
						|
import logging
 | 
						|
from itertools import starmap
 | 
						|
 | 
						|
from tornado.gen import multi
 | 
						|
from traitlets import Any, Bool, Dict, HasTraits, Instance, List, Unicode, default, observe
 | 
						|
from traitlets import validate as validate_trait
 | 
						|
from traitlets.config import LoggingConfigurable
 | 
						|
 | 
						|
from .config import ExtensionConfigManager
 | 
						|
from .utils import ExtensionMetadataError, ExtensionModuleNotFound, get_loader, get_metadata
 | 
						|
 | 
						|
 | 
						|
class ExtensionPoint(HasTraits):
 | 
						|
    """A simple API for connecting to a Jupyter Server extension
 | 
						|
    point defined by metadata and importable from a Python package.
 | 
						|
    """
 | 
						|
 | 
						|
    _linked = Bool(False)
 | 
						|
    _app = Any(None, allow_none=True)
 | 
						|
 | 
						|
    metadata = Dict()
 | 
						|
 | 
						|
    log = Instance(logging.Logger)
 | 
						|
 | 
						|
    @default("log")
 | 
						|
    def _default_log(self):
 | 
						|
        return logging.getLogger("ExtensionPoint")
 | 
						|
 | 
						|
    @validate_trait("metadata")
 | 
						|
    def _valid_metadata(self, proposed):
 | 
						|
        """Validate metadata."""
 | 
						|
        metadata = proposed["value"]
 | 
						|
        # Verify that the metadata has a "name" key.
 | 
						|
        try:
 | 
						|
            self._module_name = metadata["module"]
 | 
						|
        except KeyError:
 | 
						|
            msg = "There is no 'module' key in the extension's metadata packet."
 | 
						|
            raise ExtensionMetadataError(msg) from None
 | 
						|
 | 
						|
        try:
 | 
						|
            self._module = importlib.import_module(self._module_name)
 | 
						|
        except ImportError:
 | 
						|
            msg = (
 | 
						|
                f"The submodule '{self._module_name}' could not be found. Are you "
 | 
						|
                "sure the extension is installed?"
 | 
						|
            )
 | 
						|
            raise ExtensionModuleNotFound(msg) from None
 | 
						|
        # If the metadata includes an ExtensionApp, create an instance.
 | 
						|
        if "app" in metadata:
 | 
						|
            self._app = metadata["app"]()
 | 
						|
        return metadata
 | 
						|
 | 
						|
    @property
 | 
						|
    def linked(self):
 | 
						|
        """Has this extension point been linked to the server.
 | 
						|
 | 
						|
        Will pull from ExtensionApp's trait, if this point
 | 
						|
        is an instance of ExtensionApp.
 | 
						|
        """
 | 
						|
        if self.app:
 | 
						|
            return self.app._linked
 | 
						|
        return self._linked
 | 
						|
 | 
						|
    @property
 | 
						|
    def app(self):
 | 
						|
        """If the metadata includes an `app` field"""
 | 
						|
        return self._app
 | 
						|
 | 
						|
    @property
 | 
						|
    def config(self):
 | 
						|
        """Return any configuration provided by this extension point."""
 | 
						|
        if self.app:
 | 
						|
            return self.app._jupyter_server_config()
 | 
						|
        # At some point, we might want to add logic to load config from
 | 
						|
        # disk when extensions don't use ExtensionApp.
 | 
						|
        else:
 | 
						|
            return {}
 | 
						|
 | 
						|
    @property
 | 
						|
    def module_name(self):
 | 
						|
        """Name of the Python package module where the extension's
 | 
						|
        _load_jupyter_server_extension can be found.
 | 
						|
        """
 | 
						|
        return self._module_name
 | 
						|
 | 
						|
    @property
 | 
						|
    def name(self):
 | 
						|
        """Name of the extension.
 | 
						|
 | 
						|
        If it's not provided in the metadata, `name` is set
 | 
						|
        to the extensions' module name.
 | 
						|
        """
 | 
						|
        if self.app:
 | 
						|
            return self.app.name
 | 
						|
        return self.metadata.get("name", self.module_name)
 | 
						|
 | 
						|
    @property
 | 
						|
    def module(self):
 | 
						|
        """The imported module (using importlib.import_module)"""
 | 
						|
        return self._module
 | 
						|
 | 
						|
    def _get_linker(self):
 | 
						|
        """Get a linker."""
 | 
						|
        if self.app:
 | 
						|
            linker = self.app._link_jupyter_server_extension
 | 
						|
        else:
 | 
						|
            linker = getattr(
 | 
						|
                self.module,
 | 
						|
                # Search for a _link_jupyter_extension
 | 
						|
                "_link_jupyter_server_extension",
 | 
						|
                # Otherwise return a dummy function.
 | 
						|
                lambda serverapp: None,
 | 
						|
            )
 | 
						|
        return linker
 | 
						|
 | 
						|
    def _get_loader(self):
 | 
						|
        """Get a loader."""
 | 
						|
        loc = self.app
 | 
						|
        if not loc:
 | 
						|
            loc = self.module
 | 
						|
        loader = get_loader(loc)
 | 
						|
        return loader
 | 
						|
 | 
						|
    def _get_starter(self):
 | 
						|
        """Get a starter function."""
 | 
						|
        if self.app:
 | 
						|
            linker = self.app._start_jupyter_server_extension
 | 
						|
        else:
 | 
						|
 | 
						|
            async def _noop_start(serverapp):
 | 
						|
                return
 | 
						|
 | 
						|
            linker = getattr(
 | 
						|
                self.module,
 | 
						|
                # Search for a _start_jupyter_extension
 | 
						|
                "_start_jupyter_server_extension",
 | 
						|
                # Otherwise return a no-op function.
 | 
						|
                _noop_start,
 | 
						|
            )
 | 
						|
        return linker
 | 
						|
 | 
						|
    def validate(self):
 | 
						|
        """Check that both a linker and loader exists."""
 | 
						|
        try:
 | 
						|
            self._get_linker()
 | 
						|
            self._get_loader()
 | 
						|
        except Exception:
 | 
						|
            return False
 | 
						|
        else:
 | 
						|
            return True
 | 
						|
 | 
						|
    def link(self, serverapp):
 | 
						|
        """Link the extension to a Jupyter ServerApp object.
 | 
						|
 | 
						|
        This looks for a `_link_jupyter_server_extension` function
 | 
						|
        in the extension's module or ExtensionApp class.
 | 
						|
        """
 | 
						|
        if not self.linked:
 | 
						|
            linker = self._get_linker()
 | 
						|
            linker(serverapp)
 | 
						|
            # Store this extension as already linked.
 | 
						|
            self._linked = True
 | 
						|
 | 
						|
    def load(self, serverapp):
 | 
						|
        """Load the extension in a Jupyter ServerApp object.
 | 
						|
 | 
						|
        This looks for a `_load_jupyter_server_extension` function
 | 
						|
        in the extension's module or ExtensionApp class.
 | 
						|
        """
 | 
						|
        loader = self._get_loader()
 | 
						|
        return loader(serverapp)
 | 
						|
 | 
						|
    async def start(self, serverapp):
 | 
						|
        """Call's the extensions 'start' hook where it can
 | 
						|
        start (possibly async) tasks _after_ the event loop is running.
 | 
						|
        """
 | 
						|
        starter = self._get_starter()
 | 
						|
        return await starter(serverapp)
 | 
						|
 | 
						|
 | 
						|
class ExtensionPackage(LoggingConfigurable):
 | 
						|
    """An API for interfacing with a Jupyter Server extension package.
 | 
						|
 | 
						|
    Usage:
 | 
						|
 | 
						|
    ext_name = "my_extensions"
 | 
						|
    extpkg = ExtensionPackage(name=ext_name)
 | 
						|
    """
 | 
						|
 | 
						|
    name = Unicode(help="Name of the an importable Python package.")
 | 
						|
    enabled = Bool(False, help="Whether the extension package is enabled.")
 | 
						|
 | 
						|
    _linked_points = Dict()
 | 
						|
    extension_points = Dict()
 | 
						|
    module = Any(allow_none=True, help="The module for this extension package. None if not enabled")
 | 
						|
    metadata = List(Dict(), help="Extension metadata loaded from the extension package.")
 | 
						|
    version = Unicode(
 | 
						|
        help="""
 | 
						|
            The version of this extension package, if it can be found.
 | 
						|
            Otherwise, an empty string.
 | 
						|
            """,
 | 
						|
    )
 | 
						|
 | 
						|
    @default("version")
 | 
						|
    def _load_version(self):
 | 
						|
        if not self.enabled:
 | 
						|
            return ""
 | 
						|
        return getattr(self.module, "__version__", "")
 | 
						|
 | 
						|
    def __init__(self, **kwargs):
 | 
						|
        """Initialize an extension package."""
 | 
						|
        super().__init__(**kwargs)
 | 
						|
        if self.enabled:
 | 
						|
            self._load_metadata()
 | 
						|
 | 
						|
    def _load_metadata(self):
 | 
						|
        """Import package and load metadata
 | 
						|
 | 
						|
        Only used if extension package is enabled
 | 
						|
        """
 | 
						|
        name = self.name
 | 
						|
        try:
 | 
						|
            self.module, self.metadata = get_metadata(name, logger=self.log)
 | 
						|
        except ImportError as e:
 | 
						|
            msg = (
 | 
						|
                f"The module '{name}' could not be found ({e}). Are you "
 | 
						|
                "sure the extension is installed?"
 | 
						|
            )
 | 
						|
            raise ExtensionModuleNotFound(msg) from None
 | 
						|
        # Create extension point interfaces for each extension path.
 | 
						|
        for m in self.metadata:
 | 
						|
            point = ExtensionPoint(metadata=m, log=self.log)
 | 
						|
            self.extension_points[point.name] = point
 | 
						|
        return name
 | 
						|
 | 
						|
    def validate(self):
 | 
						|
        """Validate all extension points in this package."""
 | 
						|
        return all(extension.validate() for extension in self.extension_points.values())
 | 
						|
 | 
						|
    def link_point(self, point_name, serverapp):
 | 
						|
        """Link an extension point."""
 | 
						|
        linked = self._linked_points.get(point_name, False)
 | 
						|
        if not linked:
 | 
						|
            point = self.extension_points[point_name]
 | 
						|
            point.link(serverapp)
 | 
						|
 | 
						|
    def load_point(self, point_name, serverapp):
 | 
						|
        """Load an extension point."""
 | 
						|
        point = self.extension_points[point_name]
 | 
						|
        return point.load(serverapp)
 | 
						|
 | 
						|
    async def start_point(self, point_name, serverapp):
 | 
						|
        """Load an extension point."""
 | 
						|
        point = self.extension_points[point_name]
 | 
						|
        return await point.start(serverapp)
 | 
						|
 | 
						|
    def link_all_points(self, serverapp):
 | 
						|
        """Link all extension points."""
 | 
						|
        for point_name in self.extension_points:
 | 
						|
            self.link_point(point_name, serverapp)
 | 
						|
 | 
						|
    def load_all_points(self, serverapp):
 | 
						|
        """Load all extension points."""
 | 
						|
        return [self.load_point(point_name, serverapp) for point_name in self.extension_points]
 | 
						|
 | 
						|
    async def start_all_points(self, serverapp):
 | 
						|
        """Load all extension points."""
 | 
						|
        for point_name in self.extension_points:
 | 
						|
            await self.start_point(point_name, serverapp)
 | 
						|
 | 
						|
 | 
						|
class ExtensionManager(LoggingConfigurable):
 | 
						|
    """High level interface for finding, validating,
 | 
						|
    linking, loading, and managing Jupyter Server extensions.
 | 
						|
 | 
						|
    Usage:
 | 
						|
    m = ExtensionManager(config_manager=...)
 | 
						|
    """
 | 
						|
 | 
						|
    config_manager = Instance(ExtensionConfigManager, allow_none=True)
 | 
						|
 | 
						|
    serverapp = Any()  # Use Any to avoid circular import of Instance(ServerApp)
 | 
						|
 | 
						|
    @default("config_manager")
 | 
						|
    def _load_default_config_manager(self):
 | 
						|
        config_manager = ExtensionConfigManager()
 | 
						|
        self._load_config_manager(config_manager)
 | 
						|
        return config_manager
 | 
						|
 | 
						|
    @observe("config_manager")
 | 
						|
    def _config_manager_changed(self, change):
 | 
						|
        if change.new:
 | 
						|
            self._load_config_manager(change.new)
 | 
						|
 | 
						|
    # The `extensions` attribute provides a dictionary
 | 
						|
    # with extension (package) names mapped to their ExtensionPackage interface
 | 
						|
    # (see above). This manager simplifies the interaction between the
 | 
						|
    # ServerApp and the extensions being appended.
 | 
						|
    extensions = Dict(
 | 
						|
        help="""
 | 
						|
        Dictionary with extension package names as keys
 | 
						|
        and ExtensionPackage objects as values.
 | 
						|
        """
 | 
						|
    )
 | 
						|
 | 
						|
    @property
 | 
						|
    def sorted_extensions(self):
 | 
						|
        """Returns an extensions dictionary, sorted alphabetically."""
 | 
						|
        return dict(sorted(self.extensions.items()))
 | 
						|
 | 
						|
    # The `_linked_extensions` attribute tracks when each extension
 | 
						|
    # has been successfully linked to a ServerApp. This helps prevent
 | 
						|
    # extensions from being re-linked recursively unintentionally if another
 | 
						|
    # extension attempts to link extensions again.
 | 
						|
    linked_extensions = Dict(
 | 
						|
        help="""
 | 
						|
        Dictionary with extension names as keys
 | 
						|
 | 
						|
        values are True if the extension is linked, False if not.
 | 
						|
        """
 | 
						|
    )
 | 
						|
 | 
						|
    @property
 | 
						|
    def extension_apps(self):
 | 
						|
        """Return mapping of extension names and sets of ExtensionApp objects."""
 | 
						|
        return {
 | 
						|
            name: {point.app for point in extension.extension_points.values() if point.app}
 | 
						|
            for name, extension in self.extensions.items()
 | 
						|
        }
 | 
						|
 | 
						|
    @property
 | 
						|
    def extension_points(self):
 | 
						|
        """Return mapping of extension point names and ExtensionPoint objects."""
 | 
						|
        return {
 | 
						|
            name: point
 | 
						|
            for value in self.extensions.values()
 | 
						|
            for name, point in value.extension_points.items()
 | 
						|
        }
 | 
						|
 | 
						|
    def from_config_manager(self, config_manager):
 | 
						|
        """Add extensions found by an ExtensionConfigManager"""
 | 
						|
        # load triggered via config_manager trait observer
 | 
						|
        self.config_manager = config_manager
 | 
						|
 | 
						|
    def _load_config_manager(self, config_manager):
 | 
						|
        """Actually load our config manager"""
 | 
						|
        jpserver_extensions = config_manager.get_jpserver_extensions()
 | 
						|
        self.from_jpserver_extensions(jpserver_extensions)
 | 
						|
 | 
						|
    def from_jpserver_extensions(self, jpserver_extensions):
 | 
						|
        """Add extensions from 'jpserver_extensions'-like dictionary."""
 | 
						|
        for name, enabled in jpserver_extensions.items():
 | 
						|
            self.add_extension(name, enabled=enabled)
 | 
						|
 | 
						|
    def add_extension(self, extension_name, enabled=False):
 | 
						|
        """Try to add extension to manager, return True if successful.
 | 
						|
        Otherwise, return False.
 | 
						|
        """
 | 
						|
        try:
 | 
						|
            extpkg = ExtensionPackage(name=extension_name, enabled=enabled)
 | 
						|
            self.extensions[extension_name] = extpkg
 | 
						|
            return True
 | 
						|
        # Raise a warning if the extension cannot be loaded.
 | 
						|
        except Exception as e:
 | 
						|
            if self.serverapp and self.serverapp.reraise_server_extension_failures:
 | 
						|
                raise
 | 
						|
            self.log.warning(
 | 
						|
                "%s | error adding extension (enabled: %s): %s",
 | 
						|
                extension_name,
 | 
						|
                enabled,
 | 
						|
                e,
 | 
						|
                exc_info=True,
 | 
						|
            )
 | 
						|
        return False
 | 
						|
 | 
						|
    def link_extension(self, name):
 | 
						|
        """Link an extension by name."""
 | 
						|
        linked = self.linked_extensions.get(name, False)
 | 
						|
        extension = self.extensions[name]
 | 
						|
        if not linked and extension.enabled:
 | 
						|
            try:
 | 
						|
                # Link extension and store links
 | 
						|
                extension.link_all_points(self.serverapp)
 | 
						|
                self.linked_extensions[name] = True
 | 
						|
                self.log.info("%s | extension was successfully linked.", name)
 | 
						|
            except Exception as e:
 | 
						|
                if self.serverapp and self.serverapp.reraise_server_extension_failures:
 | 
						|
                    raise
 | 
						|
                self.log.warning("%s | error linking extension: %s", name, e, exc_info=True)
 | 
						|
 | 
						|
    def load_extension(self, name):
 | 
						|
        """Load an extension by name."""
 | 
						|
        extension = self.extensions.get(name)
 | 
						|
 | 
						|
        if extension and extension.enabled:
 | 
						|
            try:
 | 
						|
                extension.load_all_points(self.serverapp)
 | 
						|
            except Exception as e:
 | 
						|
                if self.serverapp and self.serverapp.reraise_server_extension_failures:
 | 
						|
                    raise
 | 
						|
                self.log.warning(
 | 
						|
                    "%s | extension failed loading with message: %r", name, e, exc_info=True
 | 
						|
                )
 | 
						|
            else:
 | 
						|
                self.log.info("%s | extension was successfully loaded.", name)
 | 
						|
 | 
						|
    async def start_extension(self, name):
 | 
						|
        """Start an extension by name."""
 | 
						|
        extension = self.extensions.get(name)
 | 
						|
 | 
						|
        if extension and extension.enabled:
 | 
						|
            try:
 | 
						|
                await extension.start_all_points(self.serverapp)
 | 
						|
            except Exception as e:
 | 
						|
                if self.serverapp and self.serverapp.reraise_server_extension_failures:
 | 
						|
                    raise
 | 
						|
                self.log.warning(
 | 
						|
                    "%s | extension failed starting with message: %r", name, e, exc_info=True
 | 
						|
                )
 | 
						|
            else:
 | 
						|
                self.log.debug("%s | extension was successfully started.", name)
 | 
						|
 | 
						|
    async def stop_extension(self, name, apps):
 | 
						|
        """Call the shutdown hooks in the specified apps."""
 | 
						|
        for app in apps:
 | 
						|
            self.log.debug("%s | extension app %r stopping", name, app.name)
 | 
						|
            await app.stop_extension()
 | 
						|
            self.log.debug("%s | extension app %r stopped", name, app.name)
 | 
						|
 | 
						|
    def link_all_extensions(self):
 | 
						|
        """Link all enabled extensions
 | 
						|
        to an instance of ServerApp
 | 
						|
        """
 | 
						|
        # Sort the extension names to enforce deterministic linking
 | 
						|
        # order.
 | 
						|
        for name in self.sorted_extensions:
 | 
						|
            self.link_extension(name)
 | 
						|
 | 
						|
    def load_all_extensions(self):
 | 
						|
        """Load all enabled extensions and append them to
 | 
						|
        the parent ServerApp.
 | 
						|
        """
 | 
						|
        # Sort the extension names to enforce deterministic loading
 | 
						|
        # order.
 | 
						|
        for name in self.sorted_extensions:
 | 
						|
            self.load_extension(name)
 | 
						|
 | 
						|
    async def start_all_extensions(self):
 | 
						|
        """Start all enabled extensions."""
 | 
						|
        # Sort the extension names to enforce deterministic loading
 | 
						|
        # order.
 | 
						|
        await multi([self.start_extension(name) for name in self.sorted_extensions])
 | 
						|
 | 
						|
    async def stop_all_extensions(self):
 | 
						|
        """Call the shutdown hooks in all extensions."""
 | 
						|
        await multi(list(starmap(self.stop_extension, sorted(dict(self.extension_apps).items()))))
 | 
						|
 | 
						|
    def any_activity(self):
 | 
						|
        """Check for any activity currently happening across all extension applications."""
 | 
						|
        for _, apps in sorted(dict(self.extension_apps).items()):
 | 
						|
            for app in apps:
 | 
						|
                if app.current_activity():
 | 
						|
                    return True
 |