payload.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519
  1. import asyncio
  2. import enum
  3. import io
  4. import json
  5. import mimetypes
  6. import os
  7. import sys
  8. import warnings
  9. from abc import ABC, abstractmethod
  10. from itertools import chain
  11. from typing import (
  12. IO,
  13. TYPE_CHECKING,
  14. Any,
  15. Dict,
  16. Final,
  17. Iterable,
  18. Optional,
  19. TextIO,
  20. Tuple,
  21. Type,
  22. Union,
  23. )
  24. from multidict import CIMultiDict
  25. from . import hdrs
  26. from .abc import AbstractStreamWriter
  27. from .helpers import (
  28. _SENTINEL,
  29. content_disposition_header,
  30. guess_filename,
  31. parse_mimetype,
  32. sentinel,
  33. )
  34. from .streams import StreamReader
  35. from .typedefs import JSONEncoder, _CIMultiDict
  36. __all__ = (
  37. "PAYLOAD_REGISTRY",
  38. "get_payload",
  39. "payload_type",
  40. "Payload",
  41. "BytesPayload",
  42. "StringPayload",
  43. "IOBasePayload",
  44. "BytesIOPayload",
  45. "BufferedReaderPayload",
  46. "TextIOPayload",
  47. "StringIOPayload",
  48. "JsonPayload",
  49. "AsyncIterablePayload",
  50. )
  51. TOO_LARGE_BYTES_BODY: Final[int] = 2**20 # 1 MB
  52. if TYPE_CHECKING:
  53. from typing import List
  54. class LookupError(Exception):
  55. pass
  56. class Order(str, enum.Enum):
  57. normal = "normal"
  58. try_first = "try_first"
  59. try_last = "try_last"
  60. def get_payload(data: Any, *args: Any, **kwargs: Any) -> "Payload":
  61. return PAYLOAD_REGISTRY.get(data, *args, **kwargs)
  62. def register_payload(
  63. factory: Type["Payload"], type: Any, *, order: Order = Order.normal
  64. ) -> None:
  65. PAYLOAD_REGISTRY.register(factory, type, order=order)
  66. class payload_type:
  67. def __init__(self, type: Any, *, order: Order = Order.normal) -> None:
  68. self.type = type
  69. self.order = order
  70. def __call__(self, factory: Type["Payload"]) -> Type["Payload"]:
  71. register_payload(factory, self.type, order=self.order)
  72. return factory
  73. PayloadType = Type["Payload"]
  74. _PayloadRegistryItem = Tuple[PayloadType, Any]
  75. class PayloadRegistry:
  76. """Payload registry.
  77. note: we need zope.interface for more efficient adapter search
  78. """
  79. __slots__ = ("_first", "_normal", "_last", "_normal_lookup")
  80. def __init__(self) -> None:
  81. self._first: List[_PayloadRegistryItem] = []
  82. self._normal: List[_PayloadRegistryItem] = []
  83. self._last: List[_PayloadRegistryItem] = []
  84. self._normal_lookup: Dict[Any, PayloadType] = {}
  85. def get(
  86. self,
  87. data: Any,
  88. *args: Any,
  89. _CHAIN: "Type[chain[_PayloadRegistryItem]]" = chain,
  90. **kwargs: Any,
  91. ) -> "Payload":
  92. if self._first:
  93. for factory, type_ in self._first:
  94. if isinstance(data, type_):
  95. return factory(data, *args, **kwargs)
  96. # Try the fast lookup first
  97. if lookup_factory := self._normal_lookup.get(type(data)):
  98. return lookup_factory(data, *args, **kwargs)
  99. # Bail early if its already a Payload
  100. if isinstance(data, Payload):
  101. return data
  102. # Fallback to the slower linear search
  103. for factory, type_ in _CHAIN(self._normal, self._last):
  104. if isinstance(data, type_):
  105. return factory(data, *args, **kwargs)
  106. raise LookupError()
  107. def register(
  108. self, factory: PayloadType, type: Any, *, order: Order = Order.normal
  109. ) -> None:
  110. if order is Order.try_first:
  111. self._first.append((factory, type))
  112. elif order is Order.normal:
  113. self._normal.append((factory, type))
  114. if isinstance(type, Iterable):
  115. for t in type:
  116. self._normal_lookup[t] = factory
  117. else:
  118. self._normal_lookup[type] = factory
  119. elif order is Order.try_last:
  120. self._last.append((factory, type))
  121. else:
  122. raise ValueError(f"Unsupported order {order!r}")
  123. class Payload(ABC):
  124. _default_content_type: str = "application/octet-stream"
  125. _size: Optional[int] = None
  126. def __init__(
  127. self,
  128. value: Any,
  129. headers: Optional[
  130. Union[_CIMultiDict, Dict[str, str], Iterable[Tuple[str, str]]]
  131. ] = None,
  132. content_type: Union[str, None, _SENTINEL] = sentinel,
  133. filename: Optional[str] = None,
  134. encoding: Optional[str] = None,
  135. **kwargs: Any,
  136. ) -> None:
  137. self._encoding = encoding
  138. self._filename = filename
  139. self._headers: _CIMultiDict = CIMultiDict()
  140. self._value = value
  141. if content_type is not sentinel and content_type is not None:
  142. self._headers[hdrs.CONTENT_TYPE] = content_type
  143. elif self._filename is not None:
  144. if sys.version_info >= (3, 13):
  145. guesser = mimetypes.guess_file_type
  146. else:
  147. guesser = mimetypes.guess_type
  148. content_type = guesser(self._filename)[0]
  149. if content_type is None:
  150. content_type = self._default_content_type
  151. self._headers[hdrs.CONTENT_TYPE] = content_type
  152. else:
  153. self._headers[hdrs.CONTENT_TYPE] = self._default_content_type
  154. if headers:
  155. self._headers.update(headers)
  156. @property
  157. def size(self) -> Optional[int]:
  158. """Size of the payload."""
  159. return self._size
  160. @property
  161. def filename(self) -> Optional[str]:
  162. """Filename of the payload."""
  163. return self._filename
  164. @property
  165. def headers(self) -> _CIMultiDict:
  166. """Custom item headers"""
  167. return self._headers
  168. @property
  169. def _binary_headers(self) -> bytes:
  170. return (
  171. "".join([k + ": " + v + "\r\n" for k, v in self.headers.items()]).encode(
  172. "utf-8"
  173. )
  174. + b"\r\n"
  175. )
  176. @property
  177. def encoding(self) -> Optional[str]:
  178. """Payload encoding"""
  179. return self._encoding
  180. @property
  181. def content_type(self) -> str:
  182. """Content type"""
  183. return self._headers[hdrs.CONTENT_TYPE]
  184. def set_content_disposition(
  185. self,
  186. disptype: str,
  187. quote_fields: bool = True,
  188. _charset: str = "utf-8",
  189. **params: Any,
  190. ) -> None:
  191. """Sets ``Content-Disposition`` header."""
  192. self._headers[hdrs.CONTENT_DISPOSITION] = content_disposition_header(
  193. disptype, quote_fields=quote_fields, _charset=_charset, **params
  194. )
  195. @abstractmethod
  196. def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
  197. """Return string representation of the value.
  198. This is named decode() to allow compatibility with bytes objects.
  199. """
  200. @abstractmethod
  201. async def write(self, writer: AbstractStreamWriter) -> None:
  202. """Write payload.
  203. writer is an AbstractStreamWriter instance:
  204. """
  205. class BytesPayload(Payload):
  206. _value: bytes
  207. def __init__(
  208. self, value: Union[bytes, bytearray, memoryview], *args: Any, **kwargs: Any
  209. ) -> None:
  210. if "content_type" not in kwargs:
  211. kwargs["content_type"] = "application/octet-stream"
  212. super().__init__(value, *args, **kwargs)
  213. if isinstance(value, memoryview):
  214. self._size = value.nbytes
  215. elif isinstance(value, (bytes, bytearray)):
  216. self._size = len(value)
  217. else:
  218. raise TypeError(f"value argument must be byte-ish, not {type(value)!r}")
  219. if self._size > TOO_LARGE_BYTES_BODY:
  220. kwargs = {"source": self}
  221. warnings.warn(
  222. "Sending a large body directly with raw bytes might"
  223. " lock the event loop. You should probably pass an "
  224. "io.BytesIO object instead",
  225. ResourceWarning,
  226. **kwargs,
  227. )
  228. def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
  229. return self._value.decode(encoding, errors)
  230. async def write(self, writer: AbstractStreamWriter) -> None:
  231. await writer.write(self._value)
  232. class StringPayload(BytesPayload):
  233. def __init__(
  234. self,
  235. value: str,
  236. *args: Any,
  237. encoding: Optional[str] = None,
  238. content_type: Optional[str] = None,
  239. **kwargs: Any,
  240. ) -> None:
  241. if encoding is None:
  242. if content_type is None:
  243. real_encoding = "utf-8"
  244. content_type = "text/plain; charset=utf-8"
  245. else:
  246. mimetype = parse_mimetype(content_type)
  247. real_encoding = mimetype.parameters.get("charset", "utf-8")
  248. else:
  249. if content_type is None:
  250. content_type = "text/plain; charset=%s" % encoding
  251. real_encoding = encoding
  252. super().__init__(
  253. value.encode(real_encoding),
  254. encoding=real_encoding,
  255. content_type=content_type,
  256. *args,
  257. **kwargs,
  258. )
  259. class StringIOPayload(StringPayload):
  260. def __init__(self, value: IO[str], *args: Any, **kwargs: Any) -> None:
  261. super().__init__(value.read(), *args, **kwargs)
  262. class IOBasePayload(Payload):
  263. _value: io.IOBase
  264. def __init__(
  265. self, value: IO[Any], disposition: str = "attachment", *args: Any, **kwargs: Any
  266. ) -> None:
  267. if "filename" not in kwargs:
  268. kwargs["filename"] = guess_filename(value)
  269. super().__init__(value, *args, **kwargs)
  270. if self._filename is not None and disposition is not None:
  271. if hdrs.CONTENT_DISPOSITION not in self.headers:
  272. self.set_content_disposition(disposition, filename=self._filename)
  273. async def write(self, writer: AbstractStreamWriter) -> None:
  274. loop = asyncio.get_event_loop()
  275. try:
  276. chunk = await loop.run_in_executor(None, self._value.read, 2**16)
  277. while chunk:
  278. await writer.write(chunk)
  279. chunk = await loop.run_in_executor(None, self._value.read, 2**16)
  280. finally:
  281. await loop.run_in_executor(None, self._value.close)
  282. def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
  283. return "".join(r.decode(encoding, errors) for r in self._value.readlines())
  284. class TextIOPayload(IOBasePayload):
  285. _value: io.TextIOBase
  286. def __init__(
  287. self,
  288. value: TextIO,
  289. *args: Any,
  290. encoding: Optional[str] = None,
  291. content_type: Optional[str] = None,
  292. **kwargs: Any,
  293. ) -> None:
  294. if encoding is None:
  295. if content_type is None:
  296. encoding = "utf-8"
  297. content_type = "text/plain; charset=utf-8"
  298. else:
  299. mimetype = parse_mimetype(content_type)
  300. encoding = mimetype.parameters.get("charset", "utf-8")
  301. else:
  302. if content_type is None:
  303. content_type = "text/plain; charset=%s" % encoding
  304. super().__init__(
  305. value,
  306. content_type=content_type,
  307. encoding=encoding,
  308. *args,
  309. **kwargs,
  310. )
  311. @property
  312. def size(self) -> Optional[int]:
  313. try:
  314. return os.fstat(self._value.fileno()).st_size - self._value.tell()
  315. except OSError:
  316. return None
  317. def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
  318. return self._value.read()
  319. async def write(self, writer: AbstractStreamWriter) -> None:
  320. loop = asyncio.get_event_loop()
  321. try:
  322. chunk = await loop.run_in_executor(None, self._value.read, 2**16)
  323. while chunk:
  324. data = (
  325. chunk.encode(encoding=self._encoding)
  326. if self._encoding
  327. else chunk.encode()
  328. )
  329. await writer.write(data)
  330. chunk = await loop.run_in_executor(None, self._value.read, 2**16)
  331. finally:
  332. await loop.run_in_executor(None, self._value.close)
  333. class BytesIOPayload(IOBasePayload):
  334. _value: io.BytesIO
  335. @property
  336. def size(self) -> int:
  337. position = self._value.tell()
  338. end = self._value.seek(0, os.SEEK_END)
  339. self._value.seek(position)
  340. return end - position
  341. def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
  342. return self._value.read().decode(encoding, errors)
  343. class BufferedReaderPayload(IOBasePayload):
  344. _value: io.BufferedIOBase
  345. @property
  346. def size(self) -> Optional[int]:
  347. try:
  348. return os.fstat(self._value.fileno()).st_size - self._value.tell()
  349. except (OSError, AttributeError):
  350. # data.fileno() is not supported, e.g.
  351. # io.BufferedReader(io.BytesIO(b'data'))
  352. # For some file-like objects (e.g. tarfile), the fileno() attribute may
  353. # not exist at all, and will instead raise an AttributeError.
  354. return None
  355. def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
  356. return self._value.read().decode(encoding, errors)
  357. class JsonPayload(BytesPayload):
  358. def __init__(
  359. self,
  360. value: Any,
  361. encoding: str = "utf-8",
  362. content_type: str = "application/json",
  363. dumps: JSONEncoder = json.dumps,
  364. *args: Any,
  365. **kwargs: Any,
  366. ) -> None:
  367. super().__init__(
  368. dumps(value).encode(encoding),
  369. content_type=content_type,
  370. encoding=encoding,
  371. *args,
  372. **kwargs,
  373. )
  374. if TYPE_CHECKING:
  375. from typing import AsyncIterable, AsyncIterator
  376. _AsyncIterator = AsyncIterator[bytes]
  377. _AsyncIterable = AsyncIterable[bytes]
  378. else:
  379. from collections.abc import AsyncIterable, AsyncIterator
  380. _AsyncIterator = AsyncIterator
  381. _AsyncIterable = AsyncIterable
  382. class AsyncIterablePayload(Payload):
  383. _iter: Optional[_AsyncIterator] = None
  384. _value: _AsyncIterable
  385. def __init__(self, value: _AsyncIterable, *args: Any, **kwargs: Any) -> None:
  386. if not isinstance(value, AsyncIterable):
  387. raise TypeError(
  388. "value argument must support "
  389. "collections.abc.AsyncIterable interface, "
  390. "got {!r}".format(type(value))
  391. )
  392. if "content_type" not in kwargs:
  393. kwargs["content_type"] = "application/octet-stream"
  394. super().__init__(value, *args, **kwargs)
  395. self._iter = value.__aiter__()
  396. async def write(self, writer: AbstractStreamWriter) -> None:
  397. if self._iter:
  398. try:
  399. # iter is not None check prevents rare cases
  400. # when the case iterable is used twice
  401. while True:
  402. chunk = await self._iter.__anext__()
  403. await writer.write(chunk)
  404. except StopAsyncIteration:
  405. self._iter = None
  406. def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
  407. raise TypeError("Unable to decode.")
  408. class StreamReaderPayload(AsyncIterablePayload):
  409. def __init__(self, value: StreamReader, *args: Any, **kwargs: Any) -> None:
  410. super().__init__(value.iter_any(), *args, **kwargs)
  411. PAYLOAD_REGISTRY = PayloadRegistry()
  412. PAYLOAD_REGISTRY.register(BytesPayload, (bytes, bytearray, memoryview))
  413. PAYLOAD_REGISTRY.register(StringPayload, str)
  414. PAYLOAD_REGISTRY.register(StringIOPayload, io.StringIO)
  415. PAYLOAD_REGISTRY.register(TextIOPayload, io.TextIOBase)
  416. PAYLOAD_REGISTRY.register(BytesIOPayload, io.BytesIO)
  417. PAYLOAD_REGISTRY.register(BufferedReaderPayload, (io.BufferedReader, io.BufferedRandom))
  418. PAYLOAD_REGISTRY.register(IOBasePayload, io.IOBase)
  419. PAYLOAD_REGISTRY.register(StreamReaderPayload, StreamReader)
  420. # try_last for giving a chance to more specialized async interables like
  421. # multidict.BodyPartReaderPayload override the default
  422. PAYLOAD_REGISTRY.register(AsyncIterablePayload, AsyncIterable, order=Order.try_last)