You cannot select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
	
	
		
			831 lines
		
	
	
		
			32 KiB
		
	
	
	
		
			Python
		
	
			
		
		
	
	
			831 lines
		
	
	
		
			32 KiB
		
	
	
	
		
			Python
		
	
"""A kernel gateway client."""
 | 
						|
 | 
						|
# Copyright (c) Jupyter Development Team.
 | 
						|
# Distributed under the terms of the Modified BSD License.
 | 
						|
from __future__ import annotations
 | 
						|
 | 
						|
import asyncio
 | 
						|
import json
 | 
						|
import logging
 | 
						|
import os
 | 
						|
import typing as ty
 | 
						|
from abc import ABC, ABCMeta, abstractmethod
 | 
						|
from datetime import datetime, timezone
 | 
						|
from email.utils import parsedate_to_datetime
 | 
						|
from http.cookies import SimpleCookie
 | 
						|
from socket import gaierror
 | 
						|
 | 
						|
from jupyter_events import EventLogger
 | 
						|
from tornado import web
 | 
						|
from tornado.httpclient import AsyncHTTPClient, HTTPClientError, HTTPResponse
 | 
						|
from traitlets import (
 | 
						|
    Bool,
 | 
						|
    Float,
 | 
						|
    Instance,
 | 
						|
    Int,
 | 
						|
    TraitError,
 | 
						|
    Type,
 | 
						|
    Unicode,
 | 
						|
    default,
 | 
						|
    observe,
 | 
						|
    validate,
 | 
						|
)
 | 
						|
from traitlets.config import LoggingConfigurable, SingletonConfigurable
 | 
						|
 | 
						|
from jupyter_server import DEFAULT_EVENTS_SCHEMA_PATH, JUPYTER_SERVER_EVENTS_URI
 | 
						|
 | 
						|
ERROR_STATUS = "error"
 | 
						|
SUCCESS_STATUS = "success"
 | 
						|
STATUS_KEY = "status"
 | 
						|
STATUS_CODE_KEY = "status_code"
 | 
						|
MESSAGE_KEY = "msg"
 | 
						|
 | 
						|
if ty.TYPE_CHECKING:
 | 
						|
    from http.cookies import Morsel
 | 
						|
 | 
						|
 | 
						|
class GatewayTokenRenewerMeta(ABCMeta, type(LoggingConfigurable)):  # type: ignore[misc]
 | 
						|
    """The metaclass necessary for proper ABC behavior in a Configurable."""
 | 
						|
 | 
						|
 | 
						|
class GatewayTokenRenewerBase(  # type:ignore[misc]
 | 
						|
    ABC, LoggingConfigurable, metaclass=GatewayTokenRenewerMeta
 | 
						|
):
 | 
						|
    """
 | 
						|
    Abstract base class for refreshing tokens used between this server and a Gateway
 | 
						|
    server.  Implementations requiring additional configuration can extend their class
 | 
						|
    with appropriate configuration values or convey those values via appropriate
 | 
						|
    environment variables relative to the implementation.
 | 
						|
    """
 | 
						|
 | 
						|
    @abstractmethod
 | 
						|
    def get_token(
 | 
						|
        self,
 | 
						|
        auth_header_key: str,
 | 
						|
        auth_scheme: ty.Union[str, None],
 | 
						|
        auth_token: str,
 | 
						|
        **kwargs: ty.Any,
 | 
						|
    ) -> str:
 | 
						|
        """
 | 
						|
        Given the current authorization header key, scheme, and token, this method returns
 | 
						|
        a (potentially renewed) token for use against the Gateway server.
 | 
						|
        """
 | 
						|
 | 
						|
 | 
						|
class NoOpTokenRenewer(GatewayTokenRenewerBase):  # type:ignore[misc]
 | 
						|
    """NoOpTokenRenewer is the default value to the GatewayClient trait
 | 
						|
    `gateway_token_renewer` and merely returns the provided token.
 | 
						|
    """
 | 
						|
 | 
						|
    def get_token(
 | 
						|
        self,
 | 
						|
        auth_header_key: str,
 | 
						|
        auth_scheme: ty.Union[str, None],
 | 
						|
        auth_token: str,
 | 
						|
        **kwargs: ty.Any,
 | 
						|
    ) -> str:
 | 
						|
        """This implementation simply returns the current authorization token."""
 | 
						|
        return auth_token
 | 
						|
 | 
						|
 | 
						|
