Skip to content

Commit

Permalink
Extended zarr compression
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 committed Mar 29, 2024
1 parent 22808ca commit 40429b4
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 23 deletions.
17 changes: 16 additions & 1 deletion src/spikeinterface/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,9 +967,24 @@ def save_to_zarr(
For cloud storage locations, this should not be None (in case of default values, use an empty dict)
channel_chunk_size: int or None, default: None
Channels per chunk (only for BaseRecording)
compressor: numcodecs.Codec or None, default: None
Global compressor. If None, Blosc-zstd level 5 is used.
filters: list[numcodecs.Codec] or None, default: None
Global filters for zarr (global)
compressor_by_dataset: dict or None, default: None
Optional compressor per dataset.:
- traces
- times
If None, the global compressor is used
filters_by_dataset: dict or None, default: None
Optional filters per dataset:
- traces
- times
If None, the global filters are used
verbose: bool, default: True
If True, the output is verbose
**save_kwargs: Keyword arguments for saving to zarr
auto_cast_uint: bool, default: True
If True, unsigned integers are cast to signed integers to avoid issues with zarr (only for BaseRecording)
Returns
-------
Expand Down
65 changes: 50 additions & 15 deletions src/spikeinterface/core/tests/test_zarrextractors.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,72 @@
import pytest
from pathlib import Path

import shutil

import zarr

from spikeinterface.core import (
ZarrRecordingExtractor,
ZarrSortingExtractor,
generate_recording,
generate_sorting,
load_extractor,
)
from spikeinterface.core.zarrextractors import add_sorting_to_zarr_group
from spikeinterface.core.zarrextractors import add_sorting_to_zarr_group, get_default_zarr_compressor


def test_zarr_compression_options(tmp_path):
from numcodecs import Blosc, Delta, FixedScaleOffset

recording = generate_recording(durations=[2])
recording.set_times(recording.get_times() + 100)

# store in root standard normal way
# default compressor
defaut_compressor = get_default_zarr_compressor()

# other compressor
other_compressor1 = Blosc(cname="zlib", clevel=3, shuffle=Blosc.NOSHUFFLE)
other_compressor2 = Blosc(cname="blosclz", clevel=8, shuffle=Blosc.AUTOSHUFFLE)

# timestamps compressors / filters
default_filters = None
other_filters1 = [FixedScaleOffset(scale=5, offset=2, dtype=recording.get_dtype())]
other_filters2 = [Delta(dtype="float64")]

# default
ZarrRecordingExtractor.write_recording(recording, tmp_path / "rec_default.zarr")
rec_default = ZarrRecordingExtractor(tmp_path / "rec_default.zarr")
assert rec_default._root["traces_seg0"].compressor == defaut_compressor
assert rec_default._root["traces_seg0"].filters == default_filters
assert rec_default._root["times_seg0"].compressor == defaut_compressor
assert rec_default._root["times_seg0"].filters == default_filters

if hasattr(pytest, "global_test_folder"):
cache_folder = pytest.global_test_folder / "core"
else:
cache_folder = Path("cache_folder") / "core"
# now with other compressor
ZarrRecordingExtractor.write_recording(
recording,
tmp_path / "rec_other.zarr",
compressor=defaut_compressor,
filters=default_filters,
compressor_by_dataset={"traces": other_compressor1, "times": other_compressor2},
filters_by_dataset={"traces": other_filters1, "times": other_filters2},
)
rec_other = ZarrRecordingExtractor(tmp_path / "rec_other.zarr")
assert rec_other._root["traces_seg0"].compressor == other_compressor1
assert rec_other._root["traces_seg0"].filters == other_filters1
assert rec_other._root["times_seg0"].compressor == other_compressor2
assert rec_other._root["times_seg0"].filters == other_filters2


def test_ZarrSortingExtractor():
def test_ZarrSortingExtractor(tmp_path):
np_sorting = generate_sorting()

# store in root standard normal way
folder = cache_folder / "zarr_sorting"
if folder.is_dir():
shutil.rmtree(folder)
folder = tmp_path / "zarr_sorting"
ZarrSortingExtractor.write_sorting(np_sorting, folder)
sorting = ZarrSortingExtractor(folder)
sorting = load_extractor(sorting.to_dict())

# store the sorting in a sub group (for instance SortingResult)
folder = cache_folder / "zarr_sorting_sub_group"
if folder.is_dir():
shutil.rmtree(folder)
folder = tmp_path / "zarr_sorting_sub_group"
zarr_root = zarr.open(folder, mode="w")
zarr_sorting_group = zarr_root.create_group("sorting")
add_sorting_to_zarr_group(sorting, zarr_sorting_group)
Expand All @@ -43,4 +76,6 @@ def test_ZarrSortingExtractor():


if __name__ == "__main__":
test_ZarrSortingExtractor()
tmp_path = Path("tmp")
test_zarr_compression_options(tmp_path)
test_ZarrSortingExtractor(tmp_path)
30 changes: 23 additions & 7 deletions src/spikeinterface/core/zarrextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,9 @@ def add_sorting_to_zarr_group(sorting: BaseSorting, zarr_group: zarr.hierarchy.G


# Recording
def add_recording_to_zarr_group(recording: BaseRecording, zarr_group: zarr.hierarchy.Group, **kwargs):
def add_recording_to_zarr_group(
recording: BaseRecording, zarr_group: zarr.hierarchy.Group, verbose=False, auto_cast_uint=True, **kwargs
):
zarr_kwargs, job_kwargs = split_job_kwargs(kwargs)

if recording.check_if_json_serializable():
Expand All @@ -380,15 +382,25 @@ def add_recording_to_zarr_group(recording: BaseRecording, zarr_group: zarr.hiera
zarr_group.create_dataset(name="channel_ids", data=recording.get_channel_ids(), compressor=None)
dataset_paths = [f"traces_seg{i}" for i in range(recording.get_num_segments())]

zarr_kwargs["dtype"] = kwargs.get("dtype", None) or recording.get_dtype()
if "compressor" not in zarr_kwargs:
zarr_kwargs["compressor"] = get_default_zarr_compressor()
dtype = zarr_kwargs.get("dtype", None) or recording.get_dtype()
channel_chunk_size = zarr_kwargs.get("channel_chunk_size", None)
global_compressor = zarr_kwargs.pop("compressor", get_default_zarr_compressor())
compressor_by_dataset = zarr_kwargs.pop("compressor_by_dataset", {})
global_filters = zarr_kwargs.pop("filters", None)
filters_by_dataset = zarr_kwargs.pop("filters_by_dataset", {})

compressor_traces = compressor_by_dataset.get("traces", global_compressor)
filters_traces = filters_by_dataset.get("traces", global_filters)
add_traces_to_zarr(
recording=recording,
zarr_group=zarr_group,
dataset_paths=dataset_paths,
**zarr_kwargs,
compressor=compressor_traces,
filters=filters_traces,
dtype=dtype,
channel_chunk_size=channel_chunk_size,
auto_cast_uint=auto_cast_uint,
verbose=verbose,
**job_kwargs,
)

Expand All @@ -402,12 +414,16 @@ def add_recording_to_zarr_group(recording: BaseRecording, zarr_group: zarr.hiera
for segment_index, rs in enumerate(recording._recording_segments):
d = rs.get_times_kwargs()
time_vector = d["time_vector"]

compressor_times = compressor_by_dataset.get("times", global_compressor)
filters_times = filters_by_dataset.get("times", global_filters)

if time_vector is not None:
_ = zarr_group.create_dataset(
name=f"times_seg{segment_index}",
data=time_vector,
filters=zarr_kwargs.get("filters", None),
compressor=zarr_kwargs["compressor"],
filters=filters_times,
compressor=compressor_times,
)
elif d["t_start"] is not None:
t_starts[segment_index] = d["t_start"]
Expand Down

0 comments on commit 40429b4

Please sign in to comment.