arrow.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. import errno
  2. import io
  3. import os
  4. import secrets
  5. import shutil
  6. from contextlib import suppress
  7. from functools import cached_property, wraps
  8. from urllib.parse import parse_qs
  9. from fsspec.spec import AbstractFileSystem
  10. from fsspec.utils import (
  11. get_package_version_without_import,
  12. infer_storage_options,
  13. mirror_from,
  14. tokenize,
  15. )
  16. def wrap_exceptions(func):
  17. @wraps(func)
  18. def wrapper(*args, **kwargs):
  19. try:
  20. return func(*args, **kwargs)
  21. except OSError as exception:
  22. if not exception.args:
  23. raise
  24. message, *args = exception.args
  25. if isinstance(message, str) and "does not exist" in message:
  26. raise FileNotFoundError(errno.ENOENT, message) from exception
  27. else:
  28. raise
  29. return wrapper
  30. PYARROW_VERSION = None
  31. class ArrowFSWrapper(AbstractFileSystem):
  32. """FSSpec-compatible wrapper of pyarrow.fs.FileSystem.
  33. Parameters
  34. ----------
  35. fs : pyarrow.fs.FileSystem
  36. """
  37. root_marker = "/"
  38. def __init__(self, fs, **kwargs):
  39. global PYARROW_VERSION
  40. PYARROW_VERSION = get_package_version_without_import("pyarrow")
  41. self.fs = fs
  42. super().__init__(**kwargs)
  43. @property
  44. def protocol(self):
  45. return self.fs.type_name
  46. @cached_property
  47. def fsid(self):
  48. return "hdfs_" + tokenize(self.fs.host, self.fs.port)
  49. @classmethod
  50. def _strip_protocol(cls, path):
  51. ops = infer_storage_options(path)
  52. path = ops["path"]
  53. if path.startswith("//"):
  54. # special case for "hdfs://path" (without the triple slash)
  55. path = path[1:]
  56. return path
  57. def ls(self, path, detail=False, **kwargs):
  58. path = self._strip_protocol(path)
  59. from pyarrow.fs import FileSelector
  60. entries = [
  61. self._make_entry(entry)
  62. for entry in self.fs.get_file_info(FileSelector(path))
  63. ]
  64. if detail:
  65. return entries
  66. else:
  67. return [entry["name"] for entry in entries]
  68. def info(self, path, **kwargs):
  69. path = self._strip_protocol(path)
  70. [info] = self.fs.get_file_info([path])
  71. return self._make_entry(info)
  72. def exists(self, path):
  73. path = self._strip_protocol(path)
  74. try:
  75. self.info(path)
  76. except FileNotFoundError:
  77. return False
  78. else:
  79. return True
  80. def _make_entry(self, info):
  81. from pyarrow.fs import FileType
  82. if info.type is FileType.Directory:
  83. kind = "directory"
  84. elif info.type is FileType.File:
  85. kind = "file"
  86. elif info.type is FileType.NotFound:
  87. raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), info.path)
  88. else:
  89. kind = "other"
  90. return {
  91. "name": info.path,
  92. "size": info.size,
  93. "type": kind,
  94. "mtime": info.mtime,
  95. }
  96. @wrap_exceptions
  97. def cp_file(self, path1, path2, **kwargs):
  98. path1 = self._strip_protocol(path1).rstrip("/")
  99. path2 = self._strip_protocol(path2).rstrip("/")
  100. with self._open(path1, "rb") as lstream:
  101. tmp_fname = f"{path2}.tmp.{secrets.token_hex(6)}"
  102. try:
  103. with self.open(tmp_fname, "wb") as rstream:
  104. shutil.copyfileobj(lstream, rstream)
  105. self.fs.move(tmp_fname, path2)
  106. except BaseException:
  107. with suppress(FileNotFoundError):
  108. self.fs.delete_file(tmp_fname)
  109. raise
  110. @wrap_exceptions
  111. def mv(self, path1, path2, **kwargs):
  112. path1 = self._strip_protocol(path1).rstrip("/")
  113. path2 = self._strip_protocol(path2).rstrip("/")
  114. self.fs.move(path1, path2)
  115. @wrap_exceptions
  116. def rm_file(self, path):
  117. path = self._strip_protocol(path)
  118. self.fs.delete_file(path)
  119. @wrap_exceptions
  120. def rm(self, path, recursive=False, maxdepth=None):
  121. path = self._strip_protocol(path).rstrip("/")
  122. if self.isdir(path):
  123. if recursive:
  124. self.fs.delete_dir(path)
  125. else:
  126. raise ValueError("Can't delete directories without recursive=False")
  127. else:
  128. self.fs.delete_file(path)
  129. @wrap_exceptions
  130. def _open(self, path, mode="rb", block_size=None, seekable=True, **kwargs):
  131. if mode == "rb":
  132. if seekable:
  133. method = self.fs.open_input_file
  134. else:
  135. method = self.fs.open_input_stream
  136. elif mode == "wb":
  137. method = self.fs.open_output_stream
  138. elif mode == "ab":
  139. method = self.fs.open_append_stream
  140. else:
  141. raise ValueError(f"unsupported mode for Arrow filesystem: {mode!r}")
  142. _kwargs = {}
  143. if mode != "rb" or not seekable:
  144. if int(PYARROW_VERSION.split(".")[0]) >= 4:
  145. # disable compression auto-detection
  146. _kwargs["compression"] = None
  147. stream = method(path, **_kwargs)
  148. return ArrowFile(self, stream, path, mode, block_size, **kwargs)
  149. @wrap_exceptions
  150. def mkdir(self, path, create_parents=True, **kwargs):
  151. path = self._strip_protocol(path)
  152. if create_parents:
  153. self.makedirs(path, exist_ok=True)
  154. else:
  155. self.fs.create_dir(path, recursive=False)
  156. @wrap_exceptions
  157. def makedirs(self, path, exist_ok=False):
  158. path = self._strip_protocol(path)
  159. self.fs.create_dir(path, recursive=True)
  160. @wrap_exceptions
  161. def rmdir(self, path):
  162. path = self._strip_protocol(path)
  163. self.fs.delete_dir(path)
  164. @wrap_exceptions
  165. def modified(self, path):
  166. path = self._strip_protocol(path)
  167. return self.fs.get_file_info(path).mtime
  168. def cat_file(self, path, start=None, end=None, **kwargs):
  169. kwargs["seekable"] = start not in [None, 0]
  170. return super().cat_file(path, start=None, end=None, **kwargs)
  171. def get_file(self, rpath, lpath, **kwargs):
  172. kwargs["seekable"] = False
  173. super().get_file(rpath, lpath, **kwargs)
  174. @mirror_from(
  175. "stream",
  176. [
  177. "read",
  178. "seek",
  179. "tell",
  180. "write",
  181. "readable",
  182. "writable",
  183. "close",
  184. "size",
  185. "seekable",
  186. ],
  187. )
  188. class ArrowFile(io.IOBase):
  189. def __init__(self, fs, stream, path, mode, block_size=None, **kwargs):
  190. self.path = path
  191. self.mode = mode
  192. self.fs = fs
  193. self.stream = stream
  194. self.blocksize = self.block_size = block_size
  195. self.kwargs = kwargs
  196. def __enter__(self):
  197. return self
  198. def __exit__(self, *args):
  199. return self.close()
  200. class HadoopFileSystem(ArrowFSWrapper):
  201. """A wrapper on top of the pyarrow.fs.HadoopFileSystem
  202. to connect it's interface with fsspec"""
  203. protocol = "hdfs"
  204. def __init__(
  205. self,
  206. host="default",
  207. port=0,
  208. user=None,
  209. kerb_ticket=None,
  210. replication=3,
  211. extra_conf=None,
  212. **kwargs,
  213. ):
  214. """
  215. Parameters
  216. ----------
  217. host: str
  218. Hostname, IP or "default" to try to read from Hadoop config
  219. port: int
  220. Port to connect on, or default from Hadoop config if 0
  221. user: str or None
  222. If given, connect as this username
  223. kerb_ticket: str or None
  224. If given, use this ticket for authentication
  225. replication: int
  226. set replication factor of file for write operations. default value is 3.
  227. extra_conf: None or dict
  228. Passed on to HadoopFileSystem
  229. """
  230. from pyarrow.fs import HadoopFileSystem
  231. fs = HadoopFileSystem(
  232. host=host,
  233. port=port,
  234. user=user,
  235. kerb_ticket=kerb_ticket,
  236. replication=replication,
  237. extra_conf=extra_conf,
  238. )
  239. super().__init__(fs=fs, **kwargs)
  240. @staticmethod
  241. def _get_kwargs_from_urls(path):
  242. ops = infer_storage_options(path)
  243. out = {}
  244. if ops.get("host", None):
  245. out["host"] = ops["host"]
  246. if ops.get("username", None):
  247. out["user"] = ops["username"]
  248. if ops.get("port", None):
  249. out["port"] = ops["port"]
  250. if ops.get("url_query", None):
  251. queries = parse_qs(ops["url_query"])
  252. if queries.get("replication", None):
  253. out["replication"] = int(queries["replication"][0])
  254. return out