class GatewayClient(SingletonConfigurable):
 | 
						|
    """This class manages the configuration.  It's its own singleton class so
 | 
						|
    that we can share these values across all objects.  It also contains some
 | 
						|
    options.
 | 
						|
    helper methods to build request arguments out of the various config
 | 
						|
    """
 | 
						|
 | 
						|
    event_schema_id = JUPYTER_SERVER_EVENTS_URI + "/gateway_client/v1"
 | 
						|
    event_logger = Instance(EventLogger).tag(config=True)
 | 
						|
 | 
						|
    @default("event_logger")
 | 
						|
    def _default_event_logger(self):
 | 
						|
        if self.parent and hasattr(self.parent, "event_logger"):
 | 
						|
            # Event logger is attached from serverapp.
 | 
						|
            return self.parent.event_logger
 | 
						|
        else:
 | 
						|
            # If parent does not have an event logger, create one.
 | 
						|
            logger = EventLogger()
 | 
						|
            schema_path = DEFAULT_EVENTS_SCHEMA_PATH / "gateway_client" / "v1.yaml"
 | 
						|
            logger.register_event_schema(schema_path)
 | 
						|
            self.log.info("Event is registered in GatewayClient.")
 | 
						|
            return logger
 | 
						|
 | 
						|
    def emit(self, data):
 | 
						|
        """Emit event using the core event schema from Jupyter Server's Gateway Client."""
 | 
						|
        self.event_logger.emit(schema_id=self.event_schema_id, data=data)
 | 
						|
 | 
						|
    url = Unicode(
 | 
						|
        default_value=None,
 | 
						|
        allow_none=True,
 | 
						|
        config=True,
 | 
						|
        help="""The url of the Kernel or Enterprise Gateway server where
 | 
						|
kernel specifications are defined and kernel management takes place.
 | 
						|
If defined, this Notebook server acts as a proxy for all kernel
 | 
						|
management and kernel specification retrieval.  (JUPYTER_GATEWAY_URL env var)
 | 
						|
        """,
 | 
						|
    )
 | 
						|
 | 
						|
    url_env = "JUPYTER_GATEWAY_URL"
 | 
						|
 | 
						|
    @default("url")
 | 
						|
    def _url_default(self):
 | 
						|
        return os.environ.get(self.url_env)
 | 
						|
 | 
						|
    @validate("url")
 | 
						|
    def _url_validate(self, proposal):
 | 
						|
        value = proposal["value"]
 | 
						|
        # Ensure value, if present, starts with 'http'
 | 
						|
        if value is not None and len(value) > 0 and not str(value).lower().startswith("http"):
 | 
						|
            message = "GatewayClient url must start with 'http': '%r'" % value
 | 
						|
            self.emit(data={STATUS_KEY: ERROR_STATUS, STATUS_CODE_KEY: 400, MESSAGE_KEY: message})
 | 
						|
            raise TraitError(message)
 | 
						|
        return value
 | 
						|
 | 
						|
    ws_url = Unicode(
 | 
						|
        default_value=None,
 | 
						|
        allow_none=True,
 | 
						|
        config=True,
 | 
						|
        help="""The websocket url of the Kernel or Enterprise Gateway server.  If not provided, this value
 | 
						|
will correspond to the value of the Gateway url with 'ws' in place of 'http'.  (JUPYTER_GATEWAY_WS_URL env var)
 | 
						|
        """,
 | 
						|
    )
 | 
						|
 | 
						|
    ws_url_env = "JUPYTER_GATEWAY_WS_URL"
 | 
						|
 | 
						|
    @default("ws_url")
 | 
						|
    def _ws_url_default(self):
 | 
						|
        default_value = os.environ.get(self.ws_url_env)
 | 
						|
        if self.url is not None and default_value is None and self.gateway_enabled:
 | 
						|
            default_value = self.url.lower().replace("http", "ws")
 | 
						|
        return default_value
 | 
						|
 | 
						|
    @validate("ws_url")
 | 
						|
    def _ws_url_validate(self, proposal):
 | 
						|
        value = proposal["value"]
 | 
						|
        # Ensure value, if present, starts with 'ws'
 | 
						|
        if value is not None and len(value) > 0 and not str(value).lower().startswith("ws"):
 | 
						|
            message = "GatewayClient ws_url must start with 'ws': '%r'" % value
 | 
						|
            self.emit(data={STATUS_KEY: ERROR_STATUS, STATUS_CODE_KEY: 400, MESSAGE_KEY: message})
 | 
						|
            raise TraitError(message)
 | 
						|
        return value
 | 
						|
 | 
						|
    kernels_endpoint_default_value = "/api/kernels"
 | 
						|
    kernels_endpoint_env = "JUPYTER_GATEWAY_KERNELS_ENDPOINT"
 | 
						|
    kernels_endpoint = Unicode(
 | 
						|
        default_value=kernels_endpoint_default_value,
 | 
						|
        config=True,
 | 
						|
        help="""The gateway API endpoint for accessing kernel resources (JUPYTER_GATEWAY_KERNELS_ENDPOINT env var)""",
 | 
						|
    )
 | 
						|
 | 
						|
    @default("kernels_endpoint")
 | 
						|
    def _kernels_endpoint_default(self):
 | 
						|
        return os.environ.get(self.kernels_endpoint_env, self.kernels_endpoint_default_value)
 | 
						|
 | 
						|
    kernelspecs_endpoint_default_value = "/api/kernelspecs"
 | 
						|
    kernelspecs_endpoint_env = "JUPYTER_GATEWAY_KERNELSPECS_ENDPOINT"
 | 
						|
    kernelspecs_endpoint = Unicode(
 | 
						|
        default_value=kernelspecs_endpoint_default_value,
 | 
						|
        config=True,
 | 
						|
        help="""The gateway API endpoint for accessing kernelspecs (JUPYTER_GATEWAY_KERNELSPECS_ENDPOINT env var)""",
 | 
						|
    )
 | 
						|
 | 
						|
    @default("kernelspecs_endpoint")
 | 
						|
    def _kernelspecs_endpoint_default(self):
 | 
						|
        return os.environ.get(
 | 
						|
            self.kernelspecs_endpoint_env, self.kernelspecs_endpoint_default_value
 | 
						|
        )
 | 
						|
 | 
						|
    kernelspecs_resource_endpoint_default_value = "/kernelspecs"
 | 
						|
    kernelspecs_resource_endpoint_env = "JUPYTER_GATEWAY_KERNELSPECS_RESOURCE_ENDPOINT"
 | 
						|
    kernelspecs_resource_endpoint = Unicode(
 | 
						|
        default_value=kernelspecs_resource_endpoint_default_value,
 | 
						|
        config=True,
 | 
						|
        help="""The gateway endpoint for accessing kernelspecs resources
 | 
						|
(JUPYTER_GATEWAY_KERNELSPECS_RESOURCE_ENDPOINT env var)""",
 | 
						|
    )
 | 
						|
 | 
						|
    @default("kernelspecs_resource_endpoint")
 | 
						|
    def _kernelspecs_resource_endpoint_default(self):
 | 
						|
        return os.environ.get(
 | 
						|
            self.kernelspecs_resource_endpoint_env,
 | 
						|
            self.kernelspecs_resource_endpoint_default_value,
 | 
						|
        )
 | 
						|
 | 
						|
    connect_timeout_default_value = 40.0
 | 
						|
    connect_timeout_env = "JUPYTER_GATEWAY_CONNECT_TIMEOUT"
 | 
						|
    connect_timeout = Float(
 | 
						|
        default_value=connect_timeout_default_value,
 | 
						|
        config=True,
 | 
						|
        help="""The time allowed for HTTP connection establishment with the Gateway server.
 | 
						|
(JUPYTER_GATEWAY_CONNECT_TIMEOUT env var)""",
 | 
						|
    )
 | 
						|
 | 
						|
    @default("connect_timeout")
 | 
						|
    def _connect_timeout_default(self):
 | 
						|
        return float(os.environ.get(self.connect_timeout_env, self.connect_timeout_default_value))
 | 
						|
 | 
						|
    request_timeout_default_value = 42.0
 | 
						|
    request_timeout_env = "JUPYTER_GATEWAY_REQUEST_TIMEOUT"
 | 
						|
    request_timeout = Float(
 | 
						|
        default_value=request_timeout_default_value,
 | 
						|
        config=True,
 | 
						|
        help="""The time allowed for HTTP request completion. (JUPYTER_GATEWAY_REQUEST_TIMEOUT env var)""",
 | 
						|
    )
 | 
						|
 | 
						|
    @default("request_timeout")
 | 
						|
    def _request_timeout_default(self):
 | 
						|
        return float(os.environ.get(self.request_timeout_env, self.request_timeout_default_value))
 | 
						|
 | 
						|
    client_key = Unicode(
 | 
						|
        default_value=None,
 | 
						|
        allow_none=True,
 | 
						|
        config=True,
 | 
						|
        help="""The filename for client SSL key, if any.  (JUPYTER_GATEWAY_CLIENT_KEY env var)
 | 
						|
        """,
 | 
						|
    )
 | 
						|
    client_key_env = "JUPYTER_GATEWAY_CLIENT_KEY"
 | 
						|
 | 
						|
    @default("client_key")
 | 
						|
    def _client_key_default(self):
 | 
						|
        return os.environ.get(self.client_key_env)
 | 
						|
 | 
						|
    client_cert = Unicode(
 | 
						|
        default_value=None,
 | 
						|
        allow_none=True,
 | 
						|
        config=True,
 | 
						|
        help="""The filename for client SSL certificate, if any.  (JUPYTER_GATEWAY_CLIENT_CERT env var)
 | 
						|
        """,
 | 
						|
    )
 | 
						|
    client_cert_env = "JUPYTER_GATEWAY_CLIENT_CERT"
 | 
						|
 | 
						|
    @default("client_cert")
 | 
						|
    def _client_cert_default(self):
 | 
						|
        return os.environ.get(self.client_cert_env)
 | 
						|
 | 
						|
    ca_certs = Unicode(
 | 
						|
        default_value=None,
 | 
						|
        allow_none=True,
 | 
						|
        config=True,
 | 
						|
        help="""The filename of CA certificates or None to use defaults.  (JUPYTER_GATEWAY_CA_CERTS env var)
 | 
						|
        """,
 | 
						|
    )
 | 
						|
    ca_certs_env = "JUPYTER_GATEWAY_CA_CERTS"
 | 
						|
 | 
						|
    @default("ca_certs")
 | 
						|
    def _ca_certs_default(self):
 | 
						|
        return os.environ.get(self.ca_certs_env)
 | 
						|
 | 
						|
    http_user = Unicode(
 | 
						|
        default_value=None,
 | 
						|
        allow_none=True,
 | 
						|
        config=True,
 | 
						|
        help="""The username for HTTP authentication. (JUPYTER_GATEWAY_HTTP_USER env var)
 | 
						|
        """,
 | 
						|
    )
 | 
						|
    http_user_env = "JUPYTER_GATEWAY_HTTP_USER"
 | 
						|
 | 
						|
    @default("http_user")
 | 
						|
    def _http_user_default(self):
 | 
						|
        return os.environ.get(self.http_user_env)
 | 
						|
 | 
						|
    http_pwd = Unicode(
 | 
						|
        default_value=None,
 | 
						|
        allow_none=True,
 | 
						|
        config=True,
 | 
						|
        help="""The password for HTTP authentication.  (JUPYTER_GATEWAY_HTTP_PWD env var)
 | 
						|
        """,
 | 
						|
    )
 | 
						|
    http_pwd_env = "JUPYTER_GATEWAY_HTTP_PWD"  # noqa: S105
 | 
						|
 | 
						|
    @default("http_pwd")
 | 
						|
    def _http_pwd_default(self):
 | 
						|
        return os.environ.get(self.http_pwd_env)
 | 
						|
 | 
						|
    headers_default_value = "{}"
 | 
						|
    headers_env = "JUPYTER_GATEWAY_HEADERS"
 | 
						|
    headers = Unicode(
 | 
						|
        default_value=headers_default_value,
 | 
						|
        allow_none=True,
 | 
						|
        config=True,
 | 
						|
        help="""Additional HTTP headers to pass on the request.  This value will be converted to a dict.
 | 
						|
          (JUPYTER_GATEWAY_HEADERS env var)
 | 
						|
        """,
 | 
						|
    )
 | 
						|
 | 
						|
    @default("headers")
 | 
						|
    def _headers_default(self):
 | 
						|
        return os.environ.get(self.headers_env, self.headers_default_value)
 | 
						|
 | 
						|
    auth_header_key_default_value = "Authorization"
 | 
						|
    auth_header_key = Unicode(
 | 
						|
        config=True,
 | 
						|
        help="""The authorization header's key name (typically 'Authorization') used in the HTTP headers. The
 | 
						|
header will be formatted as::
 | 
						|
 | 
						|
{'{auth_header_key}': '{auth_scheme} {auth_token}'}
 | 
						|
 | 
						|
If the authorization header key takes a single value, `auth_scheme` should be set to None and
 | 
						|
'auth_token' should be configured to use the appropriate value.
 | 
						|
 | 
						|
(JUPYTER_GATEWAY_AUTH_HEADER_KEY env var)""",
 | 
						|
    )
 | 
						|
    auth_header_key_env = "JUPYTER_GATEWAY_AUTH_HEADER_KEY"
 | 
						|
 | 
						|
    @default("auth_header_key")
 | 
						|
    def _auth_header_key_default(self):
 | 
						|
        return os.environ.get(self.auth_header_key_env, self.auth_header_key_default_value)
 | 
						|
 | 
						|
    auth_token_default_value = ""
 | 
						|
    auth_token = Unicode(
 | 
						|
        default_value=None,
 | 
						|
        allow_none=True,
 | 
						|
        config=True,
 | 
						|
        help="""The authorization token used in the HTTP headers. The header will be formatted as::
 | 
						|
 | 
						|
{'{auth_header_key}': '{auth_scheme} {auth_token}'}
 | 
						|
 | 
						|
(JUPYTER_GATEWAY_AUTH_TOKEN env var)""",
 | 
						|
    )
 | 
						|
    auth_token_env = "JUPYTER_GATEWAY_AUTH_TOKEN"  # noqa: S105
 | 
						|
 | 
						|
    @default("auth_token")
 | 
						|
    def _auth_token_default(self):
 | 
						|
        return os.environ.get(self.auth_token_env, self.auth_token_default_value)
 | 
						|
 | 
						|
    auth_scheme_default_value = "token"  # This value is purely for backwards compatibility
 | 
						|
    auth_scheme = Unicode(
 | 
						|
        allow_none=True,
 | 
						|
        config=True,
 | 
						|
        help="""The auth scheme, added as a prefix to the authorization token used in the HTTP headers.
 | 
						|
(JUPYTER_GATEWAY_AUTH_SCHEME env var)""",
 | 
						|
    )
 | 
						|
    auth_scheme_env = "JUPYTER_GATEWAY_AUTH_SCHEME"
 | 
						|
 | 
						|
    @default("auth_scheme")
 | 
						|
    def _auth_scheme_default(self):
 | 
						|
        return os.environ.get(self.auth_scheme_env, self.auth_scheme_default_value)
 | 
						|
 | 
						|
    validate_cert_default_value = True
 | 
						|
    validate_cert_env = "JUPYTER_GATEWAY_VALIDATE_CERT"
 | 
						|
    validate_cert = Bool(
 | 
						|
        default_value=validate_cert_default_value,
 | 
						|
        config=True,
 | 
						|
        help="""For HTTPS requests, determines if server's certificate should be validated or not.
 | 
						|
(JUPYTER_GATEWAY_VALIDATE_CERT env var)""",
 | 
						|
    )
 | 
						|
 | 
						|
    @default("validate_cert")
 | 
						|
    def _validate_cert_default(self):
 | 
						|
        return bool(
 | 
						|
            os.environ.get(self.validate_cert_env, str(self.validate_cert_default_value))
 | 
						|
            not in ["no", "false"]
 | 
						|
        )
 | 
						|
 | 
						|
    allowed_envs_default_value = ""
 | 
						|
    allowed_envs_env = "JUPYTER_GATEWAY_ALLOWED_ENVS"
 | 
						|
    allowed_envs = Unicode(
 | 
						|
        default_value=allowed_envs_default_value,
 | 
						|
        config=True,
 | 
						|
        help="""A comma-separated list of environment variable names that will be included, along with
 | 
						|
their values, in the kernel startup request.  The corresponding `client_envs` configuration
 | 
						|
value must also be set on the Gateway server - since that configuration value indicates which
 | 
						|
environmental values to make available to the kernel. (JUPYTER_GATEWAY_ALLOWED_ENVS env var)""",
 | 
						|
    )
 | 
						|
 | 
						|
    @default("allowed_envs")
 | 
						|
    def _allowed_envs_default(self):
 | 
						|
        return os.environ.get(
 | 
						|
            self.allowed_envs_env,
 | 
						|
            os.environ.get("JUPYTER_GATEWAY_ENV_WHITELIST", self.allowed_envs_default_value),
 | 
						|
        )
 | 
						|
 | 
						|
    env_whitelist = Unicode(
 | 
						|
        default_value=allowed_envs_default_value,
 | 
						|
        config=True,
 | 
						|
        help="""Deprecated, use `GatewayClient.allowed_envs`""",
 | 
						|
    )
 | 
						|
 | 
						|
    gateway_retry_interval_default_value = 1.0
 | 
						|
    gateway_retry_interval_env = "JUPYTER_GATEWAY_RETRY_INTERVAL"
 | 
						|
    gateway_retry_interval = Float(
 | 
						|
        default_value=gateway_retry_interval_default_value,
 | 
						|
        config=True,
 | 
						|
        help="""The time allowed for HTTP reconnection with the Gateway server for the first time.
 | 
						|
Next will be JUPYTER_GATEWAY_RETRY_INTERVAL multiplied by two in factor of numbers of retries
 | 
						|
but less than JUPYTER_GATEWAY_RETRY_INTERVAL_MAX.
 | 
						|
(JUPYTER_GATEWAY_RETRY_INTERVAL env var)""",
 | 
						|
    )
 | 
						|
 | 
						|
    @default("gateway_retry_interval")
 | 
						|
    def _gateway_retry_interval_default(self):
 | 
						|
        return float(
 | 
						|
            os.environ.get(
 | 
						|
                self.gateway_retry_interval_env,
 | 
						|
                self.gateway_retry_interval_default_value,
 | 
						|
            )
 | 
						|
        )
 | 
						|
 | 
						|
    gateway_retry_interval_max_default_value = 30.0
 | 
						|
    gateway_retry_interval_max_env = "JUPYTER_GATEWAY_RETRY_INTERVAL_MAX"
 | 
						|
    gateway_retry_interval_max = Float(
 | 
						|
        default_value=gateway_retry_interval_max_default_value,
 | 
						|
        config=True,
 | 
						|
        help="""The maximum time allowed for HTTP reconnection retry with the Gateway server.
 | 
						|
(JUPYTER_GATEWAY_RETRY_INTERVAL_MAX env var)""",
 | 
						|
    )
 | 
						|
 | 
						|
    @default("gateway_retry_interval_max")
 | 
						|
    def _gateway_retry_interval_max_default(self):
 | 
						|
        return float(
 | 
						|
            os.environ.get(
 | 
						|
                self.gateway_retry_interval_max_env,
 | 
						|
                self.gateway_retry_interval_max_default_value,
 | 
						|
            )
 | 
						|
        )
 | 
						|
 | 
						|
    gateway_retry_max_default_value = 5
 | 
						|
    gateway_retry_max_env = "JUPYTER_GATEWAY_RETRY_MAX"
 | 
						|
    gateway_retry_max = Int(
 | 
						|
        default_value=gateway_retry_max_default_value,
 | 
						|
        config=True,
 | 
						|
        help="""The maximum retries allowed for HTTP reconnection with the Gateway server.
 | 
						|
(JUPYTER_GATEWAY_RETRY_MAX env var)""",
 | 
						|
    )
 | 
						|
 | 
						|
    @default("gateway_retry_max")
 | 
						|
    def _gateway_retry_max_default(self):
 | 
						|
        return int(os.environ.get(self.gateway_retry_max_env, self.gateway_retry_max_default_value))
 | 
						|
 | 
						|
    gateway_token_renewer_class_default_value = (
 | 
						|
        "jupyter_server.gateway.gateway_client.NoOpTokenRenewer"  # noqa: S105
 | 
						|
    )
 | 
						|
    gateway_token_renewer_class_env = "JUPYTER_GATEWAY_TOKEN_RENEWER_CLASS"  # noqa: S105
 | 
						|
    gateway_token_renewer_class = Type(
 | 
						|
        klass=GatewayTokenRenewerBase,
 | 
						|
        config=True,
 | 
						|
        help="""The class to use for Gateway token renewal. (JUPYTER_GATEWAY_TOKEN_RENEWER_CLASS env var)""",
 | 
						|
    )
 | 
						|
 | 
						|
    @default("gateway_token_renewer_class")
 | 
						|
    def _gateway_token_renewer_class_default(self):
 | 
						|
        return os.environ.get(
 | 
						|
            self.gateway_token_renewer_class_env, self.gateway_token_renewer_class_default_value
 | 
						|
        )
 | 
						|
 | 
						|
    launch_timeout_pad_default_value = 2.0
 | 
						|
    launch_timeout_pad_env = "JUPYTER_GATEWAY_LAUNCH_TIMEOUT_PAD"
 | 
						|
    launch_timeout_pad = Float(
 | 
						|
        default_value=launch_timeout_pad_default_value,
 | 
						|
        config=True,
 | 
						|
        help="""Timeout pad to be ensured between KERNEL_LAUNCH_TIMEOUT and request_timeout
 | 
						|
such that request_timeout >= KERNEL_LAUNCH_TIMEOUT + launch_timeout_pad.
 | 
						|
(JUPYTER_GATEWAY_LAUNCH_TIMEOUT_PAD env var)""",
 | 
						|
    )
 | 
						|
 | 
						|
    @default("launch_timeout_pad")
 | 
						|
    def _launch_timeout_pad_default(self):
 | 
						|
        return float(
 | 
						|
            os.environ.get(
 | 
						|
                self.launch_timeout_pad_env,
 | 
						|
                self.launch_timeout_pad_default_value,
 | 
						|
            )
 | 
						|
        )
 | 
						|
 | 
						|
    accept_cookies_value = False
 | 
						|
    accept_cookies_env = "JUPYTER_GATEWAY_ACCEPT_COOKIES"
 | 
						|
    accept_cookies = Bool(
 | 
						|
        default_value=accept_cookies_value,
 | 
						|
        config=True,
 | 
						|
        help="""Accept and manage cookies sent by the service side. This is often useful
 | 
						|
        for load balancers to decide which backend node to use.
 | 
						|
        (JUPYTER_GATEWAY_ACCEPT_COOKIES env var)""",
 | 
						|
    )
 | 
						|
 | 
						|
    @default("accept_cookies")
 | 
						|
    def _accept_cookies_default(self):
 | 
						|
        return bool(
 | 
						|
            os.environ.get(self.accept_cookies_env, str(self.accept_cookies_value).lower())
 | 
						|
            not in ["no", "false"]
 | 
						|
        )
 | 
						|
 | 
						|
    _deprecated_traits = {
 | 
						|
        "env_whitelist": ("allowed_envs", "2.0"),
 | 
						|
    }
 | 
						|
 | 
						|
    # Method copied from
 | 
						|
    # https://github.com/jupyterhub/jupyterhub/blob/d1a85e53dccfc7b1dd81b0c1985d158cc6b61820/jupyterhub/auth.py#L143-L161
 | 
						|
    @observe(*list(_deprecated_traits))
 | 
						|
    def _deprecated_trait(self, change):
 | 
						|
        """observer for deprecated traits"""
 | 
						|
        old_attr = change.name
 | 
						|
        new_attr, version = self._deprecated_traits[old_attr]
 | 
						|
        new_value = getattr(self, new_attr)
 | 
						|
        if new_value != change.new:
 | 
						|
            # only warn if different
 | 
						|
            # protects backward-compatible config from warnings
 | 
						|
            # if they set the same value under both names
 | 
						|
            self.log.warning(
 | 
						|
                f"{self.__class__.__name__}.{old_attr} is deprecated in jupyter_server "
 | 
						|
                f"{version}, use {self.__class__.__name__}.{new_attr} instead"
 | 
						|
            )
 | 
						|
            setattr(self, new_attr, change.new)
 | 
						|
 | 
						|
    @property
 | 
						|
    def gateway_enabled(self):
 | 
						|
        return bool(self.url is not None and len(self.url) > 0)
 | 
						|
 | 
						|
    # Ensure KERNEL_LAUNCH_TIMEOUT has a default value.
 | 
						|
    KERNEL_LAUNCH_TIMEOUT = int(os.environ.get("KERNEL_LAUNCH_TIMEOUT", 40))
 | 
						|
 | 
						|
    _connection_args: dict[str, ty.Any]  # initialized on first use
 | 
						|
 | 
						|
    gateway_token_renewer: GatewayTokenRenewerBase
 | 
						|
 | 
						|
    def __init__(self, **kwargs):
 | 
						|
        """Initialize a gateway client."""
 | 
						|
        super().__init__(**kwargs)
 | 
						|
        self._connection_args = {}  # initialized on first use
 | 
						|
        self.gateway_token_renewer = self.gateway_token_renewer_class(parent=self, log=self.log)  # type:ignore[abstract]
 | 
						|
 | 
						|
        # store of cookies with store time
 | 
						|
        self._cookies: dict[str, tuple[Morsel[ty.Any], datetime]] = {}
 | 
						|
 | 
						|
    def init_connection_args(self):
 | 
						|
        """Initialize arguments used on every request.  Since these are primarily static values,
 | 
						|
        we'll perform this operation once.
 | 
						|
        """
 | 
						|
        # Ensure that request timeout and KERNEL_LAUNCH_TIMEOUT are in sync, taking the
 | 
						|
        #  greater value of the two and taking into account the following relation:
 | 
						|
        #  request_timeout = KERNEL_LAUNCH_TIME + padding
 | 
						|
        minimum_request_timeout = (
 | 
						|
            float(GatewayClient.KERNEL_LAUNCH_TIMEOUT) + self.launch_timeout_pad
 | 
						|
        )
 | 
						|
        if self.request_timeout < minimum_request_timeout:
 | 
						|
            self.request_timeout = minimum_request_timeout
 | 
						|
        elif self.request_timeout > minimum_request_timeout:
 | 
						|
            GatewayClient.KERNEL_LAUNCH_TIMEOUT = int(
 | 
						|
                self.request_timeout - self.launch_timeout_pad
 | 
						|
            )
 | 
						|
        # Ensure any adjustments are reflected in env.
 | 
						|
        os.environ["KERNEL_LAUNCH_TIMEOUT"] = str(GatewayClient.KERNEL_LAUNCH_TIMEOUT)
 | 
						|
 | 
						|
        if self.headers:
 | 
						|
            self._connection_args["headers"] = json.loads(self.headers)
 | 
						|
            if self.auth_header_key not in self._connection_args["headers"]:
 | 
						|
                self._connection_args["headers"].update(
 | 
						|
                    {f"{self.auth_header_key}": f"{self.auth_scheme} {self.auth_token}"}
 | 
						|
                )
 | 
						|
        self._connection_args["connect_timeout"] = self.connect_timeout
 | 
						|
        self._connection_args["request_timeout"] = self.request_timeout
 | 
						|
        self._connection_args["validate_cert"] = self.validate_cert
 | 
						|
        if self.client_cert:
 | 
						|
            self._connection_args["client_cert"] = self.client_cert
 | 
						|
            self._connection_args["client_key"] = self.client_key
 | 
						|
            if self.ca_certs:
 | 
						|
                self._connection_args["ca_certs"] = self.ca_certs
 | 
						|
        if self.http_user:
 | 
						|
            self._connection_args["auth_username"] = self.http_user
 | 
						|
        if self.http_pwd:
 | 
						|
            self._connection_args["auth_password"] = self.http_pwd
 | 
						|
 | 
						|
    def load_connection_args(self, **kwargs):
 | 
						|
        """Merges the static args relative to the connection, with the given keyword arguments.  If static
 | 
						|
        args have yet to be initialized, we'll do that here.
 | 
						|
 | 
						|
        """
 | 
						|
        if len(self._connection_args) == 0:
 | 
						|
            self.init_connection_args()
 | 
						|
 | 
						|
        # Give token renewal a shot at renewing the token
 | 
						|
        prev_auth_token = self.auth_token
 | 
						|
        if self.auth_token is not None:
 | 
						|
            try:
 | 
						|
                self.auth_token = self.gateway_token_renewer.get_token(
 | 
						|
                    self.auth_header_key, self.auth_scheme, self.auth_token
 | 
						|
                )
 | 
						|
            except Exception as ex:
 | 
						|
                self.log.error(
 | 
						|
                    f"An exception occurred attempting to renew the "
 | 
						|
                    f"Gateway authorization token using an instance of class "
 | 
						|
                    f"'{self.gateway_token_renewer_class}'.  The request will "
 | 
						|
                    f"proceed using the current token value.  Exception was: {ex}"
 | 
						|
                )
 | 
						|
                self.auth_token = prev_auth_token
 | 
						|
 | 
						|
        for arg, value in self._connection_args.items():
 | 
						|
            if arg == "headers":
 | 
						|
                given_value = kwargs.setdefault(arg, {})
 | 
						|
                if isinstance(given_value, dict):
 | 
						|
                    given_value.update(value)
 | 
						|
                    # Ensure the auth header is current
 | 
						|
                    given_value.update(
 | 
						|
                        {f"{self.auth_header_key}": f"{self.auth_scheme} {self.auth_token}"}
 | 
						|
                    )
 | 
						|
            else:
 | 
						|
                kwargs[arg] = value
 | 
						|
 | 
						|
        if self.accept_cookies:
 | 
						|
            self._update_cookie_header(kwargs)
 | 
						|
 | 
						|
        return kwargs
 | 
						|
 | 
						|
    def update_cookies(self, cookie: SimpleCookie) -> None:
 | 
						|
        """Update cookies from existing requests for load balancers"""
 | 
						|
        if not self.accept_cookies:
 | 
						|
            return
 | 
						|
 | 
						|
        store_time = datetime.now(tz=timezone.utc)
 | 
						|
        for key, item in cookie.items():
 | 
						|
            # Convert "expires" arg into "max-age" to facilitate expiration management.
 | 
						|
            # As "max-age" has precedence, ignore "expires" when "max-age" exists.
 | 
						|
            if item.get("expires") and not item.get("max-age"):
 | 
						|
                expire_timedelta = parsedate_to_datetime(item["expires"]) - store_time
 | 
						|
                item["max-age"] = str(expire_timedelta.total_seconds())
 | 
						|
 | 
						|
            self._cookies[key] = (item, store_time)
 | 
						|
 | 
						|
    def _clear_expired_cookies(self) -> None:
 | 
						|
        """Clear expired cookies."""
 | 
						|
        check_time = datetime.now(tz=timezone.utc)
 | 
						|
        expired_keys = []
 | 
						|
 | 
						|
        for key, (morsel, store_time) in self._cookies.items():
 | 
						|
            cookie_max_age = morsel.get("max-age")
 | 
						|
            if not cookie_max_age:
 | 
						|
                continue
 | 
						|
            expired_timedelta = check_time - store_time
 | 
						|
            if expired_timedelta.total_seconds() > float(cookie_max_age):
 | 
						|
                expired_keys.append(key)
 | 
						|
 | 
						|
        for key in expired_keys:
 | 
						|
            self._cookies.pop(key)
 | 
						|
 | 
						|
    def _update_cookie_header(self, connection_args: dict[str, ty.Any]) -> None:
 | 
						|
        """Update a cookie header."""
 | 
						|
        self._clear_expired_cookies()
 | 
						|
 | 
						|
        gateway_cookie_values = "; ".join(
 | 
						|
            f"{name}={morsel.coded_value}" for name, (morsel, _time) in self._cookies.items()
 | 
						|
        )
 | 
						|
        if gateway_cookie_values:
 | 
						|
            headers = connection_args.get("headers", {})
 | 
						|
 | 
						|
            # As headers are case-insensitive, we get existing name of cookie header,
 | 
						|
            #  or use "Cookie" by default.
 | 
						|
            cookie_header_name = next(
 | 
						|
                (header_key for header_key in headers if header_key.lower() == "cookie"),
 | 
						|
                "Cookie",
 | 
						|
            )
 | 
						|
            existing_cookie = headers.get(cookie_header_name)
 | 
						|
 | 
						|
            # merge gateway-managed cookies with cookies already in arguments
 | 
						|
            if existing_cookie:
 | 
						|
                gateway_cookie_values = existing_cookie + "; " + gateway_cookie_values
 | 
						|
            headers[cookie_header_name] = gateway_cookie_values
 | 
						|
 | 
						|
            connection_args["headers"] = headers
 | 
						|
 | 
						|
 | 
						|
