From 40429b4c325b904835d8b2507e4e9667454df2be Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 29 Mar 2024 11:26:25 +0100 Subject: [PATCH] Extended zarr compression --- src/spikeinterface/core/base.py | 17 ++++- .../core/tests/test_zarrextractors.py | 65 ++++++++++++++----- src/spikeinterface/core/zarrextractors.py | 30 +++++++-- 3 files changed, 89 insertions(+), 23 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 80341811b9..d25f1bf97b 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -967,9 +967,24 @@ def save_to_zarr( For cloud storage locations, this should not be None (in case of default values, use an empty dict) channel_chunk_size: int or None, default: None Channels per chunk (only for BaseRecording) + compressor: numcodecs.Codec or None, default: None + Global compressor. If None, Blosc-zstd level 5 is used. + filters: list[numcodecs.Codec] or None, default: None + Global filters for zarr (global) + compressor_by_dataset: dict or None, default: None + Optional compressor per dataset.: + - traces + - times + If None, the global compressor is used + filters_by_dataset: dict or None, default: None + Optional filters per dataset: + - traces + - times + If None, the global filters are used verbose: bool, default: True If True, the output is verbose - **save_kwargs: Keyword arguments for saving to zarr + auto_cast_uint: bool, default: True + If True, unsigned integers are cast to signed integers to avoid issues with zarr (only for BaseRecording) Returns ------- diff --git a/src/spikeinterface/core/tests/test_zarrextractors.py b/src/spikeinterface/core/tests/test_zarrextractors.py index 72247cb42a..2fc1f42ec5 100644 --- a/src/spikeinterface/core/tests/test_zarrextractors.py +++ b/src/spikeinterface/core/tests/test_zarrextractors.py @@ -1,39 +1,72 @@ import pytest from pathlib import Path -import shutil - import zarr from spikeinterface.core import ( ZarrRecordingExtractor, ZarrSortingExtractor, + generate_recording, generate_sorting, load_extractor, ) -from spikeinterface.core.zarrextractors import add_sorting_to_zarr_group +from spikeinterface.core.zarrextractors import add_sorting_to_zarr_group, get_default_zarr_compressor + + +def test_zarr_compression_options(tmp_path): + from numcodecs import Blosc, Delta, FixedScaleOffset + + recording = generate_recording(durations=[2]) + recording.set_times(recording.get_times() + 100) + + # store in root standard normal way + # default compressor + defaut_compressor = get_default_zarr_compressor() + + # other compressor + other_compressor1 = Blosc(cname="zlib", clevel=3, shuffle=Blosc.NOSHUFFLE) + other_compressor2 = Blosc(cname="blosclz", clevel=8, shuffle=Blosc.AUTOSHUFFLE) + + # timestamps compressors / filters + default_filters = None + other_filters1 = [FixedScaleOffset(scale=5, offset=2, dtype=recording.get_dtype())] + other_filters2 = [Delta(dtype="float64")] + + # default + ZarrRecordingExtractor.write_recording(recording, tmp_path / "rec_default.zarr") + rec_default = ZarrRecordingExtractor(tmp_path / "rec_default.zarr") + assert rec_default._root["traces_seg0"].compressor == defaut_compressor + assert rec_default._root["traces_seg0"].filters == default_filters + assert rec_default._root["times_seg0"].compressor == defaut_compressor + assert rec_default._root["times_seg0"].filters == default_filters -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "core" -else: - cache_folder = Path("cache_folder") / "core" + # now with other compressor + ZarrRecordingExtractor.write_recording( + recording, + tmp_path / "rec_other.zarr", + compressor=defaut_compressor, + filters=default_filters, + compressor_by_dataset={"traces": other_compressor1, "times": other_compressor2}, + filters_by_dataset={"traces": other_filters1, "times": other_filters2}, + ) + rec_other = ZarrRecordingExtractor(tmp_path / "rec_other.zarr") + assert rec_other._root["traces_seg0"].compressor == other_compressor1 + assert rec_other._root["traces_seg0"].filters == other_filters1 + assert rec_other._root["times_seg0"].compressor == other_compressor2 + assert rec_other._root["times_seg0"].filters == other_filters2 -def test_ZarrSortingExtractor(): +def test_ZarrSortingExtractor(tmp_path): np_sorting = generate_sorting() # store in root standard normal way - folder = cache_folder / "zarr_sorting" - if folder.is_dir(): - shutil.rmtree(folder) + folder = tmp_path / "zarr_sorting" ZarrSortingExtractor.write_sorting(np_sorting, folder) sorting = ZarrSortingExtractor(folder) sorting = load_extractor(sorting.to_dict()) # store the sorting in a sub group (for instance SortingResult) - folder = cache_folder / "zarr_sorting_sub_group" - if folder.is_dir(): - shutil.rmtree(folder) + folder = tmp_path / "zarr_sorting_sub_group" zarr_root = zarr.open(folder, mode="w") zarr_sorting_group = zarr_root.create_group("sorting") add_sorting_to_zarr_group(sorting, zarr_sorting_group) @@ -43,4 +76,6 @@ def test_ZarrSortingExtractor(): if __name__ == "__main__": - test_ZarrSortingExtractor() + tmp_path = Path("tmp") + test_zarr_compression_options(tmp_path) + test_ZarrSortingExtractor(tmp_path) diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index a8a23b5863..47e2ea2849 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -366,7 +366,9 @@ def add_sorting_to_zarr_group(sorting: BaseSorting, zarr_group: zarr.hierarchy.G # Recording -def add_recording_to_zarr_group(recording: BaseRecording, zarr_group: zarr.hierarchy.Group, **kwargs): +def add_recording_to_zarr_group( + recording: BaseRecording, zarr_group: zarr.hierarchy.Group, verbose=False, auto_cast_uint=True, **kwargs +): zarr_kwargs, job_kwargs = split_job_kwargs(kwargs) if recording.check_if_json_serializable(): @@ -380,15 +382,25 @@ def add_recording_to_zarr_group(recording: BaseRecording, zarr_group: zarr.hiera zarr_group.create_dataset(name="channel_ids", data=recording.get_channel_ids(), compressor=None) dataset_paths = [f"traces_seg{i}" for i in range(recording.get_num_segments())] - zarr_kwargs["dtype"] = kwargs.get("dtype", None) or recording.get_dtype() - if "compressor" not in zarr_kwargs: - zarr_kwargs["compressor"] = get_default_zarr_compressor() + dtype = zarr_kwargs.get("dtype", None) or recording.get_dtype() + channel_chunk_size = zarr_kwargs.get("channel_chunk_size", None) + global_compressor = zarr_kwargs.pop("compressor", get_default_zarr_compressor()) + compressor_by_dataset = zarr_kwargs.pop("compressor_by_dataset", {}) + global_filters = zarr_kwargs.pop("filters", None) + filters_by_dataset = zarr_kwargs.pop("filters_by_dataset", {}) + compressor_traces = compressor_by_dataset.get("traces", global_compressor) + filters_traces = filters_by_dataset.get("traces", global_filters) add_traces_to_zarr( recording=recording, zarr_group=zarr_group, dataset_paths=dataset_paths, - **zarr_kwargs, + compressor=compressor_traces, + filters=filters_traces, + dtype=dtype, + channel_chunk_size=channel_chunk_size, + auto_cast_uint=auto_cast_uint, + verbose=verbose, **job_kwargs, ) @@ -402,12 +414,16 @@ def add_recording_to_zarr_group(recording: BaseRecording, zarr_group: zarr.hiera for segment_index, rs in enumerate(recording._recording_segments): d = rs.get_times_kwargs() time_vector = d["time_vector"] + + compressor_times = compressor_by_dataset.get("times", global_compressor) + filters_times = filters_by_dataset.get("times", global_filters) + if time_vector is not None: _ = zarr_group.create_dataset( name=f"times_seg{segment_index}", data=time_vector, - filters=zarr_kwargs.get("filters", None), - compressor=zarr_kwargs["compressor"], + filters=filters_times, + compressor=compressor_times, ) elif d["t_start"] is not None: t_starts[segment_index] = d["t_start"]