You cannot select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
	
	
		
			357 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
			
		
		
	
	
			357 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
from __future__ import annotations
 | 
						|
 | 
						|
import logging
 | 
						|
import re
 | 
						|
import threading
 | 
						|
import types
 | 
						|
import typing
 | 
						|
 | 
						|
import h2.config  # type: ignore[import-untyped]
 | 
						|
import h2.connection  # type: ignore[import-untyped]
 | 
						|
import h2.events  # type: ignore[import-untyped]
 | 
						|
 | 
						|
from .._base_connection import _TYPE_BODY
 | 
						|
from .._collections import HTTPHeaderDict
 | 
						|
from ..connection import HTTPSConnection, _get_default_user_agent
 | 
						|
from ..exceptions import ConnectionError
 | 
						|
from ..response import BaseHTTPResponse
 | 
						|
 | 
						|
orig_HTTPSConnection = HTTPSConnection
 | 
						|
 | 
						|
T = typing.TypeVar("T")
 | 
						|
 | 
						|
log = logging.getLogger(__name__)
 | 
						|
 | 
						|
RE_IS_LEGAL_HEADER_NAME = re.compile(rb"^[!#$%&'*+\-.^_`|~0-9a-z]+$")
 | 
						|
RE_IS_ILLEGAL_HEADER_VALUE = re.compile(rb"[\0\x00\x0a\x0d\r\n]|^[ \r\n\t]|[ \r\n\t]$")
 | 
						|
 | 
						|
 | 
						|
def _is_legal_header_name(name: bytes) -> bool:
 | 
						|
    """
 | 
						|
    "An implementation that validates fields according to the definitions in Sections
 | 
						|
    5.1 and 5.5 of [HTTP] only needs an additional check that field names do not
 | 
						|
    include uppercase characters." (https://httpwg.org/specs/rfc9113.html#n-field-validity)
 | 
						|
 | 
						|
    `http.client._is_legal_header_name` does not validate the field name according to the
 | 
						|
    HTTP 1.1 spec, so we do that here, in addition to checking for uppercase characters.
 | 
						|
 | 
						|
    This does not allow for the `:` character in the header name, so should not
 | 
						|
    be used to validate pseudo-headers.
 | 
						|
    """
 | 
						|
    return bool(RE_IS_LEGAL_HEADER_NAME.match(name))
 | 
						|
 | 
						|
 | 
						|
def _is_illegal_header_value(value: bytes) -> bool:
 | 
						|
    """
 | 
						|
    "A field value MUST NOT contain the zero value (ASCII NUL, 0x00), line feed
 | 
						|
    (ASCII LF, 0x0a), or carriage return (ASCII CR, 0x0d) at any position. A field
 | 
						|
    value MUST NOT start or end with an ASCII whitespace character (ASCII SP or HTAB,
 | 
						|
    0x20 or 0x09)." (https://httpwg.org/specs/rfc9113.html#n-field-validity)
 | 
						|
    """
 | 
						|
    return bool(RE_IS_ILLEGAL_HEADER_VALUE.search(value))
 | 
						|
 | 
						|
 | 
						|
class _LockedObject(typing.Generic[T]):
 | 
						|
    """
 | 
						|
    A wrapper class that hides a specific object behind a lock.
 | 
						|
    The goal here is to provide a simple way to protect access to an object
 | 
						|
    that cannot safely be simultaneously accessed from multiple threads. The
 | 
						|
    intended use of this class is simple: take hold of it with a context
 | 
						|
    manager, which returns the protected object.
 | 
						|
    """
 | 
						|
 | 
						|
    __slots__ = (
 | 
						|
        "lock",
 | 
						|
        "_obj",
 | 
						|
    )
 | 
						|
 | 
						|
    def __init__(self, obj: T):
 | 
						|
        self.lock = threading.RLock()
 | 
						|
        self._obj = obj
 | 
						|
 | 
						|
    def __enter__(self) -> T:
 | 
						|
        self.lock.acquire()
 | 
						|
        return self._obj
 | 
						|
 | 
						|
    def __exit__(
 | 
						|
        self,
 | 
						|
        exc_type: type[BaseException] | None,
 | 
						|
        exc_val: BaseException | None,
 | 
						|
        exc_tb: types.TracebackType | None,
 | 
						|
    ) -> None:
 | 
						|
        self.lock.release()
 | 
						|
 | 
						|
 | 
						|
