You cannot select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
	
	
		
			352 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
			
		
		
	
	
			352 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
""" Defines a KernelClient that provides thread-safe sockets with async callbacks on message
 | 
						|
replies.
 | 
						|
"""
 | 
						|
import asyncio
 | 
						|
import atexit
 | 
						|
import time
 | 
						|
from concurrent.futures import Future
 | 
						|
from functools import partial
 | 
						|
from threading import Thread
 | 
						|
from typing import Any, Dict, List, Optional
 | 
						|
 | 
						|
import zmq
 | 
						|
from tornado.ioloop import IOLoop
 | 
						|
from traitlets import Instance, Type
 | 
						|
from traitlets.log import get_logger
 | 
						|
from zmq.eventloop import zmqstream
 | 
						|
 | 
						|
from .channels import HBChannel
 | 
						|
from .client import KernelClient
 | 
						|
from .session import Session
 | 
						|
 | 
						|
# Local imports
 | 
						|
# import ZMQError in top-level namespace, to avoid ugly attribute-error messages
 | 
						|
# during garbage collection of threads at exit
 | 
						|
 | 
						|
 | 
						|
class ThreadedZMQSocketChannel:
 | 
						|
    """A ZMQ socket invoking a callback in the ioloop"""
 | 
						|
 | 
						|
    session = None
 | 
						|
    socket = None
 | 
						|
    ioloop = None
 | 
						|
    stream = None
 | 
						|
    _inspect = None
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        socket: Optional[zmq.Socket],
 | 
						|
        session: Optional[Session],
 | 
						|
        loop: Optional[IOLoop],
 | 
						|
    ) -> None:
 | 
						|
        """Create a channel.
 | 
						|
 | 
						|
        Parameters
 | 
						|
        ----------
 | 
						|
        socket : :class:`zmq.Socket`
 | 
						|
            The ZMQ socket to use.
 | 
						|
        session : :class:`session.Session`
 | 
						|
            The session to use.
 | 
						|
        loop
 | 
						|
            A tornado ioloop to connect the socket to using a ZMQStream
 | 
						|
        """
 | 
						|
        super().__init__()
 | 
						|
 | 
						|
        self.socket = socket
 | 
						|
        self.session = session
 | 
						|
        self.ioloop = loop
 | 
						|
        f: Future = Future()
 | 
						|
 | 
						|
        def setup_stream() -> None:
 | 
						|
            try:
 | 
						|
                assert self.socket is not None
 | 
						|
                self.stream = zmqstream.ZMQStream(self.socket, self.ioloop)
 | 
						|
                self.stream.on_recv(self._handle_recv)
 | 
						|
            except Exception as e:
 | 
						|
                f.set_exception(e)
 | 
						|
            else:
 | 
						|
                f.set_result(None)
 | 
						|
 | 
						|
        assert self.ioloop is not None
 | 
						|
        self.ioloop.add_callback(setup_stream)
 | 
						|
        # don't wait forever, raise any errors
 | 
						|
        f.result(timeout=10)
 | 
						|
 | 
						|
    _is_alive = False
 | 
						|
 | 
						|
    def is_alive(self) -> bool:
 | 
						|
        """Whether the channel is alive."""
 | 
						|
        return self._is_alive
 | 
						|
 | 
						|
    def start(self) -> None:
 | 
						|
        """Start the channel."""
 | 
						|
        self._is_alive = True
 | 
						|
 | 
						|
    def stop(self) -> None:
 | 
						|
        """Stop the channel."""
 | 
						|
        self._is_alive = False
 | 
						|
 | 
						|
    def close(self) -> None:
 | 
						|
        """Close the channel."""
 | 
						|
        if self.stream is not None and self.ioloop is not None:
 | 
						|
            # c.f.Future for threadsafe results
 | 
						|
            f: Future = Future()
 | 
						|
 | 
						|
            def close_stream() -> None:
 | 
						|
                try:
 | 
						|
                    if self.stream is not None:
 | 
						|
                        self.stream.close(linger=0)
 | 
						|
                        self.stream = None
 | 
						|
                except Exception as e:
 | 
						|
                    f.set_exception(e)
 | 
						|
                else:
 | 
						|
                    f.set_result(None)
 | 
						|
 | 
						|
            self.ioloop.add_callback(close_stream)
 | 
						|
            # wait for result
 | 
						|
            try:
 | 
						|
                f.result(timeout=5)
 | 
						|
            except Exception as e:
 | 
						|
                log = get_logger()
 | 
						|
                msg = f"Error closing stream {self.stream}: {e}"
 | 
						|
                log.warning(msg, RuntimeWarning, stacklevel=2)
 | 
						|
 | 
						|
        if self.socket is not None:
 | 
						|
            try:
 | 
						|
                self.socket.close(linger=0)
 | 
						|
            except Exception:
 | 
						|
                pass
 | 
						|
            self.socket = None
 | 
						|
 | 
						|
    def send(self, msg: Dict[str, Any]) -> None:
 | 
						|
        """Queue a message to be sent from the IOLoop's thread.
 | 
						|
 | 
						|
        Parameters
 | 
						|
        ----------
 | 
						|
        msg : message to send
 | 
						|
 | 
						|
        This is threadsafe, as it uses IOLoop.add_callback to give the loop's
 | 
						|
        thread control of the action.
 | 
						|
        """
 | 
						|
 | 
						|
        def thread_send() -> None:
 | 
						|
            assert self.session is not None
 | 
						|
            self.session.send(self.stream, msg)
 | 
						|
 | 
						|
        assert self.ioloop is not None
 | 
						|
        self.ioloop.add_callback(thread_send)
 | 
						|
 | 
						|
    def _handle_recv(self, msg_list: List) -> None:
 | 
						|
        """Callback for stream.on_recv.
 | 
						|
 | 
						|
        Unpacks message, and calls handlers with it.
 | 
						|
        """
 | 
						|
        assert self.ioloop is not None
 | 
						|
        assert self.session is not None
 | 
						|
        ident, smsg = self.session.feed_identities(msg_list)
 | 
						|
        msg = self.session.deserialize(smsg)
 | 
						|
        # let client inspect messages
 | 
						|
        if self._inspect:
 | 
						|
            self._inspect(msg)  # type:ignore[unreachable]
 | 
						|
        self.call_handlers(msg)
 | 
						|
 | 
						|
    def call_handlers(self, msg: Dict[str, Any]) -> None:
 | 
						|
        """This method is called in the ioloop thread when a message arrives.
 | 
						|
 | 
						|
        Subclasses should override this method to handle incoming messages.
 | 
						|
        It is important to remember that this method is called in the thread
 | 
						|
        so that some logic must be done to ensure that the application level
 | 
						|
        handlers are called in the application thread.
 | 
						|
        """
 | 
						|
        pass
 | 
						|
 | 
						|
    def process_events(self) -> None:
 | 
						|
        """Subclasses should override this with a method
 | 
						|
        processing any pending GUI events.
 | 
						|
        """
 | 
						|
        pass
 | 
						|
 | 
						|
    def flush(self, timeout: float = 1.0) -> None:
 | 
						|
        """Immediately processes all pending messages on this channel.
 | 
						|
 | 
						|
        This is only used for the IOPub channel.
 | 
						|
 | 
						|
        Callers should use this method to ensure that :meth:`call_handlers`
 | 
						|
        has been called for all messages that have been received on the
 | 
						|
        0MQ SUB socket of this channel.
 | 
						|
 | 
						|
        This method is thread safe.
 | 
						|
 | 
						|
        Parameters
 | 
						|
        ----------
 | 
						|
        timeout : float, optional
 | 
						|
            The maximum amount of time to spend flushing, in seconds. The
 | 
						|
            default is one second.
 | 
						|
        """
 | 
						|
        # We do the IOLoop callback process twice to ensure that the IOLoop
 | 
						|
        # gets to perform at least one full poll.
 | 
						|
        stop_time = time.monotonic() + timeout
 | 
						|
        assert self.ioloop is not None
 | 
						|
        if self.stream is None or self.stream.closed():
 | 
						|
            # don't bother scheduling flush on a thread if we're closed
 | 
						|
            _msg = "Attempt to flush closed stream"
 | 
						|
            raise OSError(_msg)
 | 
						|
 | 
						|
        def flush(f: Any) -> None:
 | 
						|
            try:
 | 
						|
                self._flush()
 | 
						|
            except Exception as e:
 | 
						|
                f.set_exception(e)
 | 
						|
            else:
 | 
						|
                f.set_result(None)
 | 
						|
 | 
						|
        for _ in range(2):
 | 
						|
            f: Future = Future()
 | 
						|
            self.ioloop.add_callback(partial(flush, f))
 | 
						|
            # wait for async flush, re-raise any errors
 | 
						|
            timeout = max(stop_time - time.monotonic(), 0)
 | 
						|
            try:
 | 
						|
                f.result(max(stop_time - time.monotonic(), 0))
 | 
						|
            except TimeoutError:
 | 
						|
                # flush with a timeout means stop waiting, not raise
 | 
						|
                return
 | 
						|
 | 
						|
    def _flush(self) -> None:
 | 
						|
        """Callback for :method:`self.flush`."""
 | 
						|
        assert self.stream is not None
 | 
						|
        self.stream.flush()
 | 
						|
        self._flushed = True
 | 
						|
 | 
						|
 | 
						|
