Skip to content

Commit

Permalink
Merge pull request #2410 from samuelgarcia/estimate_sparsity
Browse files Browse the repository at this point in the history
Improvement to compute sparsity without WaveformsExtractor
  • Loading branch information
alejoe91 authored Jan 29, 2024
2 parents 34b9dbc + 4a7f6ef commit fcb7ce4
Show file tree
Hide file tree
Showing 11 changed files with 668 additions and 98 deletions.
2 changes: 1 addition & 1 deletion src/spikeinterface/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,6 @@
)

# channel sparsity
from .sparsity import ChannelSparsity, compute_sparsity
from .sparsity import ChannelSparsity, compute_sparsity, estimate_sparsity

from .template import Templates
3 changes: 2 additions & 1 deletion src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import numpy as np

from .base import BaseExtractor, BaseSegment
from .sorting_tools import spike_vector_to_spike_trains
from .waveform_tools import has_exceeding_spikes


Expand Down Expand Up @@ -499,6 +498,8 @@ def precompute_spike_trains(self, from_spike_vector=None):
If True, will compute it from the spike vector.
If False, will call `get_unit_spike_train` for each segment for each unit.
"""
from .sorting_tools import spike_vector_to_spike_trains

unit_ids = self.unit_ids

if from_spike_vector is None:
Expand Down
74 changes: 74 additions & 0 deletions src/spikeinterface/core/sorting_tools.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
from .basesorting import BaseSorting
import numpy as np


Expand Down Expand Up @@ -90,3 +91,76 @@ def vector_to_list_of_spiketrain_numba(sample_indices, unit_indices, num_units):
get_numba_vector_to_list_of_spiketrain._cached_numba_function = vector_to_list_of_spiketrain_numba

return vector_to_list_of_spiketrain_numba


# TODO later : implement other method like "maximum_rate", "by_percent", ...
def random_spikes_selection(
sorting: BaseSorting,
num_samples: int,
method: str = "uniform",
max_spikes_per_unit: int = 500,
margin_size: int | None = None,
seed: int | None = None,
):
"""
This replaces `select_random_spikes_uniformly()`.
Random spikes selection of spike across per units.
Can optionally avoid spikes on segment borders if
margin_size is not None.
Parameters
----------
sorting: BaseSorting
The sorting object
num_samples: list of int
The number of samples per segment.
Can be retrieved from recording with
num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())]
method: "uniform", default: "uniform"
The method to use. Only "uniform" is implemented for now
max_spikes_per_unit: int, default: 500
The number of spikes per units
margin_size: None | int, default: None
A margin on each border of segments to avoid spikes
seed: None | int, default: None
A seed for random generator
Returns
-------
random_spikes_indices: np.array
Selected spike indices coresponding to the sorting spike vector.
"""
rng = np.random.default_rng(seed=seed)
spikes = sorting.to_spike_vector()

random_spikes_indices = []
for unit_index, unit_id in enumerate(sorting.unit_ids):
all_unit_indices = np.flatnonzero(unit_index == spikes["unit_index"])

if method == "uniform":
selected_unit_indices = rng.choice(
all_unit_indices, size=min(max_spikes_per_unit, all_unit_indices.size), replace=False, shuffle=False
)
else:
raise ValueError(f"random_spikes_selection wrong method {method}, currently only 'uniform' can be used.")

if margin_size is not None:
margin_size = int(margin_size)
keep = np.ones(selected_unit_indices.size, dtype=bool)
# left margin
keep[selected_unit_indices < margin_size] = False
# right margin
for segment_index in range(sorting.get_num_segments()):
remove_mask = np.flatnonzero(
(spikes[selected_unit_indices]["segment_index"] == segment_index)
& (spikes[selected_unit_indices]["sample_index"] >= (num_samples[segment_index] - margin_size))
)
keep[remove_mask] = False
selected_unit_indices = selected_unit_indices[keep]

random_spikes_indices.append(selected_unit_indices)

random_spikes_indices = np.concatenate(random_spikes_indices)
random_spikes_indices = np.sort(random_spikes_indices)

return random_spikes_indices
182 changes: 152 additions & 30 deletions src/spikeinterface/core/sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@

import numpy as np

from .recording_tools import get_channel_distances, get_noise_levels
from .basesorting import BaseSorting
from .baserecording import BaseRecording
from .recording_tools import get_noise_levels
from .sorting_tools import random_spikes_selection
from .job_tools import _shared_job_kwargs_doc
from .waveform_tools import estimate_templates


_sparsity_doc = """
Expand Down Expand Up @@ -263,38 +268,38 @@ def from_dict(cls, dictionary: dict):

## Some convinient function to compute sparsity from several strategy
@classmethod
def from_best_channels(cls, we, num_channels, peak_sign="neg"):
def from_best_channels(cls, templates_or_we, num_channels, peak_sign="neg"):
"""
Construct sparsity from N best channels with the largest amplitude.
Use the "num_channels" argument to specify the number of channels.
"""
from .template_tools import get_template_amplitudes