class HTTP2Connection(HTTPSConnection):
 | 
						|
    def __init__(
 | 
						|
        self, host: str, port: int | None = None, **kwargs: typing.Any
 | 
						|
    ) -> None:
 | 
						|
        self._h2_conn = self._new_h2_conn()
 | 
						|
        self._h2_stream: int | None = None
 | 
						|
        self._headers: list[tuple[bytes, bytes]] = []
 | 
						|
 | 
						|
        if "proxy" in kwargs or "proxy_config" in kwargs:  # Defensive:
 | 
						|
            raise NotImplementedError("Proxies aren't supported with HTTP/2")
 | 
						|
 | 
						|
        super().__init__(host, port, **kwargs)
 | 
						|
 | 
						|
        if self._tunnel_host is not None:
 | 
						|
            raise NotImplementedError("Tunneling isn't supported with HTTP/2")
 | 
						|
 | 
						|
    def _new_h2_conn(self) -> _LockedObject[h2.connection.H2Connection]:
 | 
						|
        config = h2.config.H2Configuration(client_side=True)
 | 
						|
        return _LockedObject(h2.connection.H2Connection(config=config))
 | 
						|
 | 
						|
    def connect(self) -> None:
 | 
						|
        super().connect()
 | 
						|
        with self._h2_conn as conn:
 | 
						|
            conn.initiate_connection()
 | 
						|
            if data_to_send := conn.data_to_send():
 | 
						|
                self.sock.sendall(data_to_send)
 | 
						|
 | 
						|
    def putrequest(  # type: ignore[override]
 | 
						|
        self,
 | 
						|
        method: str,
 | 
						|
        url: str,
 | 
						|
        **kwargs: typing.Any,
 | 
						|
    ) -> None:
 | 
						|
        """putrequest
 | 
						|
        This deviates from the HTTPConnection method signature since we never need to override
 | 
						|
        sending accept-encoding headers or the host header.
 | 
						|
        """
 | 
						|
        if "skip_host" in kwargs:
 | 
						|
            raise NotImplementedError("`skip_host` isn't supported")
 | 
						|
        if "skip_accept_encoding" in kwargs:
 | 
						|
            raise NotImplementedError("`skip_accept_encoding` isn't supported")
 | 
						|
 | 
						|
        self._request_url = url or "/"
 | 
						|
        self._validate_path(url)  # type: ignore[attr-defined]
 | 
						|
 | 
						|
        if ":" in self.host:
 | 
						|
            authority = f"[{self.host}]:{self.port or 443}"
 | 
						|
        else:
 | 
						|
            authority = f"{self.host}:{self.port or 443}"
 | 
						|
 | 
						|
        self._headers.append((b":scheme", b"https"))
 | 
						|
        self._headers.append((b":method", method.encode()))
 | 
						|
        self._headers.append((b":authority", authority.encode()))
 | 
						|
        self._headers.append((b":path", url.encode()))
 | 
						|
 | 
						|
        with self._h2_conn as conn:
 | 
						|
            self._h2_stream = conn.get_next_available_stream_id()
 | 
						|
 | 
						|
    def putheader(self, header: str | bytes, *values: str | bytes) -> None:  # type: ignore[override]
 | 
						|
        # TODO SKIPPABLE_HEADERS from urllib3 are ignored.
 | 
						|
        header = header.encode() if isinstance(header, str) else header
 | 
						|
        header = header.lower()  # A lot of upstream code uses capitalized headers.
 | 
						|
        if not _is_legal_header_name(header):
 | 
						|
            raise ValueError(f"Illegal header name {str(header)}")
 | 
						|
 | 
						|
        for value in values:
 | 
						|
            value = value.encode() if isinstance(value, str) else value
 | 
						|
            if _is_illegal_header_value(value):
 | 
						|
                raise ValueError(f"Illegal header value {str(value)}")
 | 
						|
            self._headers.append((header, value))
 | 
						|
 | 
						|
    def endheaders(self, message_body: typing.Any = None) -> None:  # type: ignore[override]
 | 
						|
        if self._h2_stream is None:
 | 
						|
            raise ConnectionError("Must call `putrequest` first.")
 | 
						|
 | 
						|
        with self._h2_conn as conn:
 | 
						|
            conn.send_headers(
 | 
						|
                stream_id=self._h2_stream,
 | 
						|
                headers=self._headers,
 | 
						|
                end_stream=(message_body is None),
 | 
						|
            )
 | 
						|
            if data_to_send := conn.data_to_send():
 | 
						|
                self.sock.sendall(data_to_send)
 | 
						|
        self._headers = []  # Reset headers for the next request.
 | 
						|
 | 
						|
    def send(self, data: typing.Any) -> None:
 | 
						|
        """Send data to the server.
 | 
						|
        `data` can be: `str`, `bytes`, an iterable, or file-like objects
 | 
						|
        that support a .read() method.
 | 
						|
        """
 | 
						|
        if self._h2_stream is None:
 | 
						|
            raise ConnectionError("Must call `putrequest` first.")
 | 
						|
 | 
						|
        with self._h2_conn as conn:
 | 
						|
            if data_to_send := conn.data_to_send():
 | 
						|
                self.sock.sendall(data_to_send)
 | 
						|
 | 
						|
            if hasattr(data, "read"):  # file-like objects
 | 
						|
                while True:
 | 
						|
                    chunk = data.read(self.blocksize)
 | 
						|
                    if not chunk:
 | 
						|
                        break
 | 
						|
                    if isinstance(chunk, str):
 | 
						|
                        chunk = chunk.encode()  # pragma: no cover
 | 
						|
                    conn.send_data(self._h2_stream, chunk, end_stream=False)
 | 
						|
                    if data_to_send := conn.data_to_send():
 | 
						|
                        self.sock.sendall(data_to_send)
 | 
						|
                conn.end_stream(self._h2_stream)
 | 
						|
                return
 | 
						|
 | 
						|
            if isinstance(data, str):  # str -> bytes
 | 
						|
                data = data.encode()
 | 
						|
 | 
						|
            try:
 | 
						|
                if isinstance(data, bytes):
 | 
						|
                    conn.send_data(self._h2_stream, data, end_stream=True)
 | 
						|
                    if data_to_send := conn.data_to_send():
 | 
						|
                        self.sock.sendall(data_to_send)
 | 
						|
                else:
 | 
						|
                    for chunk in data:
 | 
						|
                        conn.send_data(self._h2_stream, chunk, end_stream=False)
 | 
						|
                        if data_to_send := conn.data_to_send():
 | 
						|
                            self.sock.sendall(data_to_send)
 | 
						|
                    conn.end_stream(self._h2_stream)
 | 
						|
            except TypeError:
 | 
						|
                raise TypeError(
 | 
						|
                    "`data` should be str, bytes, iterable, or file. got %r"
 | 
						|
                    % type(data)
 | 
						|
                )
 | 
						|
 | 
						|
    def set_tunnel(
 | 
						|
        self,
 | 
						|
        host: str,
 | 
						|
        port: int | None = None,
 | 
						|
        headers: typing.Mapping[str, str] | None = None,
 | 
						|
        scheme: str = "http",
 | 
						|
    ) -> None:
 | 
						|
        raise NotImplementedError(
 | 
						|
            "HTTP/2 does not support setting up a tunnel through a proxy"
 | 
						|
        )
 | 
						|
 | 
						|
    def getresponse(  # type: ignore[override]
 | 
						|
        self,
 | 
						|
    ) -> HTTP2Response:
 | 
						|
        status = None
 | 
						|
        data = bytearray()
 | 
						|
        with self._h2_conn as conn:
 | 
						|
            end_stream = False
 | 
						|
            while not end_stream:
 | 
						|
                # TODO: Arbitrary read value.
 | 
						|
                if received_data := self.sock.recv(65535):
 | 
						|
                    events = conn.receive_data(received_data)
 | 
						|
                    for event in events:
 | 
						|
                        if isinstance(event, h2.events.ResponseReceived):
 | 
						|
                            headers = HTTPHeaderDict()
 | 
						|
                            for header, value in event.headers:
 | 
						|
                                if header == b":status":
 | 
						|
                                    status = int(value.decode())
 | 
						|
                                else:
 | 
						|
                                    headers.add(
 | 
						|
                                        header.decode("ascii"), value.decode("ascii")
 | 
						|
                                    )
 | 
						|
 | 
						|
                        elif isinstance(event, h2.events.DataReceived):
 | 
						|
                            data += event.data
 | 
						|
                            conn.acknowledge_received_data(
 | 
						|
                                event.flow_controlled_length, event.stream_id
 | 
						|
                            )
 | 
						|
 | 
						|
                        elif isinstance(event, h2.events.StreamEnded):
 | 
						|
                            end_stream = True
 | 
						|
 | 
						|
                if data_to_send := conn.data_to_send():
 | 
						|
                    self.sock.sendall(data_to_send)
 | 
						|
 | 
						|
        assert status is not None
 | 
						|
        return HTTP2Response(
 | 
						|
            status=status,
 | 
						|
            headers=headers,
 | 
						|
            request_url=self._request_url,
 | 
						|
            data=bytes(data),
 | 
						|
        )
 | 
						|
 | 
						|
    def request(  # type: ignore[override]
 | 
						|
        self,
 | 
						|
        method: str,
 | 
						|
        url: str,
 | 
						|
        body: _TYPE_BODY | None = None,
 | 
						|
        headers: typing.Mapping[str, str] | None = None,
 | 
						|
        *,
 | 
						|
        preload_content: bool = True,
 | 
						|
        decode_content: bool = True,
 | 
						|
        enforce_content_length: bool = True,
 | 
						|
        **kwargs: typing.Any,
 | 
						|
    ) -> None:
 | 
						|
        """Send an HTTP/2 request"""
 | 
						|
        if "chunked" in kwargs:
 | 
						|
            # TODO this is often present from upstream.
 | 
						|
            # raise NotImplementedError("`chunked` isn't supported with HTTP/2")
 | 
						|
            pass
 | 
						|
 | 
						|
        if self.sock is not None:
 | 
						|
            self.sock.settimeout(self.timeout)
 | 
						|
 | 
						|
        self.putrequest(method, url)
 | 
						|
 | 
						|
        headers = headers or {}
 | 
						|
        for k, v in headers.items():
 | 
						|
            if k.lower() == "transfer-encoding" and v == "chunked":
 | 
						|
                continue
 | 
						|
            else:
 | 
						|
                self.putheader(k, v)
 | 
						|
 | 
						|
        if b"user-agent" not in dict(self._headers):
 | 
						|
            self.putheader(b"user-agent", _get_default_user_agent())
 | 
						|
 | 
						|
        if body:
 | 
						|
            self.endheaders(message_body=body)
 | 
						|
            self.send(body)
 | 
						|
        else:
 | 
						|
            self.endheaders()
 | 
						|
 | 
						|
    def close(self) -> None:
 | 
						|
        with self._h2_conn as conn:
 | 
						|
            try:
 | 
						|
                conn.close_connection()
 | 
						|
                if data := conn.data_to_send():
 | 
						|
                    self.sock.sendall(data)
 | 
						|
            except Exception:
 | 
						|
                pass
 | 
						|
 | 
						|
        # Reset all our HTTP/2 connection state.
 | 
						|
        self._h2_conn = self._new_h2_conn()
 | 
						|
        self._h2_stream = None
 | 
						|
        self._headers = []
 | 
						|
 | 
						|
        super().close()
 | 
						|
 | 
						|
 | 
						|