class IOLoopThread(Thread):
 | 
						|
    """Run a pyzmq ioloop in a thread to send and receive messages"""
 | 
						|
 | 
						|
    _exiting = False
 | 
						|
    ioloop = None
 | 
						|
 | 
						|
    def __init__(self) -> None:
 | 
						|
        """Initialize an io loop thread."""
 | 
						|
        super().__init__()
 | 
						|
        self.daemon = True
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    @atexit.register
 | 
						|
    def _notice_exit() -> None:
 | 
						|
        # Class definitions can be torn down during interpreter shutdown.
 | 
						|
        # We only need to set _exiting flag if this hasn't happened.
 | 
						|
        if IOLoopThread is not None:
 | 
						|
            IOLoopThread._exiting = True
 | 
						|
 | 
						|
    def start(self) -> None:
 | 
						|
        """Start the IOLoop thread
 | 
						|
 | 
						|
        Don't return until self.ioloop is defined,
 | 
						|
        which is created in the thread
 | 
						|
        """
 | 
						|
        self._start_future: Future = Future()
 | 
						|
        Thread.start(self)
 | 
						|
        # wait for start, re-raise any errors
 | 
						|
        self._start_future.result(timeout=10)
 | 
						|
 | 
						|
    def run(self) -> None:
 | 
						|
        """Run my loop, ignoring EINTR events in the poller"""
 | 
						|
        try:
 | 
						|
            loop = asyncio.new_event_loop()
 | 
						|
            asyncio.set_event_loop(loop)
 | 
						|
 | 
						|
            async def assign_ioloop() -> None:
 | 
						|
                self.ioloop = IOLoop.current()
 | 
						|
 | 
						|
            loop.run_until_complete(assign_ioloop())
 | 
						|
        except Exception as e:
 | 
						|
            self._start_future.set_exception(e)
 | 
						|
        else:
 | 
						|
            self._start_future.set_result(None)
 | 
						|
 | 
						|
        loop.run_until_complete(self._async_run())
 | 
						|
 | 
						|
    async def _async_run(self) -> None:
 | 
						|
        """Run forever (until self._exiting is set)"""
 | 
						|
        while not self._exiting:
 | 
						|
            await asyncio.sleep(1)
 | 
						|
 | 
						|
    def stop(self) -> None:
 | 
						|
        """Stop the channel's event loop and join its thread.
 | 
						|
 | 
						|
        This calls :meth:`~threading.Thread.join` and returns when the thread
 | 
						|
        terminates. :class:`RuntimeError` will be raised if
 | 
						|
        :meth:`~threading.Thread.start` is called again.
 | 
						|
        """
 | 
						|
        self._exiting = True
 | 
						|
        self.join()
 | 
						|
        self.close()
 | 
						|
        self.ioloop = None
 | 
						|
 | 
						|
    def __del__(self) -> None:
 | 
						|
        self.close()
 | 
						|
 | 
						|
    def close(self) -> None:
 | 
						|
        """Close the io loop thread."""
 | 
						|
        if self.ioloop is not None:
 | 
						|
            try:
 | 
						|
                self.ioloop.close(all_fds=True)
 | 
						|
            except Exception:
 | 
						|
                pass
 | 
						|
 | 
						|
 | 
						|