mask = np.zeros((we.unit_ids.size, we.channel_ids.size), dtype="bool")
peak_values = get_template_amplitudes(we, peak_sign=peak_sign)
for unit_ind, unit_id in enumerate(we.unit_ids):
mask = np.zeros((templates_or_we.unit_ids.size, templates_or_we.channel_ids.size), dtype="bool")
peak_values = get_template_amplitudes(templates_or_we, peak_sign=peak_sign)
for unit_ind, unit_id in enumerate(templates_or_we.unit_ids):
chan_inds = np.argsort(np.abs(peak_values[unit_id]))[::-1]
chan_inds = chan_inds[:num_channels]
mask[unit_ind, chan_inds] = True
return cls(mask, we.unit_ids, we.channel_ids)
return cls(mask, templates_or_we.unit_ids, templates_or_we.channel_ids)

@classmethod
def from_radius(cls, we, radius_um, peak_sign="neg"):
def from_radius(cls, templates_or_we, radius_um, peak_sign="neg"):
"""
Construct sparsity from a radius around the best channel.
Use the "radius_um" argument to specify the radius in um
"""
from .template_tools import get_template_extremum_channel

mask = np.zeros((we.unit_ids.size, we.channel_ids.size), dtype="bool")
locations = we.get_channel_locations()
distances = np.linalg.norm(locations[:, np.newaxis] - locations[np.newaxis, :], axis=2)
best_chan = get_template_extremum_channel(we, peak_sign=peak_sign, outputs="index")
for unit_ind, unit_id in enumerate(we.unit_ids):
mask = np.zeros((templates_or_we.unit_ids.size, templates_or_we.channel_ids.size), dtype="bool")
channel_locations = templates_or_we.get_channel_locations()
distances = np.linalg.norm(channel_locations[:, np.newaxis] - channel_locations[np.newaxis, :], axis=2)
best_chan = get_template_extremum_channel(templates_or_we, peak_sign=peak_sign, outputs="index")
for unit_ind, unit_id in enumerate(templates_or_we.unit_ids):
chan_ind = best_chan[unit_id]
(chan_inds,) = np.nonzero(distances[chan_ind, :] <= radius_um)
mask[unit_ind, chan_inds] = True
return cls(mask, we.unit_ids, we.channel_ids)
return cls(mask, templates_or_we.unit_ids, templates_or_we.channel_ids)

@classmethod
def from_snr(cls, we, threshold, peak_sign="neg"):
Expand Down Expand Up @@ -374,7 +379,7 @@ def create_dense(cls, we):


