You cannot select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
	
	
		
			188 lines
		
	
	
		
			5.4 KiB
		
	
	
	
		
			Python
		
	
			
		
		
	
	
			188 lines
		
	
	
		
			5.4 KiB
		
	
	
	
		
			Python
		
	
from __future__ import annotations
 | 
						|
 | 
						|
import typing
 | 
						|
 | 
						|
from .._models import Request, Response
 | 
						|
from .._types import AsyncByteStream
 | 
						|
from .base import AsyncBaseTransport
 | 
						|
 | 
						|
if typing.TYPE_CHECKING:  # pragma: no cover
 | 
						|
    import asyncio
 | 
						|
 | 
						|
    import trio
 | 
						|
 | 
						|
    Event = typing.Union[asyncio.Event, trio.Event]
 | 
						|
 | 
						|
 | 
						|
_Message = typing.MutableMapping[str, typing.Any]
 | 
						|
_Receive = typing.Callable[[], typing.Awaitable[_Message]]
 | 
						|
_Send = typing.Callable[
 | 
						|
    [typing.MutableMapping[str, typing.Any]], typing.Awaitable[None]
 | 
						|
]
 | 
						|
_ASGIApp = typing.Callable[
 | 
						|
    [typing.MutableMapping[str, typing.Any], _Receive, _Send], typing.Awaitable[None]
 | 
						|
]
 | 
						|
 | 
						|
__all__ = ["ASGITransport"]
 | 
						|
 | 
						|
 | 
						|
def is_running_trio() -> bool:
 | 
						|
    try:
 | 
						|
        # sniffio is a dependency of trio.
 | 
						|
 | 
						|
        # See https://github.com/python-trio/trio/issues/2802
 | 
						|
        import sniffio
 | 
						|
 | 
						|
        if sniffio.current_async_library() == "trio":
 | 
						|
            return True
 | 
						|
    except ImportError:  # pragma: nocover
 | 
						|
        pass
 | 
						|
 | 
						|
    return False
 | 
						|
 | 
						|
 | 
						|
def create_event() -> Event:
 | 
						|
    if is_running_trio():
 | 
						|
        import trio
 | 
						|
 | 
						|
        return trio.Event()
 | 
						|
 | 
						|
    import asyncio
 | 
						|
 | 
						|
    return asyncio.Event()
 | 
						|
 | 
						|
 | 
						|
class ASGIResponseStream(AsyncByteStream):
 | 
						|
    def __init__(self, body: list[bytes]) -> None:
 | 
						|
        self._body = body
 | 
						|
 | 
						|
    async def __aiter__(self) -> typing.AsyncIterator[bytes]:
 | 
						|
        yield b"".join(self._body)
 | 
						|
 | 
						|
 | 
						|
