You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
96 lines
3.5 KiB
Python
96 lines
3.5 KiB
Python
"""Tornado handlers for WebSocket <-> ZMQ sockets."""
|
|
# Copyright (c) Jupyter Development Team.
|
|
# Distributed under the terms of the Modified BSD License.
|
|
|
|
from jupyter_core.utils import ensure_async
|
|
from tornado import web
|
|
from tornado.websocket import WebSocketHandler
|
|
|
|
from jupyter_server.auth.decorator import ws_authenticated
|
|
from jupyter_server.base.handlers import JupyterHandler
|
|
from jupyter_server.base.websocket import WebSocketMixin
|
|
|
|
AUTH_RESOURCE = "kernels"
|
|
|
|
|
|
class KernelWebsocketHandler(WebSocketMixin, WebSocketHandler, JupyterHandler): # type:ignore[misc]
|
|
"""The kernels websocket should connect"""
|
|
|
|
auth_resource = AUTH_RESOURCE
|
|
|
|
@property
|
|
def kernel_websocket_connection_class(self):
|
|
"""The kernel websocket connection class."""
|
|
return self.settings.get("kernel_websocket_connection_class")
|
|
|
|
def set_default_headers(self):
|
|
"""Undo the set_default_headers in JupyterHandler
|
|
|
|
which doesn't make sense for websockets
|
|
"""
|
|
|
|
def get_compression_options(self):
|
|
"""Get the socket connection options."""
|
|
return self.settings.get("websocket_compression_options", None)
|
|
|
|
async def pre_get(self):
|
|
"""Handle a pre_get."""
|
|
user = self.current_user
|
|
|
|
# authorize the user.
|
|
authorized = await ensure_async(
|
|
self.authorizer.is_authorized(self, user, "execute", "kernels")
|
|
)
|
|
if not authorized:
|
|
raise web.HTTPError(403)
|
|
|
|
kernel = self.kernel_manager.get_kernel(self.kernel_id)
|
|
self.connection = self.kernel_websocket_connection_class(
|
|
parent=kernel, websocket_handler=self, config=self.config
|
|
)
|
|
|
|
if self.get_argument("session_id", None):
|
|
self.connection.session.session = self.get_argument("session_id")
|
|
else:
|
|
self.log.warning("No session ID specified")
|
|
# For backwards compatibility with older versions
|
|
# of the websocket connection, call a prepare method if found.
|
|
if hasattr(self.connection, "prepare"):
|
|
await self.connection.prepare()
|
|
|
|
@ws_authenticated
|
|
async def get(self, kernel_id):
|
|
"""Handle a get request for a kernel."""
|
|
self.kernel_id = kernel_id
|
|
await self.pre_get()
|
|
await super().get(kernel_id=kernel_id)
|
|
|
|
async def open(self, kernel_id):
|
|
"""Open a kernel websocket."""
|
|
# Need to call super here to make sure we
|
|
# begin a ping-pong loop with the client.
|
|
super().open()
|
|
# Wait for the kernel to emit an idle status.
|
|
self.log.info(f"Connecting to kernel {self.kernel_id}.")
|
|
await self.connection.connect()
|
|
|
|
def on_message(self, ws_message):
|
|
"""Get a kernel message from the websocket and turn it into a ZMQ message."""
|
|
self.connection.handle_incoming_message(ws_message)
|
|
|
|
def on_close(self):
|
|
"""Handle a socket closure."""
|
|
self.connection.disconnect()
|
|
self.connection = None
|
|
|
|
def select_subprotocol(self, subprotocols):
|
|
"""Select the sub protocol for the socket."""
|
|
preferred_protocol = self.connection.kernel_ws_protocol
|
|
if preferred_protocol is None:
|
|
preferred_protocol = "v1.kernel.websocket.jupyter.org"
|
|
elif preferred_protocol == "":
|
|
preferred_protocol = None
|
|
selected_subprotocol = preferred_protocol if preferred_protocol in subprotocols else None
|
|
# None is the default, "legacy" protocol
|
|
return selected_subprotocol
|