Skip to content

Commit

Permalink
Merge pull request #2885 from alejoe91/peak-to-peak-extremum
Browse files Browse the repository at this point in the history
Add peak_to_peak mode to get_templates_amplitude
  • Loading branch information
samuelgarcia authored May 21, 2024
2 parents 641be70 + 9331d86 commit 606a1cf
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 25 deletions.
51 changes: 29 additions & 22 deletions src/spikeinterface/core/template_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _get_nbefore(one_object):
def get_template_amplitudes(
templates_or_sorting_analyzer,
peak_sign: "neg" | "pos" | "both" = "neg",
mode: "extremum" | "at_index" = "extremum",
mode: "extremum" | "at_index" | "peak_to_peak" = "extremum",
return_scaled: bool = True,
abs_value: bool = True,
):
Expand All @@ -67,11 +67,13 @@ def get_template_amplitudes(
----------
templates_or_sorting_analyzer: Templates | SortingAnalyzer
A Templates or a SortingAnalyzer object
peak_sign: "neg" | "pos" | "both", default: "neg"
Sign of the template to compute best channels
mode: "extremum" | "at_index", default: "extremum"
"extremum": max or min
"at_index": take value at spike index
peak_sign: "neg" | "pos" | "both"
Sign of the template to find extremum channels
mode: "extremum" | "at_index" | "peak_to_peak", default: "at_index"
Where the amplitude is computed
* "extremum": take the peak value (max or min depending on `peak_sign`)
* "at_index": take value at `nbefore` index
* "peak_to_peak": take the peak-to-peak amplitude
return_scaled: bool, default True
The amplitude is scaled or not.
abs_value: bool = True
Expand All @@ -83,7 +85,7 @@ def get_template_amplitudes(
Dictionary with unit ids as keys and template amplitudes as values
"""
assert peak_sign in ("both", "neg", "pos"), "'peak_sign' must be 'both', 'neg', or 'pos'"
assert mode in ("extremum", "at_index"), "'mode' must be 'extremum' or 'at_index'"
assert mode in ("extremum", "at_index", "peak_to_peak"), "'mode' must be 'extremum', 'at_index', or 'peak_to_peak'"

unit_ids = templates_or_sorting_analyzer.unit_ids
before = _get_nbefore(templates_or_sorting_analyzer)
Expand All @@ -107,6 +109,8 @@ def get_template_amplitudes(
values = np.abs(template[before, :])
elif peak_sign in ["neg", "pos"]:
values = template[before, :]
elif mode == "peak_to_peak":
values = np.ptp(template, axis=0)

if abs_value:
values = np.abs(values)
Expand All @@ -119,7 +123,7 @@ def get_template_amplitudes(
def get_template_extremum_channel(
templates_or_sorting_analyzer,
peak_sign: "neg" | "pos" | "both" = "neg",
mode: "extremum" | "at_index" = "extremum",
mode: "extremum" | "at_index" | "peak_to_peak" = "extremum",
outputs: "id" | "index" = "id",
):
"""
Expand All @@ -129,11 +133,13 @@ def get_template_extremum_channel(
----------
templates_or_sorting_analyzer: Templates | SortingAnalyzer
A Templates or a SortingAnalyzer object
peak_sign: "neg" | "pos" | "both", default: "neg"
Sign of the template to compute best channels
mode: "extremum" | "at_index", default: "extremum"
"extremum": max or min
"at_index": take value at spike index
peak_sign: "neg" | "pos" | "both"
Sign of the template to find extremum channels
mode: "extremum" | "at_index" | "peak_to_peak", default: "at_index"
Where the amplitude is computed
* "extremum": take the peak value (max or min depending on `peak_sign`)
* "at_index": take value at `nbefore` index
* "peak_to_peak": take the peak-to-peak amplitude
outputs: "id" | "index", default: "id"
* "id": channel id
* "index": channel index
Expand All @@ -145,7 +151,7 @@ def get_template_extremum_channel(
as values
"""
assert peak_sign in ("both", "neg", "pos"), "`peak_sign` must be one of `both`, `neg`, or `pos`"
assert mode in ("extremum", "at_index"), "`mode` must be either `extremum` or `at_index`"
assert mode in ("extremum", "at_index", "peak_to_peak"), "'mode' must be 'extremum', 'at_index', or 'peak_to_peak'"
assert outputs in ("id", "index"), "`outputs` must be either `id` or `index`"

unit_ids = templates_or_sorting_analyzer.unit_ids
Expand Down Expand Up @@ -184,8 +190,8 @@ def get_template_extremum_channel_peak_shift(templates_or_sorting_analyzer, peak
----------
templates_or_sorting_analyzer: Templates | SortingAnalyzer
A Templates or a SortingAnalyzer object
peak_sign: "neg" | "pos" | "both", default: "neg"
Sign of the template to compute best channels
peak_sign: "neg" | "pos" | "both"
Sign of the template to find extremum channels
Returns
-------
Expand Down Expand Up @@ -230,7 +236,7 @@ def get_template_extremum_channel_peak_shift(templates_or_sorting_analyzer, peak
def get_template_extremum_amplitude(
templates_or_sorting_analyzer,
peak_sign: "neg" | "pos" | "both" = "neg",
mode: "extremum" | "at_index" = "at_index",
mode: "extremum" | "at_index" | "peak_to_peak" = "at_index",
abs_value: bool = True,
):
"""
Expand All @@ -241,11 +247,12 @@ def get_template_extremum_amplitude(
templates_or_sorting_analyzer: Templates | SortingAnalyzer
A Templates or a SortingAnalyzer object
peak_sign: "neg" | "pos" | "both"
Sign of the template to compute best channels
mode: "extremum" | "at_index", default: "at_index"
Sign of the template to find extremum channels
mode: "extremum" | "at_index" | "peak_to_peak", default: "at_index"
Where the amplitude is computed
"extremum": max or min
"at_index": take value at spike index
* "extremum": take the peak value (max or min depending on `peak_sign`)
* "at_index": take value at `nbefore` index
* "peak_to_peak": take the peak-to-peak amplitude
abs_value: bool = True
Whether the extremum amplitude should be returned as an absolute value or not
Expand All @@ -256,7 +263,7 @@ def get_template_extremum_amplitude(
Dictionary with unit ids as keys and amplitudes as values
"""
assert peak_sign in ("both", "neg", "pos"), "'peak_sign' must be 'neg' or 'pos' or 'both'"
assert mode in ("extremum", "at_index"), "'mode' must be 'extremum' or 'at_index'"
assert mode in ("extremum", "at_index", "peak_to_peak"), "'mode' must be 'extremum', 'at_index', or 'peak_to_peak'"
unit_ids = templates_or_sorting_analyzer.unit_ids
channel_ids = templates_or_sorting_analyzer.channel_ids

Expand Down
7 changes: 4 additions & 3 deletions src/spikeinterface/core/tests/test_template_tools.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
import numpy as np

from spikeinterface.core import generate_ground_truth_recording, create_sorting_analyzer

Expand Down Expand Up @@ -54,10 +55,10 @@ def _get_templates_object_from_sorting_analyzer(sorting_analyzer):

def test_get_template_amplitudes(sorting_analyzer):
peak_values = get_template_amplitudes(sorting_analyzer)
print(peak_values)
templates = _get_templates_object_from_sorting_analyzer(sorting_analyzer)
peak_values = get_template_amplitudes(templates)
print(peak_values)
peak_values = get_template_amplitudes(templates, abs_value=True)
peak_to_peak_values = get_template_amplitudes(templates, mode="peak_to_peak")
assert np.all(ptp > p for ptp, p in zip(peak_to_peak_values.values(), peak_values.values()))


def test_get_template_extremum_channel(sorting_analyzer):
Expand Down

0 comments on commit 606a1cf

Please sign in to comment.