diff --git a/pyproject.toml b/pyproject.toml index efb92d30a2..f2829f90ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ dependencies = [ "numpy", "threadpoolctl>=3.0.0", "tqdm", - "zarr", + "zarr>=2.15", "neo>=0.12.0", "probeinterface>=0.2.19", ] diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 3f42182101..af7ae544ef 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -2,6 +2,7 @@ import json from dataclasses import dataclass, field, astuple from probeinterface import Probe +from pathlib import Path from .sparsity import ChannelSparsity @@ -168,6 +169,131 @@ def from_dict(cls, data): probe=data["probe"] if data["probe"] is None else Probe.from_dict(data["probe"]), ) + def add_templates_to_zarr_group(self, zarr_group: "zarr.Group") -> None: + """ + Adds a serialized version of the object to a given Zarr group. + + It is the inverse of the `from_zarr_group` method. + + Parameters + ---------- + zarr_group : zarr.Group + The Zarr group to which the template object will be serialized. + + Notes + ----- + This method will create datasets within the Zarr group for `templates_array`, + `channel_ids`, and `unit_ids`. It will also add `sampling_frequency` and `nbefore` + as attributes to the group. If `sparsity_mask` and `probe` are not None, they will + be included as a dataset and a subgroup, respectively. + + The `templates_array` dataset is saved with a chunk size that has a single unit per chunk + to optimize read/write operations for individual units. + """ + + # Saves one chunk per unit + arrays_chunk = (1, None, None) + zarr_group.create_dataset("templates_array", data=self.templates_array, chunks=arrays_chunk) + zarr_group.create_dataset("channel_ids", data=self.channel_ids) + zarr_group.create_dataset("unit_ids", data=self.unit_ids) + + zarr_group.attrs["sampling_frequency"] = self.sampling_frequency + zarr_group.attrs["nbefore"] = self.nbefore + + if self.sparsity_mask is not None: + zarr_group.create_dataset("sparsity_mask", data=self.sparsity_mask) + + if self.probe is not None: + probe_group = zarr_group.create_group("probe") + self.probe.add_probe_to_zarr_group(probe_group) + + def to_zarr(self, folder_path: str | Path) -> None: + """ + Saves the object's data to a Zarr file in the specified folder. + + Use the `add_templates_to_zarr_group` method to serialize the object to a Zarr group and then + save the group to a Zarr file. + + Parameters + ---------- + folder_path : str | Path + The path to the folder where the Zarr data will be saved. + + """ + import zarr + + zarr_group = zarr.open_group(folder_path, mode="w") + + self.add_templates_to_zarr_group(zarr_group) + + @classmethod + def from_zarr_group(cls, zarr_group: "zarr.Group") -> "Templates": + """ + Loads an instance of the class from an open Zarr group. + + This is the inverse of the `add_templates_to_zarr_group` method. + + Parameters + ---------- + zarr_group : zarr.Group + The Zarr group from which to load the instance. + + Returns + ------- + Templates + An instance of Templates populated with the data from the Zarr group. + + Notes + ----- + This method assumes the Zarr group has the same structure as the one created by + the `add_templates_to_zarr_group` method. + + """ + templates_array = zarr_group["templates_array"] + channel_ids = zarr_group["channel_ids"] + unit_ids = zarr_group["unit_ids"] + sampling_frequency = zarr_group.attrs["sampling_frequency"] + nbefore = zarr_group.attrs["nbefore"] + + sparsity_mask = None + if "sparsity_mask" in zarr_group: + sparsity_mask = zarr_group["sparsity_mask"] + + probe = None + if "probe" in zarr_group: + probe = Probe.from_zarr_group(zarr_group["probe"]) + + return cls( + templates_array=templates_array, + sampling_frequency=sampling_frequency, + nbefore=nbefore, + sparsity_mask=sparsity_mask, + channel_ids=channel_ids, + unit_ids=unit_ids, + probe=probe, + ) + + @staticmethod + def from_zarr(folder_path: str | Path) -> "Templates": + """ + Deserialize the Templates object from a Zarr file located at the given folder path. + + Parameters + ---------- + folder_path : str | Path + The path to the folder where the Zarr file is located. + + Returns + ------- + Templates + An instance of Templates initialized with data from the Zarr file. + """ + import zarr + + zarr_group = zarr.open_group(folder_path, mode="r") + + return Templates.from_zarr_group(zarr_group) + def to_json(self): from spikeinterface.core.core_tools import SIJsonEncoder @@ -209,7 +335,6 @@ def __eq__(self, other): return False if not np.array_equal(s_field.channel_ids, o_field.channel_ids): return False - else: if s_field != o_field: return False diff --git a/src/spikeinterface/core/tests/test_template_class.py b/src/spikeinterface/core/tests/test_template_class.py index db04a90d81..0ed6bc2e3e 100644 --- a/src/spikeinterface/core/tests/test_template_class.py +++ b/src/spikeinterface/core/tests/test_template_class.py @@ -20,7 +20,9 @@ def generate_test_template(template_type): probe = generate_multi_columns_probe(num_columns=1, num_contact_per_column=[3]) if template_type == "dense": - return Templates(templates_array=templates_array, sampling_frequency=sampling_frequency, nbefore=nbefore) + return Templates( + templates_array=templates_array, sampling_frequency=sampling_frequency, nbefore=nbefore, probe=probe + ) elif template_type == "sparse": # sparse with sparse templates sparsity_mask = np.array([[True, False, True], [False, True, False]]) sparsity = ChannelSparsity( @@ -92,6 +94,19 @@ def test_initialization_fail_with_dense_templates(): template = generate_test_template(template_type="sparse_with_dense_templates") +@pytest.mark.parametrize("template_type", ["dense", "sparse"]) +def test_save_and_load_zarr(template_type, tmp_path): + original_template = generate_test_template(template_type) + + zarr_path = tmp_path / "templates.zarr" + original_template.to_zarr(str(zarr_path)) + + # Load from the Zarr archive + loaded_template = Templates.from_zarr(str(zarr_path)) + + assert original_template == loaded_template + + if __name__ == "__main__": # test_json_serialization("sparse") test_json_serialization("dense")