diff --git a/pyproject.toml b/pyproject.toml index 8fc4c449f7..cb71bb499b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,9 +23,10 @@ dependencies = [ "numpy", "threadpoolctl>=3.0.0", "tqdm", - "zarr>=0.2.16", + "zarr>=2.16,<2.18", "neo>=0.13.0", "probeinterface>=0.2.21", + "packaging", ] [build-system] diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 3a4561822e..e5691603ac 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -499,6 +499,7 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save channel_ids=self.sorting_analyzer.channel_ids, unit_ids=unit_ids, probe=self.sorting_analyzer.get_probe(), + is_scaled=self.sorting_analyzer.return_scaled, ) else: raise ValueError("outputs must be numpy or Templates") diff --git a/src/spikeinterface/core/channelsaggregationrecording.py b/src/spikeinterface/core/channelsaggregationrecording.py index 9ec7ca8d4d..b8735dff3c 100644 --- a/src/spikeinterface/core/channelsaggregationrecording.py +++ b/src/spikeinterface/core/channelsaggregationrecording.py @@ -14,18 +14,31 @@ class ChannelsAggregationRecording(BaseRecording): """ def __init__(self, recording_list, renamed_channel_ids=None): + + # Generate a default list of channel ids that are unique and consecutive numbers as strings. channel_map = {} + num_all_channels = sum(rec.get_num_channels() for rec in recording_list) - num_all_channels = sum([rec.get_num_channels() for rec in recording_list]) if renamed_channel_ids is not None: - assert len(np.unique(renamed_channel_ids)) == num_all_channels, ( - "'renamed_channel_ids' doesn't have the " "right size or has duplicates!" - ) + assert ( + len(np.unique(renamed_channel_ids)) == num_all_channels + ), "'renamed_channel_ids' doesn't have the right size or has duplicates!" channel_ids = list(renamed_channel_ids) else: - channel_ids = list(np.arange(num_all_channels)) + # Collect channel IDs from all recordings + all_channels_have_same_type = np.unique([rec.channel_ids.dtype for rec in recording_list]).size == 1 + all_channel_ids_are_unique = False + if all_channels_have_same_type: + combined_ids = np.concatenate([rec.channel_ids for rec in recording_list]) + all_channel_ids_are_unique = np.unique(combined_ids).size == num_all_channels + + if all_channels_have_same_type and all_channel_ids_are_unique: + channel_ids = combined_ids + else: + # If IDs are not unique or not of the same type, use default as stringify IDs + default_channel_ids = [str(i) for i in range(num_all_channels)] + channel_ids = default_channel_ids - # channel map maps channel indices that are used to get traces ch_id = 0 for r_i, recording in enumerate(recording_list): single_channel_ids = recording.get_channel_ids() @@ -49,7 +62,9 @@ def __init__(self, recording_list, renamed_channel_ids=None): break if not (ok1 and ok2 and ok3 and ok4): - raise ValueError("Sortings don't have the same sampling_frequency/num_segments/dtype/num samples") + raise ValueError( + "Recordings do not have consistent sampling frequency, number of segments, data type, or number of samples." + ) BaseRecording.__init__(self, sampling_frequency, channel_ids, dtype) @@ -91,7 +106,7 @@ def __init__(self, recording_list, renamed_channel_ids=None): self.add_recording_segment(sub_segment) self._recordings = recording_list - self._kwargs = {"recording_list": [rec for rec in recording_list], "renamed_channel_ids": renamed_channel_ids} + self._kwargs = {"recording_list": recording_list, "renamed_channel_ids": renamed_channel_ids} @property def recordings(self): @@ -173,11 +188,11 @@ def aggregate_channels(recording_list, renamed_channel_ids=None): recording_list: list List of BaseRecording objects to aggregate renamed_channel_ids: array-like - If given, channel ids are renamed as provided. If None, unit ids are sequential integers. + If given, channel ids are renamed as provided. Returns ------- - aggregate_recording: UnitsAggregationSorting - The aggregated sorting object + aggregate_recording: ChannelsAggregationRecording + The aggregated recording object """ return ChannelsAggregationRecording(recording_list, renamed_channel_ids) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index ec76fcbaa9..05c1ebc7ed 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1448,7 +1448,7 @@ def generate_templates( mode="ellipsoid", ): """ - Generate some templates from the given channel positions and neuron position.s + Generate some templates from the given channel positions and neuron positions. The implementation is very naive : it generates a mono channel waveform using generate_single_fake_waveform() and duplicates this same waveform on all channel given a simple decay law per unit. diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 3ce20a0209..d57df2b5ae 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -195,7 +195,7 @@ def __init__( ): # very fast init because checks are done in load and create self.sorting = sorting - # self.recorsding will be a property + # self.recording will be a property self._recording = recording self.rec_attributes = rec_attributes self.format = format diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 51688709b2..4eb82be2d6 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -30,9 +30,11 @@ class Templates: Array of unit IDs. If `None`, defaults to an array of increasing integers. probe: Probe, default: None A `probeinterface.Probe` object + is_scaled : bool, optional default: True + If True, it means that the templates are in uV, otherwise they are in raw ADC values. check_for_consistent_sparsity : bool, optional default: None When passing a sparsity_mask, this checks that the templates array is also sparse and that it matches the - structure fo the sparsity_masl. + structure of the sparsity_mask. If False, this check is skipped. The following attributes are available after construction: @@ -58,6 +60,7 @@ class Templates: templates_array: np.ndarray sampling_frequency: float nbefore: int + is_scaled: bool = True sparsity_mask: np.ndarray = None channel_ids: np.ndarray = None @@ -193,6 +196,7 @@ def to_dict(self): "unit_ids": self.unit_ids, "sampling_frequency": self.sampling_frequency, "nbefore": self.nbefore, + "is_scaled": self.is_scaled, "probe": self.probe.to_dict() if self.probe is not None else None, } @@ -205,6 +209,7 @@ def from_dict(cls, data): unit_ids=np.asarray(data["unit_ids"]), sampling_frequency=data["sampling_frequency"], nbefore=data["nbefore"], + is_scaled=data["is_scaled"], probe=data["probe"] if data["probe"] is None else Probe.from_dict(data["probe"]), ) @@ -238,6 +243,7 @@ def add_templates_to_zarr_group(self, zarr_group: "zarr.Group") -> None: zarr_group.attrs["sampling_frequency"] = self.sampling_frequency zarr_group.attrs["nbefore"] = self.nbefore + zarr_group.attrs["is_scaled"] = self.is_scaled if self.sparsity_mask is not None: zarr_group.create_dataset("sparsity_mask", data=self.sparsity_mask) diff --git a/src/spikeinterface/core/template_tools.py b/src/spikeinterface/core/template_tools.py index 11581a49ce..fead805ab5 100644 --- a/src/spikeinterface/core/template_tools.py +++ b/src/spikeinterface/core/template_tools.py @@ -58,6 +58,7 @@ def get_template_amplitudes( peak_sign: "neg" | "pos" | "both" = "neg", mode: "extremum" | "at_index" = "extremum", return_scaled: bool = True, + abs_value: bool = True, ): """ Get amplitude per channel for each unit. @@ -73,6 +74,8 @@ def get_template_amplitudes( "at_index": take value at spike index return_scaled: bool, default True The amplitude is scaled or not. + abs_value: bool = True + Whether the extremum amplitude should be returned as an absolute value or not Returns ------- @@ -96,17 +99,18 @@ def get_template_amplitudes( if peak_sign == "both": values = np.max(np.abs(template), axis=0) elif peak_sign == "neg": - values = -np.min(template, axis=0) + values = np.min(template, axis=0) elif peak_sign == "pos": values = np.max(template, axis=0) elif mode == "at_index": if peak_sign == "both": values = np.abs(template[before, :]) - elif peak_sign == "neg": - values = -template[before, :] - elif peak_sign == "pos": + elif peak_sign in ["neg", "pos"]: values = template[before, :] + if abs_value: + values = np.abs(values) + peak_values[unit_id] = values return peak_values @@ -160,7 +164,7 @@ def get_template_extremum_channel( extremum_channels_id = {} extremum_channels_index = {} for unit_id in unit_ids: - max_ind = np.argmax(peak_values[unit_id]) + max_ind = np.argmax(np.abs(peak_values[unit_id])) extremum_channels_id[unit_id] = channel_ids[max_ind] extremum_channels_index[unit_id] = max_ind @@ -227,6 +231,7 @@ def get_template_extremum_amplitude( templates_or_sorting_analyzer, peak_sign: "neg" | "pos" | "both" = "neg", mode: "extremum" | "at_index" = "at_index", + abs_value: bool = True, ): """ Computes amplitudes on the best channel. @@ -241,6 +246,9 @@ def get_template_extremum_amplitude( Where the amplitude is computed "extremum": max or min "at_index": take value at spike index + abs_value: bool = True + Whether the extremum amplitude should be returned as an absolute value or not + Returns ------- @@ -260,7 +268,7 @@ def get_template_extremum_amplitude( return_scaled = True extremum_amplitudes = get_template_amplitudes( - templates_or_sorting_analyzer, peak_sign=peak_sign, mode=mode, return_scaled=return_scaled + templates_or_sorting_analyzer, peak_sign=peak_sign, mode=mode, return_scaled=return_scaled, abs_value=abs_value ) unit_amplitudes = {} diff --git a/src/spikeinterface/core/tests/test_channelsaggregationrecording.py b/src/spikeinterface/core/tests/test_channelsaggregationrecording.py index 944c3aa0b2..118b6092a9 100644 --- a/src/spikeinterface/core/tests/test_channelsaggregationrecording.py +++ b/src/spikeinterface/core/tests/test_channelsaggregationrecording.py @@ -23,7 +23,6 @@ def test_channelsaggregationrecording(): # test num channels recording_agg = aggregate_channels([recording1, recording2, recording3]) - print(recording_agg) assert len(recording_agg.get_channel_ids()) == 3 * num_channels assert np.allclose(recording_agg.get_times(0), recording1.get_times(0)) @@ -37,21 +36,21 @@ def test_channelsaggregationrecording(): traces2_0 = recording2.get_traces(channel_ids=[channel_ids[0]], segment_index=seg) traces3_2 = recording3.get_traces(channel_ids=[channel_ids[2]], segment_index=seg) - assert np.allclose(traces1_1, recording_agg.get_traces(channel_ids=[channel_ids[1]], segment_index=seg)) + assert np.allclose(traces1_1, recording_agg.get_traces(channel_ids=[str(channel_ids[1])], segment_index=seg)) assert np.allclose( - traces2_0, recording_agg.get_traces(channel_ids=[num_channels + channel_ids[0]], segment_index=seg) + traces2_0, recording_agg.get_traces(channel_ids=[str(num_channels + channel_ids[0])], segment_index=seg) ) assert np.allclose( - traces3_2, recording_agg.get_traces(channel_ids=[2 * num_channels + channel_ids[2]], segment_index=seg) + traces3_2, recording_agg.get_traces(channel_ids=[str(2 * num_channels + channel_ids[2])], segment_index=seg) ) # all traces traces1 = recording1.get_traces(segment_index=seg) traces2 = recording2.get_traces(segment_index=seg) traces3 = recording3.get_traces(segment_index=seg) - assert np.allclose(traces1, recording_agg.get_traces(channel_ids=[0, 1, 2], segment_index=seg)) - assert np.allclose(traces2, recording_agg.get_traces(channel_ids=[3, 4, 5], segment_index=seg)) - assert np.allclose(traces3, recording_agg.get_traces(channel_ids=[6, 7, 8], segment_index=seg)) + assert np.allclose(traces1, recording_agg.get_traces(channel_ids=["0", "1", "2"], segment_index=seg)) + assert np.allclose(traces2, recording_agg.get_traces(channel_ids=["3", "4", "5"], segment_index=seg)) + assert np.allclose(traces3, recording_agg.get_traces(channel_ids=["6", "7", "8"], segment_index=seg)) # test rename channels renamed_channel_ids = [f"#Channel {i}" for i in range(3 * num_channels)] @@ -81,5 +80,41 @@ def test_channelsaggregationrecording(): print(recording_agg_prop.get_property("brain_area")) +def test_channel_aggregation_preserve_ids(): + + recording1 = generate_recording(num_channels=3, durations=[10], set_probe=False) # To avoid location check + recording1 = recording1.rename_channels(new_channel_ids=["a", "b", "c"]) + recording2 = generate_recording(num_channels=2, durations=[10], set_probe=False) + recording2 = recording2.rename_channels(new_channel_ids=["d", "e"]) + + aggregated_recording = aggregate_channels([recording1, recording2]) + assert aggregated_recording.get_num_channels() == 5 + assert list(aggregated_recording.get_channel_ids()) == ["a", "b", "c", "d", "e"] + + +def test_channel_aggregation_does_not_preserve_ids_if_not_unique(): + + recording1 = generate_recording(num_channels=3, durations=[10], set_probe=False) # To avoid location check + recording1 = recording1.rename_channels(new_channel_ids=["a", "b", "c"]) + recording2 = generate_recording(num_channels=2, durations=[10], set_probe=False) + recording2 = recording2.rename_channels(new_channel_ids=["a", "b"]) + + aggregated_recording = aggregate_channels([recording1, recording2]) + assert aggregated_recording.get_num_channels() == 5 + assert list(aggregated_recording.get_channel_ids()) == ["0", "1", "2", "3", "4"] + + +def test_channel_aggregation_does_not_preserve_ids_not_the_same_type(): + + recording1 = generate_recording(num_channels=3, durations=[10], set_probe=False) # To avoid location check + recording1 = recording1.rename_channels(new_channel_ids=["a", "b", "c"]) + recording2 = generate_recording(num_channels=2, durations=[10], set_probe=False) + recording2 = recording2.rename_channels(new_channel_ids=[1, 2]) + + aggregated_recording = aggregate_channels([recording1, recording2]) + assert aggregated_recording.get_num_channels() == 5 + assert list(aggregated_recording.get_channel_ids()) == ["0", "1", "2", "3", "4"] + + if __name__ == "__main__": test_channelsaggregationrecording() diff --git a/src/spikeinterface/core/tests/test_template_tools.py b/src/spikeinterface/core/tests/test_template_tools.py index f79c830db6..6ef8267742 100644 --- a/src/spikeinterface/core/tests/test_template_tools.py +++ b/src/spikeinterface/core/tests/test_template_tools.py @@ -47,6 +47,7 @@ def _get_templates_object_from_sorting_analyzer(sorting_analyzer): sparsity_mask=None, channel_ids=sorting_analyzer.channel_ids, unit_ids=sorting_analyzer.unit_ids, + is_scaled=sorting_analyzer.return_scaled, ) return templates diff --git a/src/spikeinterface/extractors/neoextractors/spikegadgets.py b/src/spikeinterface/extractors/neoextractors/spikegadgets.py index e835edf961..7d6b492325 100644 --- a/src/spikeinterface/extractors/neoextractors/spikegadgets.py +++ b/src/spikeinterface/extractors/neoextractors/spikegadgets.py @@ -1,9 +1,10 @@ from __future__ import annotations - from pathlib import Path -import probeinterface +import packaging +import packaging.version +import probeinterface from spikeinterface.core.core_tools import define_function_from_class from .neobaseextractor import NeoBaseRecordingExtractor @@ -38,7 +39,9 @@ def __init__(self, file_path, stream_id=None, stream_name=None, block_index=None ) self._kwargs.update(dict(file_path=str(Path(file_path).absolute()), stream_id=stream_id)) - probegroup = probeinterface.read_spikegadgets(file_path, raise_error=False) + probegroup = None # TODO remove once probeinterface is updated to 0.2.22 in the pyproject.toml + if packaging.version.parse(probeinterface.__version__) > packaging.version.parse("0.2.21"): + probegroup = probeinterface.read_spikegadgets(file_path, raise_error=False) if probegroup is not None: self.set_probes(probegroup, in_place=True) diff --git a/src/spikeinterface/generation/drifting_generator.py b/src/spikeinterface/generation/drifting_generator.py index 8a658cd97d..7f617c3ade 100644 --- a/src/spikeinterface/generation/drifting_generator.py +++ b/src/spikeinterface/generation/drifting_generator.py @@ -404,6 +404,7 @@ def generate_drifting_recording( sampling_frequency=sampling_frequency, nbefore=nbefore, probe=probe, + is_scaled=True, ) drifting_templates = DriftingTemplates.from_static(templates) diff --git a/src/spikeinterface/generation/tests/test_drift_tools.py b/src/spikeinterface/generation/tests/test_drift_tools.py index ab03b30d82..e64e64ffda 100644 --- a/src/spikeinterface/generation/tests/test_drift_tools.py +++ b/src/spikeinterface/generation/tests/test_drift_tools.py @@ -73,6 +73,7 @@ def make_some_templates(): sampling_frequency=sampling_frequency, nbefore=nbefore, probe=probe, + is_scaled=True, ) return templates diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index ba6870eef2..6575aba15e 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -1,4 +1,5 @@ from __future__ import annotations +from operator import is_ from .si_based import ComponentsBasedSorter @@ -250,13 +251,14 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ) templates = Templates( - templates_array, - sampling_frequency, - nbefore, - None, - recording_w.channel_ids, - unit_ids, - recording_w.get_probe(), + templates_array=templates_array, + sampling_frequency=sampling_frequency, + nbefore=nbefore, + sparsity_mask=None, + channel_ids=recording_w.channel_ids, + unit_ids=unit_ids, + probe=recording_w.get_probe(), + is_scaled=False, ) sparsity = compute_sparsity(templates, noise_levels, **params["sparsity"]) diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index 6ac4d63e7e..c47cbb376d 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -220,7 +220,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): templates_array=templates_array, sampling_frequency=sampling_frequency, nbefore=nbefore, + sparsity_mask=None, probe=recording_for_peeler.get_probe(), + is_scaled=False, ) # TODO : try other methods for sparsity diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/common_benchmark_testing.py b/src/spikeinterface/sortingcomponents/benchmark/tests/common_benchmark_testing.py index 3401e36dd0..313f19537e 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/common_benchmark_testing.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/common_benchmark_testing.py @@ -77,6 +77,7 @@ def compute_gt_templates(recording, gt_sorting, ms_before=2.0, ms_after=3.0, ret channel_ids=recording.channel_ids, unit_ids=gt_sorting.unit_ids, probe=recording.get_probe(), + is_scaled=return_scaled, ) return gt_templates diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 4135bd4b6e..ce7a78e4c6 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -226,7 +226,14 @@ def main_function(cls, recording, peaks, params): ) templates = Templates( - templates_array, fs, nbefore, None, recording.channel_ids, unit_ids, recording.get_probe() + templates_array=templates_array, + sampling_frequency=fs, + nbefore=nbefore, + sparsity_mask=None, + channel_ids=recording.channel_ids, + unit_ids=unit_ids, + probe=recording.get_probe(), + is_scaled=False, ) if params["noise_levels"] is None: params["noise_levels"] = get_noise_levels(recording, return_scaled=False) diff --git a/src/spikeinterface/sortingcomponents/clustering/position_and_features.py b/src/spikeinterface/sortingcomponents/clustering/position_and_features.py index d24af3c175..a07a6140e1 100644 --- a/src/spikeinterface/sortingcomponents/clustering/position_and_features.py +++ b/src/spikeinterface/sortingcomponents/clustering/position_and_features.py @@ -184,7 +184,12 @@ def main_function(cls, recording, peaks, params): **params["job_kwargs"], ) templates = Templates( - templates_array=templates_array, sampling_frequency=fs, nbefore=nbefore, probe=recording.get_probe() + templates_array=templates_array, + sampling_frequency=fs, + nbefore=nbefore, + sparsity_mask=None, + probe=recording.get_probe(), + is_scaled=False, ) labels, peak_labels = remove_duplicates_via_matching( diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index efd63be55f..6c1ad75383 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -137,7 +137,14 @@ def main_function(cls, recording, peaks, params): ) templates = Templates( - templates_array, fs, nbefore, None, recording.channel_ids, unit_ids, recording.get_probe() + templates_array=templates_array, + sampling_frequency=fs, + nbefore=nbefore, + sparsity_mask=None, + channel_ids=recording.channel_ids, + unit_ids=unit_ids, + probe=recording.get_probe(), + is_scaled=False, ) if params["noise_levels"] is None: params["noise_levels"] = get_noise_levels(recording, return_scaled=False) diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 06dfd994f3..cf0d22c0c8 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -137,4 +137,5 @@ def remove_empty_templates(templates): channel_ids=templates.channel_ids, unit_ids=templates.unit_ids[not_empty], probe=templates.probe, + is_scaled=templates.is_scaled, )