You cannot select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
	
	
		
			327 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
			
		
		
	
	
			327 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
"""Default classes for Comm and CommManager, for usage in IPython."""
 | 
						|
 | 
						|
# Copyright (c) IPython Development Team.
 | 
						|
# Distributed under the terms of the Modified BSD License.
 | 
						|
from __future__ import annotations
 | 
						|
 | 
						|
import contextlib
 | 
						|
import logging
 | 
						|
import typing as t
 | 
						|
import uuid
 | 
						|
 | 
						|
import comm
 | 
						|
 | 
						|
if t.TYPE_CHECKING:
 | 
						|
    from zmq.eventloop.zmqstream import ZMQStream
 | 
						|
 | 
						|
logger = logging.getLogger("Comm")
 | 
						|
 | 
						|
MessageType = t.Dict[str, t.Any]
 | 
						|
MaybeDict = t.Optional[t.Dict[str, t.Any]]
 | 
						|
BuffersType = t.Optional[t.List[bytes]]
 | 
						|
CommCallback = t.Callable[[MessageType], None]
 | 
						|
CommTargetCallback = t.Callable[["BaseComm", MessageType], None]
 | 
						|
 | 
						|
 | 
						|
class BaseComm:
 | 
						|
    """Class for communicating between a Frontend and a Kernel
 | 
						|
 | 
						|
    Must be subclassed with a publish_msg method implementation which
 | 
						|
    sends comm messages through the iopub channel.
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        target_name: str = "comm",
 | 
						|
        data: MaybeDict = None,
 | 
						|
        metadata: MaybeDict = None,
 | 
						|
        buffers: BuffersType = None,
 | 
						|
        comm_id: str | None = None,
 | 
						|
        primary: bool = True,
 | 
						|
        target_module: str | None = None,
 | 
						|
        topic: bytes | None = None,
 | 
						|
        _open_data: MaybeDict = None,
 | 
						|
        _close_data: MaybeDict = None,
 | 
						|
        **kwargs: t.Any,
 | 
						|
    ) -> None:
 | 
						|
        super().__init__(**kwargs)
 | 
						|
 | 
						|
        self.comm_id = comm_id if comm_id else uuid.uuid4().hex
 | 
						|
        self.primary = primary
 | 
						|
        self.target_name = target_name
 | 
						|
        self.target_module = target_module
 | 
						|
        self.topic = topic if topic else (f"comm-{self.comm_id}").encode("ascii")
 | 
						|
 | 
						|
        self._open_data = _open_data if _open_data else {}
 | 
						|
        self._close_data = _close_data if _close_data else {}
 | 
						|
 | 
						|
        self._msg_callback: CommCallback | None = None
 | 
						|
        self._close_callback: CommCallback | None = None
 | 
						|
 | 
						|
        self._closed = True
 | 
						|
 | 
						|
        if self.primary:
 | 
						|
            # I am primary, open my peer.
 | 
						|
            self.open(data=data, metadata=metadata, buffers=buffers)
 | 
						|
        else:
 | 
						|
            self._closed = False
 | 
						|
 | 
						|
    def publish_msg(
 | 
						|
        self,
 | 
						|
        msg_type: str,
 | 
						|
        data: MaybeDict = None,
 | 
						|
        metadata: MaybeDict = None,
 | 
						|
        buffers: BuffersType = None,
 | 
						|
        **keys: t.Any,
 | 
						|
    ) -> None:
 | 
						|
        msg = "publish_msg Comm method is not implemented"
 | 
						|
        raise NotImplementedError(msg)
 | 
						|
 | 
						|
    def __del__(self) -> None:
 | 
						|
        """trigger close on gc"""
 | 
						|
        with contextlib.suppress(Exception):
 | 
						|
            # any number of things can have gone horribly wrong
 | 
						|
            # when called during interpreter teardown
 | 
						|
            self.close(deleting=True)
 | 
						|
 | 
						|
    # publishing messages
 | 
						|
 | 
						|
    def open(
 | 
						|
        self, data: MaybeDict = None, metadata: MaybeDict = None, buffers: BuffersType = None
 | 
						|
    ) -> None:
 | 
						|
        """Open the frontend-side version of this comm"""
 | 
						|
 | 
						|
        if data is None:
 | 
						|
            data = self._open_data
 | 
						|
        comm_manager = comm.get_comm_manager()
 | 
						|
        if comm_manager is None:
 | 
						|
            msg = "Comms cannot be opened without a comm_manager."  # type:ignore[unreachable]
 | 
						|
            raise RuntimeError(msg)
 | 
						|
 | 
						|
        comm_manager.register_comm(self)
 | 
						|
        try:
 | 
						|
            self.publish_msg(
 | 
						|
                "comm_open",
 | 
						|
                data=data,
 | 
						|
                metadata=metadata,
 | 
						|
                buffers=buffers,
 | 
						|
                target_name=self.target_name,
 | 
						|
                target_module=self.target_module,
 | 
						|
            )
 | 
						|
            self._closed = False
 | 
						|
        except Exception:
 | 
						|
            comm_manager.unregister_comm(self)
 | 
						|
            raise
 | 
						|
 | 
						|
    def close(
 | 
						|
        self,
 | 
						|
        data: MaybeDict = None,
 | 
						|
        metadata: MaybeDict = None,
 | 
						|
        buffers: BuffersType = None,
 | 
						|
        deleting: bool = False,
 | 
						|
    ) -> None:
 | 
						|
        """Close the frontend-side version of this comm"""
 | 
						|
        if self._closed:
 | 
						|
            # only close once
 | 
						|
            return
 | 
						|
        self._closed = True
 | 
						|
        if data is None:
 | 
						|
            data = self._close_data
 | 
						|
        self.publish_msg(
 | 
						|
            "comm_close",
 | 
						|
            data=data,
 | 
						|
            metadata=metadata,
 | 
						|
            buffers=buffers,
 | 
						|
        )
 | 
						|
        if not deleting:
 | 
						|
            # If deleting, the comm can't be registered
 | 
						|
            comm.get_comm_manager().unregister_comm(self)
 | 
						|
 | 
						|
    def send(
 | 
						|
        self, data: MaybeDict = None, metadata: MaybeDict = None, buffers: BuffersType = None
 | 
						|
    ) -> None:
 | 
						|
        """Send a message to the frontend-side version of this comm"""
 | 
						|
        self.publish_msg(
 | 
						|
            "comm_msg",
 | 
						|
            data=data,
 | 
						|
            metadata=metadata,
 | 
						|
            buffers=buffers,
 | 
						|
        )
 | 
						|
 | 
						|
    # registering callbacks
 | 
						|
 | 
						|
    def on_close(self, callback: CommCallback | None) -> None:
 | 
						|
        """Register a callback for comm_close
 | 
						|
 | 
						|
        Will be called with the `data` of the close message.
 | 
						|
 | 
						|
        Call `on_close(None)` to disable an existing callback.
 | 
						|
        """
 | 
						|
        self._close_callback = callback
 | 
						|
 | 
						|
    def on_msg(self, callback: CommCallback | None) -> None:
 | 
						|
        """Register a callback for comm_msg
 | 
						|
 | 
						|
        Will be called with the `data` of any comm_msg messages.
 | 
						|
 | 
						|
        Call `on_msg(None)` to disable an existing callback.
 | 
						|
        """
 | 
						|
        self._msg_callback = callback
 | 
						|
 | 
						|
    # handling of incoming messages
 | 
						|
 | 
						|
    def handle_close(self, msg: MessageType) -> None:
 | 
						|
        """Handle a comm_close message"""
 | 
						|
        logger.debug("handle_close[%s](%s)", self.comm_id, msg)
 | 
						|
        if self._close_callback:
 | 
						|
            self._close_callback(msg)
 | 
						|
 | 
						|
    def handle_msg(self, msg: MessageType) -> None:
 | 
						|
        """Handle a comm_msg message"""
 | 
						|
        logger.debug("handle_msg[%s](%s)", self.comm_id, msg)
 | 
						|
        if self._msg_callback:
 | 
						|
            from IPython import get_ipython
 | 
						|
 | 
						|
            shell = get_ipython()
 | 
						|
            if shell:
 | 
						|
                shell.events.trigger("pre_execute")
 | 
						|
            self._msg_callback(msg)
 | 
						|
            if shell:
 | 
						|
                shell.events.trigger("post_execute")
 | 
						|
 | 
						|
 | 
						|
class CommManager:
 | 
						|
    """Default CommManager singleton implementation for Comms in the Kernel"""
 | 
						|
 | 
						|
    # Public APIs
 | 
						|
 | 
						|
    def __init__(self) -> None:
 | 
						|
        self.comms: dict[str, BaseComm] = {}
 | 
						|
        self.targets: dict[str, CommTargetCallback] = {}
 | 
						|
 | 
						|
    def register_target(self, target_name: str, f: CommTargetCallback | str) -> None:
 | 
						|
        """Register a callable f for a given target name
 | 
						|
 | 
						|
        f will be called with two arguments when a comm_open message is received with `target`:
 | 
						|
 | 
						|
        - the Comm instance
 | 
						|
        - the `comm_open` message itself.
 | 
						|
 | 
						|
        f can be a Python callable or an import string for one.
 | 
						|
        """
 | 
						|
        if isinstance(f, str):
 | 
						|
            parts = f.rsplit(".", 1)
 | 
						|
            if len(parts) == 2:
 | 
						|
                # called with 'foo.bar....'
 | 
						|
                package, obj = parts
 | 
						|
                module = __import__(package, fromlist=[obj])
 | 
						|
                try:
 | 
						|
                    f = getattr(module, obj)
 | 
						|
                except AttributeError as e:
 | 
						|
                    error_msg = f"No module named {obj}"
 | 
						|
                    raise ImportError(error_msg) from e
 | 
						|
            else:
 | 
						|
                # called with un-dotted string
 | 
						|
                f = __import__(parts[0])
 | 
						|
 | 
						|
        self.targets[target_name] = t.cast(CommTargetCallback, f)
 | 
						|
 | 
						|
    def unregister_target(self, target_name: str, f: CommTargetCallback) -> CommTargetCallback:  # noqa: ARG002
 | 
						|
        """Unregister a callable registered with register_target"""
 | 
						|
        return self.targets.pop(target_name)
 | 
						|
 | 
						|
    def register_comm(self, comm: BaseComm) -> str:
 | 
						|
        """Register a new comm"""
 | 
						|
        comm_id = comm.comm_id
 | 
						|
        self.comms[comm_id] = comm
 | 
						|
        return comm_id
 | 
						|
 | 
						|
    def unregister_comm(self, comm: BaseComm) -> None:
 | 
						|
        """Unregister a comm, and close its counterpart"""
 | 
						|
        # unlike get_comm, this should raise a KeyError
 | 
						|
        comm = self.comms.pop(comm.comm_id)
 | 
						|
 | 
						|
    def get_comm(self, comm_id: str) -> BaseComm | None:
 | 
						|
        """Get a comm with a particular id
 | 
						|
 | 
						|
        Returns the comm if found, otherwise None.
 | 
						|
 | 
						|
        This will not raise an error,
 | 
						|
        it will log messages if the comm cannot be found.
 | 
						|
        """
 | 
						|
        try:
 | 
						|
            return self.comms[comm_id]
 | 
						|
        except KeyError:
 | 
						|
            logger.warning("No such comm: %s", comm_id)
 | 
						|
            if logger.isEnabledFor(logging.DEBUG):
 | 
						|
                # don't create the list of keys if debug messages aren't enabled
 | 
						|
                logger.debug("Current comms: %s", list(self.comms.keys()))
 | 
						|
            return None
 | 
						|
 | 
						|
    # Message handlers
 | 
						|
 | 
						|
    def comm_open(self, stream: ZMQStream, ident: str, msg: MessageType) -> None:  # noqa: ARG002
 | 
						|
        """Handler for comm_open messages"""
 | 
						|
        from comm import create_comm
 | 
						|
 | 
						|
        content = msg["content"]
 | 
						|
        comm_id = content["comm_id"]
 | 
						|
        target_name = content["target_name"]
 | 
						|
        f = self.targets.get(target_name, None)
 | 
						|
        comm = create_comm(
 | 
						|
            comm_id=comm_id,
 | 
						|
            primary=False,
 | 
						|
            target_name=target_name,
 | 
						|
        )
 | 
						|
        self.register_comm(comm)
 | 
						|
        if f is None:
 | 
						|
            logger.error("No such comm target registered: %s", target_name)
 | 
						|
        else:
 | 
						|
            try:
 | 
						|
                f(comm, msg)
 | 
						|
                return
 | 
						|
            except Exception:
 | 
						|
                logger.error("Exception opening comm with target: %s", target_name, exc_info=True)
 | 
						|
 | 
						|
        # Failure.
 | 
						|
        try:
 | 
						|
            comm.close()
 | 
						|
        except Exception:
 | 
						|
            logger.error(
 | 
						|
                """Could not close comm during `comm_open` failure
 | 
						|
                clean-up.  The comm may not have been opened yet.""",
 | 
						|
                exc_info=True,
 | 
						|
            )
 | 
						|
 | 
						|
    def comm_msg(self, stream: ZMQStream, ident: str, msg: MessageType) -> None:  # noqa: ARG002
 | 
						|
        """Handler for comm_msg messages"""
 | 
						|
        content = msg["content"]
 | 
						|
        comm_id = content["comm_id"]
 | 
						|
        comm = self.get_comm(comm_id)
 | 
						|
        if comm is None:
 | 
						|
            return
 | 
						|
 | 
						|
        try:
 | 
						|
            comm.handle_msg(msg)
 | 
						|
        except Exception:
 | 
						|
            logger.error("Exception in comm_msg for %s", comm_id, exc_info=True)
 | 
						|
 | 
						|
    def comm_close(self, stream: ZMQStream, ident: str, msg: MessageType) -> None:  # noqa: ARG002
 | 
						|
        """Handler for comm_close messages"""
 | 
						|
        content = msg["content"]
 | 
						|
        comm_id = content["comm_id"]
 | 
						|
        comm = self.get_comm(comm_id)
 | 
						|
        if comm is None:
 | 
						|
            return
 | 
						|
 | 
						|
        self.comms[comm_id]._closed = True
 | 
						|
        del self.comms[comm_id]
 | 
						|
 | 
						|
        try:
 | 
						|
            comm.handle_close(msg)
 | 
						|
        except Exception:
 | 
						|
            logger.error("Exception in comm_close for %s", comm_id, exc_info=True)
 | 
						|
 | 
						|
 | 
						|
__all__ = ["BaseComm", "CommManager"]
 |