123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175 |
- """Helper functions for a standard streaming compression API"""
- from zipfile import ZipFile
- import fsspec.utils
- from fsspec.spec import AbstractBufferedFile
- def noop_file(file, mode, **kwargs):
- return file
- # TODO: files should also be available as contexts
- # should be functions of the form func(infile, mode=, **kwargs) -> file-like
- compr = {None: noop_file}
- def register_compression(name, callback, extensions, force=False):
- """Register an "inferable" file compression type.
- Registers transparent file compression type for use with fsspec.open.
- Compression can be specified by name in open, or "infer"-ed for any files
- ending with the given extensions.
- Args:
- name: (str) The compression type name. Eg. "gzip".
- callback: A callable of form (infile, mode, **kwargs) -> file-like.
- Accepts an input file-like object, the target mode and kwargs.
- Returns a wrapped file-like object.
- extensions: (str, Iterable[str]) A file extension, or list of file
- extensions for which to infer this compression scheme. Eg. "gz".
- force: (bool) Force re-registration of compression type or extensions.
- Raises:
- ValueError: If name or extensions already registered, and not force.
- """
- if isinstance(extensions, str):
- extensions = [extensions]
- # Validate registration
- if name in compr and not force:
- raise ValueError(f"Duplicate compression registration: {name}")
- for ext in extensions:
- if ext in fsspec.utils.compressions and not force:
- raise ValueError(f"Duplicate compression file extension: {ext} ({name})")
- compr[name] = callback
- for ext in extensions:
- fsspec.utils.compressions[ext] = name
- def unzip(infile, mode="rb", filename=None, **kwargs):
- if "r" not in mode:
- filename = filename or "file"
- z = ZipFile(infile, mode="w", **kwargs)
- fo = z.open(filename, mode="w")
- fo.close = lambda closer=fo.close: closer() or z.close()
- return fo
- z = ZipFile(infile)
- if filename is None:
- filename = z.namelist()[0]
- return z.open(filename, mode="r", **kwargs)
- register_compression("zip", unzip, "zip")
- try:
- from bz2 import BZ2File
- except ImportError:
- pass
- else:
- register_compression("bz2", BZ2File, "bz2")
- try: # pragma: no cover
- from isal import igzip
- def isal(infile, mode="rb", **kwargs):
- return igzip.IGzipFile(fileobj=infile, mode=mode, **kwargs)
- register_compression("gzip", isal, "gz")
- except ImportError:
- from gzip import GzipFile
- register_compression(
- "gzip", lambda f, **kwargs: GzipFile(fileobj=f, **kwargs), "gz"
- )
- try:
- from lzma import LZMAFile
- register_compression("lzma", LZMAFile, "lzma")
- register_compression("xz", LZMAFile, "xz")
- except ImportError:
- pass
- try:
- import lzmaffi
- register_compression("lzma", lzmaffi.LZMAFile, "lzma", force=True)
- register_compression("xz", lzmaffi.LZMAFile, "xz", force=True)
- except ImportError:
- pass
- class SnappyFile(AbstractBufferedFile):
- def __init__(self, infile, mode, **kwargs):
- import snappy
- super().__init__(
- fs=None, path="snappy", mode=mode.strip("b") + "b", size=999999999, **kwargs
- )
- self.infile = infile
- if "r" in mode:
- self.codec = snappy.StreamDecompressor()
- else:
- self.codec = snappy.StreamCompressor()
- def _upload_chunk(self, final=False):
- self.buffer.seek(0)
- out = self.codec.add_chunk(self.buffer.read())
- self.infile.write(out)
- return True
- def seek(self, loc, whence=0):
- raise NotImplementedError("SnappyFile is not seekable")
- def seekable(self):
- return False
- def _fetch_range(self, start, end):
- """Get the specified set of bytes from remote"""
- data = self.infile.read(end - start)
- return self.codec.decompress(data)
- try:
- import snappy
- snappy.compress(b"")
- # Snappy may use the .sz file extension, but this is not part of the
- # standard implementation.
- register_compression("snappy", SnappyFile, [])
- except (ImportError, NameError, AttributeError):
- pass
- try:
- import lz4.frame
- register_compression("lz4", lz4.frame.open, "lz4")
- except ImportError:
- pass
- try:
- import zstandard as zstd
- def zstandard_file(infile, mode="rb"):
- if "r" in mode:
- cctx = zstd.ZstdDecompressor()
- return cctx.stream_reader(infile)
- else:
- cctx = zstd.ZstdCompressor(level=10)
- return cctx.stream_writer(infile)
- register_compression("zstd", zstandard_file, "zst")
- except ImportError:
- pass
- def available_compressions():
- """Return a list of the implemented compressions."""
- return list(compr)
|