You cannot select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
	
	
		
			301 lines
		
	
	
		
			9.6 KiB
		
	
	
	
		
			Python
		
	
			
		
		
	
	
			301 lines
		
	
	
		
			9.6 KiB
		
	
	
	
		
			Python
		
	
from __future__ import annotations
 | 
						|
 | 
						|
import io
 | 
						|
import mimetypes
 | 
						|
import os
 | 
						|
import re
 | 
						|
import typing
 | 
						|
from pathlib import Path
 | 
						|
 | 
						|
from ._types import (
 | 
						|
    AsyncByteStream,
 | 
						|
    FileContent,
 | 
						|
    FileTypes,
 | 
						|
    RequestData,
 | 
						|
    RequestFiles,
 | 
						|
    SyncByteStream,
 | 
						|
)
 | 
						|
from ._utils import (
 | 
						|
    peek_filelike_length,
 | 
						|
    primitive_value_to_str,
 | 
						|
    to_bytes,
 | 
						|
)
 | 
						|
 | 
						|
_HTML5_FORM_ENCODING_REPLACEMENTS = {'"': "%22", "\\": "\\\\"}
 | 
						|
_HTML5_FORM_ENCODING_REPLACEMENTS.update(
 | 
						|
    {chr(c): "%{:02X}".format(c) for c in range(0x1F + 1) if c != 0x1B}
 | 
						|
)
 | 
						|
_HTML5_FORM_ENCODING_RE = re.compile(
 | 
						|
    r"|".join([re.escape(c) for c in _HTML5_FORM_ENCODING_REPLACEMENTS.keys()])
 | 
						|
)
 | 
						|
 | 
						|
 | 
						|
def _format_form_param(name: str, value: str) -> bytes:
 | 
						|
    """
 | 
						|
    Encode a name/value pair within a multipart form.
 | 
						|
    """
 | 
						|
 | 
						|
    def replacer(match: typing.Match[str]) -> str:
 | 
						|
        return _HTML5_FORM_ENCODING_REPLACEMENTS[match.group(0)]
 | 
						|
 | 
						|
    value = _HTML5_FORM_ENCODING_RE.sub(replacer, value)
 | 
						|
    return f'{name}="{value}"'.encode()
 | 
						|
 | 
						|
 | 
						|
def _guess_content_type(filename: str | None) -> str | None:
 | 
						|
    """
 | 
						|
    Guesses the mimetype based on a filename. Defaults to `application/octet-stream`.
 | 
						|
 | 
						|
    Returns `None` if `filename` is `None` or empty.
 | 
						|
    """
 | 
						|
    if filename:
 | 
						|
        return mimetypes.guess_type(filename)[0] or "application/octet-stream"
 | 
						|
    return None
 | 
						|
 | 
						|
 | 
						|
def get_multipart_boundary_from_content_type(
 | 
						|
    content_type: bytes | None,
 | 
						|
) -> bytes | None:
 | 
						|
    if not content_type or not content_type.startswith(b"multipart/form-data"):
 | 
						|
        return None
 | 
						|
    # parse boundary according to
 | 
						|
    # https://www.rfc-editor.org/rfc/rfc2046#section-5.1.1
 | 
						|
    if b";" in content_type:
 | 
						|
        for section in content_type.split(b";"):
 | 
						|
            if section.strip().lower().startswith(b"boundary="):
 | 
						|
                return section.strip()[len(b"boundary=") :].strip(b'"')
 | 
						|
    return None
 | 
						|
 | 
						|
 | 
						|