class RetryableHTTPClient:
 | 
						|
    """
 | 
						|
    Inspired by urllib.util.Retry (https://urllib3.readthedocs.io/en/stable/reference/urllib3.util.html),
 | 
						|
    this class is initialized with desired retry characteristics, uses a recursive method `fetch()` against an instance
 | 
						|
    of `AsyncHTTPClient` which tracks the current retry count across applicable request retries.
 | 
						|
    """
 | 
						|
 | 
						|
    MAX_RETRIES_DEFAULT = 2
 | 
						|
    MAX_RETRIES_CAP = 10  # The upper limit to max_retries value.
 | 
						|
    max_retries: int = int(os.getenv("JUPYTER_GATEWAY_MAX_REQUEST_RETRIES", MAX_RETRIES_DEFAULT))
 | 
						|
    max_retries = max(0, min(max_retries, MAX_RETRIES_CAP))  # Enforce boundaries
 | 
						|
    retried_methods: set[str] = {"GET", "DELETE"}
 | 
						|
    retried_errors: set[int] = {502, 503, 504, 599}
 | 
						|
    retried_exceptions: set[type] = {ConnectionError}
 | 
						|
    backoff_factor: float = 0.1
 | 
						|
 | 
						|
    def __init__(self):
 | 
						|
        """Initialize the retryable http client."""
 | 
						|
        self.retry_count: int = 0
 | 
						|
        self.client: AsyncHTTPClient = AsyncHTTPClient()
 | 
						|
 | 
						|
    async def fetch(self, endpoint: str, **kwargs: ty.Any) -> HTTPResponse:
 | 
						|
        """
 | 
						|
        Retryable AsyncHTTPClient.fetch() method.  When the request fails, this method will
 | 
						|
        recurse up to max_retries times if the condition deserves a retry.
 | 
						|
        """
 | 
						|
        self.retry_count = 0
 | 
						|
        return await self._fetch(endpoint, **kwargs)
 | 
						|
 | 
						|
    async def _fetch(self, endpoint: str, **kwargs: ty.Any) -> HTTPResponse:
 | 
						|
        """
 | 
						|
        Performs the fetch against the contained AsyncHTTPClient instance and determines
 | 
						|
        if retry is necessary on any exceptions.  If so, retry is performed recursively.
 | 
						|
        """
 | 
						|
        try:
 | 
						|
            response: HTTPResponse = await self.client.fetch(endpoint, **kwargs)
 | 
						|
        except Exception as e:
 | 
						|
            is_retryable: bool = await self._is_retryable(kwargs["method"], e)
 | 
						|
            if not is_retryable:
 | 
						|
                raise e
 | 
						|
            logging.getLogger("ServerApp").info(
 | 
						|
                f"Attempting retry ({self.retry_count}) against "
 | 
						|
                f"endpoint '{endpoint}'.  Retried error: '{e!r}'"
 | 
						|
            )
 | 
						|
            response = await self._fetch(endpoint, **kwargs)
 | 
						|
        return response
 | 
						|
 | 
						|
    async def _is_retryable(self, method: str, exception: Exception) -> bool:
 | 
						|
        """Determines if the given exception is retryable based on object's configuration."""
 | 
						|
 | 
						|
        if method not in self.retried_methods:
 | 
						|
            return False
 | 
						|
        if self.retry_count == self.max_retries:
 | 
						|
            return False
 | 
						|
 | 
						|
        # Determine if error is retryable...
 | 
						|
        if isinstance(exception, HTTPClientError):
 | 
						|
            hce: HTTPClientError = exception
 | 
						|
            if hce.code not in self.retried_errors:
 | 
						|
                return False
 | 
						|
        elif not any(isinstance(exception, error) for error in self.retried_exceptions):
 | 
						|
            return False
 | 
						|
 | 
						|
        # Is retryable, wait for backoff, then increment count
 | 
						|
        await asyncio.sleep(self.backoff_factor * (2**self.retry_count))
 | 
						|
        self.retry_count += 1
 | 
						|
        return True
 | 
						|
 | 
						|
 | 
						|
