Skip to content

Commit

Permalink
Adding SharedMemoryDriftingTemplates
Browse files Browse the repository at this point in the history
  • Loading branch information
yger committed Jan 16, 2025
1 parent 4a099e5 commit 2d087f2
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/spikeinterface/generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
move_dense_templates,
interpolate_templates,
DriftingTemplates,
SharedMemoryDriftingTemplates,
InjectDriftingTemplatesRecording,
make_linear_displacement,
)
Expand Down
79 changes: 78 additions & 1 deletion src/spikeinterface/generation/drift_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from numpy.typing import ArrayLike
from probeinterface import Probe
from spikeinterface.core import BaseRecording, BaseRecordingSegment, BaseSorting, Templates

from multiprocessing.shared_memory import SharedMemory
from spikeinterface.core.core_tools import make_shared_array

def interpolate_templates(templates_array, source_locations, dest_locations, interpolation_method="cubic"):
"""
Expand Down Expand Up @@ -258,6 +259,82 @@ def precompute_displacements(self, displacements, **interpolation_kwargs):
self.displacements = displacements


class SharedMemoryDriftingTemplates(DriftingTemplates):

def __init__(
self,
shm_name,
shape,
dtype,
templates_array_moved=None,
displacements=None,
main_shm_owner=True,
**static_kwargs
):

assert len(shape) == 4
assert shape[0] > 0, "SharedMemoryTemplates only supported with no empty templates"

self.shm = SharedMemory(shm_name, create=False)
templates_array_moved = np.ndarray(shape=shape, dtype=dtype, buffer=self.shm.buf)
self.static_kwargs = static_kwargs
DriftingTemplates.__init__(
self,
templates_array_moved=templates_array_moved,
displacements=displacements,
**self.static_kwargs
)

# this is very important for the shm.unlink()
# only the main instance need to call it
# all other instances that are loaded from dict are not the main owner
self.main_shm_owner = main_shm_owner

self._kwargs = dict(
shm_name=shm_name,
shape=shape,
displacements=self.displacements,
static_kwargs=self.static_kwargs,
channel_ids=self.channel_ids,
# this ensure that all dump/load will not be main shm owner
main_shm_owner=False,
)

def __del__(self):
self.shm.close()
if self.main_shm_owner:
self.shm.unlink()

@staticmethod
def from_drifting_templates(drifting_templates):
assert drifting_templates.templates_array_moved is not None, "drifting_templates must have precomputed displacements"
data = drifting_templates.templates_array_moved
shm_templates, shm = make_shared_array(data.shape, data.dtype)
shm_templates[:] = data
static_kwargs = drifting_templates.to_dict()
init_kwargs = {
"templates_array": np.asarray(static_kwargs["templates_array"]),
"sparsity_mask": None if static_kwargs["sparsity_mask"] is None else np.asarray(static_kwargs["sparsity_mask"]),
"channel_ids": np.asarray(static_kwargs["channel_ids"]),
"unit_ids": np.asarray(static_kwargs["unit_ids"]),
"sampling_frequency": static_kwargs["sampling_frequency"],
"nbefore": static_kwargs["nbefore"],
"is_scaled": static_kwargs["is_scaled"],
"probe": static_kwargs["probe"] if static_kwargs["probe"] is None else Probe.from_dict(static_kwargs["probe"]),
}
shared_drifting_templates = SharedMemoryDriftingTemplates(
shm.name,
data.shape,
data.dtype,
shm_templates,
drifting_templates.displacements,
main_shm_owner=True,
**init_kwargs
)
shm.close()
return shared_drifting_templates


def make_linear_displacement(start, stop, num_step=10):
"""
Generates 2D linear displacements between `start` and `stop` positions (included in returned displacements).
Expand Down
13 changes: 13 additions & 0 deletions src/spikeinterface/generation/tests/test_drift_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
move_dense_templates,
make_linear_displacement,
DriftingTemplates,
SharedMemoryDriftingTemplates,
InjectDriftingTemplatesRecording,
)
from spikeinterface.core.generate import generate_templates, generate_sorting, NoiseGeneratorRecording
Expand Down Expand Up @@ -133,6 +134,18 @@ def test_DriftingTemplates():
assert np.array_equal(drifting_templates_from_precomputed.displacements, drifting_templates.displacements)


def test_SharedMemoryDriftingTemplates():
static_templates = make_some_templates()
drifting_templates = DriftingTemplates.from_static_templates(static_templates)
displacement = np.array([[5.0, 10.0]])
drifting_templates.precompute_displacements(displacement)
shm_drifting_templates = SharedMemoryDriftingTemplates.from_drifting_templates(drifting_templates)

assert np.array_equal(
shm_drifting_templates.templates_array_moved, drifting_templates.templates_array_moved
)
assert np.array_equal(shm_drifting_templates.displacements, drifting_templates.displacements)

def test_InjectDriftingTemplatesRecording(create_cache_folder):
cache_folder = create_cache_folder
templates = make_some_templates()
Expand Down

0 comments on commit 2d087f2

Please sign in to comment.