You cannot select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
	
	
		
			150 lines
		
	
	
		
			4.7 KiB
		
	
	
	
		
			Python
		
	
			
		
		
	
	
			150 lines
		
	
	
		
			4.7 KiB
		
	
	
	
		
			Python
		
	
from __future__ import annotations
 | 
						|
 | 
						|
import io
 | 
						|
import itertools
 | 
						|
import sys
 | 
						|
import typing
 | 
						|
 | 
						|
from .._models import Request, Response
 | 
						|
from .._types import SyncByteStream
 | 
						|
from .base import BaseTransport
 | 
						|
 | 
						|
if typing.TYPE_CHECKING:
 | 
						|
    from _typeshed import OptExcInfo  # pragma: no cover
 | 
						|
    from _typeshed.wsgi import WSGIApplication  # pragma: no cover
 | 
						|
 | 
						|
_T = typing.TypeVar("_T")
 | 
						|
 | 
						|
 | 
						|
__all__ = ["WSGITransport"]
 | 
						|
 | 
						|
 | 
						|
def _skip_leading_empty_chunks(body: typing.Iterable[_T]) -> typing.Iterable[_T]:
 | 
						|
    body = iter(body)
 | 
						|
    for chunk in body:
 | 
						|
        if chunk:
 | 
						|
            return itertools.chain([chunk], body)
 | 
						|
    return []
 | 
						|
 | 
						|
 | 
						|
class WSGIByteStream(SyncByteStream):
 | 
						|
    def __init__(self, result: typing.Iterable[bytes]) -> None:
 | 
						|
        self._close = getattr(result, "close", None)
 | 
						|
        self._result = _skip_leading_empty_chunks(result)
 | 
						|
 | 
						|
    def __iter__(self) -> typing.Iterator[bytes]:
 | 
						|
        for part in self._result:
 | 
						|
            yield part
 | 
						|
 | 
						|
    def close(self) -> None:
 | 
						|
        if self._close is not None:
 | 
						|
            self._close()
 | 
						|
 | 
						|
 | 
						|
class WSGITransport(BaseTransport):
 | 
						|
    """
 | 
						|
    A custom transport that handles sending requests directly to an WSGI app.
 | 
						|
    The simplest way to use this functionality is to use the `app` argument.
 | 
						|
 | 
						|
    ```
 | 
						|
    client = httpx.Client(app=app)
 | 
						|
    ```
 | 
						|
 | 
						|
    Alternatively, you can setup the transport instance explicitly.
 | 
						|
    This allows you to include any additional configuration arguments specific
 | 
						|
    to the WSGITransport class:
 | 
						|
 | 
						|
    ```
 | 
						|
    transport = httpx.WSGITransport(
 | 
						|
        app=app,
 | 
						|
        script_name="/submount",
 | 
						|
        remote_addr="1.2.3.4"
 | 
						|
    )
 | 
						|
    client = httpx.Client(transport=transport)
 | 
						|
    ```
 | 
						|
 | 
						|
    Arguments:
 | 
						|
 | 
						|
    * `app` - The WSGI application.
 | 
						|
    * `raise_app_exceptions` - Boolean indicating if exceptions in the application
 | 
						|
       should be raised. Default to `True`. Can be set to `False` for use cases
 | 
						|
       such as testing the content of a client 500 response.
 | 
						|
    * `script_name` - The root path on which the WSGI application should be mounted.
 | 
						|
    * `remote_addr` - A string indicating the client IP of incoming requests.
 | 
						|
    ```
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        app: WSGIApplication,
 | 
						|
        raise_app_exceptions: bool = True,
 | 
						|
        script_name: str = "",
 | 
						|
        remote_addr: str = "127.0.0.1",
 | 
						|
        wsgi_errors: typing.TextIO | None = None,
 | 
						|
    ) -> None:
 | 
						|
        self.app = app
 | 
						|
        self.raise_app_exceptions = raise_app_exceptions
 | 
						|
        self.script_name = script_name
 | 
						|
        self.remote_addr = remote_addr
 | 
						|
        self.wsgi_errors = wsgi_errors
 | 
						|
 | 
						|
    def handle_request(self, request: Request) -> Response:
 | 
						|
        request.read()
 | 
						|
        wsgi_input = io.BytesIO(request.content)
 | 
						|
 | 
						|
        port = request.url.port or {"http": 80, "https": 443}[request.url.scheme]
 | 
						|
        environ = {
 | 
						|
            "wsgi.version": (1, 0),
 | 
						|
            "wsgi.url_scheme": request.url.scheme,
 | 
						|
            "wsgi.input": wsgi_input,
 | 
						|
            "wsgi.errors": self.wsgi_errors or sys.stderr,
 | 
						|
            "wsgi.multithread": True,
 | 
						|
            "wsgi.multiprocess": False,
 | 
						|
            "wsgi.run_once": False,
 | 
						|
            "REQUEST_METHOD": request.method,
 | 
						|
            "SCRIPT_NAME": self.script_name,
 | 
						|
            "PATH_INFO": request.url.path,
 | 
						|
            "QUERY_STRING": request.url.query.decode("ascii"),
 | 
						|
            "SERVER_NAME": request.url.host,
 | 
						|
            "SERVER_PORT": str(port),
 | 
						|
            "SERVER_PROTOCOL": "HTTP/1.1",
 | 
						|
            "REMOTE_ADDR": self.remote_addr,
 | 
						|
        }
 | 
						|
        for header_key, header_value in request.headers.raw:
 | 
						|
            key = header_key.decode("ascii").upper().replace("-", "_")
 | 
						|
            if key not in ("CONTENT_TYPE", "CONTENT_LENGTH"):
 | 
						|
                key = "HTTP_" + key
 | 
						|
            environ[key] = header_value.decode("ascii")
 | 
						|
 | 
						|
        seen_status = None
 | 
						|
        seen_response_headers = None
 | 
						|
        seen_exc_info = None
 | 
						|
 | 
						|
        def start_response(
 | 
						|
            status: str,
 | 
						|
            response_headers: list[tuple[str, str]],
 | 
						|
            exc_info: OptExcInfo | None = None,
 | 
						|
        ) -> typing.Callable[[bytes], typing.Any]:
 | 
						|
            nonlocal seen_status, seen_response_headers, seen_exc_info
 | 
						|
            seen_status = status
 | 
						|
            seen_response_headers = response_headers
 | 
						|
            seen_exc_info = exc_info
 | 
						|
            return lambda _: None
 | 
						|
 | 
						|
        result = self.app(environ, start_response)
 | 
						|
 | 
						|
        stream = WSGIByteStream(result)
 | 
						|
 | 
						|
        assert seen_status is not None
 | 
						|
        assert seen_response_headers is not None
 | 
						|
        if seen_exc_info and seen_exc_info[0] and self.raise_app_exceptions:
 | 
						|
            raise seen_exc_info[1]
 | 
						|
 | 
						|
        status_code = int(seen_status.split()[0])
 | 
						|
        headers = [
 | 
						|
            (key.encode("ascii"), value.encode("ascii"))
 | 
						|
            for key, value in seen_response_headers
 | 
						|
        ]
 | 
						|
 | 
						|
        return Response(status_code, headers=headers, stream=stream)
 |