Skip to content

Commit

Permalink
Call cat_ranges in blockcache for async filesystems
Browse files Browse the repository at this point in the history
  • Loading branch information
monken committed Sep 1, 2023
1 parent 2fbe8de commit 46ea642
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 6 deletions.
45 changes: 41 additions & 4 deletions fsspec/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,37 @@ class MMapCache(BaseCache):
Ensure there is enough disc space in the temporary location.
This cache method might only work on posix
Parameters
----------
blocksize: int
How far to read ahead in numbers of bytes
fetcher: func
Function of the form f(start, end) which gets bytes from remote as
specified
size: int
How big this file is
location: str
Where to create the temporary file. If None, a temporary file is
created using tempfile.TemporaryFile().
blocks: set
Set of block numbers that have already been fetched. If None, an empty
set is created.
multi_fetcher: func
Function of the form f([(start, end)]) which gets bytes from remote
as specified. This function is used to fetch multiple blocks at once.
If not specified, the fetcher function is used instead.
"""

name = "mmap"

def __init__(self, blocksize, fetcher, size, location=None, blocks=None):
def __init__(
self, blocksize, fetcher, size, location=None, blocks=None, multi_fetcher=None
):
super().__init__(blocksize, fetcher, size)
self.blocks = set() if blocks is None else blocks
self.location = location
self.multi_fetcher = multi_fetcher
self.cache = self._makefile()

def _makefile(self):
Expand Down Expand Up @@ -93,16 +116,30 @@ def _fetch(self, start, end):
start_block = start // self.blocksize
end_block = end // self.blocksize
need = [i for i in range(start_block, end_block + 1) if i not in self.blocks]
ranges = []
while need:
# TODO: not a for loop so we can consolidate blocks later to
# make fewer fetch calls; this could be parallel
# make fewer fetch calls
i = need.pop(0)
sstart = i * self.blocksize
send = min(sstart + self.blocksize, self.size)
logger.debug(f"MMap get block #{i} ({sstart}-{send}")
self.cache[sstart:send] = self.fetcher(sstart, send)
ranges.append((sstart, send))
self.blocks.add(i)

if not ranges:
return self.cache[start:end]

if self.multi_fetcher:
logger.debug(f"MMap get blocks {ranges}")
for idx, r in enumerate(self.multi_fetcher(ranges)):
(sstart, send) = ranges[idx]
logger.debug(f"MMap copy block ({sstart}-{send}")
self.cache[sstart:send] = r
else:
for (sstart, send) in ranges:
logger.debug(f"MMap get block ({sstart}-{send}")
self.cache[sstart:send] = self.fetcher(sstart, send)

return self.cache[start:end]

def __getstate__(self):
Expand Down
16 changes: 15 additions & 1 deletion fsspec/implementations/cached.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,21 @@ def _open(
)
else:
detail["blocksize"] = f.blocksize
f.cache = MMapCache(f.blocksize, f._fetch_range, f.size, fn, blocks)

def _fetch_ranges(ranges):
return self.fs.cat_ranges(
[path] * len(ranges),
[r[0] for r in ranges],
[r[1] for r in ranges],
**kwargs,
)

multi_fetcher = (
None if not self.fs.async_impl or self.compression else _fetch_ranges
)
f.cache = MMapCache(
f.blocksize, f._fetch_range, f.size, fn, blocks, multi_fetcher=multi_fetcher
)
close = f.close
f.close = lambda: self.close_and_update(f, close)
self.save_cache()
Expand Down
35 changes: 34 additions & 1 deletion fsspec/tests/test_caches.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@

import pytest

from fsspec.caching import BlockCache, FirstChunkCache, caches, register_cache
from fsspec.caching import (
BlockCache,
FirstChunkCache,
MMapCache,
caches,
register_cache,
)


def test_cache_getitem(Cache_imp):
Expand Down Expand Up @@ -44,6 +50,11 @@ def letters_fetcher(start, end):
return string.ascii_letters[start:end].encode()


def multi_letters_fetcher(ranges):
print(ranges)
return [string.ascii_letters[start:end].encode() for start, end in ranges]


not_parts_caches = {k: v for k, v in caches.items() if k != "parts"}


Expand Down Expand Up @@ -82,6 +93,28 @@ def test_first_cache():
assert c._fetch(1, 4) == letters_fetcher(1, 4)


def test_mmap_cache(mocker):
fetcher = mocker.Mock(wraps=letters_fetcher)

c = MMapCache(5, fetcher, 52)
assert c._fetch(12, 15) == letters_fetcher(12, 15)
assert fetcher.call_count == 2
assert c._fetch(3, 10) == letters_fetcher(3, 10)
assert fetcher.call_count == 4
assert c._fetch(1, 4) == letters_fetcher(1, 4)
assert fetcher.call_count == 4

multi_fetcher = mocker.Mock(wraps=multi_letters_fetcher)
m = MMapCache(5, fetcher, size=52, multi_fetcher=multi_fetcher)
assert m._fetch(12, 15) == letters_fetcher(12, 15)
assert multi_fetcher.call_count == 1
assert m._fetch(3, 10) == letters_fetcher(3, 10)
assert multi_fetcher.call_count == 2
assert m._fetch(1, 4) == letters_fetcher(1, 4)
assert multi_fetcher.call_count == 2
assert fetcher.call_count == 4


@pytest.mark.parametrize(
"size_requests",
[[(0, 30), (0, 35), (51, 52)], [(0, 1), (1, 11), (1, 52)], [(0, 52), (11, 15)]],
Expand Down

0 comments on commit 46ea642

Please sign in to comment.