compression_utils.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. import asyncio
  2. import zlib
  3. from concurrent.futures import Executor
  4. from typing import Optional, cast
  5. try:
  6. try:
  7. import brotlicffi as brotli
  8. except ImportError:
  9. import brotli
  10. HAS_BROTLI = True
  11. except ImportError: # pragma: no cover
  12. HAS_BROTLI = False
  13. MAX_SYNC_CHUNK_SIZE = 1024
  14. def encoding_to_mode(
  15. encoding: Optional[str] = None,
  16. suppress_deflate_header: bool = False,
  17. ) -> int:
  18. if encoding == "gzip":
  19. return 16 + zlib.MAX_WBITS
  20. return -zlib.MAX_WBITS if suppress_deflate_header else zlib.MAX_WBITS
  21. class ZlibBaseHandler:
  22. def __init__(
  23. self,
  24. mode: int,
  25. executor: Optional[Executor] = None,
  26. max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
  27. ):
  28. self._mode = mode
  29. self._executor = executor
  30. self._max_sync_chunk_size = max_sync_chunk_size
  31. class ZLibCompressor(ZlibBaseHandler):
  32. def __init__(
  33. self,
  34. encoding: Optional[str] = None,
  35. suppress_deflate_header: bool = False,
  36. level: Optional[int] = None,
  37. wbits: Optional[int] = None,
  38. strategy: int = zlib.Z_DEFAULT_STRATEGY,
  39. executor: Optional[Executor] = None,
  40. max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
  41. ):
  42. super().__init__(
  43. mode=(
  44. encoding_to_mode(encoding, suppress_deflate_header)
  45. if wbits is None
  46. else wbits
  47. ),
  48. executor=executor,
  49. max_sync_chunk_size=max_sync_chunk_size,
  50. )
  51. if level is None:
  52. self._compressor = zlib.compressobj(wbits=self._mode, strategy=strategy)
  53. else:
  54. self._compressor = zlib.compressobj(
  55. wbits=self._mode, strategy=strategy, level=level
  56. )
  57. self._compress_lock = asyncio.Lock()
  58. def compress_sync(self, data: bytes) -> bytes:
  59. return self._compressor.compress(data)
  60. async def compress(self, data: bytes) -> bytes:
  61. """Compress the data and returned the compressed bytes.
  62. Note that flush() must be called after the last call to compress()
  63. If the data size is large than the max_sync_chunk_size, the compression
  64. will be done in the executor. Otherwise, the compression will be done
  65. in the event loop.
  66. """
  67. async with self._compress_lock:
  68. # To ensure the stream is consistent in the event
  69. # there are multiple writers, we need to lock
  70. # the compressor so that only one writer can
  71. # compress at a time.
  72. if (
  73. self._max_sync_chunk_size is not None
  74. and len(data) > self._max_sync_chunk_size
  75. ):
  76. return await asyncio.get_running_loop().run_in_executor(
  77. self._executor, self._compressor.compress, data
  78. )
  79. return self.compress_sync(data)
  80. def flush(self, mode: int = zlib.Z_FINISH) -> bytes:
  81. return self._compressor.flush(mode)
  82. class ZLibDecompressor(ZlibBaseHandler):
  83. def __init__(
  84. self,
  85. encoding: Optional[str] = None,
  86. suppress_deflate_header: bool = False,
  87. executor: Optional[Executor] = None,
  88. max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
  89. ):
  90. super().__init__(
  91. mode=encoding_to_mode(encoding, suppress_deflate_header),
  92. executor=executor,
  93. max_sync_chunk_size=max_sync_chunk_size,
  94. )
  95. self._decompressor = zlib.decompressobj(wbits=self._mode)
  96. def decompress_sync(self, data: bytes, max_length: int = 0) -> bytes:
  97. return self._decompressor.decompress(data, max_length)
  98. async def decompress(self, data: bytes, max_length: int = 0) -> bytes:
  99. """Decompress the data and return the decompressed bytes.
  100. If the data size is large than the max_sync_chunk_size, the decompression
  101. will be done in the executor. Otherwise, the decompression will be done
  102. in the event loop.
  103. """
  104. if (
  105. self._max_sync_chunk_size is not None
  106. and len(data) > self._max_sync_chunk_size
  107. ):
  108. return await asyncio.get_running_loop().run_in_executor(
  109. self._executor, self._decompressor.decompress, data, max_length
  110. )
  111. return self.decompress_sync(data, max_length)
  112. def flush(self, length: int = 0) -> bytes:
  113. return (
  114. self._decompressor.flush(length)
  115. if length > 0
  116. else self._decompressor.flush()
  117. )
  118. @property
  119. def eof(self) -> bool:
  120. return self._decompressor.eof
  121. @property
  122. def unconsumed_tail(self) -> bytes:
  123. return self._decompressor.unconsumed_tail
  124. @property
  125. def unused_data(self) -> bytes:
  126. return self._decompressor.unused_data
  127. class BrotliDecompressor:
  128. # Supports both 'brotlipy' and 'Brotli' packages
  129. # since they share an import name. The top branches
  130. # are for 'brotlipy' and bottom branches for 'Brotli'
  131. def __init__(self) -> None:
  132. if not HAS_BROTLI:
  133. raise RuntimeError(
  134. "The brotli decompression is not available. "
  135. "Please install `Brotli` module"
  136. )
  137. self._obj = brotli.Decompressor()
  138. def decompress_sync(self, data: bytes) -> bytes:
  139. if hasattr(self._obj, "decompress"):
  140. return cast(bytes, self._obj.decompress(data))
  141. return cast(bytes, self._obj.process(data))
  142. def flush(self) -> bytes:
  143. if hasattr(self._obj, "flush"):
  144. return cast(bytes, self._obj.flush())
  145. return b""