caching.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005
  1. from __future__ import annotations
  2. import collections
  3. import functools
  4. import logging
  5. import math
  6. import os
  7. import threading
  8. import warnings
  9. from concurrent.futures import Future, ThreadPoolExecutor
  10. from itertools import groupby
  11. from operator import itemgetter
  12. from typing import (
  13. TYPE_CHECKING,
  14. Any,
  15. Callable,
  16. ClassVar,
  17. Generic,
  18. NamedTuple,
  19. Optional,
  20. OrderedDict,
  21. TypeVar,
  22. )
  23. if TYPE_CHECKING:
  24. import mmap
  25. from typing_extensions import ParamSpec
  26. P = ParamSpec("P")
  27. else:
  28. P = TypeVar("P")
  29. T = TypeVar("T")
  30. logger = logging.getLogger("fsspec")
  31. Fetcher = Callable[[int, int], bytes] # Maps (start, end) to bytes
  32. MultiFetcher = Callable[[list[int, int]], bytes] # Maps [(start, end)] to bytes
  33. class BaseCache:
  34. """Pass-though cache: doesn't keep anything, calls every time
  35. Acts as base class for other cachers
  36. Parameters
  37. ----------
  38. blocksize: int
  39. How far to read ahead in numbers of bytes
  40. fetcher: func
  41. Function of the form f(start, end) which gets bytes from remote as
  42. specified
  43. size: int
  44. How big this file is
  45. """
  46. name: ClassVar[str] = "none"
  47. def __init__(self, blocksize: int, fetcher: Fetcher, size: int) -> None:
  48. self.blocksize = blocksize
  49. self.nblocks = 0
  50. self.fetcher = fetcher
  51. self.size = size
  52. self.hit_count = 0
  53. self.miss_count = 0
  54. # the bytes that we actually requested
  55. self.total_requested_bytes = 0
  56. def _fetch(self, start: int | None, stop: int | None) -> bytes:
  57. if start is None:
  58. start = 0
  59. if stop is None:
  60. stop = self.size
  61. if start >= self.size or start >= stop:
  62. return b""
  63. return self.fetcher(start, stop)
  64. def _reset_stats(self) -> None:
  65. """Reset hit and miss counts for a more ganular report e.g. by file."""
  66. self.hit_count = 0
  67. self.miss_count = 0
  68. self.total_requested_bytes = 0
  69. def _log_stats(self) -> str:
  70. """Return a formatted string of the cache statistics."""
  71. if self.hit_count == 0 and self.miss_count == 0:
  72. # a cache that does nothing, this is for logs only
  73. return ""
  74. return f" , {self.name}: {self.hit_count} hits, {self.miss_count} misses, {self.total_requested_bytes} total requested bytes"
  75. def __repr__(self) -> str:
  76. # TODO: use rich for better formatting
  77. return f"""
  78. <{self.__class__.__name__}:
  79. block size : {self.blocksize}
  80. block count : {self.nblocks}
  81. file size : {self.size}
  82. cache hits : {self.hit_count}
  83. cache misses: {self.miss_count}
  84. total requested bytes: {self.total_requested_bytes}>
  85. """
  86. class MMapCache(BaseCache):
  87. """memory-mapped sparse file cache
  88. Opens temporary file, which is filled blocks-wise when data is requested.
  89. Ensure there is enough disc space in the temporary location.
  90. This cache method might only work on posix
  91. Parameters
  92. ----------
  93. blocksize: int
  94. How far to read ahead in numbers of bytes
  95. fetcher: Fetcher
  96. Function of the form f(start, end) which gets bytes from remote as
  97. specified
  98. size: int
  99. How big this file is
  100. location: str
  101. Where to create the temporary file. If None, a temporary file is
  102. created using tempfile.TemporaryFile().
  103. blocks: set[int]
  104. Set of block numbers that have already been fetched. If None, an empty
  105. set is created.
  106. multi_fetcher: MultiFetcher
  107. Function of the form f([(start, end)]) which gets bytes from remote
  108. as specified. This function is used to fetch multiple blocks at once.
  109. If not specified, the fetcher function is used instead.
  110. """
  111. name = "mmap"
  112. def __init__(
  113. self,
  114. blocksize: int,
  115. fetcher: Fetcher,
  116. size: int,
  117. location: str | None = None,
  118. blocks: set[int] | None = None,
  119. multi_fetcher: MultiFetcher | None = None,
  120. ) -> None:
  121. super().__init__(blocksize, fetcher, size)
  122. self.blocks = set() if blocks is None else blocks
  123. self.location = location
  124. self.multi_fetcher = multi_fetcher
  125. self.cache = self._makefile()
  126. def _makefile(self) -> mmap.mmap | bytearray:
  127. import mmap
  128. import tempfile
  129. if self.size == 0:
  130. return bytearray()
  131. # posix version
  132. if self.location is None or not os.path.exists(self.location):
  133. if self.location is None:
  134. fd = tempfile.TemporaryFile()
  135. self.blocks = set()
  136. else:
  137. fd = open(self.location, "wb+")
  138. fd.seek(self.size - 1)
  139. fd.write(b"1")
  140. fd.flush()
  141. else:
  142. fd = open(self.location, "r+b")
  143. return mmap.mmap(fd.fileno(), self.size)
  144. def _fetch(self, start: int | None, end: int | None) -> bytes:
  145. logger.debug(f"MMap cache fetching {start}-{end}")
  146. if start is None:
  147. start = 0
  148. if end is None:
  149. end = self.size
  150. if start >= self.size or start >= end:
  151. return b""
  152. start_block = start // self.blocksize
  153. end_block = end // self.blocksize
  154. block_range = range(start_block, end_block + 1)
  155. # Determine which blocks need to be fetched. This sequence is sorted by construction.
  156. need = (i for i in block_range if i not in self.blocks)
  157. # Count the number of blocks already cached
  158. self.hit_count += sum(1 for i in block_range if i in self.blocks)
  159. ranges = []
  160. # Consolidate needed blocks.
  161. # Algorithm adapted from Python 2.x itertools documentation.
  162. # We are grouping an enumerated sequence of blocks. By comparing when the difference
  163. # between an ascending range (provided by enumerate) and the needed block numbers
  164. # we can detect when the block number skips values. The key computes this difference.
  165. # Whenever the difference changes, we know that we have previously cached block(s),
  166. # and a new group is started. In other words, this algorithm neatly groups
  167. # runs of consecutive block numbers so they can be fetched together.
  168. for _, _blocks in groupby(enumerate(need), key=lambda x: x[0] - x[1]):
  169. # Extract the blocks from the enumerated sequence
  170. _blocks = tuple(map(itemgetter(1), _blocks))
  171. # Compute start of first block
  172. sstart = _blocks[0] * self.blocksize
  173. # Compute the end of the last block. Last block may not be full size.
  174. send = min(_blocks[-1] * self.blocksize + self.blocksize, self.size)
  175. # Fetch bytes (could be multiple consecutive blocks)
  176. self.total_requested_bytes += send - sstart
  177. logger.debug(
  178. f"MMap get blocks {_blocks[0]}-{_blocks[-1]} ({sstart}-{send})"
  179. )
  180. ranges.append((sstart, send))
  181. # Update set of cached blocks
  182. self.blocks.update(_blocks)
  183. # Update cache statistics with number of blocks we had to cache
  184. self.miss_count += len(_blocks)
  185. if not ranges:
  186. return self.cache[start:end]
  187. if self.multi_fetcher:
  188. logger.debug(f"MMap get blocks {ranges}")
  189. for idx, r in enumerate(self.multi_fetcher(ranges)):
  190. (sstart, send) = ranges[idx]
  191. logger.debug(f"MMap copy block ({sstart}-{send}")
  192. self.cache[sstart:send] = r
  193. else:
  194. for sstart, send in ranges:
  195. logger.debug(f"MMap get block ({sstart}-{send}")
  196. self.cache[sstart:send] = self.fetcher(sstart, send)
  197. return self.cache[start:end]
  198. def __getstate__(self) -> dict[str, Any]:
  199. state = self.__dict__.copy()
  200. # Remove the unpicklable entries.
  201. del state["cache"]
  202. return state
  203. def __setstate__(self, state: dict[str, Any]) -> None:
  204. # Restore instance attributes
  205. self.__dict__.update(state)
  206. self.cache = self._makefile()
  207. class ReadAheadCache(BaseCache):
  208. """Cache which reads only when we get beyond a block of data
  209. This is a much simpler version of BytesCache, and does not attempt to
  210. fill holes in the cache or keep fragments alive. It is best suited to
  211. many small reads in a sequential order (e.g., reading lines from a file).
  212. """
  213. name = "readahead"
  214. def __init__(self, blocksize: int, fetcher: Fetcher, size: int) -> None:
  215. super().__init__(blocksize, fetcher, size)
  216. self.cache = b""
  217. self.start = 0
  218. self.end = 0
  219. def _fetch(self, start: int | None, end: int | None) -> bytes:
  220. if start is None:
  221. start = 0
  222. if end is None or end > self.size:
  223. end = self.size
  224. if start >= self.size or start >= end:
  225. return b""
  226. l = end - start
  227. if start >= self.start and end <= self.end:
  228. # cache hit
  229. self.hit_count += 1
  230. return self.cache[start - self.start : end - self.start]
  231. elif self.start <= start < self.end:
  232. # partial hit
  233. self.miss_count += 1
  234. part = self.cache[start - self.start :]
  235. l -= len(part)
  236. start = self.end
  237. else:
  238. # miss
  239. self.miss_count += 1
  240. part = b""
  241. end = min(self.size, end + self.blocksize)
  242. self.total_requested_bytes += end - start
  243. self.cache = self.fetcher(start, end) # new block replaces old
  244. self.start = start
  245. self.end = self.start + len(self.cache)
  246. return part + self.cache[:l]
  247. class FirstChunkCache(BaseCache):
  248. """Caches the first block of a file only
  249. This may be useful for file types where the metadata is stored in the header,
  250. but is randomly accessed.
  251. """
  252. name = "first"
  253. def __init__(self, blocksize: int, fetcher: Fetcher, size: int) -> None:
  254. if blocksize > size:
  255. # this will buffer the whole thing
  256. blocksize = size
  257. super().__init__(blocksize, fetcher, size)
  258. self.cache: bytes | None = None
  259. def _fetch(self, start: int | None, end: int | None) -> bytes:
  260. start = start or 0
  261. if start > self.size:
  262. logger.debug("FirstChunkCache: requested start > file size")
  263. return b""
  264. end = min(end, self.size)
  265. if start < self.blocksize:
  266. if self.cache is None:
  267. self.miss_count += 1
  268. if end > self.blocksize:
  269. self.total_requested_bytes += end
  270. data = self.fetcher(0, end)
  271. self.cache = data[: self.blocksize]
  272. return data[start:]
  273. self.cache = self.fetcher(0, self.blocksize)
  274. self.total_requested_bytes += self.blocksize
  275. part = self.cache[start:end]
  276. if end > self.blocksize:
  277. self.total_requested_bytes += end - self.blocksize
  278. part += self.fetcher(self.blocksize, end)
  279. self.hit_count += 1
  280. return part
  281. else:
  282. self.miss_count += 1
  283. self.total_requested_bytes += end - start
  284. return self.fetcher(start, end)
  285. class BlockCache(BaseCache):
  286. """
  287. Cache holding memory as a set of blocks.
  288. Requests are only ever made ``blocksize`` at a time, and are
  289. stored in an LRU cache. The least recently accessed block is
  290. discarded when more than ``maxblocks`` are stored.
  291. Parameters
  292. ----------
  293. blocksize : int
  294. The number of bytes to store in each block.
  295. Requests are only ever made for ``blocksize``, so this
  296. should balance the overhead of making a request against
  297. the granularity of the blocks.
  298. fetcher : Callable
  299. size : int
  300. The total size of the file being cached.
  301. maxblocks : int
  302. The maximum number of blocks to cache for. The maximum memory
  303. use for this cache is then ``blocksize * maxblocks``.
  304. """
  305. name = "blockcache"
  306. def __init__(
  307. self, blocksize: int, fetcher: Fetcher, size: int, maxblocks: int = 32
  308. ) -> None:
  309. super().__init__(blocksize, fetcher, size)
  310. self.nblocks = math.ceil(size / blocksize)
  311. self.maxblocks = maxblocks
  312. self._fetch_block_cached = functools.lru_cache(maxblocks)(self._fetch_block)
  313. def cache_info(self):
  314. """
  315. The statistics on the block cache.
  316. Returns
  317. -------
  318. NamedTuple
  319. Returned directly from the LRU Cache used internally.
  320. """
  321. return self._fetch_block_cached.cache_info()
  322. def __getstate__(self) -> dict[str, Any]:
  323. state = self.__dict__
  324. del state["_fetch_block_cached"]
  325. return state
  326. def __setstate__(self, state: dict[str, Any]) -> None:
  327. self.__dict__.update(state)
  328. self._fetch_block_cached = functools.lru_cache(state["maxblocks"])(
  329. self._fetch_block
  330. )
  331. def _fetch(self, start: int | None, end: int | None) -> bytes:
  332. if start is None:
  333. start = 0
  334. if end is None:
  335. end = self.size
  336. if start >= self.size or start >= end:
  337. return b""
  338. # byte position -> block numbers
  339. start_block_number = start // self.blocksize
  340. end_block_number = end // self.blocksize
  341. # these are cached, so safe to do multiple calls for the same start and end.
  342. for block_number in range(start_block_number, end_block_number + 1):
  343. self._fetch_block_cached(block_number)
  344. return self._read_cache(
  345. start,
  346. end,
  347. start_block_number=start_block_number,
  348. end_block_number=end_block_number,
  349. )
  350. def _fetch_block(self, block_number: int) -> bytes:
  351. """
  352. Fetch the block of data for `block_number`.
  353. """
  354. if block_number > self.nblocks:
  355. raise ValueError(
  356. f"'block_number={block_number}' is greater than "
  357. f"the number of blocks ({self.nblocks})"
  358. )
  359. start = block_number * self.blocksize
  360. end = start + self.blocksize
  361. self.total_requested_bytes += end - start
  362. self.miss_count += 1
  363. logger.info("BlockCache fetching block %d", block_number)
  364. block_contents = super()._fetch(start, end)
  365. return block_contents
  366. def _read_cache(
  367. self, start: int, end: int, start_block_number: int, end_block_number: int
  368. ) -> bytes:
  369. """
  370. Read from our block cache.
  371. Parameters
  372. ----------
  373. start, end : int
  374. The start and end byte positions.
  375. start_block_number, end_block_number : int
  376. The start and end block numbers.
  377. """
  378. start_pos = start % self.blocksize
  379. end_pos = end % self.blocksize
  380. self.hit_count += 1
  381. if start_block_number == end_block_number:
  382. block: bytes = self._fetch_block_cached(start_block_number)
  383. return block[start_pos:end_pos]
  384. else:
  385. # read from the initial
  386. out = [self._fetch_block_cached(start_block_number)[start_pos:]]
  387. # intermediate blocks
  388. # Note: it'd be nice to combine these into one big request. However
  389. # that doesn't play nicely with our LRU cache.
  390. out.extend(
  391. map(
  392. self._fetch_block_cached,
  393. range(start_block_number + 1, end_block_number),
  394. )
  395. )
  396. # final block
  397. out.append(self._fetch_block_cached(end_block_number)[:end_pos])
  398. return b"".join(out)
  399. class BytesCache(BaseCache):
  400. """Cache which holds data in a in-memory bytes object
  401. Implements read-ahead by the block size, for semi-random reads progressing
  402. through the file.
  403. Parameters
  404. ----------
  405. trim: bool
  406. As we read more data, whether to discard the start of the buffer when
  407. we are more than a blocksize ahead of it.
  408. """
  409. name: ClassVar[str] = "bytes"
  410. def __init__(
  411. self, blocksize: int, fetcher: Fetcher, size: int, trim: bool = True
  412. ) -> None:
  413. super().__init__(blocksize, fetcher, size)
  414. self.cache = b""
  415. self.start: int | None = None
  416. self.end: int | None = None
  417. self.trim = trim
  418. def _fetch(self, start: int | None, end: int | None) -> bytes:
  419. # TODO: only set start/end after fetch, in case it fails?
  420. # is this where retry logic might go?
  421. if start is None:
  422. start = 0
  423. if end is None:
  424. end = self.size
  425. if start >= self.size or start >= end:
  426. return b""
  427. if (
  428. self.start is not None
  429. and start >= self.start
  430. and self.end is not None
  431. and end < self.end
  432. ):
  433. # cache hit: we have all the required data
  434. offset = start - self.start
  435. self.hit_count += 1
  436. return self.cache[offset : offset + end - start]
  437. if self.blocksize:
  438. bend = min(self.size, end + self.blocksize)
  439. else:
  440. bend = end
  441. if bend == start or start > self.size:
  442. return b""
  443. if (self.start is None or start < self.start) and (
  444. self.end is None or end > self.end
  445. ):
  446. # First read, or extending both before and after
  447. self.total_requested_bytes += bend - start
  448. self.miss_count += 1
  449. self.cache = self.fetcher(start, bend)
  450. self.start = start
  451. else:
  452. assert self.start is not None
  453. assert self.end is not None
  454. self.miss_count += 1
  455. if start < self.start:
  456. if self.end is None or self.end - end > self.blocksize:
  457. self.total_requested_bytes += bend - start
  458. self.cache = self.fetcher(start, bend)
  459. self.start = start
  460. else:
  461. self.total_requested_bytes += self.start - start
  462. new = self.fetcher(start, self.start)
  463. self.start = start
  464. self.cache = new + self.cache
  465. elif self.end is not None and bend > self.end:
  466. if self.end > self.size:
  467. pass
  468. elif end - self.end > self.blocksize:
  469. self.total_requested_bytes += bend - start
  470. self.cache = self.fetcher(start, bend)
  471. self.start = start
  472. else:
  473. self.total_requested_bytes += bend - self.end
  474. new = self.fetcher(self.end, bend)
  475. self.cache = self.cache + new
  476. self.end = self.start + len(self.cache)
  477. offset = start - self.start
  478. out = self.cache[offset : offset + end - start]
  479. if self.trim:
  480. num = (self.end - self.start) // (self.blocksize + 1)
  481. if num > 1:
  482. self.start += self.blocksize * num
  483. self.cache = self.cache[self.blocksize * num :]
  484. return out
  485. def __len__(self) -> int:
  486. return len(self.cache)
  487. class AllBytes(BaseCache):
  488. """Cache entire contents of the file"""
  489. name: ClassVar[str] = "all"
  490. def __init__(
  491. self,
  492. blocksize: int | None = None,
  493. fetcher: Fetcher | None = None,
  494. size: int | None = None,
  495. data: bytes | None = None,
  496. ) -> None:
  497. super().__init__(blocksize, fetcher, size) # type: ignore[arg-type]
  498. if data is None:
  499. self.miss_count += 1
  500. self.total_requested_bytes += self.size
  501. data = self.fetcher(0, self.size)
  502. self.data = data
  503. def _fetch(self, start: int | None, stop: int | None) -> bytes:
  504. self.hit_count += 1
  505. return self.data[start:stop]
  506. class KnownPartsOfAFile(BaseCache):
  507. """
  508. Cache holding known file parts.
  509. Parameters
  510. ----------
  511. blocksize: int
  512. How far to read ahead in numbers of bytes
  513. fetcher: func
  514. Function of the form f(start, end) which gets bytes from remote as
  515. specified
  516. size: int
  517. How big this file is
  518. data: dict
  519. A dictionary mapping explicit `(start, stop)` file-offset tuples
  520. with known bytes.
  521. strict: bool, default True
  522. Whether to fetch reads that go beyond a known byte-range boundary.
  523. If `False`, any read that ends outside a known part will be zero
  524. padded. Note that zero padding will not be used for reads that
  525. begin outside a known byte-range.
  526. """
  527. name: ClassVar[str] = "parts"
  528. def __init__(
  529. self,
  530. blocksize: int,
  531. fetcher: Fetcher,
  532. size: int,
  533. data: Optional[dict[tuple[int, int], bytes]] = None,
  534. strict: bool = True,
  535. **_: Any,
  536. ):
  537. super().__init__(blocksize, fetcher, size)
  538. self.strict = strict
  539. # simple consolidation of contiguous blocks
  540. if data:
  541. old_offsets = sorted(data.keys())
  542. offsets = [old_offsets[0]]
  543. blocks = [data.pop(old_offsets[0])]
  544. for start, stop in old_offsets[1:]:
  545. start0, stop0 = offsets[-1]
  546. if start == stop0:
  547. offsets[-1] = (start0, stop)
  548. blocks[-1] += data.pop((start, stop))
  549. else:
  550. offsets.append((start, stop))
  551. blocks.append(data.pop((start, stop)))
  552. self.data = dict(zip(offsets, blocks))
  553. else:
  554. self.data = {}
  555. def _fetch(self, start: int | None, stop: int | None) -> bytes:
  556. if start is None:
  557. start = 0
  558. if stop is None:
  559. stop = self.size
  560. out = b""
  561. for (loc0, loc1), data in self.data.items():
  562. # If self.strict=False, use zero-padded data
  563. # for reads beyond the end of a "known" buffer
  564. if loc0 <= start < loc1:
  565. off = start - loc0
  566. out = data[off : off + stop - start]
  567. if not self.strict or loc0 <= stop <= loc1:
  568. # The request is within a known range, or
  569. # it begins within a known range, and we
  570. # are allowed to pad reads beyond the
  571. # buffer with zero
  572. out += b"\x00" * (stop - start - len(out))
  573. self.hit_count += 1
  574. return out
  575. else:
  576. # The request ends outside a known range,
  577. # and we are being "strict" about reads
  578. # beyond the buffer
  579. start = loc1
  580. break
  581. # We only get here if there is a request outside the
  582. # known parts of the file. In an ideal world, this
  583. # should never happen
  584. if self.fetcher is None:
  585. # We cannot fetch the data, so raise an error
  586. raise ValueError(f"Read is outside the known file parts: {(start, stop)}. ")
  587. # We can fetch the data, but should warn the user
  588. # that this may be slow
  589. warnings.warn(
  590. f"Read is outside the known file parts: {(start, stop)}. "
  591. f"IO/caching performance may be poor!"
  592. )
  593. logger.debug(f"KnownPartsOfAFile cache fetching {start}-{stop}")
  594. self.total_requested_bytes += stop - start
  595. self.miss_count += 1
  596. return out + super()._fetch(start, stop)
  597. class UpdatableLRU(Generic[P, T]):
  598. """
  599. Custom implementation of LRU cache that allows updating keys
  600. Used by BackgroudBlockCache
  601. """
  602. class CacheInfo(NamedTuple):
  603. hits: int
  604. misses: int
  605. maxsize: int
  606. currsize: int
  607. def __init__(self, func: Callable[P, T], max_size: int = 128) -> None:
  608. self._cache: OrderedDict[Any, T] = collections.OrderedDict()
  609. self._func = func
  610. self._max_size = max_size
  611. self._hits = 0
  612. self._misses = 0
  613. self._lock = threading.Lock()
  614. def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
  615. if kwargs:
  616. raise TypeError(f"Got unexpected keyword argument {kwargs.keys()}")
  617. with self._lock:
  618. if args in self._cache:
  619. self._cache.move_to_end(args)
  620. self._hits += 1
  621. return self._cache[args]
  622. result = self._func(*args, **kwargs)
  623. with self._lock:
  624. self._cache[args] = result
  625. self._misses += 1
  626. if len(self._cache) > self._max_size:
  627. self._cache.popitem(last=False)
  628. return result
  629. def is_key_cached(self, *args: Any) -> bool:
  630. with self._lock:
  631. return args in self._cache
  632. def add_key(self, result: T, *args: Any) -> None:
  633. with self._lock:
  634. self._cache[args] = result
  635. if len(self._cache) > self._max_size:
  636. self._cache.popitem(last=False)
  637. def cache_info(self) -> UpdatableLRU.CacheInfo:
  638. with self._lock:
  639. return self.CacheInfo(
  640. maxsize=self._max_size,
  641. currsize=len(self._cache),
  642. hits=self._hits,
  643. misses=self._misses,
  644. )
  645. class BackgroundBlockCache(BaseCache):
  646. """
  647. Cache holding memory as a set of blocks with pre-loading of
  648. the next block in the background.
  649. Requests are only ever made ``blocksize`` at a time, and are
  650. stored in an LRU cache. The least recently accessed block is
  651. discarded when more than ``maxblocks`` are stored. If the
  652. next block is not in cache, it is loaded in a separate thread
  653. in non-blocking way.
  654. Parameters
  655. ----------
  656. blocksize : int
  657. The number of bytes to store in each block.
  658. Requests are only ever made for ``blocksize``, so this
  659. should balance the overhead of making a request against
  660. the granularity of the blocks.
  661. fetcher : Callable
  662. size : int
  663. The total size of the file being cached.
  664. maxblocks : int
  665. The maximum number of blocks to cache for. The maximum memory
  666. use for this cache is then ``blocksize * maxblocks``.
  667. """
  668. name: ClassVar[str] = "background"
  669. def __init__(
  670. self, blocksize: int, fetcher: Fetcher, size: int, maxblocks: int = 32
  671. ) -> None:
  672. super().__init__(blocksize, fetcher, size)
  673. self.nblocks = math.ceil(size / blocksize)
  674. self.maxblocks = maxblocks
  675. self._fetch_block_cached = UpdatableLRU(self._fetch_block, maxblocks)
  676. self._thread_executor = ThreadPoolExecutor(max_workers=1)
  677. self._fetch_future_block_number: int | None = None
  678. self._fetch_future: Future[bytes] | None = None
  679. self._fetch_future_lock = threading.Lock()
  680. def cache_info(self) -> UpdatableLRU.CacheInfo:
  681. """
  682. The statistics on the block cache.
  683. Returns
  684. -------
  685. NamedTuple
  686. Returned directly from the LRU Cache used internally.
  687. """
  688. return self._fetch_block_cached.cache_info()
  689. def __getstate__(self) -> dict[str, Any]:
  690. state = self.__dict__
  691. del state["_fetch_block_cached"]
  692. del state["_thread_executor"]
  693. del state["_fetch_future_block_number"]
  694. del state["_fetch_future"]
  695. del state["_fetch_future_lock"]
  696. return state
  697. def __setstate__(self, state) -> None:
  698. self.__dict__.update(state)
  699. self._fetch_block_cached = UpdatableLRU(self._fetch_block, state["maxblocks"])
  700. self._thread_executor = ThreadPoolExecutor(max_workers=1)
  701. self._fetch_future_block_number = None
  702. self._fetch_future = None
  703. self._fetch_future_lock = threading.Lock()
  704. def _fetch(self, start: int | None, end: int | None) -> bytes:
  705. if start is None:
  706. start = 0
  707. if end is None:
  708. end = self.size
  709. if start >= self.size or start >= end:
  710. return b""
  711. # byte position -> block numbers
  712. start_block_number = start // self.blocksize
  713. end_block_number = end // self.blocksize
  714. fetch_future_block_number = None
  715. fetch_future = None
  716. with self._fetch_future_lock:
  717. # Background thread is running. Check we we can or must join it.
  718. if self._fetch_future is not None:
  719. assert self._fetch_future_block_number is not None
  720. if self._fetch_future.done():
  721. logger.info("BlockCache joined background fetch without waiting.")
  722. self._fetch_block_cached.add_key(
  723. self._fetch_future.result(), self._fetch_future_block_number
  724. )
  725. # Cleanup the fetch variables. Done with fetching the block.
  726. self._fetch_future_block_number = None
  727. self._fetch_future = None
  728. else:
  729. # Must join if we need the block for the current fetch
  730. must_join = bool(
  731. start_block_number
  732. <= self._fetch_future_block_number
  733. <= end_block_number
  734. )
  735. if must_join:
  736. # Copy to the local variables to release lock
  737. # before waiting for result
  738. fetch_future_block_number = self._fetch_future_block_number
  739. fetch_future = self._fetch_future
  740. # Cleanup the fetch variables. Have a local copy.
  741. self._fetch_future_block_number = None
  742. self._fetch_future = None
  743. # Need to wait for the future for the current read
  744. if fetch_future is not None:
  745. logger.info("BlockCache waiting for background fetch.")
  746. # Wait until result and put it in cache
  747. self._fetch_block_cached.add_key(
  748. fetch_future.result(), fetch_future_block_number
  749. )
  750. # these are cached, so safe to do multiple calls for the same start and end.
  751. for block_number in range(start_block_number, end_block_number + 1):
  752. self._fetch_block_cached(block_number)
  753. # fetch next block in the background if nothing is running in the background,
  754. # the block is within file and it is not already cached
  755. end_block_plus_1 = end_block_number + 1
  756. with self._fetch_future_lock:
  757. if (
  758. self._fetch_future is None
  759. and end_block_plus_1 <= self.nblocks
  760. and not self._fetch_block_cached.is_key_cached(end_block_plus_1)
  761. ):
  762. self._fetch_future_block_number = end_block_plus_1
  763. self._fetch_future = self._thread_executor.submit(
  764. self._fetch_block, end_block_plus_1, "async"
  765. )
  766. return self._read_cache(
  767. start,
  768. end,
  769. start_block_number=start_block_number,
  770. end_block_number=end_block_number,
  771. )
  772. def _fetch_block(self, block_number: int, log_info: str = "sync") -> bytes:
  773. """
  774. Fetch the block of data for `block_number`.
  775. """
  776. if block_number > self.nblocks:
  777. raise ValueError(
  778. f"'block_number={block_number}' is greater than "
  779. f"the number of blocks ({self.nblocks})"
  780. )
  781. start = block_number * self.blocksize
  782. end = start + self.blocksize
  783. logger.info("BlockCache fetching block (%s) %d", log_info, block_number)
  784. self.total_requested_bytes += end - start
  785. self.miss_count += 1
  786. block_contents = super()._fetch(start, end)
  787. return block_contents
  788. def _read_cache(
  789. self, start: int, end: int, start_block_number: int, end_block_number: int
  790. ) -> bytes:
  791. """
  792. Read from our block cache.
  793. Parameters
  794. ----------
  795. start, end : int
  796. The start and end byte positions.
  797. start_block_number, end_block_number : int
  798. The start and end block numbers.
  799. """
  800. start_pos = start % self.blocksize
  801. end_pos = end % self.blocksize
  802. # kind of pointless to count this as a hit, but it is
  803. self.hit_count += 1
  804. if start_block_number == end_block_number:
  805. block = self._fetch_block_cached(start_block_number)
  806. return block[start_pos:end_pos]
  807. else:
  808. # read from the initial
  809. out = [self._fetch_block_cached(start_block_number)[start_pos:]]
  810. # intermediate blocks
  811. # Note: it'd be nice to combine these into one big request. However
  812. # that doesn't play nicely with our LRU cache.
  813. out.extend(
  814. map(
  815. self._fetch_block_cached,
  816. range(start_block_number + 1, end_block_number),
  817. )
  818. )
  819. # final block
  820. out.append(self._fetch_block_cached(end_block_number)[:end_pos])
  821. return b"".join(out)
  822. caches: dict[str | None, type[BaseCache]] = {
  823. # one custom case
  824. None: BaseCache,
  825. }
  826. def register_cache(cls: type[BaseCache], clobber: bool = False) -> None:
  827. """'Register' cache implementation.
  828. Parameters
  829. ----------
  830. clobber: bool, optional
  831. If set to True (default is False) - allow to overwrite existing
  832. entry.
  833. Raises
  834. ------
  835. ValueError
  836. """
  837. name = cls.name
  838. if not clobber and name in caches:
  839. raise ValueError(f"Cache with name {name!r} is already known: {caches[name]}")
  840. caches[name] = cls
  841. for c in (
  842. BaseCache,
  843. MMapCache,
  844. BytesCache,
  845. ReadAheadCache,
  846. BlockCache,
  847. FirstChunkCache,
  848. AllBytes,
  849. KnownPartsOfAFile,
  850. BackgroundBlockCache,
  851. ):
  852. register_cache(c)