You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
145 lines
4.4 KiB
Python
145 lines
4.4 KiB
Python
"""A Websocket Handler for emitting Jupyter server events.
|
|
|
|
.. versionadded:: 2.0
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from datetime import datetime
|
|
from typing import TYPE_CHECKING, Any, Optional, cast
|
|
|
|
from jupyter_core.utils import ensure_async
|
|
from tornado import web, websocket
|
|
|
|
from jupyter_server.auth.decorator import authorized, ws_authenticated
|
|
from jupyter_server.base.handlers import JupyterHandler
|
|
|
|
from ...base.handlers import APIHandler
|
|
|
|
AUTH_RESOURCE = "events"
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
import jupyter_events.logger
|
|
|
|
|
|
class SubscribeWebsocket(
|
|
JupyterHandler,
|
|
websocket.WebSocketHandler,
|
|
):
|
|
"""Websocket handler for subscribing to events"""
|
|
|
|
auth_resource = AUTH_RESOURCE
|
|
|
|
async def pre_get(self):
|
|
"""Handles authorization when
|
|
attempting to subscribe to events emitted by
|
|
Jupyter Server's eventbus.
|
|
"""
|
|
user = self.current_user
|
|
# authorize the user.
|
|
authorized = await ensure_async(
|
|
self.authorizer.is_authorized(self, user, "execute", "events")
|
|
)
|
|
if not authorized:
|
|
raise web.HTTPError(403)
|
|
|
|
@ws_authenticated
|
|
async def get(self, *args, **kwargs):
|
|
"""Get an event socket."""
|
|
await ensure_async(self.pre_get())
|
|
res = super().get(*args, **kwargs)
|
|
if res is not None:
|
|
await res
|
|
|
|
async def event_listener(
|
|
self, logger: jupyter_events.logger.EventLogger, schema_id: str, data: dict[str, Any]
|
|
) -> None:
|
|
"""Write an event message."""
|
|
capsule = dict(schema_id=schema_id, **data)
|
|
self.write_message(json.dumps(capsule))
|
|
|
|
def open(self):
|
|
"""Routes events that are emitted by Jupyter Server's
|
|
EventBus to a WebSocket client in the browser.
|
|
"""
|
|
self.event_logger.add_listener(listener=self.event_listener)
|
|
|
|
def on_close(self):
|
|
"""Handle a socket close."""
|
|
self.event_logger.remove_listener(listener=self.event_listener)
|
|
|
|
|
|
def validate_model(
|
|
data: dict[str, Any], registry: jupyter_events.schema_registry.SchemaRegistry
|
|
) -> None:
|
|
"""Validates for required fields in the JSON request body and verifies that
|
|
a registered schema/version exists"""
|
|
required_keys = {"schema_id", "version", "data"}
|
|
for key in required_keys:
|
|
if key not in data:
|
|
message = f"Missing `{key}` in the JSON request body."
|
|
raise Exception(message)
|
|
schema_id = cast(str, data.get("schema_id"))
|
|
# The case where a given schema_id isn't found,
|
|
# jupyter_events raises a useful error, so there's no need to
|
|
# handle that case here.
|
|
schema = registry.get(schema_id)
|
|
version = str(cast(str, data.get("version")))
|
|
if schema.version != version:
|
|
message = f"Unregistered version: {version!r}≠{schema.version!r} for `{schema_id}`"
|
|
raise Exception(message)
|
|
|
|
|
|
def get_timestamp(data: dict[str, Any]) -> Optional[datetime]:
|
|
"""Parses timestamp from the JSON request body"""
|
|
try:
|
|
if "timestamp" in data:
|
|
timestamp = datetime.strptime(data["timestamp"], "%Y-%m-%dT%H:%M:%S%zZ")
|
|
else:
|
|
timestamp = None
|
|
except Exception as e:
|
|
raise web.HTTPError(
|
|
400,
|
|
"""Failed to parse timestamp from JSON request body,
|
|
an ISO format datetime string with UTC offset is expected,
|
|
for example, 2022-05-26T13:50:00+05:00Z""",
|
|
) from e
|
|
|
|
return timestamp
|
|
|
|
|
|
class EventHandler(APIHandler):
|
|
"""REST api handler for events"""
|
|
|
|
auth_resource = AUTH_RESOURCE
|
|
|
|
@web.authenticated
|
|
@authorized
|
|
async def post(self):
|
|
"""Emit an event."""
|
|
payload = self.get_json_body()
|
|
if payload is None:
|
|
raise web.HTTPError(400, "No JSON data provided")
|
|
|
|
try:
|
|
validate_model(payload, self.event_logger.schemas)
|
|
self.event_logger.emit(
|
|
schema_id=cast(str, payload.get("schema_id")),
|
|
data=cast("dict[str, Any]", payload.get("data")),
|
|
timestamp_override=get_timestamp(payload),
|
|
)
|
|
self.set_status(204)
|
|
self.finish()
|
|
except Exception as e:
|
|
# All known exceptions are raised by bad requests, e.g., bad
|
|
# version, unregistered schema, invalid emission data payload, etc.
|
|
raise web.HTTPError(400, str(e)) from e
|
|
|
|
|
|
default_handlers = [
|
|
(r"/api/events", EventHandler),
|
|
(r"/api/events/subscribe", SubscribeWebsocket),
|
|
]
|