You cannot select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
	
	
		
			272 lines
		
	
	
		
			8.6 KiB
		
	
	
	
		
			Python
		
	
			
		
		
	
	
			272 lines
		
	
	
		
			8.6 KiB
		
	
	
	
		
			Python
		
	
from __future__ import annotations
 | 
						|
 | 
						|
import io
 | 
						|
import socket
 | 
						|
import ssl
 | 
						|
import typing
 | 
						|
 | 
						|
from ..exceptions import ProxySchemeUnsupported
 | 
						|
 | 
						|
if typing.TYPE_CHECKING:
 | 
						|
    from typing_extensions import Self
 | 
						|
 | 
						|
    from .ssl_ import _TYPE_PEER_CERT_RET, _TYPE_PEER_CERT_RET_DICT
 | 
						|
 | 
						|
 | 
						|
_WriteBuffer = typing.Union[bytearray, memoryview]
 | 
						|
_ReturnValue = typing.TypeVar("_ReturnValue")
 | 
						|
 | 
						|
SSL_BLOCKSIZE = 16384
 | 
						|
 | 
						|
 | 
						|
class SSLTransport:
 | 
						|
    """
 | 
						|
    The SSLTransport wraps an existing socket and establishes an SSL connection.
 | 
						|
 | 
						|
    Contrary to Python's implementation of SSLSocket, it allows you to chain
 | 
						|
    multiple TLS connections together. It's particularly useful if you need to
 | 
						|
    implement TLS within TLS.
 | 
						|
 | 
						|
    The class supports most of the socket API operations.
 | 
						|
    """
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def _validate_ssl_context_for_tls_in_tls(ssl_context: ssl.SSLContext) -> None:
 | 
						|
        """
 | 
						|
        Raises a ProxySchemeUnsupported if the provided ssl_context can't be used
 | 
						|
        for TLS in TLS.
 | 
						|
 | 
						|
        The only requirement is that the ssl_context provides the 'wrap_bio'
 | 
						|
        methods.
 | 
						|
        """
 | 
						|
 | 
						|
        if not hasattr(ssl_context, "wrap_bio"):
 | 
						|
            raise ProxySchemeUnsupported(
 | 
						|
                "TLS in TLS requires SSLContext.wrap_bio() which isn't "
 | 
						|
                "available on non-native SSLContext"
 | 
						|
            )
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        socket: socket.socket,
 | 
						|
        ssl_context: ssl.SSLContext,
 | 
						|
        server_hostname: str | None = None,
 | 
						|
        suppress_ragged_eofs: bool = True,
 | 
						|
    ) -> None:
 | 
						|
        """
 | 
						|
        Create an SSLTransport around socket using the provided ssl_context.
 | 
						|
        """
 | 
						|
        self.incoming = ssl.MemoryBIO()
 | 
						|
        self.outgoing = ssl.MemoryBIO()
 | 
						|
 | 
						|
        self.suppress_ragged_eofs = suppress_ragged_eofs
 | 
						|
        self.socket = socket
 | 
						|
 | 
						|
        self.sslobj = ssl_context.wrap_bio(
 | 
						|
            self.incoming, self.outgoing, server_hostname=server_hostname
 | 
						|
        )
 | 
						|
 | 
						|
        # Perform initial handshake.
 | 
						|
        self._ssl_io_loop(self.sslobj.do_handshake)
 | 
						|
 | 
						|
    def __enter__(self) -> Self:
 | 
						|
        return self
 | 
						|
 | 
						|
    def __exit__(self, *_: typing.Any) -> None:
 | 
						|
        self.close()
 | 
						|
 | 
						|
    def fileno(self) -> int:
 | 
						|
        return self.socket.fileno()
 | 
						|
 | 
						|
    def read(self, len: int = 1024, buffer: typing.Any | None = None) -> int | bytes:
 | 
						|
        return self._wrap_ssl_read(len, buffer)
 | 
						|
 | 
						|
    def recv(self, buflen: int = 1024, flags: int = 0) -> int | bytes:
 | 
						|
        if flags != 0:
 | 
						|
            raise ValueError("non-zero flags not allowed in calls to recv")
 | 
						|
        return self._wrap_ssl_read(buflen)
 | 
						|
 | 
						|
    def recv_into(
 | 
						|
        self,
 | 
						|
        buffer: _WriteBuffer,
 | 
						|
        nbytes: int | None = None,
 | 
						|
        flags: int = 0,
 | 
						|
    ) -> None | int | bytes:
 | 
						|
        if flags != 0:
 | 
						|
            raise ValueError("non-zero flags not allowed in calls to recv_into")
 | 
						|
        if nbytes is None:
 | 
						|
            nbytes = len(buffer)
 | 
						|
        return self.read(nbytes, buffer)
 | 
						|
 | 
						|
    def sendall(self, data: bytes, flags: int = 0) -> None:
 | 
						|
        if flags != 0:
 | 
						|
            raise ValueError("non-zero flags not allowed in calls to sendall")
 | 
						|
        count = 0
 | 
						|
        with memoryview(data) as view, view.cast("B") as byte_view:
 | 
						|
            amount = len(byte_view)
 | 
						|
            while count < amount:
 | 
						|
                v = self.send(byte_view[count:])
 | 
						|
                count += v
 | 
						|
 | 
						|
    def send(self, data: bytes, flags: int = 0) -> int:
 | 
						|
        if flags != 0:
 | 
						|
            raise ValueError("non-zero flags not allowed in calls to send")
 | 
						|
        return self._ssl_io_loop(self.sslobj.write, data)
 | 
						|
 | 
						|
    def makefile(
 | 
						|
        self,
 | 
						|
        mode: str,
 | 
						|
        buffering: int | None = None,
 | 
						|
        *,
 | 
						|
        encoding: str | None = None,
 | 
						|
        errors: str | None = None,
 | 
						|
        newline: str | None = None,
 | 
						|
    ) -> typing.BinaryIO | typing.TextIO | socket.SocketIO:
 | 
						|
        """
 | 
						|
        Python's httpclient uses makefile and buffered io when reading HTTP
 | 
						|
        messages and we need to support it.
 | 
						|
 | 
						|
        This is unfortunately a copy and paste of socket.py makefile with small
 | 
						|
        changes to point to the socket directly.
 | 
						|
        """
 | 
						|
        if not set(mode) <= {"r", "w", "b"}:
 | 
						|
            raise ValueError(f"invalid mode {mode!r} (only r, w, b allowed)")
 | 
						|
 | 
						|
        writing = "w" in mode
 | 
						|
        reading = "r" in mode or not writing
 | 
						|
        assert reading or writing
 | 
						|
        binary = "b" in mode
 | 
						|
        rawmode = ""
 | 
						|
        if reading:
 | 
						|
            rawmode += "r"
 | 
						|
        if writing:
 | 
						|
            rawmode += "w"
 | 
						|
        raw = socket.SocketIO(self, rawmode)  # type: ignore[arg-type]
 | 
						|
        self.socket._io_refs += 1  # type: ignore[attr-defined]
 | 
						|
        if buffering is None:
 | 
						|
            buffering = -1
 | 
						|
        if buffering < 0:
 | 
						|
            buffering = io.DEFAULT_BUFFER_SIZE
 | 
						|
        if buffering == 0:
 | 
						|
            if not binary:
 | 
						|
                raise ValueError("unbuffered streams must be binary")
 | 
						|
            return raw
 | 
						|
        buffer: typing.BinaryIO
 | 
						|
        if reading and writing:
 | 
						|
            buffer = io.BufferedRWPair(raw, raw, buffering)  # type: ignore[assignment]
 | 
						|
        elif reading:
 | 
						|
            buffer = io.BufferedReader(raw, buffering)
 | 
						|
        else:
 | 
						|
            assert writing
 | 
						|
            buffer = io.BufferedWriter(raw, buffering)
 | 
						|
        if binary:
 | 
						|
            return buffer
 | 
						|
        text = io.TextIOWrapper(buffer, encoding, errors, newline)
 | 
						|
        text.mode = mode  # type: ignore[misc]
 | 
						|
        return text
 | 
						|
 | 
						|
    def unwrap(self) -> None:
 | 
						|
        self._ssl_io_loop(self.sslobj.unwrap)
 | 
						|
 | 
						|
    def close(self) -> None:
 | 
						|
        self.socket.close()
 | 
						|
 | 
						|
    @typing.overload
 | 
						|
    def getpeercert(
 | 
						|
        self, binary_form: typing.Literal[False] = ...
 | 
						|
    ) -> _TYPE_PEER_CERT_RET_DICT | None: ...
 | 
						|
 | 
						|
    @typing.overload
 | 
						|
    def getpeercert(self, binary_form: typing.Literal[True]) -> bytes | None: ...
 | 
						|
 | 
						|
    def getpeercert(self, binary_form: bool = False) -> _TYPE_PEER_CERT_RET:
 | 
						|
        return self.sslobj.getpeercert(binary_form)  # type: ignore[return-value]
 | 
						|
 | 
						|
    def version(self) -> str | None:
 | 
						|
        return self.sslobj.version()
 | 
						|
 | 
						|
    def cipher(self) -> tuple[str, str, int] | None:
 | 
						|
        return self.sslobj.cipher()
 | 
						|
 | 
						|
    def selected_alpn_protocol(self) -> str | None:
 | 
						|
        return self.sslobj.selected_alpn_protocol()
 | 
						|
 | 
						|
    def shared_ciphers(self) -> list[tuple[str, str, int]] | None:
 | 
						|
        return self.sslobj.shared_ciphers()
 | 
						|
 | 
						|
    def compression(self) -> str | None:
 | 
						|
        return self.sslobj.compression()
 | 
						|
 | 
						|
    def settimeout(self, value: float | None) -> None:
 | 
						|
        self.socket.settimeout(value)
 | 
						|
 | 
						|
    def gettimeout(self) -> float | None:
 | 
						|
        return self.socket.gettimeout()
 | 
						|
 | 
						|
    def _decref_socketios(self) -> None:
 | 
						|
        self.socket._decref_socketios()  # type: ignore[attr-defined]
 | 
						|
 | 
						|
    def _wrap_ssl_read(self, len: int, buffer: bytearray | None = None) -> int | bytes:
 | 
						|
        try:
 | 
						|
            return self._ssl_io_loop(self.sslobj.read, len, buffer)
 | 
						|
        except ssl.SSLError as e:
 | 
						|
            if e.errno == ssl.SSL_ERROR_EOF and self.suppress_ragged_eofs:
 | 
						|
                return 0  # eof, return 0.
 | 
						|
            else:
 | 
						|
                raise
 | 
						|
 | 
						|
    # func is sslobj.do_handshake or sslobj.unwrap
 | 
						|
    @typing.overload
 | 
						|
    def _ssl_io_loop(self, func: typing.Callable[[], None]) -> None: ...
 | 
						|
 | 
						|
    # func is sslobj.write, arg1 is data
 | 
						|
    @typing.overload
 | 
						|
    def _ssl_io_loop(self, func: typing.Callable[[bytes], int], arg1: bytes) -> int: ...
 | 
						|
 | 
						|
    # func is sslobj.read, arg1 is len, arg2 is buffer
 | 
						|
    @typing.overload
 | 
						|
    def _ssl_io_loop(
 | 
						|
        self,
 | 
						|
        func: typing.Callable[[int, bytearray | None], bytes],
 | 
						|
        arg1: int,
 | 
						|
        arg2: bytearray | None,
 | 
						|
    ) -> bytes: ...
 | 
						|
 | 
						|
    def _ssl_io_loop(
 | 
						|
        self,
 | 
						|
        func: typing.Callable[..., _ReturnValue],
 | 
						|
        arg1: None | bytes | int = None,
 | 
						|
        arg2: bytearray | None = None,
 | 
						|
    ) -> _ReturnValue:
 | 
						|
        """Performs an I/O loop between incoming/outgoing and the socket."""
 | 
						|
        should_loop = True
 | 
						|
        ret = None
 | 
						|
 | 
						|
        while should_loop:
 | 
						|
            errno = None
 | 
						|
            try:
 | 
						|
                if arg1 is None and arg2 is None:
 | 
						|
                    ret = func()
 | 
						|
                elif arg2 is None:
 | 
						|
                    ret = func(arg1)
 | 
						|
                else:
 | 
						|
                    ret = func(arg1, arg2)
 | 
						|
            except ssl.SSLError as e:
 | 
						|
                if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE):
 | 
						|
                    # WANT_READ, and WANT_WRITE are expected, others are not.
 | 
						|
                    raise e
 | 
						|
                errno = e.errno
 | 
						|
 | 
						|
            buf = self.outgoing.read()
 | 
						|
            self.socket.sendall(buf)
 | 
						|
 | 
						|
            if errno is None:
 | 
						|
                should_loop = False
 | 
						|
            elif errno == ssl.SSL_ERROR_WANT_READ:
 | 
						|
                buf = self.socket.recv(SSL_BLOCKSIZE)
 | 
						|
                if buf:
 | 
						|
                    self.incoming.write(buf)
 | 
						|
                else:
 | 
						|
                    self.incoming.write_eof()
 | 
						|
        return typing.cast(_ReturnValue, ret)
 |