123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300 |
- from __future__ import annotations
- import io
- import mimetypes
- import os
- import re
- import typing
- from pathlib import Path
- from ._types import (
- AsyncByteStream,
- FileContent,
- FileTypes,
- RequestData,
- RequestFiles,
- SyncByteStream,
- )
- from ._utils import (
- peek_filelike_length,
- primitive_value_to_str,
- to_bytes,
- )
- _HTML5_FORM_ENCODING_REPLACEMENTS = {'"': "%22", "\\": "\\\\"}
- _HTML5_FORM_ENCODING_REPLACEMENTS.update(
- {chr(c): "%{:02X}".format(c) for c in range(0x1F + 1) if c != 0x1B}
- )
- _HTML5_FORM_ENCODING_RE = re.compile(
- r"|".join([re.escape(c) for c in _HTML5_FORM_ENCODING_REPLACEMENTS.keys()])
- )
- def _format_form_param(name: str, value: str) -> bytes:
- """
- Encode a name/value pair within a multipart form.
- """
- def replacer(match: typing.Match[str]) -> str:
- return _HTML5_FORM_ENCODING_REPLACEMENTS[match.group(0)]
- value = _HTML5_FORM_ENCODING_RE.sub(replacer, value)
- return f'{name}="{value}"'.encode()
- def _guess_content_type(filename: str | None) -> str | None:
- """
- Guesses the mimetype based on a filename. Defaults to `application/octet-stream`.
- Returns `None` if `filename` is `None` or empty.
- """
- if filename:
- return mimetypes.guess_type(filename)[0] or "application/octet-stream"
- return None
- def get_multipart_boundary_from_content_type(
- content_type: bytes | None,
- ) -> bytes | None:
- if not content_type or not content_type.startswith(b"multipart/form-data"):
- return None
- # parse boundary according to
- # https://www.rfc-editor.org/rfc/rfc2046#section-5.1.1
- if b";" in content_type:
- for section in content_type.split(b";"):
- if section.strip().lower().startswith(b"boundary="):
- return section.strip()[len(b"boundary=") :].strip(b'"')
- return None
- class DataField:
- """
- A single form field item, within a multipart form field.
- """
- def __init__(self, name: str, value: str | bytes | int | float | None) -> None:
- if not isinstance(name, str):
- raise TypeError(
- f"Invalid type for name. Expected str, got {type(name)}: {name!r}"
- )
- if value is not None and not isinstance(value, (str, bytes, int, float)):
- raise TypeError(
- "Invalid type for value. Expected primitive type,"
- f" got {type(value)}: {value!r}"
- )
- self.name = name
- self.value: str | bytes = (
- value if isinstance(value, bytes) else primitive_value_to_str(value)
- )
- def render_headers(self) -> bytes:
- if not hasattr(self, "_headers"):
- name = _format_form_param("name", self.name)
- self._headers = b"".join(
- [b"Content-Disposition: form-data; ", name, b"\r\n\r\n"]
- )
- return self._headers
- def render_data(self) -> bytes:
- if not hasattr(self, "_data"):
- self._data = to_bytes(self.value)
- return self._data
- def get_length(self) -> int:
- headers = self.render_headers()
- data = self.render_data()
- return len(headers) + len(data)
- def render(self) -> typing.Iterator[bytes]:
- yield self.render_headers()
- yield self.render_data()
- class FileField:
- """
- A single file field item, within a multipart form field.
- """
- CHUNK_SIZE = 64 * 1024
- def __init__(self, name: str, value: FileTypes) -> None:
- self.name = name
- fileobj: FileContent
- headers: dict[str, str] = {}
- content_type: str | None = None
- # This large tuple based API largely mirror's requests' API
- # It would be good to think of better APIs for this that we could
- # include in httpx 2.0 since variable length tuples(especially of 4 elements)
- # are quite unwieldly
- if isinstance(value, tuple):
- if len(value) == 2:
- # neither the 3rd parameter (content_type) nor the 4th (headers)
- # was included
- filename, fileobj = value
- elif len(value) == 3:
- filename, fileobj, content_type = value
- else:
- # all 4 parameters included
- filename, fileobj, content_type, headers = value # type: ignore
- else:
- filename = Path(str(getattr(value, "name", "upload"))).name
- fileobj = value
- if content_type is None:
- content_type = _guess_content_type(filename)
- has_content_type_header = any("content-type" in key.lower() for key in headers)
- if content_type is not None and not has_content_type_header:
- # note that unlike requests, we ignore the content_type provided in the 3rd
- # tuple element if it is also included in the headers requests does
- # the opposite (it overwrites the headerwith the 3rd tuple element)
- headers["Content-Type"] = content_type
- if isinstance(fileobj, io.StringIO):
- raise TypeError(
- "Multipart file uploads require 'io.BytesIO', not 'io.StringIO'."
- )
- if isinstance(fileobj, io.TextIOBase):
- raise TypeError(
- "Multipart file uploads must be opened in binary mode, not text mode."
- )
- self.filename = filename
- self.file = fileobj
- self.headers = headers
- def get_length(self) -> int | None:
- headers = self.render_headers()
- if isinstance(self.file, (str, bytes)):
- return len(headers) + len(to_bytes(self.file))
- file_length = peek_filelike_length(self.file)
- # If we can't determine the filesize without reading it into memory,
- # then return `None` here, to indicate an unknown file length.
- if file_length is None:
- return None
- return len(headers) + file_length
- def render_headers(self) -> bytes:
- if not hasattr(self, "_headers"):
- parts = [
- b"Content-Disposition: form-data; ",
- _format_form_param("name", self.name),
- ]
- if self.filename:
- filename = _format_form_param("filename", self.filename)
- parts.extend([b"; ", filename])
- for header_name, header_value in self.headers.items():
- key, val = f"\r\n{header_name}: ".encode(), header_value.encode()
- parts.extend([key, val])
- parts.append(b"\r\n\r\n")
- self._headers = b"".join(parts)
- return self._headers
- def render_data(self) -> typing.Iterator[bytes]:
- if isinstance(self.file, (str, bytes)):
- yield to_bytes(self.file)
- return
- if hasattr(self.file, "seek"):
- try:
- self.file.seek(0)
- except io.UnsupportedOperation:
- pass
- chunk = self.file.read(self.CHUNK_SIZE)
- while chunk:
- yield to_bytes(chunk)
- chunk = self.file.read(self.CHUNK_SIZE)
- def render(self) -> typing.Iterator[bytes]:
- yield self.render_headers()
- yield from self.render_data()
- class MultipartStream(SyncByteStream, AsyncByteStream):
- """
- Request content as streaming multipart encoded form data.
- """
- def __init__(
- self,
- data: RequestData,
- files: RequestFiles,
- boundary: bytes | None = None,
- ) -> None:
- if boundary is None:
- boundary = os.urandom(16).hex().encode("ascii")
- self.boundary = boundary
- self.content_type = "multipart/form-data; boundary=%s" % boundary.decode(
- "ascii"
- )
- self.fields = list(self._iter_fields(data, files))
- def _iter_fields(
- self, data: RequestData, files: RequestFiles
- ) -> typing.Iterator[FileField | DataField]:
- for name, value in data.items():
- if isinstance(value, (tuple, list)):
- for item in value:
- yield DataField(name=name, value=item)
- else:
- yield DataField(name=name, value=value)
- file_items = files.items() if isinstance(files, typing.Mapping) else files
- for name, value in file_items:
- yield FileField(name=name, value=value)
- def iter_chunks(self) -> typing.Iterator[bytes]:
- for field in self.fields:
- yield b"--%s\r\n" % self.boundary
- yield from field.render()
- yield b"\r\n"
- yield b"--%s--\r\n" % self.boundary
- def get_content_length(self) -> int | None:
- """
- Return the length of the multipart encoded content, or `None` if
- any of the files have a length that cannot be determined upfront.
- """
- boundary_length = len(self.boundary)
- length = 0
- for field in self.fields:
- field_length = field.get_length()
- if field_length is None:
- return None
- length += 2 + boundary_length + 2 # b"--{boundary}\r\n"
- length += field_length
- length += 2 # b"\r\n"
- length += 2 + boundary_length + 4 # b"--{boundary}--\r\n"
- return length
- # Content stream interface.
- def get_headers(self) -> dict[str, str]:
- content_length = self.get_content_length()
- content_type = self.content_type
- if content_length is None:
- return {"Transfer-Encoding": "chunked", "Content-Type": content_type}
- return {"Content-Length": str(content_length), "Content-Type": content_type}
- def __iter__(self) -> typing.Iterator[bytes]:
- for chunk in self.iter_chunks():
- yield chunk
- async def __aiter__(self) -> typing.AsyncIterator[bytes]:
- for chunk in self.iter_chunks():
- yield chunk
|