Skip to content

Commit

Permalink
Merge pull request #2403 from samuelgarcia/improve_zar_sorting
Browse files Browse the repository at this point in the history
Improve ZarrSortingExtractor
  • Loading branch information
samuelgarcia authored Jan 12, 2024
2 parents d3a7e6f + 8566cfb commit 522fd2b
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 5 deletions.
46 changes: 46 additions & 0 deletions src/spikeinterface/core/tests/test_zarrextractors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import pytest
from pathlib import Path

import shutil

import zarr

from spikeinterface.core import (
ZarrRecordingExtractor,
ZarrSortingExtractor,
generate_sorting,
load_extractor,
)
from spikeinterface.core.zarrextractors import add_sorting_to_zarr_group

if hasattr(pytest, "global_test_folder"):
cache_folder = pytest.global_test_folder / "core"
else:
cache_folder = Path("cache_folder") / "core"


def test_ZarrSortingExtractor():
np_sorting = generate_sorting()

# store in root standard normal way
folder = cache_folder / "zarr_sorting"
if folder.is_dir():
shutil.rmtree(folder)
ZarrSortingExtractor.write_sorting(np_sorting, folder)
sorting = ZarrSortingExtractor(folder)
sorting = load_extractor(sorting.to_dict())

# store the sorting in a sub group (for instance SortingResult)
folder = cache_folder / "zarr_sorting_sub_group"
if folder.is_dir():
shutil.rmtree(folder)
zarr_root = zarr.open(folder, mode="w")
zarr_sorting_group = zarr_root.create_group("sorting")
add_sorting_to_zarr_group(sorting, zarr_sorting_group)
sorting = ZarrSortingExtractor(folder, zarr_group="sorting")
# and reaload
sorting = load_extractor(sorting.to_dict())


if __name__ == "__main__":
test_ZarrSortingExtractor()
15 changes: 10 additions & 5 deletions src/spikeinterface/core/zarrextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,12 @@ class ZarrSortingExtractor(BaseSorting):
Parameters
----------
root_path: str or Path
folder_path: str or Path
Path to the zarr root file
storage_options: dict or None
Storage options for zarr `store`. E.g., if "s3://" or "gcs://" they can provide authentication methods, etc.
zarr_group: str or None, default: None
Optional zarr group path to load the sorting from. This can be used when the sorting is not stored at the root, but in sub group.
Returns
-------
sorting: ZarrSortingExtractor
Expand All @@ -172,12 +173,16 @@ class ZarrSortingExtractor(BaseSorting):
installation_mesg = ""
name = "zarr"

def __init__(self, folder_path: Path | str, storage_options: dict | None = None):
def __init__(self, folder_path: Path | str, storage_options: dict | None = None, zarr_group: str | None = None):
assert self.installed, self.installation_mesg

folder_path, folder_path_kwarg = resolve_zarr_path(folder_path)

self._root = zarr.open(str(folder_path), mode="r", storage_options=storage_options)
zarr_root = self._root = zarr.open(str(folder_path), mode="r", storage_options=storage_options)
if zarr_group is None:
self._root = zarr_root
else:
self._root = zarr_root[zarr_group]

sampling_frequency = self._root.attrs.get("sampling_frequency", None)
num_segments = self._root.attrs.get("num_segments", None)
Expand Down Expand Up @@ -216,7 +221,7 @@ def __init__(self, folder_path: Path | str, storage_options: dict | None = None)
if annotations is not None:
self.annotate(**annotations)

self._kwargs = {"root_path": folder_path_kwarg, "storage_options": storage_options}
self._kwargs = {"folder_path": folder_path_kwarg, "storage_options": storage_options, "zarr_group": zarr_group}

@staticmethod
def write_sorting(sorting: BaseSorting, folder_path: str | Path, storage_options: dict | None = None, **kwargs):
Expand Down

0 comments on commit 522fd2b

Please sign in to comment.