Skip to content

Commit

Permalink
referencefs: add streaming _open support
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry committed Jan 21, 2025
1 parent 216885a commit 666d9b5
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 4 deletions.
83 changes: 79 additions & 4 deletions fsspec/implementations/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import TYPE_CHECKING, Literal

import fsspec.core
from fsspec.spec import AbstractBufferedFile

try:
import ujson as json
Expand Down Expand Up @@ -595,8 +596,7 @@ class ReferenceFileSystem(AsyncFileSystem):
async, and must allow start and end args in _cat_file. Later versions
may allow multiple arbitrary URLs for the targets.
This FileSystem is read-only. It is designed to be used with async
targets (for now). This FileSystem only allows whole-file access, no
``open``. We do not get original file details from the target FS.
targets (for now). We do not get original file details from the target FS.
Configuration is by passing a dict of references at init, or a URL to
a JSON file containing the same; this dict
can also contain concrete data for some set of paths.
Expand Down Expand Up @@ -1100,8 +1100,28 @@ def _dircache_from_items(self):
self.dircache[par].append({"name": path, "type": "file", "size": size})

def _open(self, path, mode="rb", block_size=None, cache_options=None, **kwargs):
data = self.cat_file(path) # load whole chunk into memory
return io.BytesIO(data)
part_or_url, start0, end0 = self._cat_common(path)
if isinstance(part_or_url, bytes):
return io.BytesIO(part_or_url[start0:end0])

protocol, _ = split_protocol(part_or_url)
if start0 is None and end0 is None:
return self.fss[protocol]._open(
part_or_url,
mode,
block_size=block_size,
cache_options=cache_options,
**kwargs,
)

return ReferenceFile(
self,
path,
mode,
block_size=block_size,
cache_options=cache_options,
**kwargs,
)

def ls(self, path, detail=True, **kwargs):
path = self._strip_protocol(path)
Expand Down Expand Up @@ -1214,3 +1234,58 @@ def save_json(self, url, **storage_options):
out[k] = v
with fsspec.open(url, "wb", **storage_options) as f:
f.write(json.dumps({"version": 1, "refs": out}).encode())


class ReferenceFile(AbstractBufferedFile):
def __init__(
self,
fs,
path,
mode="rb",
block_size="default",
autocommit=True,
cache_type="readahead",
cache_options=None,
size=None,
**kwargs,
):
super().__init__(
fs,
path,
mode=mode,
block_size=block_size,
autocommit=autocommit,
size=size,
cache_type=cache_type,
cache_options=cache_options,
**kwargs,
)
part_or_url, self.start, self.end = self.fs._cat_common(self.path)
protocol, _ = split_protocol(part_or_url)
self.src_fs = self.fs.fss[protocol]
self.src_path = part_or_url
self._f = None

@property
def f(self):
if self._f is None or self._f.closed:
self._f = self.src_fs._open(
self.src_path,
mode=self.mode,
block_size=self.blocksize,
autocommit=self.autocommit,
cache_type="none",
**self.kwargs,
)
return self._f

def close(self):
if self._f is not None:
self._f.close()
return super().close()

def _fetch_range(self, start, end):
start = start + self.start
end = min(end + self.start, self.end)
self.f.seek(start)
return self.f.read(end - start)
46 changes: 46 additions & 0 deletions fsspec/implementations/tests/test_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,52 @@ def test_simple(server):
assert f.read(2) == "he"


def test_open(m):
from fsspec.implementations.reference import json as json_impl

m.pipe("/data/0", data)
refs = {
"a": b"data",
"b": ["memory://data/0"],
"c": ("memory://data/0", 0, 5),
"d": ("memory://data/0", 1, 5),
"e": b"base64:aGVsbG8=",
"f": {"key": "value"},
}
fs = fsspec.filesystem("reference", fo=refs, fs=m)

with fs.open("a", "rb") as f:
assert f.read() == b"data"

with fs.open("b", "rb") as f:
assert f.read() == data

with fs.open("c", "rb") as f:
assert f.read() == data[:5]
assert not f.read()

with fs.open("d", "rb") as f:
assert f.read() == data[1:6]
assert not f.read()

with fs.open("e", "rb") as f:
assert f.read() == b"hello"

with fs.open("f", "rb") as f:
assert f.read() == json_impl.dumps(refs["f"]).encode("utf-8")

# check partial reads
with fs.open("c", "rb") as f:
assert f.read(2) == data[:2]
f.seek(2, os.SEEK_CUR)
assert f.read() == data[4:5]

with fs.open("d", "rb") as f:
assert f.read(2) == data[1:3]
f.seek(1, os.SEEK_CUR)
assert f.read() == data[4:6]


def test_simple_ver1(server):
# The dictionary in refs may be dumped with a different separator
# depending on whether json or ujson is imported
Expand Down

0 comments on commit 666d9b5

Please sign in to comment.