class ThreadedKernelClient(KernelClient):
 | 
						|
    """A KernelClient that provides thread-safe sockets with async callbacks on message replies."""
 | 
						|
 | 
						|
    @property
 | 
						|
    def ioloop(self) -> Optional[IOLoop]:  # type:ignore[override]
 | 
						|
        if self.ioloop_thread:
 | 
						|
            return self.ioloop_thread.ioloop
 | 
						|
        return None
 | 
						|
 | 
						|
    ioloop_thread = Instance(IOLoopThread, allow_none=True)
 | 
						|
 | 
						|
    def start_channels(
 | 
						|
        self,
 | 
						|
        shell: bool = True,
 | 
						|
        iopub: bool = True,
 | 
						|
        stdin: bool = True,
 | 
						|
        hb: bool = True,
 | 
						|
        control: bool = True,
 | 
						|
    ) -> None:
 | 
						|
        """Start the channels on the client."""
 | 
						|
        self.ioloop_thread = IOLoopThread()
 | 
						|
        self.ioloop_thread.start()
 | 
						|
 | 
						|
        if shell:
 | 
						|
            self.shell_channel._inspect = self._check_kernel_info_reply
 | 
						|
 | 
						|
        super().start_channels(shell, iopub, stdin, hb, control)
 | 
						|
 | 
						|
    def _check_kernel_info_reply(self, msg: Dict[str, Any]) -> None:
 | 
						|
        """This is run in the ioloop thread when the kernel info reply is received"""
 | 
						|
        if msg["msg_type"] == "kernel_info_reply":
 | 
						|
            self._handle_kernel_info_reply(msg)
 | 
						|
            self.shell_channel._inspect = None
 | 
						|
 | 
						|
    def stop_channels(self) -> None:
 | 
						|
        """Stop the channels on the client."""
 | 
						|
        super().stop_channels()
 | 
						|
        if self.ioloop_thread and self.ioloop_thread.is_alive():
 | 
						|
            self.ioloop_thread.stop()
 | 
						|
 | 
						|
    iopub_channel_class = Type(ThreadedZMQSocketChannel)  # type:ignore[arg-type]
 | 
						|
    shell_channel_class = Type(ThreadedZMQSocketChannel)  # type:ignore[arg-type]
 | 
						|
    stdin_channel_class = Type(ThreadedZMQSocketChannel)  # type:ignore[arg-type]
 | 
						|
    hb_channel_class = Type(HBChannel)  # type:ignore[arg-type]
 | 
						|
    control_channel_class = Type(ThreadedZMQSocketChannel)  # type:ignore[arg-type]
 | 
						|
 | 
						|
    def is_alive(self) -> bool:
 | 
						|
        """Is the kernel process still running?"""
 | 
						|
        if self._hb_channel is not None:
 | 
						|
            # We don't have access to the KernelManager,
 | 
						|
            # so we use the heartbeat.
 | 
						|
            return self._hb_channel.is_beating()
 | 
						|
        # no heartbeat and not local, we can't tell if it's running,
 | 
						|
        # so naively return True
 | 
						|
        return True
 |