class ASGITransport(AsyncBaseTransport):
 | 
						|
    """
 | 
						|
    A custom AsyncTransport that handles sending requests directly to an ASGI app.
 | 
						|
 | 
						|
    ```python
 | 
						|
    transport = httpx.ASGITransport(
 | 
						|
        app=app,
 | 
						|
        root_path="/submount",
 | 
						|
        client=("1.2.3.4", 123)
 | 
						|
    )
 | 
						|
    client = httpx.AsyncClient(transport=transport)
 | 
						|
    ```
 | 
						|
 | 
						|
    Arguments:
 | 
						|
 | 
						|
    * `app` - The ASGI application.
 | 
						|
    * `raise_app_exceptions` - Boolean indicating if exceptions in the application
 | 
						|
       should be raised. Default to `True`. Can be set to `False` for use cases
 | 
						|
       such as testing the content of a client 500 response.
 | 
						|
    * `root_path` - The root path on which the ASGI application should be mounted.
 | 
						|
    * `client` - A two-tuple indicating the client IP and port of incoming requests.
 | 
						|
    ```
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        app: _ASGIApp,
 | 
						|
        raise_app_exceptions: bool = True,
 | 
						|
        root_path: str = "",
 | 
						|
        client: tuple[str, int] = ("127.0.0.1", 123),
 | 
						|
    ) -> None:
 | 
						|
        self.app = app
 | 
						|
        self.raise_app_exceptions = raise_app_exceptions
 | 
						|
        self.root_path = root_path
 | 
						|
        self.client = client
 | 
						|
 | 
						|
    async def handle_async_request(
 | 
						|
        self,
 | 
						|
        request: Request,
 | 
						|
    ) -> Response:
 | 
						|
        assert isinstance(request.stream, AsyncByteStream)
 | 
						|
 | 
						|
        # ASGI scope.
 | 
						|
        scope = {
 | 
						|
            "type": "http",
 | 
						|
            "asgi": {"version": "3.0"},
 | 
						|
            "http_version": "1.1",
 | 
						|
            "method": request.method,
 | 
						|
            "headers": [(k.lower(), v) for (k, v) in request.headers.raw],
 | 
						|
            "scheme": request.url.scheme,
 | 
						|
            "path": request.url.path,
 | 
						|
            "raw_path": request.url.raw_path.split(b"?")[0],
 | 
						|
            "query_string": request.url.query,
 | 
						|
            "server": (request.url.host, request.url.port),
 | 
						|
            "client": self.client,
 | 
						|
            "root_path": self.root_path,
 | 
						|
        }
 | 
						|
 | 
						|
        # Request.
 | 
						|
        request_body_chunks = request.stream.__aiter__()
 | 
						|
        request_complete = False
 | 
						|
 | 
						|
        # Response.
 | 
						|
        status_code = None
 | 
						|
        response_headers = None
 | 
						|
        body_parts = []
 | 
						|
        response_started = False
 | 
						|
        response_complete = create_event()
 | 
						|
 | 
						|
        # ASGI callables.
 | 
						|
 | 
						|
        async def receive() -> dict[str, typing.Any]:
 | 
						|
            nonlocal request_complete
 | 
						|
 | 
						|
            if request_complete:
 | 
						|
                await response_complete.wait()
 | 
						|
                return {"type": "http.disconnect"}
 | 
						|
 | 
						|
            try:
 | 
						|
                body = await request_body_chunks.__anext__()
 | 
						|
            except StopAsyncIteration:
 | 
						|
                request_complete = True
 | 
						|
                return {"type": "http.request", "body": b"", "more_body": False}
 | 
						|
            return {"type": "http.request", "body": body, "more_body": True}
 | 
						|
 | 
						|
        async def send(message: typing.MutableMapping[str, typing.Any]) -> None:
 | 
						|
            nonlocal status_code, response_headers, response_started
 | 
						|
 | 
						|
            if message["type"] == "http.response.start":
 | 
						|
                assert not response_started
 | 
						|
 | 
						|
                status_code = message["status"]
 | 
						|
                response_headers = message.get("headers", [])
 | 
						|
                response_started = True
 | 
						|
 | 
						|
            elif message["type"] == "http.response.body":
 | 
						|
                assert not response_complete.is_set()
 | 
						|
                body = message.get("body", b"")
 | 
						|
                more_body = message.get("more_body", False)
 | 
						|
 | 
						|
                if body and request.method != "HEAD":
 | 
						|
                    body_parts.append(body)
 | 
						|
 | 
						|
                if not more_body:
 | 
						|
                    response_complete.set()
 | 
						|
 | 
						|
        try:
 | 
						|
            await self.app(scope, receive, send)
 | 
						|
        except Exception:  # noqa: PIE-786
 | 
						|
            if self.raise_app_exceptions:
 | 
						|
                raise
 | 
						|
 | 
						|
            response_complete.set()
 | 
						|
            if status_code is None:
 | 
						|
                status_code = 500
 | 
						|
            if response_headers is None:
 | 
						|
                response_headers = {}
 | 
						|
 | 
						|
        assert response_complete.is_set()
 | 
						|
        assert status_code is not None
 | 
						|
        assert response_headers is not None
 | 
						|
 | 
						|
        stream = ASGIResponseStream(body_parts)
 | 
						|
 | 
						|
        return Response(status_code, headers=response_headers, stream=stream)
 |