You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
168 lines
5.8 KiB
Python
168 lines
5.8 KiB
Python
"""Base websocket classes."""
|
|
|
|
import re
|
|
import warnings
|
|
from typing import Optional, no_type_check
|
|
from urllib.parse import urlparse
|
|
|
|
from tornado import ioloop, web
|
|
from tornado.iostream import IOStream
|
|
|
|
from jupyter_server.base.handlers import JupyterHandler
|
|
from jupyter_server.utils import JupyterServerAuthWarning
|
|
|
|
# ping interval for keeping websockets alive (30 seconds)
|
|
WS_PING_INTERVAL = 30000
|
|
|
|
|
|
class WebSocketMixin:
|
|
"""Mixin for common websocket options"""
|
|
|
|
ping_callback = None
|
|
last_ping = 0.0
|
|
last_pong = 0.0
|
|
stream: Optional[IOStream] = None
|
|
|
|
@property
|
|
def ping_interval(self):
|
|
"""The interval for websocket keep-alive pings.
|
|
|
|
Set ws_ping_interval = 0 to disable pings.
|
|
"""
|
|
return self.settings.get("ws_ping_interval", WS_PING_INTERVAL) # type:ignore[attr-defined]
|
|
|
|
@property
|
|
def ping_timeout(self):
|
|
"""If no ping is received in this many milliseconds,
|
|
close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
|
|
Default is max of 3 pings or 30 seconds.
|
|
"""
|
|
return self.settings.get( # type:ignore[attr-defined]
|
|
"ws_ping_timeout", max(3 * self.ping_interval, WS_PING_INTERVAL)
|
|
)
|
|
|
|
@no_type_check
|
|
def check_origin(self, origin: Optional[str] = None) -> bool:
|
|
"""Check Origin == Host or Access-Control-Allow-Origin.
|
|
|
|
Tornado >= 4 calls this method automatically, raising 403 if it returns False.
|
|
"""
|
|
|
|
if self.allow_origin == "*" or (
|
|
hasattr(self, "skip_check_origin") and self.skip_check_origin()
|
|
):
|
|
return True
|
|
|
|
host = self.request.headers.get("Host")
|
|
if origin is None:
|
|
origin = self.get_origin()
|
|
|
|
# If no origin or host header is provided, assume from script
|
|
if origin is None or host is None:
|
|
return True
|
|
|
|
origin = origin.lower()
|
|
origin_host = urlparse(origin).netloc
|
|
|
|
# OK if origin matches host
|
|
if origin_host == host:
|
|
return True
|
|
|
|
# Check CORS headers
|
|
if self.allow_origin:
|
|
allow = self.allow_origin == origin
|
|
elif self.allow_origin_pat:
|
|
allow = bool(re.match(self.allow_origin_pat, origin))
|
|
else:
|
|
# No CORS headers deny the request
|
|
allow = False
|
|
if not allow:
|
|
self.log.warning(
|
|
"Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s",
|
|
origin,
|
|
host,
|
|
)
|
|
return allow
|
|
|
|
def clear_cookie(self, *args, **kwargs):
|
|
"""meaningless for websockets"""
|
|
|
|
@no_type_check
|
|
def _maybe_auth(self):
|
|
"""Verify authentication if required.
|
|
|
|
Only used when the websocket class does not inherit from JupyterHandler.
|
|
"""
|
|
if not self.settings.get("allow_unauthenticated_access", False):
|
|
if not self.request.method:
|
|
raise web.HTTPError(403)
|
|
method = getattr(self, self.request.method.lower())
|
|
if not getattr(method, "__allow_unauthenticated", False):
|
|
# rather than re-using `web.authenticated` which also redirects
|
|
# to login page on GET, just raise 403 if user is not known
|
|
user = self.current_user
|
|
if user is None:
|
|
self.log.warning("Couldn't authenticate WebSocket connection")
|
|
raise web.HTTPError(403)
|
|
|
|
@no_type_check
|
|
def prepare(self, *args, **kwargs):
|
|
"""Handle a get request."""
|
|
if not isinstance(self, JupyterHandler):
|
|
should_authenticate = not self.settings.get("allow_unauthenticated_access", False)
|
|
if "identity_provider" in self.settings and should_authenticate:
|
|
warnings.warn(
|
|
"WebSocketMixin sub-class does not inherit from JupyterHandler"
|
|
" preventing proper authentication using custom identity provider.",
|
|
JupyterServerAuthWarning,
|
|
stacklevel=2,
|
|
)
|
|
self._maybe_auth()
|
|
return super().prepare(*args, **kwargs)
|
|
return super().prepare(*args, **kwargs, _redirect_to_login=False)
|
|
|
|
@no_type_check
|
|
def open(self, *args, **kwargs):
|
|
"""Open the websocket."""
|
|
self.log.debug("Opening websocket %s", self.request.path)
|
|
|
|
# start the pinging
|
|
if self.ping_interval > 0:
|
|
loop = ioloop.IOLoop.current()
|
|
self.last_ping = loop.time() # Remember time of last ping
|
|
self.last_pong = self.last_ping
|
|
self.ping_callback = ioloop.PeriodicCallback(
|
|
self.send_ping,
|
|
self.ping_interval,
|
|
)
|
|
self.ping_callback.start()
|
|
return super().open(*args, **kwargs)
|
|
|
|
@no_type_check
|
|
def send_ping(self):
|
|
"""send a ping to keep the websocket alive"""
|
|
if self.ws_connection is None and self.ping_callback is not None:
|
|
self.ping_callback.stop()
|
|
return
|
|
|
|
if self.ws_connection.client_terminated:
|
|
self.close()
|
|
return
|
|
|
|
# check for timeout on pong. Make sure that we really have sent a recent ping in
|
|
# case the machine with both server and client has been suspended since the last ping.
|
|
now = ioloop.IOLoop.current().time()
|
|
since_last_pong = 1e3 * (now - self.last_pong)
|
|
since_last_ping = 1e3 * (now - self.last_ping)
|
|
if since_last_ping < 2 * self.ping_interval and since_last_pong > self.ping_timeout:
|
|
self.log.warning("WebSocket ping timeout after %i ms.", since_last_pong)
|
|
self.close()
|
|
return
|
|
|
|
self.ping(b"")
|
|
self.last_ping = now
|
|
|
|
def on_pong(self, data):
|
|
"""Handle a pong message."""
|
|
self.last_pong = ioloop.IOLoop.current().time()
|