diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index ec76fcbaa9..05c1ebc7ed 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1448,7 +1448,7 @@ def generate_templates( mode="ellipsoid", ): """ - Generate some templates from the given channel positions and neuron position.s + Generate some templates from the given channel positions and neuron positions. The implementation is very naive : it generates a mono channel waveform using generate_single_fake_waveform() and duplicates this same waveform on all channel given a simple decay law per unit. diff --git a/src/spikeinterface/core/tests/test_template_tools.py b/src/spikeinterface/core/tests/test_template_tools.py index f79c830db6..6ef8267742 100644 --- a/src/spikeinterface/core/tests/test_template_tools.py +++ b/src/spikeinterface/core/tests/test_template_tools.py @@ -47,6 +47,7 @@ def _get_templates_object_from_sorting_analyzer(sorting_analyzer): sparsity_mask=None, channel_ids=sorting_analyzer.channel_ids, unit_ids=sorting_analyzer.unit_ids, + is_scaled=sorting_analyzer.return_scaled, ) return templates diff --git a/src/spikeinterface/generation/drifting_generator.py b/src/spikeinterface/generation/drifting_generator.py index 8a658cd97d..7f617c3ade 100644 --- a/src/spikeinterface/generation/drifting_generator.py +++ b/src/spikeinterface/generation/drifting_generator.py @@ -404,6 +404,7 @@ def generate_drifting_recording( sampling_frequency=sampling_frequency, nbefore=nbefore, probe=probe, + is_scaled=True, ) drifting_templates = DriftingTemplates.from_static(templates) diff --git a/src/spikeinterface/generation/tests/test_drift_tools.py b/src/spikeinterface/generation/tests/test_drift_tools.py index ab03b30d82..e64e64ffda 100644 --- a/src/spikeinterface/generation/tests/test_drift_tools.py +++ b/src/spikeinterface/generation/tests/test_drift_tools.py @@ -73,6 +73,7 @@ def make_some_templates(): sampling_frequency=sampling_frequency, nbefore=nbefore, probe=probe, + is_scaled=True, ) return templates diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index ba6870eef2..6575aba15e 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -1,4 +1,5 @@ from __future__ import annotations +from operator import is_ from .si_based import ComponentsBasedSorter @@ -250,13 +251,14 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ) templates = Templates( - templates_array, - sampling_frequency, - nbefore, - None, - recording_w.channel_ids, - unit_ids, - recording_w.get_probe(), + templates_array=templates_array, + sampling_frequency=sampling_frequency, + nbefore=nbefore, + sparsity_mask=None, + channel_ids=recording_w.channel_ids, + unit_ids=unit_ids, + probe=recording_w.get_probe(), + is_scaled=False, ) sparsity = compute_sparsity(templates, noise_levels, **params["sparsity"]) diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index c2b9f4cfc7..e07924b196 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -191,7 +191,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): templates_array=templates_array, sampling_frequency=sampling_frequency, nbefore=nbefore, + sparsity_mask=None, probe=recording_w.get_probe(), + is_scaled=False, ) # TODO : try other methods for sparsity # sparsity = compute_sparsity(templates_dense, method="radius", radius_um=120.) diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/common_benchmark_testing.py b/src/spikeinterface/sortingcomponents/benchmark/tests/common_benchmark_testing.py index 3401e36dd0..313f19537e 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/common_benchmark_testing.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/common_benchmark_testing.py @@ -77,6 +77,7 @@ def compute_gt_templates(recording, gt_sorting, ms_before=2.0, ms_after=3.0, ret channel_ids=recording.channel_ids, unit_ids=gt_sorting.unit_ids, probe=recording.get_probe(), + is_scaled=return_scaled, ) return gt_templates diff --git a/src/spikeinterface/sortingcomponents/clustering/position_and_features.py b/src/spikeinterface/sortingcomponents/clustering/position_and_features.py index d24af3c175..a07a6140e1 100644 --- a/src/spikeinterface/sortingcomponents/clustering/position_and_features.py +++ b/src/spikeinterface/sortingcomponents/clustering/position_and_features.py @@ -184,7 +184,12 @@ def main_function(cls, recording, peaks, params): **params["job_kwargs"], ) templates = Templates( - templates_array=templates_array, sampling_frequency=fs, nbefore=nbefore, probe=recording.get_probe() + templates_array=templates_array, + sampling_frequency=fs, + nbefore=nbefore, + sparsity_mask=None, + probe=recording.get_probe(), + is_scaled=False, ) labels, peak_labels = remove_duplicates_via_matching( diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 06dfd994f3..cf0d22c0c8 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -137,4 +137,5 @@ def remove_empty_templates(templates): channel_ids=templates.channel_ids, unit_ids=templates.unit_ids[not_empty], probe=templates.probe, + is_scaled=templates.is_scaled, )