compression.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. """Helper functions for a standard streaming compression API"""
  2. from zipfile import ZipFile
  3. import fsspec.utils
  4. from fsspec.spec import AbstractBufferedFile
  5. def noop_file(file, mode, **kwargs):
  6. return file
  7. # TODO: files should also be available as contexts
  8. # should be functions of the form func(infile, mode=, **kwargs) -> file-like
  9. compr = {None: noop_file}
  10. def register_compression(name, callback, extensions, force=False):
  11. """Register an "inferable" file compression type.
  12. Registers transparent file compression type for use with fsspec.open.
  13. Compression can be specified by name in open, or "infer"-ed for any files
  14. ending with the given extensions.
  15. Args:
  16. name: (str) The compression type name. Eg. "gzip".
  17. callback: A callable of form (infile, mode, **kwargs) -> file-like.
  18. Accepts an input file-like object, the target mode and kwargs.
  19. Returns a wrapped file-like object.
  20. extensions: (str, Iterable[str]) A file extension, or list of file
  21. extensions for which to infer this compression scheme. Eg. "gz".
  22. force: (bool) Force re-registration of compression type or extensions.
  23. Raises:
  24. ValueError: If name or extensions already registered, and not force.
  25. """
  26. if isinstance(extensions, str):
  27. extensions = [extensions]
  28. # Validate registration
  29. if name in compr and not force:
  30. raise ValueError(f"Duplicate compression registration: {name}")
  31. for ext in extensions:
  32. if ext in fsspec.utils.compressions and not force:
  33. raise ValueError(f"Duplicate compression file extension: {ext} ({name})")
  34. compr[name] = callback
  35. for ext in extensions:
  36. fsspec.utils.compressions[ext] = name
  37. def unzip(infile, mode="rb", filename=None, **kwargs):
  38. if "r" not in mode:
  39. filename = filename or "file"
  40. z = ZipFile(infile, mode="w", **kwargs)
  41. fo = z.open(filename, mode="w")
  42. fo.close = lambda closer=fo.close: closer() or z.close()
  43. return fo
  44. z = ZipFile(infile)
  45. if filename is None:
  46. filename = z.namelist()[0]
  47. return z.open(filename, mode="r", **kwargs)
  48. register_compression("zip", unzip, "zip")
  49. try:
  50. from bz2 import BZ2File
  51. except ImportError:
  52. pass
  53. else:
  54. register_compression("bz2", BZ2File, "bz2")
  55. try: # pragma: no cover
  56. from isal import igzip
  57. def isal(infile, mode="rb", **kwargs):
  58. return igzip.IGzipFile(fileobj=infile, mode=mode, **kwargs)
  59. register_compression("gzip", isal, "gz")
  60. except ImportError:
  61. from gzip import GzipFile
  62. register_compression(
  63. "gzip", lambda f, **kwargs: GzipFile(fileobj=f, **kwargs), "gz"
  64. )
  65. try:
  66. from lzma import LZMAFile
  67. register_compression("lzma", LZMAFile, "lzma")
  68. register_compression("xz", LZMAFile, "xz")
  69. except ImportError:
  70. pass
  71. try:
  72. import lzmaffi
  73. register_compression("lzma", lzmaffi.LZMAFile, "lzma", force=True)
  74. register_compression("xz", lzmaffi.LZMAFile, "xz", force=True)
  75. except ImportError:
  76. pass
  77. class SnappyFile(AbstractBufferedFile):
  78. def __init__(self, infile, mode, **kwargs):
  79. import snappy
  80. super().__init__(
  81. fs=None, path="snappy", mode=mode.strip("b") + "b", size=999999999, **kwargs
  82. )
  83. self.infile = infile
  84. if "r" in mode:
  85. self.codec = snappy.StreamDecompressor()
  86. else:
  87. self.codec = snappy.StreamCompressor()
  88. def _upload_chunk(self, final=False):
  89. self.buffer.seek(0)
  90. out = self.codec.add_chunk(self.buffer.read())
  91. self.infile.write(out)
  92. return True
  93. def seek(self, loc, whence=0):
  94. raise NotImplementedError("SnappyFile is not seekable")
  95. def seekable(self):
  96. return False
  97. def _fetch_range(self, start, end):
  98. """Get the specified set of bytes from remote"""
  99. data = self.infile.read(end - start)
  100. return self.codec.decompress(data)
  101. try:
  102. import snappy
  103. snappy.compress(b"")
  104. # Snappy may use the .sz file extension, but this is not part of the
  105. # standard implementation.
  106. register_compression("snappy", SnappyFile, [])
  107. except (ImportError, NameError, AttributeError):
  108. pass
  109. try:
  110. import lz4.frame
  111. register_compression("lz4", lz4.frame.open, "lz4")
  112. except ImportError:
  113. pass
  114. try:
  115. import zstandard as zstd
  116. def zstandard_file(infile, mode="rb"):
  117. if "r" in mode:
  118. cctx = zstd.ZstdDecompressor()
  119. return cctx.stream_reader(infile)
  120. else:
  121. cctx = zstd.ZstdCompressor(level=10)
  122. return cctx.stream_writer(infile)
  123. register_compression("zstd", zstandard_file, "zst")
  124. except ImportError:
  125. pass
  126. def available_compressions():
  127. """Return a list of the implemented compressions."""
  128. return list(compr)