You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
160 lines
5.9 KiB
Python
160 lines
5.9 KiB
Python
from __future__ import annotations
|
|
|
|
import ssl
|
|
import typing
|
|
|
|
import trio
|
|
|
|
from .._exceptions import (
|
|
ConnectError,
|
|
ConnectTimeout,
|
|
ExceptionMapping,
|
|
ReadError,
|
|
ReadTimeout,
|
|
WriteError,
|
|
WriteTimeout,
|
|
map_exceptions,
|
|
)
|
|
from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream
|
|
|
|
|
|
class TrioStream(AsyncNetworkStream):
|
|
def __init__(self, stream: trio.abc.Stream) -> None:
|
|
self._stream = stream
|
|
|
|
async def read(self, max_bytes: int, timeout: float | None = None) -> bytes:
|
|
timeout_or_inf = float("inf") if timeout is None else timeout
|
|
exc_map: ExceptionMapping = {
|
|
trio.TooSlowError: ReadTimeout,
|
|
trio.BrokenResourceError: ReadError,
|
|
trio.ClosedResourceError: ReadError,
|
|
}
|
|
with map_exceptions(exc_map):
|
|
with trio.fail_after(timeout_or_inf):
|
|
data: bytes = await self._stream.receive_some(max_bytes=max_bytes)
|
|
return data
|
|
|
|
async def write(self, buffer: bytes, timeout: float | None = None) -> None:
|
|
if not buffer:
|
|
return
|
|
|
|
timeout_or_inf = float("inf") if timeout is None else timeout
|
|
exc_map: ExceptionMapping = {
|
|
trio.TooSlowError: WriteTimeout,
|
|
trio.BrokenResourceError: WriteError,
|
|
trio.ClosedResourceError: WriteError,
|
|
}
|
|
with map_exceptions(exc_map):
|
|
with trio.fail_after(timeout_or_inf):
|
|
await self._stream.send_all(data=buffer)
|
|
|
|
async def aclose(self) -> None:
|
|
await self._stream.aclose()
|
|
|
|
async def start_tls(
|
|
self,
|
|
ssl_context: ssl.SSLContext,
|
|
server_hostname: str | None = None,
|
|
timeout: float | None = None,
|
|
) -> AsyncNetworkStream:
|
|
timeout_or_inf = float("inf") if timeout is None else timeout
|
|
exc_map: ExceptionMapping = {
|
|
trio.TooSlowError: ConnectTimeout,
|
|
trio.BrokenResourceError: ConnectError,
|
|
}
|
|
ssl_stream = trio.SSLStream(
|
|
self._stream,
|
|
ssl_context=ssl_context,
|
|
server_hostname=server_hostname,
|
|
https_compatible=True,
|
|
server_side=False,
|
|
)
|
|
with map_exceptions(exc_map):
|
|
try:
|
|
with trio.fail_after(timeout_or_inf):
|
|
await ssl_stream.do_handshake()
|
|
except Exception as exc: # pragma: nocover
|
|
await self.aclose()
|
|
raise exc
|
|
return TrioStream(ssl_stream)
|
|
|
|
def get_extra_info(self, info: str) -> typing.Any:
|
|
if info == "ssl_object" and isinstance(self._stream, trio.SSLStream):
|
|
# Type checkers cannot see `_ssl_object` attribute because trio._ssl.SSLStream uses __getattr__/__setattr__.
|
|
# Tracked at https://github.com/python-trio/trio/issues/542
|
|
return self._stream._ssl_object # type: ignore[attr-defined]
|
|
if info == "client_addr":
|
|
return self._get_socket_stream().socket.getsockname()
|
|
if info == "server_addr":
|
|
return self._get_socket_stream().socket.getpeername()
|
|
if info == "socket":
|
|
stream = self._stream
|
|
while isinstance(stream, trio.SSLStream):
|
|
stream = stream.transport_stream
|
|
assert isinstance(stream, trio.SocketStream)
|
|
return stream.socket
|
|
if info == "is_readable":
|
|
socket = self.get_extra_info("socket")
|
|
return socket.is_readable()
|
|
return None
|
|
|
|
def _get_socket_stream(self) -> trio.SocketStream:
|
|
stream = self._stream
|
|
while isinstance(stream, trio.SSLStream):
|
|
stream = stream.transport_stream
|
|
assert isinstance(stream, trio.SocketStream)
|
|
return stream
|
|
|
|
|
|
class TrioBackend(AsyncNetworkBackend):
|
|
async def connect_tcp(
|
|
self,
|
|
host: str,
|
|
port: int,
|
|
timeout: float | None = None,
|
|
local_address: str | None = None,
|
|
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
|
|
) -> AsyncNetworkStream:
|
|
# By default for TCP sockets, trio enables TCP_NODELAY.
|
|
# https://trio.readthedocs.io/en/stable/reference-io.html#trio.SocketStream
|
|
if socket_options is None:
|
|
socket_options = [] # pragma: no cover
|
|
timeout_or_inf = float("inf") if timeout is None else timeout
|
|
exc_map: ExceptionMapping = {
|
|
trio.TooSlowError: ConnectTimeout,
|
|
trio.BrokenResourceError: ConnectError,
|
|
OSError: ConnectError,
|
|
}
|
|
with map_exceptions(exc_map):
|
|
with trio.fail_after(timeout_or_inf):
|
|
stream: trio.abc.Stream = await trio.open_tcp_stream(
|
|
host=host, port=port, local_address=local_address
|
|
)
|
|
for option in socket_options:
|
|
stream.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover
|
|
return TrioStream(stream)
|
|
|
|
async def connect_unix_socket(
|
|
self,
|
|
path: str,
|
|
timeout: float | None = None,
|
|
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
|
|
) -> AsyncNetworkStream: # pragma: nocover
|
|
if socket_options is None:
|
|
socket_options = []
|
|
timeout_or_inf = float("inf") if timeout is None else timeout
|
|
exc_map: ExceptionMapping = {
|
|
trio.TooSlowError: ConnectTimeout,
|
|
trio.BrokenResourceError: ConnectError,
|
|
OSError: ConnectError,
|
|
}
|
|
with map_exceptions(exc_map):
|
|
with trio.fail_after(timeout_or_inf):
|
|
stream: trio.abc.Stream = await trio.open_unix_socket(path)
|
|
for option in socket_options:
|
|
stream.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover
|
|
return TrioStream(stream)
|
|
|
|
async def sleep(self, seconds: float) -> None:
|
|
await trio.sleep(seconds) # pragma: nocover
|