diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index 1eb01d27ce..fce3892071 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -137,6 +137,6 @@ ) # channel sparsity -from .sparsity import ChannelSparsity, compute_sparsity +from .sparsity import ChannelSparsity, compute_sparsity, estimate_sparsity from .template import Templates diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index eab0137add..4e3551e290 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -6,7 +6,6 @@ import numpy as np from .base import BaseExtractor, BaseSegment -from .sorting_tools import spike_vector_to_spike_trains from .waveform_tools import has_exceeding_spikes @@ -499,6 +498,8 @@ def precompute_spike_trains(self, from_spike_vector=None): If True, will compute it from the spike vector. If False, will call `get_unit_spike_train` for each segment for each unit. """ + from .sorting_tools import spike_vector_to_spike_trains + unit_ids = self.unit_ids if from_spike_vector is None: diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index aea247e909..df0af26fc0 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -1,4 +1,5 @@ from __future__ import annotations +from .basesorting import BaseSorting import numpy as np @@ -90,3 +91,76 @@ def vector_to_list_of_spiketrain_numba(sample_indices, unit_indices, num_units): get_numba_vector_to_list_of_spiketrain._cached_numba_function = vector_to_list_of_spiketrain_numba return vector_to_list_of_spiketrain_numba + + +# TODO later : implement other method like "maximum_rate", "by_percent", ... +def random_spikes_selection( + sorting: BaseSorting, + num_samples: int, + method: str = "uniform", + max_spikes_per_unit: int = 500, + margin_size: int | None = None, + seed: int | None = None, +): + """ + This replaces `select_random_spikes_uniformly()`. + Random spikes selection of spike across per units. + Can optionally avoid spikes on segment borders if + margin_size is not None. + + Parameters + ---------- + sorting: BaseSorting + The sorting object + num_samples: list of int + The number of samples per segment. + Can be retrieved from recording with + num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())] + method: "uniform", default: "uniform" + The method to use. Only "uniform" is implemented for now + max_spikes_per_unit: int, default: 500 + The number of spikes per units + margin_size: None | int, default: None + A margin on each border of segments to avoid spikes + seed: None | int, default: None + A seed for random generator + + Returns + ------- + random_spikes_indices: np.array + Selected spike indices coresponding to the sorting spike vector. + """ + rng = np.random.default_rng(seed=seed) + spikes = sorting.to_spike_vector() + + random_spikes_indices = [] + for unit_index, unit_id in enumerate(sorting.unit_ids): + all_unit_indices = np.flatnonzero(unit_index == spikes["unit_index"]) + + if method == "uniform": + selected_unit_indices = rng.choice( + all_unit_indices, size=min(max_spikes_per_unit, all_unit_indices.size), replace=False, shuffle=False + ) + else: + raise ValueError(f"random_spikes_selection wrong method {method}, currently only 'uniform' can be used.") + + if margin_size is not None: + margin_size = int(margin_size) + keep = np.ones(selected_unit_indices.size, dtype=bool) + # left margin + keep[selected_unit_indices < margin_size] = False + # right margin + for segment_index in range(sorting.get_num_segments()): + remove_mask = np.flatnonzero( + (spikes[selected_unit_indices]["segment_index"] == segment_index) + & (spikes[selected_unit_indices]["sample_index"] >= (num_samples[segment_index] - margin_size)) + ) + keep[remove_mask] = False + selected_unit_indices = selected_unit_indices[keep] + + random_spikes_indices.append(selected_unit_indices) + + random_spikes_indices = np.concatenate(random_spikes_indices) + random_spikes_indices = np.sort(random_spikes_indices) + + return random_spikes_indices diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 3b8b6025ca..ec7f52527e 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -2,7 +2,12 @@ import numpy as np -from .recording_tools import get_channel_distances, get_noise_levels +from .basesorting import BaseSorting +from .baserecording import BaseRecording +from .recording_tools import get_noise_levels +from .sorting_tools import random_spikes_selection +from .job_tools import _shared_job_kwargs_doc +from .waveform_tools import estimate_templates _sparsity_doc = """ @@ -263,38 +268,38 @@ def from_dict(cls, dictionary: dict): ## Some convinient function to compute sparsity from several strategy @classmethod - def from_best_channels(cls, we, num_channels, peak_sign="neg"): + def from_best_channels(cls, templates_or_we, num_channels, peak_sign="neg"): """ Construct sparsity from N best channels with the largest amplitude. Use the "num_channels" argument to specify the number of channels. """ from .template_tools import get_template_amplitudes - mask = np.zeros((we.unit_ids.size, we.channel_ids.size), dtype="bool") - peak_values = get_template_amplitudes(we, peak_sign=peak_sign) - for unit_ind, unit_id in enumerate(we.unit_ids): + mask = np.zeros((templates_or_we.unit_ids.size, templates_or_we.channel_ids.size), dtype="bool") + peak_values = get_template_amplitudes(templates_or_we, peak_sign=peak_sign) + for unit_ind, unit_id in enumerate(templates_or_we.unit_ids): chan_inds = np.argsort(np.abs(peak_values[unit_id]))[::-1] chan_inds = chan_inds[:num_channels] mask[unit_ind, chan_inds] = True - return cls(mask, we.unit_ids, we.channel_ids) + return cls(mask, templates_or_we.unit_ids, templates_or_we.channel_ids) @classmethod - def from_radius(cls, we, radius_um, peak_sign="neg"): + def from_radius(cls, templates_or_we, radius_um, peak_sign="neg"): """ Construct sparsity from a radius around the best channel. Use the "radius_um" argument to specify the radius in um """ from .template_tools import get_template_extremum_channel - mask = np.zeros((we.unit_ids.size, we.channel_ids.size), dtype="bool") - locations = we.get_channel_locations() - distances = np.linalg.norm(locations[:, np.newaxis] - locations[np.newaxis, :], axis=2) - best_chan = get_template_extremum_channel(we, peak_sign=peak_sign, outputs="index") - for unit_ind, unit_id in enumerate(we.unit_ids): + mask = np.zeros((templates_or_we.unit_ids.size, templates_or_we.channel_ids.size), dtype="bool") + channel_locations = templates_or_we.get_channel_locations() + distances = np.linalg.norm(channel_locations[:, np.newaxis] - channel_locations[np.newaxis, :], axis=2) + best_chan = get_template_extremum_channel(templates_or_we, peak_sign=peak_sign, outputs="index") + for unit_ind, unit_id in enumerate(templates_or_we.unit_ids): chan_ind = best_chan[unit_id] (chan_inds,) = np.nonzero(distances[chan_ind, :] <= radius_um) mask[unit_ind, chan_inds] = True - return cls(mask, we.unit_ids, we.channel_ids) + return cls(mask, templates_or_we.unit_ids, templates_or_we.channel_ids) @classmethod def from_snr(cls, we, threshold, peak_sign="neg"): @@ -374,7 +379,7 @@ def create_dense(cls, we): def compute_sparsity( - waveform_extractor, + templates_or_waveform_extractor, method="radius", peak_sign="neg", num_channels=5, @@ -383,42 +388,159 @@ def compute_sparsity( by_property=None, ): """ - Get channel sparsity (subset of channels) for each template with several methods. + Get channel sparsity (subset of channels) for each template with several methods. - Parameters - ---------- - waveform_extractor: WaveformExtractor - The waveform extractor + Parameters + ---------- + templates_or_waveform_extractor: Templates | WaveformExtractor + A Templates or a WaveformExtractor object. + Some methods accept both objects (e.g. "best_channels", "radius", ) + Other methods need WaveformExtractor because internally the recording is needed. {} - Returns - ------- - sparsity: ChannelSparsity - The estimated sparsity + Returns + ------- + sparsity: ChannelSparsity + The estimated sparsity """ + + # Can't be done at module because this is a cyclic import, too bad + from .template import Templates + from .waveform_extractor import WaveformExtractor + + if method in ("best_channels", "radius"): + assert isinstance( + templates_or_waveform_extractor, (Templates, WaveformExtractor) + ), f"compute_sparsity() requires either a Templates or WaveformExtractor, not a type: {type(templates_or_waveform_extractor)}" + else: + assert isinstance( + templates_or_waveform_extractor, WaveformExtractor + ), f"compute_sparsity(method='{method}') requires a WaveformExtractor" + if method == "best_channels": assert num_channels is not None, "For the 'best_channels' method, 'num_channels' needs to be given" - sparsity = ChannelSparsity.from_best_channels(waveform_extractor, num_channels, peak_sign=peak_sign) + sparsity = ChannelSparsity.from_best_channels( + templates_or_waveform_extractor, num_channels, peak_sign=peak_sign + ) elif method == "radius": assert radius_um is not None, "For the 'radius' method, 'radius_um' needs to be given" - sparsity = ChannelSparsity.from_radius(waveform_extractor, radius_um, peak_sign=peak_sign) + sparsity = ChannelSparsity.from_radius(templates_or_waveform_extractor, radius_um, peak_sign=peak_sign) elif method == "snr": assert threshold is not None, "For the 'snr' method, 'threshold' needs to be given" - sparsity = ChannelSparsity.from_snr(waveform_extractor, threshold, peak_sign=peak_sign) + sparsity = ChannelSparsity.from_snr(templates_or_waveform_extractor, threshold, peak_sign=peak_sign) elif method == "energy": assert threshold is not None, "For the 'energy' method, 'threshold' needs to be given" - sparsity = ChannelSparsity.from_energy(waveform_extractor, threshold) + sparsity = ChannelSparsity.from_energy(templates_or_waveform_extractor, threshold) elif method == "ptp": assert threshold is not None, "For the 'ptp' method, 'threshold' needs to be given" - sparsity = ChannelSparsity.from_ptp(waveform_extractor, threshold) + sparsity = ChannelSparsity.from_ptp(templates_or_waveform_extractor, threshold) elif method == "by_property": assert by_property is not None, "For the 'by_property' method, 'by_property' needs to be given" - sparsity = ChannelSparsity.from_property(waveform_extractor, by_property) + sparsity = ChannelSparsity.from_property(templates_or_waveform_extractor, by_property) else: - raise ValueError(f"compute_sparsity() method={method} do not exists") + raise ValueError(f"compute_sparsity() method={method} does not exists") return sparsity compute_sparsity.__doc__ = compute_sparsity.__doc__.format(_sparsity_doc) + + +def estimate_sparsity( + recording: BaseRecording, + sorting: BaseSorting, + num_spikes_for_sparsity: int = 100, + ms_before: float = 1.0, + ms_after: float = 2.5, + method: "radius" | "best_channels" = "radius", + peak_sign: str = "neg", + radius_um: float = 100.0, + num_channels: int = 5, + **job_kwargs, +): + """ + Estimate the sparsity without needing a WaveformExtractor. + This is faster than `spikeinterface.waveforms_extractor.precompute_sparsity()` and it + traverses the recording to compute the average templates for each unit. + + Contrary to the previous implementation: + * all units are computed in one read of recording + * it doesn't require a folder + * it doesn't consume too much memory + * it uses internally the `estimate_templates()` which is fast and parallel + + Parameters + ---------- + recording: BaseRecording + The recording + sorting: BaseSorting + The sorting + num_spikes_for_sparsity: int, default: 100 + How many spikes per units to compute the sparsity + ms_before: float, default: 1.0 + Cut out in ms before spike time + ms_after: float, default: 2.5 + Cut out in ms after spike time + method: "radius" | "best_channels", default: "radius" + Sparsity method propagated to the `compute_sparsity()` function. + Only "radius" or "best_channels" are implemented + peak_sign: "neg" | "pos" | "both", default: "neg" + Sign of the template to compute best channels + radius_um: float, default: 100.0 + Used for "radius" method + num_channels: int, default: 5 + Used for "best_channels" method + + {} + + Returns + ------- + sparsity: ChannelSparsity + The estimated sparsity + """ + # Can't be done at module because this is a cyclic import, too bad + from .template import Templates + + assert method in ("radius", "best_channels"), "estimate_sparsity() handle only method='radius' or 'best_channel'" + if method == "radius": + assert ( + len(recording.get_probes()) == 1 + ), "The 'radius' method of `estimate_sparsity()` can handle only one probe" + + nbefore = int(ms_before * recording.sampling_frequency / 1000.0) + nafter = int(ms_after * recording.sampling_frequency / 1000.0) + + num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())] + random_spikes_indices = random_spikes_selection( + sorting, + num_samples, + method="uniform", + max_spikes_per_unit=num_spikes_for_sparsity, + margin_size=max(nbefore, nafter), + seed=2205, + ) + spikes = sorting.to_spike_vector() + spikes = spikes[random_spikes_indices] + + templates_array = estimate_templates( + recording, spikes, sorting.unit_ids, nbefore, nafter, return_scaled=False, **job_kwargs + ) + templates = Templates( + templates_array=templates_array, + sampling_frequency=recording.sampling_frequency, + nbefore=nbefore, + sparsity_mask=None, + channel_ids=recording.channel_ids, + unit_ids=sorting.unit_ids, + probe=recording.get_probe(), + ) + + sparsity = compute_sparsity( + templates, method=method, peak_sign=peak_sign, num_channels=num_channels, radius_um=radius_um + ) + + return sparsity + + +estimate_sparsity.__doc__ = estimate_sparsity.__doc__.format(_shared_job_kwargs_doc) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 9a2a868f98..8e486c8e9e 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -80,6 +80,9 @@ def __post_init__(self): else: self.num_channels = self.sparsity_mask.shape[1] + if self.probe is not None: + assert isinstance(self.probe, Probe), "'probe' must be a probeinterface.Probe object" + # Time and frames domain information self.nafter = self.num_samples - self.nbefore self.ms_before = self.nbefore / self.sampling_frequency * 1000 @@ -214,3 +217,8 @@ def __eq__(self, other): return False return True + + def get_channel_locations(self): + assert self.probe is not None, "Templates.get_channel_locations() needs a probe to be set" + channel_locations = self.probe.contact_positions + return channel_locations diff --git a/src/spikeinterface/core/template_tools.py b/src/spikeinterface/core/template_tools.py index a6de2de2fa..735ce2cbdc 100644 --- a/src/spikeinterface/core/template_tools.py +++ b/src/spikeinterface/core/template_tools.py @@ -2,20 +2,34 @@ import numpy as np import warnings +from .template import Templates +from .waveform_extractor import WaveformExtractor from .sparsity import compute_sparsity, _sparsity_doc from .recording_tools import get_channel_distances, get_noise_levels +def _get_dense_templates_array(templates_or_waveform_extractor): + if isinstance(templates_or_waveform_extractor, Templates): + templates_array = templates_or_waveform_extractor.get_dense_templates() + elif isinstance(templates_or_waveform_extractor, WaveformExtractor): + templates_array = templates_or_waveform_extractor.get_all_templates(mode="average") + else: + raise ValueError("templates_or_waveform_extractor should be Templates or WaveformExtractor") + return templates_array + + def get_template_amplitudes( - waveform_extractor, peak_sign: "neg" | "pos" | "both" = "neg", mode: "extremum" | "at_index" = "extremum" + templates_or_waveform_extractor, + peak_sign: "neg" | "pos" | "both" = "neg", + mode: "extremum" | "at_index" = "extremum", ): """ Get amplitude per channel for each unit. Parameters ---------- - waveform_extractor: WaveformExtractor - The waveform extractor + templates_or_waveform_extractor: Templates | WaveformExtractor + A Templates or a WaveformExtractor object peak_sign: "neg" | "pos" | "both", default: "neg" Sign of the template to compute best channels mode: "extremum" | "at_index", default: "extremum" @@ -29,15 +43,16 @@ def get_template_amplitudes( """ assert peak_sign in ("both", "neg", "pos"), "'peak_sign' must be 'both', 'neg', or 'pos'" assert mode in ("extremum", "at_index"), "'mode' must be 'extremum' or 'at_index'" - unit_ids = waveform_extractor.sorting.unit_ids - before = waveform_extractor.nbefore + unit_ids = templates_or_waveform_extractor.unit_ids + before = templates_or_waveform_extractor.nbefore peak_values = {} - templates = waveform_extractor.get_all_templates(mode="average") + templates_array = _get_dense_templates_array(templates_or_waveform_extractor) + for unit_ind, unit_id in enumerate(unit_ids): - template = templates[unit_ind, :, :] + template = templates_array[unit_ind, :, :] if mode == "extremum": if peak_sign == "both": @@ -60,7 +75,7 @@ def get_template_amplitudes( def get_template_extremum_channel( - waveform_extractor, + templates_or_waveform_extractor, peak_sign: "neg" | "pos" | "both" = "neg", mode: "extremum" | "at_index" = "extremum", outputs: "id" | "index" = "id", @@ -70,8 +85,8 @@ def get_template_extremum_channel( Parameters ---------- - waveform_extractor: WaveformExtractor - The waveform extractor + templates_or_waveform_extractor: Templates | WaveformExtractor + A Templates or a WaveformExtractor object peak_sign: "neg" | "pos" | "both", default: "neg" Sign of the template to compute best channels mode: "extremum" | "at_index", default: "extremum" @@ -91,10 +106,10 @@ def get_template_extremum_channel( assert mode in ("extremum", "at_index") assert outputs in ("id", "index") - unit_ids = waveform_extractor.sorting.unit_ids - channel_ids = waveform_extractor.channel_ids + unit_ids = templates_or_waveform_extractor.unit_ids + channel_ids = templates_or_waveform_extractor.channel_ids - peak_values = get_template_amplitudes(waveform_extractor, peak_sign=peak_sign, mode=mode) + peak_values = get_template_amplitudes(templates_or_waveform_extractor, peak_sign=peak_sign, mode=mode) extremum_channels_id = {} extremum_channels_index = {} for unit_id in unit_ids: @@ -109,7 +124,7 @@ def get_template_extremum_channel( def get_template_channel_sparsity( - waveform_extractor, + templates_or_waveform_extractor, method="radius", peak_sign="neg", num_channels=5, @@ -119,22 +134,24 @@ def get_template_channel_sparsity( outputs="id", ): """ - Get channel sparsity (subset of channels) for each template with several methods. + Get channel sparsity (subset of channels) for each template with several methods. + + Parameters + ---------- + templates_or_waveform_extractor: Templates | WaveformExtractor + A Templates or a WaveformExtractor object - Parameters - ---------- - waveform_extractor: WaveformExtractor - The waveform extractor {} - outputs: str - * "id": channel id - * "index": channel index - - Returns - ------- - sparsity: dict - Dictionary with unit ids as keys and sparse channel ids or indices (id or index based on "outputs") - as values + + outputs: str + * "id": channel id + * "index": channel index + + Returns + ------- + sparsity: dict + Dictionary with unit ids as keys and sparse channel ids or indices (id or index based on "outputs") + as values """ from spikeinterface.core.sparsity import compute_sparsity @@ -146,7 +163,7 @@ def get_template_channel_sparsity( assert outputs in ("id", "index"), "'outputs' can either be 'id' or 'index'" sparsity = compute_sparsity( - waveform_extractor, + templates_or_waveform_extractor, method=method, peak_sign=peak_sign, num_channels=num_channels, @@ -165,7 +182,9 @@ def get_template_channel_sparsity( get_template_channel_sparsity.__doc__ = get_template_channel_sparsity.__doc__.format(_sparsity_doc) -def get_template_extremum_channel_peak_shift(waveform_extractor, peak_sign: "neg" | "pos" | "both" = "neg"): +def get_template_extremum_channel_peak_shift( + templates_or_waveform_extractor, peak_sign: "neg" | "pos" | "both" = "neg" +): """ In some situations spike sorters could return a spike index with a small shift related to the waveform peak. This function estimates and return these alignment shifts for the mean template. @@ -173,8 +192,8 @@ def get_template_extremum_channel_peak_shift(waveform_extractor, peak_sign: "neg Parameters ---------- - waveform_extractor: WaveformExtractor - The waveform extractor + templates_or_waveform_extractor: Templates | WaveformExtractor + A Templates or a WaveformExtractor object peak_sign: "neg" | "pos" | "both", default: "neg" Sign of the template to compute best channels @@ -183,19 +202,21 @@ def get_template_extremum_channel_peak_shift(waveform_extractor, peak_sign: "neg shifts: dict Dictionary with unit ids as keys and shifts as values """ - sorting = waveform_extractor.sorting - unit_ids = sorting.unit_ids + unit_ids = templates_or_waveform_extractor.unit_ids + channel_ids = templates_or_waveform_extractor.channel_ids + nbefore = templates_or_waveform_extractor.nbefore - extremum_channels_ids = get_template_extremum_channel(waveform_extractor, peak_sign=peak_sign) + extremum_channels_ids = get_template_extremum_channel(templates_or_waveform_extractor, peak_sign=peak_sign) shifts = {} - templates = waveform_extractor.get_all_templates(mode="average") + templates_array = _get_dense_templates_array(templates_or_waveform_extractor) + for unit_ind, unit_id in enumerate(unit_ids): - template = templates[unit_ind, :, :] + template = templates_array[unit_ind, :, :] chan_id = extremum_channels_ids[unit_id] - chan_ind = waveform_extractor.channel_ids_to_indices([chan_id])[0] + chan_ind = list(channel_ids).index(chan_id) if peak_sign == "both": peak_pos = np.argmax(np.abs(template[:, chan_ind])) @@ -203,22 +224,24 @@ def get_template_extremum_channel_peak_shift(waveform_extractor, peak_sign: "neg peak_pos = np.argmin(template[:, chan_ind]) elif peak_sign == "pos": peak_pos = np.argmax(template[:, chan_ind]) - shift = peak_pos - waveform_extractor.nbefore + shift = peak_pos - nbefore shifts[unit_id] = shift return shifts def get_template_extremum_amplitude( - waveform_extractor, peak_sign: "neg" | "pos" | "both" = "neg", mode: "extremum" | "at_index" = "at_index" + templates_or_waveform_extractor, + peak_sign: "neg" | "pos" | "both" = "neg", + mode: "extremum" | "at_index" = "at_index", ): """ Computes amplitudes on the best channel. Parameters ---------- - waveform_extractor: WaveformExtractor - The waveform extractor + templates_or_waveform_extractor: Templates | WaveformExtractor + A Templates or a WaveformExtractor object peak_sign: "neg" | "pos" | "both" Sign of the template to compute best channels mode: "extremum" | "at_index", default: "at_index" @@ -233,18 +256,21 @@ def get_template_extremum_amplitude( """ assert peak_sign in ("both", "neg", "pos"), "'peak_sign' must be 'neg' or 'pos' or 'both'" assert mode in ("extremum", "at_index"), "'mode' must be 'extremum' or 'at_index'" - unit_ids = waveform_extractor.sorting.unit_ids + unit_ids = templates_or_waveform_extractor.unit_ids + channel_ids = templates_or_waveform_extractor.channel_ids - before = waveform_extractor.nbefore + before = templates_or_waveform_extractor.nbefore - extremum_channels_ids = get_template_extremum_channel(waveform_extractor, peak_sign=peak_sign, mode=mode) + extremum_channels_ids = get_template_extremum_channel( + templates_or_waveform_extractor, peak_sign=peak_sign, mode=mode + ) - extremum_amplitudes = get_template_amplitudes(waveform_extractor, peak_sign=peak_sign, mode=mode) + extremum_amplitudes = get_template_amplitudes(templates_or_waveform_extractor, peak_sign=peak_sign, mode=mode) unit_amplitudes = {} for unit_id in unit_ids: channel_id = extremum_channels_ids[unit_id] - best_channel = waveform_extractor.channel_ids_to_indices([channel_id])[0] + best_channel = list(channel_ids).index(channel_id) unit_amplitudes[unit_id] = extremum_amplitudes[unit_id][best_channel] return unit_amplitudes diff --git a/src/spikeinterface/core/tests/test_sorting_tools.py b/src/spikeinterface/core/tests/test_sorting_tools.py index ceaa8006ee..63e4e7d6b9 100644 --- a/src/spikeinterface/core/tests/test_sorting_tools.py +++ b/src/spikeinterface/core/tests/test_sorting_tools.py @@ -4,7 +4,8 @@ from spikeinterface.core import NumpySorting -from spikeinterface.core.sorting_tools import spike_vector_to_spike_trains +from spikeinterface.core import generate_ground_truth_recording +from spikeinterface.core.sorting_tools import spike_vector_to_spike_trains, random_spikes_selection @pytest.mark.skipif( @@ -20,5 +21,36 @@ def test_spike_vector_to_spike_trains(): assert np.array_equal(spike_trains[0][unit_id], sorting.get_unit_spike_train(unit_id=unit_id, segment_index=0)) +def test_random_spikes_selection(): + recording, sorting = generate_ground_truth_recording( + durations=[30.0], + sampling_frequency=16000.0, + num_channels=10, + num_units=5, + generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), + noise_kwargs=dict(noise_level=5.0, strategy="tile_pregenerated"), + seed=2205, + ) + max_spikes_per_unit = 12 + num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())] + + random_spikes_indices = random_spikes_selection( + sorting, num_samples, method="uniform", max_spikes_per_unit=max_spikes_per_unit, margin_size=None, seed=2205 + ) + spikes = sorting.to_spike_vector() + some_spikes = spikes[random_spikes_indices] + for unit_index, unit_id in enumerate(sorting.unit_ids): + spike_slected_unit = some_spikes[some_spikes["unit_index"] == unit_index] + assert spike_slected_unit.size == max_spikes_per_unit + + # with margin + random_spikes_indices = random_spikes_selection( + sorting, num_samples, method="uniform", max_spikes_per_unit=max_spikes_per_unit, margin_size=25, seed=2205 + ) + # in that case the number is not garanty so it can be a bit less + assert random_spikes_indices.size >= (0.9 * sorting.unit_ids.size * max_spikes_per_unit) + + if __name__ == "__main__": - test_spike_vector_to_spike_trains() + # test_spike_vector_to_spike_trains() + test_random_spikes_selection() diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index ac114ac161..65d850ae1c 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -3,8 +3,9 @@ import numpy as np import json -from spikeinterface.core import ChannelSparsity +from spikeinterface.core import ChannelSparsity, estimate_sparsity from spikeinterface.core.core_tools import check_json +from spikeinterface.core import generate_ground_truth_recording def test_ChannelSparsity(): @@ -143,5 +144,50 @@ def test_densify_waveforms(): assert np.array_equal(template_sparse, template_sparse2) +def test_estimate_sparsity(): + num_units = 5 + recording, sorting = generate_ground_truth_recording( + durations=[30.0], + sampling_frequency=16000.0, + num_channels=10, + num_units=5, + generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), + noise_kwargs=dict(noise_level=1.0, strategy="tile_pregenerated"), + seed=2205, + ) + + # small radius should give a very sparse = one channel per unit + sparsity = estimate_sparsity( + recording, + sorting, + num_spikes_for_sparsity=50, + ms_before=1.0, + ms_after=2.0, + method="radius", + radius_um=1.0, + chunk_duration="1s", + progress_bar=True, + n_jobs=2, + ) + # print(sparsity) + assert np.array_equal(np.sum(sparsity.mask, axis=1), np.ones(num_units)) + + # best_channel : the mask should exactly 3 channels per units + sparsity = estimate_sparsity( + recording, + sorting, + num_spikes_for_sparsity=50, + ms_before=1.0, + ms_after=2.0, + method="best_channels", + num_channels=3, + chunk_duration="1s", + progress_bar=True, + n_jobs=1, + ) + assert np.array_equal(np.sum(sparsity.mask, axis=1), np.ones(num_units) * 3) + + if __name__ == "__main__": test_ChannelSparsity() + test_estimate_sparsity() diff --git a/src/spikeinterface/core/tests/test_template_tools.py b/src/spikeinterface/core/tests/test_template_tools.py index 17bfb81d49..eaa7712fcb 100644 --- a/src/spikeinterface/core/tests/test_template_tools.py +++ b/src/spikeinterface/core/tests/test_template_tools.py @@ -3,7 +3,7 @@ from pathlib import Path from spikeinterface import load_extractor, extract_waveforms, load_waveforms, generate_recording, generate_sorting - +from spikeinterface import Templates from spikeinterface.core import ( get_template_amplitudes, get_template_extremum_channel, @@ -36,22 +36,42 @@ def setup_module(): we = extract_waveforms(recording, sorting, cache_folder / "toy_waveforms") +def _get_templates_object_from_waveform_extractor(we): + templates = Templates( + templates_array=we.get_all_templates(mode="average"), + sampling_frequency=we.sampling_frequency, + nbefore=we.nbefore, + sparsity_mask=None, + channel_ids=we.channel_ids, + unit_ids=we.unit_ids, + ) + return templates + + def test_get_template_amplitudes(): we = load_waveforms(cache_folder / "toy_waveforms") peak_values = get_template_amplitudes(we) print(peak_values) + templates = _get_templates_object_from_waveform_extractor(we) + peak_values = get_template_amplitudes(templates) + print(peak_values) def test_get_template_extremum_channel(): we = load_waveforms(cache_folder / "toy_waveforms") extremum_channels_ids = get_template_extremum_channel(we, peak_sign="both") print(extremum_channels_ids) + templates = _get_templates_object_from_waveform_extractor(we) + extremum_channels_ids = get_template_extremum_channel(templates, peak_sign="both") + print(extremum_channels_ids) def test_get_template_extremum_channel_peak_shift(): we = load_waveforms(cache_folder / "toy_waveforms") shifts = get_template_extremum_channel_peak_shift(we, peak_sign="neg") print(shifts) + templates = _get_templates_object_from_waveform_extractor(we) + shifts = get_template_extremum_channel_peak_shift(templates, peak_sign="neg") # DEBUG # import matplotlib.pyplot as plt @@ -75,6 +95,9 @@ def test_get_template_extremum_amplitude(): extremum_channels_ids = get_template_extremum_amplitude(we, peak_sign="both") print(extremum_channels_ids) + templates = _get_templates_object_from_waveform_extractor(we) + extremum_channels_ids = get_template_extremum_amplitude(templates, peak_sign="both") + if __name__ == "__main__": setup_module() diff --git a/src/spikeinterface/core/tests/test_waveform_tools.py b/src/spikeinterface/core/tests/test_waveform_tools.py index e9cf1bfb5f..71d30495d8 100644 --- a/src/spikeinterface/core/tests/test_waveform_tools.py +++ b/src/spikeinterface/core/tests/test_waveform_tools.py @@ -5,11 +5,12 @@ import numpy as np -from spikeinterface.core import generate_recording, generate_sorting +from spikeinterface.core import generate_recording, generate_sorting, generate_ground_truth_recording from spikeinterface.core.waveform_tools import ( extract_waveforms_to_buffers, extract_waveforms_to_single_buffer, split_waveforms_by_units, + estimate_templates, ) @@ -26,22 +27,38 @@ def _check_all_wf_equal(list_wfs_arrays): assert np.array_equal(wfs_arrays[unit_id], wfs_arrays0[unit_id]) -def test_waveform_tools(): - durations = [30, 40] - sampling_frequency = 30000.0 - - # 2 segments - num_channels = 2 - recording = generate_recording( - num_channels=num_channels, durations=durations, sampling_frequency=sampling_frequency +def get_dataset(): + recording, sorting = generate_ground_truth_recording( + durations=[30.0, 40.0], + sampling_frequency=30000.0, + num_channels=4, + num_units=7, + generate_sorting_kwargs=dict(firing_rates=5.0, refractory_period_ms=4.0), + noise_kwargs=dict(noise_level=1.0, strategy="tile_pregenerated"), + seed=2205, ) - recording.annotate(is_filtered=True) - num_units = 15 - sorting = generate_sorting(num_units=num_units, sampling_frequency=sampling_frequency, durations=durations) + return recording, sorting + + +def test_waveform_tools(): + # durations = [30, 40] + # sampling_frequency = 30000.0 + + # # 2 segments + # num_channels = 2 + # recording = generate_recording( + # num_channels=num_channels, durations=durations, sampling_frequency=sampling_frequency + # ) + # recording.annotate(is_filtered=True) + # num_units = 15 + # sorting = generate_sorting(num_units=num_units, sampling_frequency=sampling_frequency, durations=durations) # test with dump !!!! - recording = recording.save() - sorting = sorting.save() + # recording = recording.save() + # sorting = sorting.save() + + recording, sorting = get_dataset() + sampling_frequency = recording.sampling_frequency nbefore = int(3.0 * sampling_frequency / 1000.0) nafter = int(4.0 * sampling_frequency / 1000.0) @@ -145,5 +162,38 @@ def test_waveform_tools(): _check_all_wf_equal(list_wfs_sparse) +def test_estimate_templates(): + recording, sorting = get_dataset() + + ms_before = 1.0 + ms_after = 1.5 + + nbefore = int(ms_before * recording.sampling_frequency / 1000.0) + nafter = int(ms_after * recording.sampling_frequency / 1000.0) + + spikes = sorting.to_spike_vector() + # take one spikes every 10 + spikes = spikes[::10] + + job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") + + templates = estimate_templates( + recording, spikes, sorting.unit_ids, nbefore, nafter, return_scaled=True, **job_kwargs + ) + print(templates.shape) + assert templates.shape[0] == sorting.unit_ids.size + assert templates.shape[1] == nbefore + nafter + assert templates.shape[2] == recording.get_num_channels() + + assert np.any(templates != 0) + + # import matplotlib.pyplot as plt + # fig, ax = plt.subplots() + # for unit_index, unit_id in enumerate(sorting.unit_ids): + # ax.plot(templates[unit_index, :, :].T.flatten()) + # plt.show() + + if __name__ == "__main__": - test_waveform_tools() + # test_waveform_tools() + test_estimate_templates() diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index 87cfc0e9e6..e1c4f91c5b 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -12,7 +12,11 @@ from pathlib import Path import numpy as np +import multiprocessing +from spikeinterface.core.baserecording import BaseRecording + +from .baserecording import BaseRecording from .job_tools import ChunkRecordingExecutor, _shared_job_kwargs_doc from .core_tools import make_shared_array from .job_tools import fix_job_kwargs @@ -692,3 +696,187 @@ def has_exceeding_spikes(recording, sorting): if spike_vector_seg["sample_index"][-1] > recording.get_num_samples(segment_index=segment_index) - 1: return True return False + + +def estimate_templates( + recording: BaseRecording, + spikes: np.ndarray, + unit_ids: list | np.ndarray, + nbefore: int, + nafter: int, + return_scaled: bool = True, + **job_kwargs, +): + """ + This is a fast implementation to compute average templates. + This is useful to estimate sparsity without the need to allocate large waveform buffers. + The mechanism is pretty simple: it accumulates and sums spike waveforms in-place per worker and per unit. + Note that std, median and percentiles can't be computed with this method. + + Parameters + ---------- + recording: BaseRecording + The recording object + spikes: 1d numpy array with several fields + Spikes handled as a unique vector. + This vector can be obtained with: `spikes = sorting.to_spike_vector()` + unit_ids: list ot numpy + List of unit_ids + nbefore: int + Number of samples to cut out before a spike + nafter: int + Number of samples to cut out after a spike + return_scaled: bool, default: True + If True, the traces are scaled before averaging + + Returns + ------- + templates_array: np.array + The average templates with shape (num_units, nbefore + nafter, num_channels) + """ + + assert spikes.size > 0, "estimate_templates() need non empty sorting" + + job_kwargs = fix_job_kwargs(job_kwargs) + num_worker = job_kwargs["n_jobs"] + + num_chans = recording.get_num_channels() + num_units = len(unit_ids) + + shape = (num_worker, num_units, nbefore + nafter, num_chans) + dtype = np.dtype("float32") + waveforms_per_worker, shm = make_shared_array(shape, dtype) + shm_name = shm.name + + # trick to get the work_index given pid arrays + lock = multiprocessing.Lock() + array_pid = multiprocessing.Array("i", num_worker) + for i in range(num_worker): + array_pid[i] = -1 + + func = _worker_estimate_templates + init_func = _init_worker_estimate_templates + + init_args = ( + recording, + spikes, + shm_name, + shape, + dtype, + nbefore, + nafter, + return_scaled, + lock, + array_pid, + ) + + processor = ChunkRecordingExecutor( + recording, func, init_func, init_args, job_name="estimate_templates", **job_kwargs + ) + processor.run() + + # average + templates_array = np.sum(waveforms_per_worker, axis=0) + unit_indices, spike_count = np.unique(spikes["unit_index"], return_counts=True) + templates_array[unit_indices, :, :] /= spike_count[:, np.newaxis, np.newaxis] + + # important : release the sharemem + del waveforms_per_worker + shm.unlink() + shm.close() + + return templates_array + + +def _init_worker_estimate_templates( + recording, + spikes, + shm_name, + shape, + dtype, + nbefore, + nafter, + return_scaled, + lock, + array_pid, +): + worker_ctx = {} + worker_ctx["recording"] = recording + worker_ctx["spikes"] = spikes + worker_ctx["nbefore"] = nbefore + worker_ctx["nafter"] = nafter + worker_ctx["return_scaled"] = return_scaled + + from multiprocessing.shared_memory import SharedMemory + import multiprocessing + + shm = SharedMemory(shm_name) + waveforms_per_worker = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf) + worker_ctx["shm"] = shm + worker_ctx["waveforms_per_worker"] = waveforms_per_worker + + # prepare segment slices + segment_slices = [] + for segment_index in range(recording.get_num_segments()): + s0, s1 = np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1]) + segment_slices.append((s0, s1)) + worker_ctx["segment_slices"] = segment_slices + + child_process = multiprocessing.current_process() + + lock.acquire() + num_worker = None + for i in range(len(array_pid)): + if array_pid[i] == -1: + num_worker = i + array_pid[i] = child_process.ident + break + worker_ctx["worker_index"] = num_worker + lock.release() + + return worker_ctx + + +# used by ChunkRecordingExecutor +def _worker_estimate_templates(segment_index, start_frame, end_frame, worker_ctx): + # recover variables of the worker + recording = worker_ctx["recording"] + segment_slices = worker_ctx["segment_slices"] + spikes = worker_ctx["spikes"] + nbefore = worker_ctx["nbefore"] + nafter = worker_ctx["nafter"] + waveforms_per_worker = worker_ctx["waveforms_per_worker"] + worker_index = worker_ctx["worker_index"] + return_scaled = worker_ctx["return_scaled"] + + seg_size = recording.get_num_samples(segment_index=segment_index) + + s0, s1 = segment_slices[segment_index] + in_seg_spikes = spikes[s0:s1] + + # take only spikes in range [start_frame, end_frame] + # this is a slice so no copy!! + # the border of segment are protected by nbefore on left an nafter on the right + i0, i1 = np.searchsorted( + in_seg_spikes["sample_index"], [max(start_frame, nbefore), min(end_frame, seg_size - nafter)] + ) + + # slice in absolut in spikes vector + l0 = i0 + s0 + l1 = i1 + s0 + + if l1 > l0: + start = spikes[l0]["sample_index"] - nbefore + end = spikes[l1 - 1]["sample_index"] + nafter + + # load trace in memory + traces = recording.get_traces( + start_frame=start, end_frame=end, segment_index=segment_index, return_scaled=return_scaled + ) + + for spike_index in range(l0, l1): + sample_index = spikes[spike_index]["sample_index"] + unit_index = spikes[spike_index]["unit_index"] + wf = traces[sample_index - start - nbefore : sample_index - start + nafter, :] + + waveforms_per_worker[worker_index, unit_index, :, :] += wf