class HTTP2Response(BaseHTTPResponse):
 | 
						|
    # TODO: This is a woefully incomplete response object, but works for non-streaming.
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        status: int,
 | 
						|
        headers: HTTPHeaderDict,
 | 
						|
        request_url: str,
 | 
						|
        data: bytes,
 | 
						|
        decode_content: bool = False,  # TODO: support decoding
 | 
						|
    ) -> None:
 | 
						|
        super().__init__(
 | 
						|
            status=status,
 | 
						|
            headers=headers,
 | 
						|
            # Following CPython, we map HTTP versions to major * 10 + minor integers
 | 
						|
            version=20,
 | 
						|
            version_string="HTTP/2",
 | 
						|
            # No reason phrase in HTTP/2
 | 
						|
            reason=None,
 | 
						|
            decode_content=decode_content,
 | 
						|
            request_url=request_url,
 | 
						|
        )
 | 
						|
        self._data = data
 | 
						|
        self.length_remaining = 0
 | 
						|
 | 
						|
    @property
 | 
						|
    def data(self) -> bytes:
 | 
						|
        return self._data
 | 
						|
 | 
						|
    def get_redirect_location(self) -> None:
 | 
						|
        return None
 | 
						|
 | 
						|
    def close(self) -> None:
 | 
						|
        pass
 |