Skip to content

Commit

Permalink
Merge pull request #2752 from zm711/template-tools
Browse files Browse the repository at this point in the history
Check for SortingAnalyzer `return_scaled` in template_tools
  • Loading branch information
alejoe91 authored Apr 29, 2024
2 parents c3bee78 + 7352ba7 commit c6f8f1f
Showing 1 changed file with 29 additions and 6 deletions.
35 changes: 29 additions & 6 deletions src/spikeinterface/core/template_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,14 +140,23 @@ def get_template_extremum_channel(
Dictionary with unit ids as keys and extremum channels (id or index based on "outputs")
as values
"""
assert peak_sign in ("both", "neg", "pos")
assert mode in ("extremum", "at_index")
assert outputs in ("id", "index")
assert peak_sign in ("both", "neg", "pos"), "`peak_sign` must be one of `both`, `neg`, or `pos`"
assert mode in ("extremum", "at_index"), "`mode` must be either `extremum` or `at_index`"
assert outputs in ("id", "index"), "`outputs` must be either `id` or `index`"

unit_ids = templates_or_sorting_analyzer.unit_ids
channel_ids = templates_or_sorting_analyzer.channel_ids

peak_values = get_template_amplitudes(templates_or_sorting_analyzer, peak_sign=peak_sign, mode=mode)
# if SortingAnalyzer need to use global SortingAnalyzer return_scaled otherwise
# we just use the previous default of return_scaled=True (for templates)
if isinstance(templates_or_sorting_analyzer, SortingAnalyzer):
return_scaled = templates_or_sorting_analyzer.return_scaled
else:
return_scaled = True

peak_values = get_template_amplitudes(
templates_or_sorting_analyzer, peak_sign=peak_sign, mode=mode, return_scaled=return_scaled
)
extremum_channels_id = {}
extremum_channels_index = {}
for unit_id in unit_ids:
Expand Down Expand Up @@ -187,7 +196,14 @@ def get_template_extremum_channel_peak_shift(templates_or_sorting_analyzer, peak

shifts = {}

templates_array = get_dense_templates_array(templates_or_sorting_analyzer)
# We need to use the SortingAnalyzer return_scaled if possible
# otherwise for Templates default to True
if isinstance(templates_or_sorting_analyzer, SortingAnalyzer):
return_scaled = templates_or_sorting_analyzer.return_scaled
else:
return_scaled = True

templates_array = get_dense_templates_array(templates_or_sorting_analyzer, return_scaled=return_scaled)

for unit_ind, unit_id in enumerate(unit_ids):
template = templates_array[unit_ind, :, :]
Expand Down Expand Up @@ -238,7 +254,14 @@ def get_template_extremum_amplitude(

extremum_channels_ids = get_template_extremum_channel(templates_or_sorting_analyzer, peak_sign=peak_sign, mode=mode)

extremum_amplitudes = get_template_amplitudes(templates_or_sorting_analyzer, peak_sign=peak_sign, mode=mode)
if isinstance(templates_or_sorting_analyzer, SortingAnalyzer):
return_scaled = templates_or_sorting_analyzer.return_scaled
else:
return_scaled = True

extremum_amplitudes = get_template_amplitudes(
templates_or_sorting_analyzer, peak_sign=peak_sign, mode=mode, return_scaled=return_scaled
)

unit_amplitudes = {}
for unit_id in unit_ids:
Expand Down

0 comments on commit c6f8f1f

Please sign in to comment.