Skip to content

Commit

Permalink
Add missing is_scaled
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 committed May 14, 2024
1 parent 4a8941b commit 08aa901
Show file tree
Hide file tree
Showing 9 changed files with 23 additions and 9 deletions.
2 changes: 1 addition & 1 deletion src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1448,7 +1448,7 @@ def generate_templates(
mode="ellipsoid",
):
"""
Generate some templates from the given channel positions and neuron position.s
Generate some templates from the given channel positions and neuron positions.
The implementation is very naive : it generates a mono channel waveform using generate_single_fake_waveform()
and duplicates this same waveform on all channel given a simple decay law per unit.
Expand Down
1 change: 1 addition & 0 deletions src/spikeinterface/core/tests/test_template_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def _get_templates_object_from_sorting_analyzer(sorting_analyzer):
sparsity_mask=None,
channel_ids=sorting_analyzer.channel_ids,
unit_ids=sorting_analyzer.unit_ids,
is_scaled=sorting_analyzer.return_scaled,
)
return templates

Expand Down
1 change: 1 addition & 0 deletions src/spikeinterface/generation/drifting_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ def generate_drifting_recording(
sampling_frequency=sampling_frequency,
nbefore=nbefore,
probe=probe,
is_scaled=True,
)

drifting_templates = DriftingTemplates.from_static(templates)
Expand Down
1 change: 1 addition & 0 deletions src/spikeinterface/generation/tests/test_drift_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def make_some_templates():
sampling_frequency=sampling_frequency,
nbefore=nbefore,
probe=probe,
is_scaled=True,
)

return templates
Expand Down
16 changes: 9 additions & 7 deletions src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
from operator import is_

from .si_based import ComponentsBasedSorter

Expand Down Expand Up @@ -250,13 +251,14 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
)

templates = Templates(
templates_array,
sampling_frequency,
nbefore,
None,
recording_w.channel_ids,
unit_ids,
recording_w.get_probe(),
templates_array=templates_array,
sampling_frequency=sampling_frequency,
nbefore=nbefore,
sparsity_mask=None,
channel_ids=recording_w.channel_ids,
unit_ids=unit_ids,
probe=recording_w.get_probe(),
is_scaled=False,
)

sparsity = compute_sparsity(templates, noise_levels, **params["sparsity"])
Expand Down
2 changes: 2 additions & 0 deletions src/spikeinterface/sorters/internal/tridesclous2.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
templates_array=templates_array,
sampling_frequency=sampling_frequency,
nbefore=nbefore,
sparsity_mask=None,
probe=recording_w.get_probe(),
is_scaled=False,
)
# TODO : try other methods for sparsity
# sparsity = compute_sparsity(templates_dense, method="radius", radius_um=120.)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def compute_gt_templates(recording, gt_sorting, ms_before=2.0, ms_after=3.0, ret
channel_ids=recording.channel_ids,
unit_ids=gt_sorting.unit_ids,
probe=recording.get_probe(),
is_scaled=return_scaled,
)
return gt_templates

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,12 @@ def main_function(cls, recording, peaks, params):
**params["job_kwargs"],
)
templates = Templates(
templates_array=templates_array, sampling_frequency=fs, nbefore=nbefore, probe=recording.get_probe()
templates_array=templates_array,
sampling_frequency=fs,
nbefore=nbefore,
sparsity_mask=None,
probe=recording.get_probe(),
is_scaled=False,
)

labels, peak_labels = remove_duplicates_via_matching(
Expand Down
1 change: 1 addition & 0 deletions src/spikeinterface/sortingcomponents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,4 +137,5 @@ def remove_empty_templates(templates):
channel_ids=templates.channel_ids,
unit_ids=templates.unit_ids[not_empty],
probe=templates.probe,
is_scaled=templates.is_scaled,
)

0 comments on commit 08aa901

Please sign in to comment.