Skip to content

Commit

Permalink
Merge pull request #2423 from h-mayorquin/templates_to_zarr
Browse files Browse the repository at this point in the history
zarr IO for templates object
  • Loading branch information
alejoe91 authored Feb 5, 2024
2 parents 1cad1f8 + c6f21c3 commit bd317fe
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ dependencies = [
"numpy",
"threadpoolctl>=3.0.0",
"tqdm",
"zarr",
"zarr>=2.15",
"neo>=0.12.0",
"probeinterface>=0.2.19",
]
Expand Down
127 changes: 126 additions & 1 deletion src/spikeinterface/core/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
from dataclasses import dataclass, field, astuple
from probeinterface import Probe
from pathlib import Path
from .sparsity import ChannelSparsity


Expand Down Expand Up @@ -168,6 +169,131 @@ def from_dict(cls, data):
probe=data["probe"] if data["probe"] is None else Probe.from_dict(data["probe"]),
)

def add_templates_to_zarr_group(self, zarr_group: "zarr.Group") -> None:
"""
Adds a serialized version of the object to a given Zarr group.
It is the inverse of the `from_zarr_group` method.
Parameters
----------
zarr_group : zarr.Group
The Zarr group to which the template object will be serialized.
Notes
-----
This method will create datasets within the Zarr group for `templates_array`,
`channel_ids`, and `unit_ids`. It will also add `sampling_frequency` and `nbefore`
as attributes to the group. If `sparsity_mask` and `probe` are not None, they will
be included as a dataset and a subgroup, respectively.
The `templates_array` dataset is saved with a chunk size that has a single unit per chunk
to optimize read/write operations for individual units.
"""

# Saves one chunk per unit
arrays_chunk = (1, None, None)
zarr_group.create_dataset("templates_array", data=self.templates_array, chunks=arrays_chunk)
zarr_group.create_dataset("channel_ids", data=self.channel_ids)
zarr_group.create_dataset("unit_ids", data=self.unit_ids)

zarr_group.attrs["sampling_frequency"] = self.sampling_frequency
zarr_group.attrs["nbefore"] = self.nbefore

if self.sparsity_mask is not None:
zarr_group.create_dataset("sparsity_mask", data=self.sparsity_mask)

if self.probe is not None:
probe_group = zarr_group.create_group("probe")
self.probe.add_probe_to_zarr_group(probe_group)

def to_zarr(self, folder_path: str | Path) -> None:
"""
Saves the object's data to a Zarr file in the specified folder.
Use the `add_templates_to_zarr_group` method to serialize the object to a Zarr group and then
save the group to a Zarr file.
Parameters
----------
folder_path : str | Path
The path to the folder where the Zarr data will be saved.
"""
import zarr

zarr_group = zarr.open_group(folder_path, mode="w")

self.add_templates_to_zarr_group(zarr_group)

@classmethod
def from_zarr_group(cls, zarr_group: "zarr.Group") -> "Templates":
"""
Loads an instance of the class from an open Zarr group.
This is the inverse of the `add_templates_to_zarr_group` method.
Parameters
----------
zarr_group : zarr.Group
The Zarr group from which to load the instance.
Returns
-------
Templates
An instance of Templates populated with the data from the Zarr group.
Notes
-----
This method assumes the Zarr group has the same structure as the one created by
the `add_templates_to_zarr_group` method.
"""
templates_array = zarr_group["templates_array"]
channel_ids = zarr_group["channel_ids"]
unit_ids = zarr_group["unit_ids"]
sampling_frequency = zarr_group.attrs["sampling_frequency"]
nbefore = zarr_group.attrs["nbefore"]

sparsity_mask = None
if "sparsity_mask" in zarr_group:
sparsity_mask = zarr_group["sparsity_mask"]

probe = None
if "probe" in zarr_group:
probe = Probe.from_zarr_group(zarr_group["probe"])

return cls(
templates_array=templates_array,
sampling_frequency=sampling_frequency,
nbefore=nbefore,
sparsity_mask=sparsity_mask,
channel_ids=channel_ids,
unit_ids=unit_ids,
probe=probe,
)

@staticmethod
def from_zarr(folder_path: str | Path) -> "Templates":
"""
Deserialize the Templates object from a Zarr file located at the given folder path.
Parameters
----------
folder_path : str | Path
The path to the folder where the Zarr file is located.
Returns
-------
Templates
An instance of Templates initialized with data from the Zarr file.
"""
import zarr

zarr_group = zarr.open_group(folder_path, mode="r")

return Templates.from_zarr_group(zarr_group)

def to_json(self):
from spikeinterface.core.core_tools import SIJsonEncoder

Expand Down Expand Up @@ -209,7 +335,6 @@ def __eq__(self, other):
return False
if not np.array_equal(s_field.channel_ids, o_field.channel_ids):
return False

else:
if s_field != o_field:
return False
Expand Down
17 changes: 16 additions & 1 deletion src/spikeinterface/core/tests/test_template_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ def generate_test_template(template_type):
probe = generate_multi_columns_probe(num_columns=1, num_contact_per_column=[3])

if template_type == "dense":
return Templates(templates_array=templates_array, sampling_frequency=sampling_frequency, nbefore=nbefore)
return Templates(
templates_array=templates_array, sampling_frequency=sampling_frequency, nbefore=nbefore, probe=probe
)
elif template_type == "sparse": # sparse with sparse templates
sparsity_mask = np.array([[True, False, True], [False, True, False]])
sparsity = ChannelSparsity(
Expand Down Expand Up @@ -92,6 +94,19 @@ def test_initialization_fail_with_dense_templates():
template = generate_test_template(template_type="sparse_with_dense_templates")


@pytest.mark.parametrize("template_type", ["dense", "sparse"])
def test_save_and_load_zarr(template_type, tmp_path):
original_template = generate_test_template(template_type)

zarr_path = tmp_path / "templates.zarr"
original_template.to_zarr(str(zarr_path))

# Load from the Zarr archive
loaded_template = Templates.from_zarr(str(zarr_path))

assert original_template == loaded_template


if __name__ == "__main__":
# test_json_serialization("sparse")
test_json_serialization("dense")

0 comments on commit bd317fe

Please sign in to comment.