You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
211 lines
6.6 KiB
Python
211 lines
6.6 KiB
Python
# Copyright (c) Jupyter Development Team.
|
|
# Distributed under the terms of the Modified BSD License.
|
|
|
|
"""Testing utils."""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
import sys
|
|
from http.cookies import SimpleCookie
|
|
from pathlib import Path
|
|
from urllib.parse import parse_qs, urlparse
|
|
|
|
import tornado.httpclient
|
|
import tornado.web
|
|
from openapi_core import V30RequestValidator, V30ResponseValidator
|
|
from openapi_core.spec.paths import Spec
|
|
from openapi_core.validation.request.datatypes import RequestParameters
|
|
from tornado.httpclient import HTTPRequest, HTTPResponse
|
|
from werkzeug.datastructures import Headers, ImmutableMultiDict
|
|
|
|
from jupyterlab_server.spec import get_openapi_spec
|
|
|
|
HERE = Path(os.path.dirname(__file__)).resolve()
|
|
|
|
with open(HERE / "test_data" / "app-settings" / "overrides.json", encoding="utf-8") as fid:
|
|
big_unicode_string = json.load(fid)["@jupyterlab/unicode-extension:plugin"]["comment"]
|
|
|
|
|
|
class TornadoOpenAPIRequest:
|
|
"""
|
|
Converts a torando request to an OpenAPI one
|
|
"""
|
|
|
|
def __init__(self, request: HTTPRequest, spec: Spec):
|
|
"""Initialize the request."""
|
|
self.request = request
|
|
self.spec = spec
|
|
if request.url is None:
|
|
msg = "Request URL is missing" # type:ignore[unreachable]
|
|
raise RuntimeError(msg)
|
|
self._url_parsed = urlparse(request.url)
|
|
|
|
cookie: SimpleCookie = SimpleCookie()
|
|
cookie.load(request.headers.get("Set-Cookie", ""))
|
|
cookies = {}
|
|
for key, morsel in cookie.items():
|
|
cookies[key] = morsel.value
|
|
|
|
# extract the path
|
|
o = urlparse(request.url)
|
|
|
|
# gets deduced by path finder against spec
|
|
path: dict = {}
|
|
|
|
self.parameters = RequestParameters(
|
|
query=ImmutableMultiDict(parse_qs(o.query)),
|
|
header=dict(request.headers),
|
|
cookie=ImmutableMultiDict(cookies),
|
|
path=path,
|
|
)
|
|
|
|
@property
|
|
def content_type(self) -> str:
|
|
return "application/json"
|
|
|
|
@property
|
|
def host_url(self) -> str:
|
|
url = self.request.url
|
|
return url[: url.index("/lab")]
|
|
|
|
@property
|
|
def path(self) -> str:
|
|
# extract the best matching url
|
|
# work around lack of support for path parameters which can contain slashes
|
|
# https://github.com/OAI/OpenAPI-Specification/issues/892
|
|
url = None
|
|
o = urlparse(self.request.url)
|
|
for path_ in self.spec["paths"]:
|
|
if url:
|
|
continue # type:ignore[unreachable]
|
|
has_arg = "{" in path_
|
|
path = path_[: path_.index("{")] if has_arg else path_
|
|
if path in o.path:
|
|
u = o.path[o.path.index(path) :]
|
|
if not has_arg and len(u) == len(path):
|
|
url = u
|
|
if has_arg and not u.endswith("/"):
|
|
url = u[: len(path)] + r"foo"
|
|
|
|
if url is None:
|
|
msg = f"Could not find matching pattern for {o.path}"
|
|
raise ValueError(msg)
|
|
return url
|
|
|
|
@property
|
|
def method(self) -> str:
|
|
method = self.request.method
|
|
return method and method.lower() or ""
|
|
|
|
@property
|
|
def body(self) -> bytes | None:
|
|
if self.request.body is None:
|
|
return None # type:ignore[unreachable]
|
|
if not isinstance(self.request.body, bytes):
|
|
msg = "Request body is invalid" # type:ignore[unreachable]
|
|
raise AssertionError(msg)
|
|
return self.request.body
|
|
|
|
@property
|
|
def mimetype(self) -> str:
|
|
# Order matters because all tornado requests
|
|
# include Accept */* which does not necessarily match the content type
|
|
request = self.request
|
|
return (
|
|
request.headers.get("Content-Type")
|
|
or request.headers.get("Accept")
|
|
or "application/json"
|
|
)
|
|
|
|
|
|
class TornadoOpenAPIResponse:
|
|
"""A tornado open API response."""
|
|
|
|
def __init__(self, response: HTTPResponse):
|
|
"""Initialize the response."""
|
|
self.response = response
|
|
|
|
@property
|
|
def data(self) -> bytes | None:
|
|
if not isinstance(self.response.body, bytes):
|
|
msg = "Response body is invalid" # type:ignore[unreachable]
|
|
raise AssertionError(msg)
|
|
return self.response.body
|
|
|
|
@property
|
|
def status_code(self) -> int:
|
|
return int(self.response.code)
|
|
|
|
@property
|
|
def content_type(self) -> str:
|
|
return "application/json"
|
|
|
|
@property
|
|
def mimetype(self) -> str:
|
|
return str(self.response.headers.get("Content-Type", "application/json"))
|
|
|
|
@property
|
|
def headers(self) -> Headers:
|
|
return Headers(dict(self.response.headers))
|
|
|
|
|
|
def validate_request(response: HTTPResponse) -> None:
|
|
"""Validate an API request"""
|
|
openapi_spec = get_openapi_spec()
|
|
|
|
request = TornadoOpenAPIRequest(response.request, openapi_spec)
|
|
V30RequestValidator(openapi_spec).validate(request)
|
|
|
|
torn_response = TornadoOpenAPIResponse(response)
|
|
V30ResponseValidator(openapi_spec).validate(request, torn_response)
|
|
|
|
|
|
def maybe_patch_ioloop() -> None:
|
|
"""a windows 3.8+ patch for the asyncio loop"""
|
|
if (
|
|
sys.platform.startswith("win")
|
|
and tornado.version_info < (6, 1)
|
|
and sys.version_info >= (3, 8)
|
|
):
|
|
try:
|
|
from asyncio import WindowsProactorEventLoopPolicy, WindowsSelectorEventLoopPolicy
|
|
except ImportError:
|
|
pass
|
|
# not affected
|
|
else:
|
|
from asyncio import get_event_loop_policy, set_event_loop_policy
|
|
|
|
if type(get_event_loop_policy()) is WindowsProactorEventLoopPolicy:
|
|
# WindowsProactorEventLoopPolicy is not compatible with tornado 6
|
|
# fallback to the pre-3.8 default of Selector
|
|
set_event_loop_policy(WindowsSelectorEventLoopPolicy())
|
|
|
|
|
|
def expected_http_error(
|
|
error: Exception, expected_code: int, expected_message: str | None = None
|
|
) -> bool:
|
|
"""Check that the error matches the expected output error."""
|
|
e = error.value # type:ignore[attr-defined]
|
|
if isinstance(e, tornado.web.HTTPError):
|
|
if expected_code != e.status_code:
|
|
return False
|
|
if expected_message is not None and expected_message != str(e):
|
|
return False
|
|
return True
|
|
if any(
|
|
[
|
|
isinstance(e, tornado.httpclient.HTTPClientError),
|
|
isinstance(e, tornado.httpclient.HTTPError),
|
|
]
|
|
):
|
|
if expected_code != e.code:
|
|
return False
|
|
if expected_message:
|
|
message = json.loads(e.response.body.decode())["message"]
|
|
if expected_message != message:
|
|
return False
|
|
return True
|
|
|
|
return False
|