class DataField:
 | 
						|
    """
 | 
						|
    A single form field item, within a multipart form field.
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(self, name: str, value: str | bytes | int | float | None) -> None:
 | 
						|
        if not isinstance(name, str):
 | 
						|
            raise TypeError(
 | 
						|
                f"Invalid type for name. Expected str, got {type(name)}: {name!r}"
 | 
						|
            )
 | 
						|
        if value is not None and not isinstance(value, (str, bytes, int, float)):
 | 
						|
            raise TypeError(
 | 
						|
                "Invalid type for value. Expected primitive type,"
 | 
						|
                f" got {type(value)}: {value!r}"
 | 
						|
            )
 | 
						|
        self.name = name
 | 
						|
        self.value: str | bytes = (
 | 
						|
            value if isinstance(value, bytes) else primitive_value_to_str(value)
 | 
						|
        )
 | 
						|
 | 
						|
    def render_headers(self) -> bytes:
 | 
						|
        if not hasattr(self, "_headers"):
 | 
						|
            name = _format_form_param("name", self.name)
 | 
						|
            self._headers = b"".join(
 | 
						|
                [b"Content-Disposition: form-data; ", name, b"\r\n\r\n"]
 | 
						|
            )
 | 
						|
 | 
						|
        return self._headers
 | 
						|
 | 
						|
    def render_data(self) -> bytes:
 | 
						|
        if not hasattr(self, "_data"):
 | 
						|
            self._data = to_bytes(self.value)
 | 
						|
 | 
						|
        return self._data
 | 
						|
 | 
						|
    def get_length(self) -> int:
 | 
						|
        headers = self.render_headers()
 | 
						|
        data = self.render_data()
 | 
						|
        return len(headers) + len(data)
 | 
						|
 | 
						|
    def render(self) -> typing.Iterator[bytes]:
 | 
						|
        yield self.render_headers()
 | 
						|
        yield self.render_data()
 | 
						|
 | 
						|
 | 
						|
class FileField:
 | 
						|
    """
 | 
						|
    A single file field item, within a multipart form field.
 | 
						|
    """
 | 
						|
 | 
						|
    CHUNK_SIZE = 64 * 1024
 | 
						|
 | 
						|
    def __init__(self, name: str, value: FileTypes) -> None:
 | 
						|
        self.name = name
 | 
						|
 | 
						|
        fileobj: FileContent
 | 
						|
 | 
						|
        headers: dict[str, str] = {}
 | 
						|
        content_type: str | None = None
 | 
						|
 | 
						|
        # This large tuple based API largely mirror's requests' API
 | 
						|
        # It would be good to think of better APIs for this that we could
 | 
						|
        # include in httpx 2.0 since variable length tuples(especially of 4 elements)
 | 
						|
        # are quite unwieldly
 | 
						|
        if isinstance(value, tuple):
 | 
						|
            if len(value) == 2:
 | 
						|
                # neither the 3rd parameter (content_type) nor the 4th (headers)
 | 
						|
                # was included
 | 
						|
                filename, fileobj = value
 | 
						|
            elif len(value) == 3:
 | 
						|
                filename, fileobj, content_type = value
 | 
						|
            else:
 | 
						|
                # all 4 parameters included
 | 
						|
                filename, fileobj, content_type, headers = value  # type: ignore
 | 
						|
        else:
 | 
						|
            filename = Path(str(getattr(value, "name", "upload"))).name
 | 
						|
            fileobj = value
 | 
						|
 | 
						|
        if content_type is None:
 | 
						|
            content_type = _guess_content_type(filename)
 | 
						|
 | 
						|
        has_content_type_header = any("content-type" in key.lower() for key in headers)
 | 
						|
        if content_type is not None and not has_content_type_header:
 | 
						|
            # note that unlike requests, we ignore the content_type provided in the 3rd
 | 
						|
            # tuple element if it is also included in the headers requests does
 | 
						|
            # the opposite (it overwrites the headerwith the 3rd tuple element)
 | 
						|
            headers["Content-Type"] = content_type
 | 
						|
 | 
						|
        if isinstance(fileobj, io.StringIO):
 | 
						|
            raise TypeError(
 | 
						|
                "Multipart file uploads require 'io.BytesIO', not 'io.StringIO'."
 | 
						|
            )
 | 
						|
        if isinstance(fileobj, io.TextIOBase):
 | 
						|
            raise TypeError(
 | 
						|
                "Multipart file uploads must be opened in binary mode, not text mode."
 | 
						|
            )
 | 
						|
 | 
						|
        self.filename = filename
 | 
						|
        self.file = fileobj
 | 
						|
        self.headers = headers
 | 
						|
 | 
						|
    def get_length(self) -> int | None:
 | 
						|
        headers = self.render_headers()
 | 
						|
 | 
						|
        if isinstance(self.file, (str, bytes)):
 | 
						|
            return len(headers) + len(to_bytes(self.file))
 | 
						|
 | 
						|
        file_length = peek_filelike_length(self.file)
 | 
						|
 | 
						|
        # If we can't determine the filesize without reading it into memory,
 | 
						|
        # then return `None` here, to indicate an unknown file length.
 | 
						|
        if file_length is None:
 | 
						|
            return None
 | 
						|
 | 
						|
        return len(headers) + file_length
 | 
						|
 | 
						|
    def render_headers(self) -> bytes:
 | 
						|
        if not hasattr(self, "_headers"):
 | 
						|
            parts = [
 | 
						|
                b"Content-Disposition: form-data; ",
 | 
						|
                _format_form_param("name", self.name),
 | 
						|
            ]
 | 
						|
            if self.filename:
 | 
						|
                filename = _format_form_param("filename", self.filename)
 | 
						|
                parts.extend([b"; ", filename])
 | 
						|
            for header_name, header_value in self.headers.items():
 | 
						|
                key, val = f"\r\n{header_name}: ".encode(), header_value.encode()
 | 
						|
                parts.extend([key, val])
 | 
						|
            parts.append(b"\r\n\r\n")
 | 
						|
            self._headers = b"".join(parts)
 | 
						|
 | 
						|
        return self._headers
 | 
						|
 | 
						|
    def render_data(self) -> typing.Iterator[bytes]:
 | 
						|
        if isinstance(self.file, (str, bytes)):
 | 
						|
            yield to_bytes(self.file)
 | 
						|
            return
 | 
						|
 | 
						|
        if hasattr(self.file, "seek"):
 | 
						|
            try:
 | 
						|
                self.file.seek(0)
 | 
						|
            except io.UnsupportedOperation:
 | 
						|
                pass
 | 
						|
 | 
						|
        chunk = self.file.read(self.CHUNK_SIZE)
 | 
						|
        while chunk:
 | 
						|
            yield to_bytes(chunk)
 | 
						|
            chunk = self.file.read(self.CHUNK_SIZE)
 | 
						|
 | 
						|
    def render(self) -> typing.Iterator[bytes]:
 | 
						|
        yield self.render_headers()
 | 
						|
        yield from self.render_data()
 | 
						|
 | 
						|
 | 
						|
class MultipartStream(SyncByteStream, AsyncByteStream):
 | 
						|
    """
 | 
						|
    Request content as streaming multipart encoded form data.
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        data: RequestData,
 | 
						|
        files: RequestFiles,
 | 
						|
        boundary: bytes | None = None,
 | 
						|
    ) -> None:
 | 
						|
        if boundary is None:
 | 
						|
            boundary = os.urandom(16).hex().encode("ascii")
 | 
						|
 | 
						|
        self.boundary = boundary
 | 
						|
        self.content_type = "multipart/form-data; boundary=%s" % boundary.decode(
 | 
						|
            "ascii"
 | 
						|
        )
 | 
						|
        self.fields = list(self._iter_fields(data, files))
 | 
						|
 | 
						|
    def _iter_fields(
 | 
						|
        self, data: RequestData, files: RequestFiles
 | 
						|
    ) -> typing.Iterator[FileField | DataField]:
 | 
						|
        for name, value in data.items():
 | 
						|
            if isinstance(value, (tuple, list)):
 | 
						|
                for item in value:
 | 
						|
                    yield DataField(name=name, value=item)
 | 
						|
            else:
 | 
						|
                yield DataField(name=name, value=value)
 | 
						|
 | 
						|
        file_items = files.items() if isinstance(files, typing.Mapping) else files
 | 
						|
        for name, value in file_items:
 | 
						|
            yield FileField(name=name, value=value)
 | 
						|
 | 
						|
    def iter_chunks(self) -> typing.Iterator[bytes]:
 | 
						|
        for field in self.fields:
 | 
						|
            yield b"--%s\r\n" % self.boundary
 | 
						|
            yield from field.render()
 | 
						|
            yield b"\r\n"
 | 
						|
        yield b"--%s--\r\n" % self.boundary
 | 
						|
 | 
						|
    def get_content_length(self) -> int | None:
 | 
						|
        """
 | 
						|
        Return the length of the multipart encoded content, or `None` if
 | 
						|
        any of the files have a length that cannot be determined upfront.
 | 
						|
        """
 | 
						|
        boundary_length = len(self.boundary)
 | 
						|
        length = 0
 | 
						|
 | 
						|
        for field in self.fields:
 | 
						|
            field_length = field.get_length()
 | 
						|
            if field_length is None:
 | 
						|
                return None
 | 
						|
 | 
						|
            length += 2 + boundary_length + 2  # b"--{boundary}\r\n"
 | 
						|
            length += field_length
 | 
						|
            length += 2  # b"\r\n"
 | 
						|
 | 
						|
        length += 2 + boundary_length + 4  # b"--{boundary}--\r\n"
 | 
						|
        return length
 | 
						|
 | 
						|
    # Content stream interface.
 | 
						|
 | 
						|
    def get_headers(self) -> dict[str, str]:
 | 
						|
        content_length = self.get_content_length()
 | 
						|
        content_type = self.content_type
 | 
						|
        if content_length is None:
 | 
						|
            return {"Transfer-Encoding": "chunked", "Content-Type": content_type}
 | 
						|
        return {"Content-Length": str(content_length), "Content-Type": content_type}
 | 
						|
 | 
						|
    def __iter__(self) -> typing.Iterator[bytes]:
 | 
						|
        for chunk in self.iter_chunks():
 | 
						|
            yield chunk
 | 
						|
 | 
						|
    async def __aiter__(self) -> typing.AsyncIterator[bytes]:
 | 
						|
        for chunk in self.iter_chunks():
 | 
						|
            yield chunk
 |