You cannot select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
	
	
		
			145 lines
		
	
	
		
			4.4 KiB
		
	
	
	
		
			Python
		
	
			
		
		
	
	
			145 lines
		
	
	
		
			4.4 KiB
		
	
	
	
		
			Python
		
	
"""A Websocket Handler for emitting Jupyter server events.
 | 
						|
 | 
						|
.. versionadded:: 2.0
 | 
						|
"""
 | 
						|
 | 
						|
from __future__ import annotations
 | 
						|
 | 
						|
import json
 | 
						|
from datetime import datetime
 | 
						|
from typing import TYPE_CHECKING, Any, Optional, cast
 | 
						|
 | 
						|
from jupyter_core.utils import ensure_async
 | 
						|
from tornado import web, websocket
 | 
						|
 | 
						|
from jupyter_server.auth.decorator import authorized, ws_authenticated
 | 
						|
from jupyter_server.base.handlers import JupyterHandler
 | 
						|
 | 
						|
from ...base.handlers import APIHandler
 | 
						|
 | 
						|
AUTH_RESOURCE = "events"
 | 
						|
 | 
						|
 | 
						|
if TYPE_CHECKING:
 | 
						|
    import jupyter_events.logger
 | 
						|
 | 
						|
 | 
						|
class SubscribeWebsocket(
 | 
						|
    JupyterHandler,
 | 
						|
    websocket.WebSocketHandler,
 | 
						|
):
 | 
						|
    """Websocket handler for subscribing to events"""
 | 
						|
 | 
						|
    auth_resource = AUTH_RESOURCE
 | 
						|
 | 
						|
    async def pre_get(self):
 | 
						|
        """Handles authorization when
 | 
						|
        attempting to subscribe to events emitted by
 | 
						|
        Jupyter Server's eventbus.
 | 
						|
        """
 | 
						|
        user = self.current_user
 | 
						|
        # authorize the user.
 | 
						|
        authorized = await ensure_async(
 | 
						|
            self.authorizer.is_authorized(self, user, "execute", "events")
 | 
						|
        )
 | 
						|
        if not authorized:
 | 
						|
            raise web.HTTPError(403)
 | 
						|
 | 
						|
    @ws_authenticated
 | 
						|
    async def get(self, *args, **kwargs):
 | 
						|
        """Get an event socket."""
 | 
						|
        await ensure_async(self.pre_get())
 | 
						|
        res = super().get(*args, **kwargs)
 | 
						|
        if res is not None:
 | 
						|
            await res
 | 
						|
 | 
						|
    async def event_listener(
 | 
						|
        self, logger: jupyter_events.logger.EventLogger, schema_id: str, data: dict[str, Any]
 | 
						|
    ) -> None:
 | 
						|
        """Write an event message."""
 | 
						|
        capsule = dict(schema_id=schema_id, **data)
 | 
						|
        self.write_message(json.dumps(capsule))
 | 
						|
 | 
						|
    def open(self):
 | 
						|
        """Routes events that are emitted by Jupyter Server's
 | 
						|
        EventBus to a WebSocket client in the browser.
 | 
						|
        """
 | 
						|
        self.event_logger.add_listener(listener=self.event_listener)
 | 
						|
 | 
						|
    def on_close(self):
 | 
						|
        """Handle a socket close."""
 | 
						|
        self.event_logger.remove_listener(listener=self.event_listener)
 | 
						|
 | 
						|
 | 
						|
def validate_model(
 | 
						|
    data: dict[str, Any], registry: jupyter_events.schema_registry.SchemaRegistry
 | 
						|
) -> None:
 | 
						|
    """Validates for required fields in the JSON request body and verifies that
 | 
						|
    a registered schema/version exists"""
 | 
						|
    required_keys = {"schema_id", "version", "data"}
 | 
						|
    for key in required_keys:
 | 
						|
        if key not in data:
 | 
						|
            message = f"Missing `{key}` in the JSON request body."
 | 
						|
            raise Exception(message)
 | 
						|
    schema_id = cast(str, data.get("schema_id"))
 | 
						|
    # The case where a given schema_id isn't found,
 | 
						|
    # jupyter_events raises a useful error, so there's no need to
 | 
						|
    # handle that case here.
 | 
						|
    schema = registry.get(schema_id)
 | 
						|
    version = str(cast(str, data.get("version")))
 | 
						|
    if schema.version != version:
 | 
						|
        message = f"Unregistered version: {version!r}≠{schema.version!r} for `{schema_id}`"
 | 
						|
        raise Exception(message)
 | 
						|
 | 
						|
 | 
						|
def get_timestamp(data: dict[str, Any]) -> Optional[datetime]:
 | 
						|
    """Parses timestamp from the JSON request body"""
 | 
						|
    try:
 | 
						|
        if "timestamp" in data:
 | 
						|
            timestamp = datetime.strptime(data["timestamp"], "%Y-%m-%dT%H:%M:%S%zZ")
 | 
						|
        else:
 | 
						|
            timestamp = None
 | 
						|
    except Exception as e:
 | 
						|
        raise web.HTTPError(
 | 
						|
            400,
 | 
						|
            """Failed to parse timestamp from JSON request body,
 | 
						|
            an ISO format datetime string with UTC offset is expected,
 | 
						|
            for example, 2022-05-26T13:50:00+05:00Z""",
 | 
						|
        ) from e
 | 
						|
 | 
						|
    return timestamp
 | 
						|
 | 
						|
 | 
						|
class EventHandler(APIHandler):
 | 
						|
    """REST api handler for events"""
 | 
						|
 | 
						|
    auth_resource = AUTH_RESOURCE
 | 
						|
 | 
						|
    @web.authenticated
 | 
						|
    @authorized
 | 
						|
    async def post(self):
 | 
						|
        """Emit an event."""
 | 
						|
        payload = self.get_json_body()
 | 
						|
        if payload is None:
 | 
						|
            raise web.HTTPError(400, "No JSON data provided")
 | 
						|
 | 
						|
        try:
 | 
						|
            validate_model(payload, self.event_logger.schemas)
 | 
						|
            self.event_logger.emit(
 | 
						|
                schema_id=cast(str, payload.get("schema_id")),
 | 
						|
                data=cast("dict[str, Any]", payload.get("data")),
 | 
						|
                timestamp_override=get_timestamp(payload),
 | 
						|
            )
 | 
						|
            self.set_status(204)
 | 
						|
            self.finish()
 | 
						|
        except Exception as e:
 | 
						|
            # All known exceptions are raised by bad requests, e.g., bad
 | 
						|
            # version, unregistered schema, invalid emission data payload, etc.
 | 
						|
            raise web.HTTPError(400, str(e)) from e
 | 
						|
 | 
						|
 | 
						|
default_handlers = [
 | 
						|
    (r"/api/events", EventHandler),
 | 
						|
    (r"/api/events/subscribe", SubscribeWebsocket),
 | 
						|
]
 |