def compute_sparsity(
waveform_extractor,
templates_or_waveform_extractor,
method="radius",
peak_sign="neg",
num_channels=5,
Expand All @@ -383,42 +388,159 @@ def compute_sparsity(
by_property=None,
):
"""
Get channel sparsity (subset of channels) for each template with several methods.
Get channel sparsity (subset of channels) for each template with several methods.
Parameters
----------
waveform_extractor: WaveformExtractor
The waveform extractor
Parameters
----------
templates_or_waveform_extractor: Templates | WaveformExtractor
A Templates or a WaveformExtractor object.
Some methods accept both objects (e.g. "best_channels", "radius", )
Other methods need WaveformExtractor because internally the recording is needed.
{}
Returns
-------
sparsity: ChannelSparsity
The estimated sparsity
Returns
-------
sparsity: ChannelSparsity
The estimated sparsity
"""

# Can't be done at module because this is a cyclic import, too bad
from .template import Templates
from .waveform_extractor import WaveformExtractor

if method in ("best_channels", "radius"):
assert isinstance(
templates_or_waveform_extractor, (Templates, WaveformExtractor)
), f"compute_sparsity() requires either a Templates or WaveformExtractor, not a type: {type(templates_or_waveform_extractor)}"
else:
assert isinstance(
templates_or_waveform_extractor, WaveformExtractor
), f"compute_sparsity(method='{method}') requires a WaveformExtractor"

if method == "best_channels":
assert num_channels is not None, "For the 'best_channels' method, 'num_channels' needs to be given"
sparsity = ChannelSparsity.from_best_channels(waveform_extractor, num_channels, peak_sign=peak_sign)
sparsity = ChannelSparsity.from_best_channels(
templates_or_waveform_extractor, num_channels, peak_sign=peak_sign
)
elif method == "radius":
assert radius_um is not None, "For the 'radius' method, 'radius_um' needs to be given"
sparsity = ChannelSparsity.from_radius(waveform_extractor, radius_um, peak_sign=peak_sign)
sparsity = ChannelSparsity.from_radius(templates_or_waveform_extractor, radius_um, peak_sign=peak_sign)
elif method == "snr":
assert threshold is not None, "For the 'snr' method, 'threshold' needs to be given"
sparsity = ChannelSparsity.from_snr(waveform_extractor, threshold, peak_sign=peak_sign)
sparsity = ChannelSparsity.from_snr(templates_or_waveform_extractor, threshold, peak_sign=peak_sign)
elif method == "energy":
assert threshold is not None, "For the 'energy' method, 'threshold' needs to be given"
sparsity = ChannelSparsity.from_energy(waveform_extractor, threshold)
sparsity = ChannelSparsity.from_energy(templates_or_waveform_extractor, threshold)
elif method == "ptp":
assert threshold is not None, "For the 'ptp' method, 'threshold' needs to be given"
sparsity = ChannelSparsity.from_ptp(waveform_extractor, threshold)
sparsity = ChannelSparsity.from_ptp(templates_or_waveform_extractor, threshold)
elif method == "by_property":
assert by_property is not None, "For the 'by_property' method, 'by_property' needs to be given"
sparsity = ChannelSparsity.from_property(waveform_extractor, by_property)
sparsity = ChannelSparsity.from_property(templates_or_waveform_extractor, by_property)
else:
raise ValueError(f"compute_sparsity() method={method} do not exists")
raise ValueError(f"compute_sparsity() method={method} does not exists")

return sparsity


compute_sparsity.__doc__ = compute_sparsity.__doc__.format(_sparsity_doc)


def estimate_sparsity(
recording: BaseRecording,
sorting: BaseSorting,
num_spikes_for_sparsity: int = 100,
ms_before: float = 1.0,
ms_after: float = 2.5,
method: "radius" | "best_channels" = "radius",
peak_sign: str = "neg",
radius_um: float = 100.0,
num_channels: int = 5,
**job_kwargs,
):
"""
Estimate the sparsity without needing a WaveformExtractor.
This is faster than `spikeinterface.waveforms_extractor.precompute_sparsity()` and it
traverses the recording to compute the average templates for each unit.
Contrary to the previous implementation:
* all units are computed in one read of recording
* it doesn't require a folder
* it doesn't consume too much memory
* it uses internally the `estimate_templates()` which is fast and parallel
Parameters
----------
recording: BaseRecording
The recording
sorting: BaseSorting
The sorting
num_spikes_for_sparsity: int, default: 100
How many spikes per units to compute the sparsity
ms_before: float, default: 1.0
Cut out in ms before spike time
ms_after: float, default: 2.5
Cut out in ms after spike time
method: "radius" | "best_channels", default: "radius"
Sparsity method propagated to the `compute_sparsity()` function.
Only "radius" or "best_channels" are implemented
peak_sign: "neg" | "pos" | "both", default: "neg"
Sign of the template to compute best channels
radius_um: float, default: 100.0
Used for "radius" method
num_channels: int, default: 5
Used for "best_channels" method
{}
Returns
-------
sparsity: ChannelSparsity
The estimated sparsity
"""
# Can't be done at module because this is a cyclic import, too bad
from .template import Templates

assert method in ("radius", "best_channels"), "estimate_sparsity() handle only method='radius' or 'best_channel'"
if method == "radius":
assert (
len(recording.get_probes()) == 1
), "The 'radius' method of `estimate_sparsity()` can handle only one probe"

nbefore = int(ms_before * recording.sampling_frequency / 1000.0)
nafter = int(ms_after * recording.sampling_frequency / 1000.0)

num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())]
random_spikes_indices = random_spikes_selection(
sorting,
num_samples,
method="uniform",
max_spikes_per_unit=num_spikes_for_sparsity,
margin_size=max(nbefore, nafter),
seed=2205,
)
spikes = sorting.to_spike_vector()
spikes = spikes[random_spikes_indices]

templates_array = estimate_templates(
recording, spikes, sorting.unit_ids, nbefore, nafter, return_scaled=False, **job_kwargs
)
templates = Templates(
templates_array=templates_array,
sampling_frequency=recording.sampling_frequency,
nbefore=nbefore,
sparsity_mask=None,
channel_ids=recording.channel_ids,
unit_ids=sorting.unit_ids,
probe=recording.get_probe(),
)

sparsity = compute_sparsity(
templates, method=method, peak_sign=peak_sign, num_channels=num_channels, radius_um=radius_um
)

return sparsity


estimate_sparsity.__doc__ = estimate_sparsity.__doc__.format(_shared_job_kwargs_doc)
8 changes: 8 additions & 0 deletions src/spikeinterface/core/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ def __post_init__(self):
else:
self.num_channels = self.sparsity_mask.shape[1]

if self.probe is not None:
assert isinstance(self.probe, Probe), "'probe' must be a probeinterface.Probe object"

# Time and frames domain information
self.nafter = self.num_samples - self.nbefore
self.ms_before = self.nbefore / self.sampling_frequency * 1000
Expand Down Expand Up @@ -214,3 +217,8 @@ def __eq__(self, other):
return False

return True

def get_channel_locations(self):
assert self.probe is not None, "Templates.get_channel_locations() needs a probe to be set"
channel_locations = self.probe.contact_positions
return channel_locations
Loading

0 comments on commit fcb7ce4

Please sign in to comment.