_multipart.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. from __future__ import annotations
  2. import io
  3. import mimetypes
  4. import os
  5. import re
  6. import typing
  7. from pathlib import Path
  8. from ._types import (
  9. AsyncByteStream,
  10. FileContent,
  11. FileTypes,
  12. RequestData,
  13. RequestFiles,
  14. SyncByteStream,
  15. )
  16. from ._utils import (
  17. peek_filelike_length,
  18. primitive_value_to_str,
  19. to_bytes,
  20. )
  21. _HTML5_FORM_ENCODING_REPLACEMENTS = {'"': "%22", "\\": "\\\\"}
  22. _HTML5_FORM_ENCODING_REPLACEMENTS.update(
  23. {chr(c): "%{:02X}".format(c) for c in range(0x1F + 1) if c != 0x1B}
  24. )
  25. _HTML5_FORM_ENCODING_RE = re.compile(
  26. r"|".join([re.escape(c) for c in _HTML5_FORM_ENCODING_REPLACEMENTS.keys()])
  27. )
  28. def _format_form_param(name: str, value: str) -> bytes:
  29. """
  30. Encode a name/value pair within a multipart form.
  31. """
  32. def replacer(match: typing.Match[str]) -> str:
  33. return _HTML5_FORM_ENCODING_REPLACEMENTS[match.group(0)]
  34. value = _HTML5_FORM_ENCODING_RE.sub(replacer, value)
  35. return f'{name}="{value}"'.encode()
  36. def _guess_content_type(filename: str | None) -> str | None:
  37. """
  38. Guesses the mimetype based on a filename. Defaults to `application/octet-stream`.
  39. Returns `None` if `filename` is `None` or empty.
  40. """
  41. if filename:
  42. return mimetypes.guess_type(filename)[0] or "application/octet-stream"
  43. return None
  44. def get_multipart_boundary_from_content_type(
  45. content_type: bytes | None,
  46. ) -> bytes | None:
  47. if not content_type or not content_type.startswith(b"multipart/form-data"):
  48. return None
  49. # parse boundary according to
  50. # https://www.rfc-editor.org/rfc/rfc2046#section-5.1.1
  51. if b";" in content_type:
  52. for section in content_type.split(b";"):
  53. if section.strip().lower().startswith(b"boundary="):
  54. return section.strip()[len(b"boundary=") :].strip(b'"')
  55. return None
  56. class DataField:
  57. """
  58. A single form field item, within a multipart form field.
  59. """
  60. def __init__(self, name: str, value: str | bytes | int | float | None) -> None:
  61. if not isinstance(name, str):
  62. raise TypeError(
  63. f"Invalid type for name. Expected str, got {type(name)}: {name!r}"
  64. )
  65. if value is not None and not isinstance(value, (str, bytes, int, float)):
  66. raise TypeError(
  67. "Invalid type for value. Expected primitive type,"
  68. f" got {type(value)}: {value!r}"
  69. )
  70. self.name = name
  71. self.value: str | bytes = (
  72. value if isinstance(value, bytes) else primitive_value_to_str(value)
  73. )
  74. def render_headers(self) -> bytes:
  75. if not hasattr(self, "_headers"):
  76. name = _format_form_param("name", self.name)
  77. self._headers = b"".join(
  78. [b"Content-Disposition: form-data; ", name, b"\r\n\r\n"]
  79. )
  80. return self._headers
  81. def render_data(self) -> bytes:
  82. if not hasattr(self, "_data"):
  83. self._data = to_bytes(self.value)
  84. return self._data
  85. def get_length(self) -> int:
  86. headers = self.render_headers()
  87. data = self.render_data()
  88. return len(headers) + len(data)
  89. def render(self) -> typing.Iterator[bytes]:
  90. yield self.render_headers()
  91. yield self.render_data()
  92. class FileField:
  93. """
  94. A single file field item, within a multipart form field.
  95. """
  96. CHUNK_SIZE = 64 * 1024
  97. def __init__(self, name: str, value: FileTypes) -> None:
  98. self.name = name
  99. fileobj: FileContent
  100. headers: dict[str, str] = {}
  101. content_type: str | None = None
  102. # This large tuple based API largely mirror's requests' API
  103. # It would be good to think of better APIs for this that we could
  104. # include in httpx 2.0 since variable length tuples(especially of 4 elements)
  105. # are quite unwieldly
  106. if isinstance(value, tuple):
  107. if len(value) == 2:
  108. # neither the 3rd parameter (content_type) nor the 4th (headers)
  109. # was included
  110. filename, fileobj = value
  111. elif len(value) == 3:
  112. filename, fileobj, content_type = value
  113. else:
  114. # all 4 parameters included
  115. filename, fileobj, content_type, headers = value # type: ignore
  116. else:
  117. filename = Path(str(getattr(value, "name", "upload"))).name
  118. fileobj = value
  119. if content_type is None:
  120. content_type = _guess_content_type(filename)
  121. has_content_type_header = any("content-type" in key.lower() for key in headers)
  122. if content_type is not None and not has_content_type_header:
  123. # note that unlike requests, we ignore the content_type provided in the 3rd
  124. # tuple element if it is also included in the headers requests does
  125. # the opposite (it overwrites the headerwith the 3rd tuple element)
  126. headers["Content-Type"] = content_type
  127. if isinstance(fileobj, io.StringIO):
  128. raise TypeError(
  129. "Multipart file uploads require 'io.BytesIO', not 'io.StringIO'."
  130. )
  131. if isinstance(fileobj, io.TextIOBase):
  132. raise TypeError(
  133. "Multipart file uploads must be opened in binary mode, not text mode."
  134. )
  135. self.filename = filename
  136. self.file = fileobj
  137. self.headers = headers
  138. def get_length(self) -> int | None:
  139. headers = self.render_headers()
  140. if isinstance(self.file, (str, bytes)):
  141. return len(headers) + len(to_bytes(self.file))
  142. file_length = peek_filelike_length(self.file)
  143. # If we can't determine the filesize without reading it into memory,
  144. # then return `None` here, to indicate an unknown file length.
  145. if file_length is None:
  146. return None
  147. return len(headers) + file_length
  148. def render_headers(self) -> bytes:
  149. if not hasattr(self, "_headers"):
  150. parts = [
  151. b"Content-Disposition: form-data; ",
  152. _format_form_param("name", self.name),
  153. ]
  154. if self.filename:
  155. filename = _format_form_param("filename", self.filename)
  156. parts.extend([b"; ", filename])
  157. for header_name, header_value in self.headers.items():
  158. key, val = f"\r\n{header_name}: ".encode(), header_value.encode()
  159. parts.extend([key, val])
  160. parts.append(b"\r\n\r\n")
  161. self._headers = b"".join(parts)
  162. return self._headers
  163. def render_data(self) -> typing.Iterator[bytes]:
  164. if isinstance(self.file, (str, bytes)):
  165. yield to_bytes(self.file)
  166. return
  167. if hasattr(self.file, "seek"):
  168. try:
  169. self.file.seek(0)
  170. except io.UnsupportedOperation:
  171. pass
  172. chunk = self.file.read(self.CHUNK_SIZE)
  173. while chunk:
  174. yield to_bytes(chunk)
  175. chunk = self.file.read(self.CHUNK_SIZE)
  176. def render(self) -> typing.Iterator[bytes]:
  177. yield self.render_headers()
  178. yield from self.render_data()
  179. class MultipartStream(SyncByteStream, AsyncByteStream):
  180. """
  181. Request content as streaming multipart encoded form data.
  182. """
  183. def __init__(
  184. self,
  185. data: RequestData,
  186. files: RequestFiles,
  187. boundary: bytes | None = None,
  188. ) -> None:
  189. if boundary is None:
  190. boundary = os.urandom(16).hex().encode("ascii")
  191. self.boundary = boundary
  192. self.content_type = "multipart/form-data; boundary=%s" % boundary.decode(
  193. "ascii"
  194. )
  195. self.fields = list(self._iter_fields(data, files))
  196. def _iter_fields(
  197. self, data: RequestData, files: RequestFiles
  198. ) -> typing.Iterator[FileField | DataField]:
  199. for name, value in data.items():
  200. if isinstance(value, (tuple, list)):
  201. for item in value:
  202. yield DataField(name=name, value=item)
  203. else:
  204. yield DataField(name=name, value=value)
  205. file_items = files.items() if isinstance(files, typing.Mapping) else files
  206. for name, value in file_items:
  207. yield FileField(name=name, value=value)
  208. def iter_chunks(self) -> typing.Iterator[bytes]:
  209. for field in self.fields:
  210. yield b"--%s\r\n" % self.boundary
  211. yield from field.render()
  212. yield b"\r\n"
  213. yield b"--%s--\r\n" % self.boundary
  214. def get_content_length(self) -> int | None:
  215. """
  216. Return the length of the multipart encoded content, or `None` if
  217. any of the files have a length that cannot be determined upfront.
  218. """
  219. boundary_length = len(self.boundary)
  220. length = 0
  221. for field in self.fields:
  222. field_length = field.get_length()
  223. if field_length is None:
  224. return None
  225. length += 2 + boundary_length + 2 # b"--{boundary}\r\n"
  226. length += field_length
  227. length += 2 # b"\r\n"
  228. length += 2 + boundary_length + 4 # b"--{boundary}--\r\n"
  229. return length
  230. # Content stream interface.
  231. def get_headers(self) -> dict[str, str]:
  232. content_length = self.get_content_length()
  233. content_type = self.content_type
  234. if content_length is None:
  235. return {"Transfer-Encoding": "chunked", "Content-Type": content_type}
  236. return {"Content-Length": str(content_length), "Content-Type": content_type}
  237. def __iter__(self) -> typing.Iterator[bytes]:
  238. for chunk in self.iter_chunks():
  239. yield chunk
  240. async def __aiter__(self) -> typing.AsyncIterator[bytes]:
  241. for chunk in self.iter_chunks():
  242. yield chunk