async def gateway_request(endpoint: str, **kwargs: ty.Any) -> HTTPResponse:
 | 
						|
    """Make an async request to kernel gateway endpoint, returns a response"""
 | 
						|
    gateway_client = GatewayClient.instance()
 | 
						|
    kwargs = gateway_client.load_connection_args(**kwargs)
 | 
						|
    rhc = RetryableHTTPClient()
 | 
						|
    try:
 | 
						|
        response = await rhc.fetch(endpoint, **kwargs)
 | 
						|
        gateway_client.emit(
 | 
						|
            data={STATUS_KEY: SUCCESS_STATUS, STATUS_CODE_KEY: 200, MESSAGE_KEY: "success"}
 | 
						|
        )
 | 
						|
    # Trap a set of common exceptions so that we can inform the user that their Gateway url is incorrect
 | 
						|
    # or the server is not running.
 | 
						|
    # NOTE: We do this here since this handler is called during the server's startup and subsequent refreshes
 | 
						|
    # of the tree view.
 | 
						|
    except HTTPClientError as e:
 | 
						|
        gateway_client.emit(
 | 
						|
            data={STATUS_KEY: ERROR_STATUS, STATUS_CODE_KEY: e.code, MESSAGE_KEY: str(e.message)}
 | 
						|
        )
 | 
						|
        error_reason = (
 | 
						|
            f"Exception while attempting to connect to Gateway server url '{gateway_client.url}'"
 | 
						|
        )
 | 
						|
        error_message = e.message
 | 
						|
        if e.response:
 | 
						|
            try:
 | 
						|
                error_payload = json.loads(e.response.body)
 | 
						|
                error_reason = error_payload.get("reason") or error_reason
 | 
						|
                error_message = error_payload.get("message") or error_message
 | 
						|
            except json.decoder.JSONDecodeError:
 | 
						|
                error_reason = e.response.body.decode()
 | 
						|
 | 
						|
        raise web.HTTPError(
 | 
						|
            e.code,
 | 
						|
            f"Error from Gateway: [{error_message}] {error_reason}. "
 | 
						|
            "Ensure gateway url is valid and the Gateway instance is running.",
 | 
						|
        ) from e
 | 
						|
    except ConnectionError as e:
 | 
						|
        gateway_client.emit(
 | 
						|
            data={STATUS_KEY: ERROR_STATUS, STATUS_CODE_KEY: 503, MESSAGE_KEY: str(e)}
 | 
						|
        )
 | 
						|
        raise web.HTTPError(
 | 
						|
            503,
 | 
						|
            f"ConnectionError was received from Gateway server url '{gateway_client.url}'.  "
 | 
						|
            "Check to be sure the Gateway instance is running.",
 | 
						|
        ) from e
 | 
						|
    except gaierror as e:
 | 
						|
        gateway_client.emit(
 | 
						|
            data={STATUS_KEY: ERROR_STATUS, STATUS_CODE_KEY: 404, MESSAGE_KEY: str(e)}
 | 
						|
        )
 | 
						|
        raise web.HTTPError(
 | 
						|
            404,
 | 
						|
            f"The Gateway server specified in the gateway_url '{gateway_client.url}' doesn't "
 | 
						|
            f"appear to be valid.  Ensure gateway url is valid and the Gateway instance is running.",
 | 
						|
        ) from e
 | 
						|
    except Exception as e:
 | 
						|
        gateway_client.emit(
 | 
						|
            data={STATUS_KEY: ERROR_STATUS, STATUS_CODE_KEY: 505, MESSAGE_KEY: str(e)}
 | 
						|
        )
 | 
						|
        logging.getLogger("ServerApp").error(
 | 
						|
            "Exception while trying to launch kernel via Gateway URL %s: %s",
 | 
						|
            gateway_client.url,
 | 
						|
            e,
 | 
						|
        )
 | 
						|
        raise e
 | 
						|
 | 
						|
    if gateway_client.accept_cookies:
 | 
						|
        # Update cookies on GatewayClient from server if configured.
 | 
						|
        cookie_values = response.headers.get("Set-Cookie")
 | 
						|
        if cookie_values:
 | 
						|
            cookie: SimpleCookie = SimpleCookie()
 | 
						|
            cookie.load(cookie_values)
 | 
						|
            gateway_client.update_cookies(cookie)
 | 
						|
    return response
 |