You cannot select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
	
	
		
			243 lines
		
	
	
		
			9.8 KiB
		
	
	
	
		
			Python
		
	
			
		
		
	
	
			243 lines
		
	
	
		
			9.8 KiB
		
	
	
	
		
			Python
		
	
"""Kernel Provisioner Classes"""
 | 
						|
# Copyright (c) Jupyter Development Team.
 | 
						|
# Distributed under the terms of the Modified BSD License.
 | 
						|
import asyncio
 | 
						|
import os
 | 
						|
import signal
 | 
						|
import sys
 | 
						|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
 | 
						|
 | 
						|
from ..connect import KernelConnectionInfo, LocalPortCache
 | 
						|
from ..launcher import launch_kernel
 | 
						|
from ..localinterfaces import is_local_ip, local_ips
 | 
						|
from .provisioner_base import KernelProvisionerBase
 | 
						|
 | 
						|
 | 
						|
class LocalProvisioner(KernelProvisionerBase):  # type:ignore[misc]
 | 
						|
    """
 | 
						|
    :class:`LocalProvisioner` is a concrete class of ABC :py:class:`KernelProvisionerBase`
 | 
						|
    and is the out-of-box default implementation used when no kernel provisioner is
 | 
						|
    specified in the kernel specification (``kernel.json``).  It provides functional
 | 
						|
    parity to existing applications by launching the kernel locally and using
 | 
						|
    :class:`subprocess.Popen` to manage its lifecycle.
 | 
						|
 | 
						|
    This class is intended to be subclassed for customizing local kernel environments
 | 
						|
    and serve as a reference implementation for other custom provisioners.
 | 
						|
    """
 | 
						|
 | 
						|
    process = None
 | 
						|
    _exit_future = None
 | 
						|
    pid = None
 | 
						|
    pgid = None
 | 
						|
    ip = None
 | 
						|
    ports_cached = False
 | 
						|
 | 
						|
    @property
 | 
						|
    def has_process(self) -> bool:
 | 
						|
        return self.process is not None
 | 
						|
 | 
						|
    async def poll(self) -> Optional[int]:
 | 
						|
        """Poll the provisioner."""
 | 
						|
        ret = 0
 | 
						|
        if self.process:
 | 
						|
            ret = self.process.poll()  # type:ignore[unreachable]
 | 
						|
        return ret
 | 
						|
 | 
						|
    async def wait(self) -> Optional[int]:
 | 
						|
        """Wait for the provisioner process."""
 | 
						|
        ret = 0
 | 
						|
        if self.process:
 | 
						|
            # Use busy loop at 100ms intervals, polling until the process is
 | 
						|
            # not alive.  If we find the process is no longer alive, complete
 | 
						|
            # its cleanup via the blocking wait().  Callers are responsible for
 | 
						|
            # issuing calls to wait() using a timeout (see kill()).
 | 
						|
            while await self.poll() is None:  # type:ignore[unreachable]
 | 
						|
                await asyncio.sleep(0.1)
 | 
						|
 | 
						|
            # Process is no longer alive, wait and clear
 | 
						|
            ret = self.process.wait()
 | 
						|
            # Make sure all the fds get closed.
 | 
						|
            for attr in ["stdout", "stderr", "stdin"]:
 | 
						|
                fid = getattr(self.process, attr)
 | 
						|
                if fid:
 | 
						|
                    fid.close()
 | 
						|
            self.process = None  # allow has_process to now return False
 | 
						|
        return ret
 | 
						|
 | 
						|
    async def send_signal(self, signum: int) -> None:
 | 
						|
        """Sends a signal to the process group of the kernel (this
 | 
						|
        usually includes the kernel and any subprocesses spawned by
 | 
						|
        the kernel).
 | 
						|
 | 
						|
        Note that since only SIGTERM is supported on Windows, we will
 | 
						|
        check if the desired signal is for interrupt and apply the
 | 
						|
        applicable code on Windows in that case.
 | 
						|
        """
 | 
						|
        if self.process:
 | 
						|
            if signum == signal.SIGINT and sys.platform == "win32":  # type:ignore[unreachable]
 | 
						|
                from ..win_interrupt import send_interrupt
 | 
						|
 | 
						|
                send_interrupt(self.process.win32_interrupt_event)
 | 
						|
                return
 | 
						|
 | 
						|
            # Prefer process-group over process
 | 
						|
            if self.pgid and hasattr(os, "killpg"):
 | 
						|
                try:
 | 
						|
                    os.killpg(self.pgid, signum)
 | 
						|
                    return
 | 
						|
                except OSError:
 | 
						|
                    pass  # We'll retry sending the signal to only the process below
 | 
						|
 | 
						|
            # If we're here, send the signal to the process and let caller handle exceptions
 | 
						|
            self.process.send_signal(signum)
 | 
						|
            return
 | 
						|
 | 
						|
    async def kill(self, restart: bool = False) -> None:
 | 
						|
        """Kill the provisioner and optionally restart."""
 | 
						|
        if self.process:
 | 
						|
            if hasattr(signal, "SIGKILL"):  # type:ignore[unreachable]
 | 
						|
                # If available, give preference to signalling the process-group over `kill()`.
 | 
						|
                try:
 | 
						|
                    await self.send_signal(signal.SIGKILL)
 | 
						|
                    return
 | 
						|
                except OSError:
 | 
						|
                    pass
 | 
						|
            try:
 | 
						|
                self.process.kill()
 | 
						|
            except OSError as e:
 | 
						|
                LocalProvisioner._tolerate_no_process(e)
 | 
						|
 | 
						|
    async def terminate(self, restart: bool = False) -> None:
 | 
						|
        """Terminate the provisioner and optionally restart."""
 | 
						|
        if self.process:
 | 
						|
            if hasattr(signal, "SIGTERM"):  # type:ignore[unreachable]
 | 
						|
                # If available, give preference to signalling the process group over `terminate()`.
 | 
						|
                try:
 | 
						|
                    await self.send_signal(signal.SIGTERM)
 | 
						|
                    return
 | 
						|
                except OSError:
 | 
						|
                    pass
 | 
						|
            try:
 | 
						|
                self.process.terminate()
 | 
						|
            except OSError as e:
 | 
						|
                LocalProvisioner._tolerate_no_process(e)
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def _tolerate_no_process(os_error: OSError) -> None:
 | 
						|
        # In Windows, we will get an Access Denied error if the process
 | 
						|
        # has already terminated. Ignore it.
 | 
						|
        if sys.platform == "win32":
 | 
						|
            if os_error.winerror != 5:
 | 
						|
                raise
 | 
						|
        # On Unix, we may get an ESRCH error (or ProcessLookupError instance) if
 | 
						|
        # the process has already terminated. Ignore it.
 | 
						|
        else:
 | 
						|
            from errno import ESRCH
 | 
						|
 | 
						|
            if not isinstance(os_error, ProcessLookupError) or os_error.errno != ESRCH:
 | 
						|
                raise
 | 
						|
 | 
						|
    async def cleanup(self, restart: bool = False) -> None:
 | 
						|
        """Clean up the resources used by the provisioner and optionally restart."""
 | 
						|
        if self.ports_cached and not restart:
 | 
						|
            # provisioner is about to be destroyed, return cached ports
 | 
						|
            lpc = LocalPortCache.instance()
 | 
						|
            ports = (
 | 
						|
                self.connection_info["shell_port"],
 | 
						|
                self.connection_info["iopub_port"],
 | 
						|
                self.connection_info["stdin_port"],
 | 
						|
                self.connection_info["hb_port"],
 | 
						|
                self.connection_info["control_port"],
 | 
						|
            )
 | 
						|
            for port in ports:
 | 
						|
                if TYPE_CHECKING:
 | 
						|
                    assert isinstance(port, int)
 | 
						|
                lpc.return_port(port)
 | 
						|
 | 
						|
    async def pre_launch(self, **kwargs: Any) -> Dict[str, Any]:
 | 
						|
        """Perform any steps in preparation for kernel process launch.
 | 
						|
 | 
						|
        This includes applying additional substitutions to the kernel launch command and env.
 | 
						|
        It also includes preparation of launch parameters.
 | 
						|
 | 
						|
        Returns the updated kwargs.
 | 
						|
        """
 | 
						|
 | 
						|
        # This should be considered temporary until a better division of labor can be defined.
 | 
						|
        km = self.parent
 | 
						|
        if km:
 | 
						|
            if km.transport == "tcp" and not is_local_ip(km.ip):
 | 
						|
                msg = (
 | 
						|
                    "Can only launch a kernel on a local interface. "
 | 
						|
                    f"This one is not: {km.ip}."
 | 
						|
                    "Make sure that the '*_address' attributes are "
 | 
						|
                    "configured properly. "
 | 
						|
                    f"Currently valid addresses are: {local_ips()}"
 | 
						|
                )
 | 
						|
                raise RuntimeError(msg)
 | 
						|
            # build the Popen cmd
 | 
						|
            extra_arguments = kwargs.pop("extra_arguments", [])
 | 
						|
 | 
						|
            # write connection file / get default ports
 | 
						|
            # TODO - change when handshake pattern is adopted
 | 
						|
            if km.cache_ports and not self.ports_cached:
 | 
						|
                lpc = LocalPortCache.instance()
 | 
						|
                km.shell_port = lpc.find_available_port(km.ip)
 | 
						|
                km.iopub_port = lpc.find_available_port(km.ip)
 | 
						|
                km.stdin_port = lpc.find_available_port(km.ip)
 | 
						|
                km.hb_port = lpc.find_available_port(km.ip)
 | 
						|
                km.control_port = lpc.find_available_port(km.ip)
 | 
						|
                self.ports_cached = True
 | 
						|
            if "env" in kwargs:
 | 
						|
                jupyter_session = kwargs["env"].get("JPY_SESSION_NAME", "")
 | 
						|
                km.write_connection_file(jupyter_session=jupyter_session)
 | 
						|
            else:
 | 
						|
                km.write_connection_file()
 | 
						|
            self.connection_info = km.get_connection_info()
 | 
						|
 | 
						|
            kernel_cmd = km.format_kernel_cmd(
 | 
						|
                extra_arguments=extra_arguments
 | 
						|
            )  # This needs to remain here for b/c
 | 
						|
        else:
 | 
						|
            extra_arguments = kwargs.pop("extra_arguments", [])
 | 
						|
            kernel_cmd = self.kernel_spec.argv + extra_arguments
 | 
						|
 | 
						|
        return await super().pre_launch(cmd=kernel_cmd, **kwargs)
 | 
						|
 | 
						|
    async def launch_kernel(self, cmd: List[str], **kwargs: Any) -> KernelConnectionInfo:
 | 
						|
        """Launch a kernel with a command."""
 | 
						|
        scrubbed_kwargs = LocalProvisioner._scrub_kwargs(kwargs)
 | 
						|
        self.process = launch_kernel(cmd, **scrubbed_kwargs)
 | 
						|
        pgid = None
 | 
						|
        if hasattr(os, "getpgid"):
 | 
						|
            try:
 | 
						|
                pgid = os.getpgid(self.process.pid)
 | 
						|
            except OSError:
 | 
						|
                pass
 | 
						|
 | 
						|
        self.pid = self.process.pid
 | 
						|
        self.pgid = pgid
 | 
						|
        return self.connection_info
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def _scrub_kwargs(kwargs: Dict[str, Any]) -> Dict[str, Any]:
 | 
						|
        """Remove any keyword arguments that Popen does not tolerate."""
 | 
						|
        keywords_to_scrub: List[str] = ["extra_arguments", "kernel_id"]
 | 
						|
        scrubbed_kwargs = kwargs.copy()
 | 
						|
        for kw in keywords_to_scrub:
 | 
						|
            scrubbed_kwargs.pop(kw, None)
 | 
						|
        return scrubbed_kwargs
 | 
						|
 | 
						|
    async def get_provisioner_info(self) -> Dict:
 | 
						|
        """Captures the base information necessary for persistence relative to this instance."""
 | 
						|
        provisioner_info = await super().get_provisioner_info()
 | 
						|
        provisioner_info.update({"pid": self.pid, "pgid": self.pgid, "ip": self.ip})
 | 
						|
        return provisioner_info
 | 
						|
 | 
						|
    async def load_provisioner_info(self, provisioner_info: Dict) -> None:
 | 
						|
        """Loads the base information necessary for persistence relative to this instance."""
 | 
						|
        await super().load_provisioner_info(provisioner_info)
 | 
						|
        self.pid = provisioner_info["pid"]
 | 
						|
        self.pgid = provisioner_info["pgid"]
 | 
						|
        self.ip = provisioner_info["ip"]
 |