diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 243c854bba..ce86f85e73 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -20,6 +20,7 @@ from spikeinterface.core.recording_tools import get_noise_levels, get_channel_distances from spikeinterface.sortingcomponents.peak_selection import select_peaks from spikeinterface.sortingcomponents.waveforms.temporal_pca import TemporalPCAProjection +from spikeinterface.sortingcomponents.waveforms.hanning_filter import HanningFilter from spikeinterface.core.template import Templates from spikeinterface.core.sparsity import compute_sparsity from spikeinterface.sortingcomponents.tools import remove_empty_templates @@ -125,11 +126,13 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): radius_um=radius_um, ) - node2 = TemporalPCAProjection( - recording, parents=[node0, node1], return_output=True, model_folder_path=model_folder + node2 = HanningFilter(recording, parents=[node0, node1], return_output=False) + + node3 = TemporalPCAProjection( + recording, parents=[node0, node2], return_output=True, model_folder_path=model_folder ) - pipeline_nodes = [node0, node1, node2] + pipeline_nodes = [node0, node1, node2, node3] if len(params["recursive_kwargs"]) == 0: from sklearn.decomposition import PCA diff --git a/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_hanning_filter.py b/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_hanning_filter.py new file mode 100644 index 0000000000..1b006af429 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_hanning_filter.py @@ -0,0 +1,33 @@ +import pytest + + +from spikeinterface.sortingcomponents.waveforms.hanning_filter import HanningFilter + +from spikeinterface.core.node_pipeline import ( + PeakRetriever, + ExtractDenseWaveforms, + run_node_pipeline, +) + + +def test_hanning_filter(generated_recording, detected_peaks, chunk_executor_kwargs): + recording = generated_recording + peaks = detected_peaks + + # Parameters + ms_before = 1.0 + ms_after = 1.0 + + # Node initialization + peak_retriever = PeakRetriever(recording, peaks) + + extract_waveforms = ExtractDenseWaveforms( + recording=recording, parents=[peak_retriever], ms_before=ms_before, ms_after=ms_after, return_output=True + ) + + hanning_filter = HanningFilter(recording=recording, parents=[peak_retriever, extract_waveforms]) + pipeline_nodes = [peak_retriever, extract_waveforms, hanning_filter] + + # Extract projected waveforms and compare + waveforms, denoised_waveforms = run_node_pipeline(recording, nodes=pipeline_nodes, job_kwargs=chunk_executor_kwargs) + assert waveforms.shape == denoised_waveforms.shape diff --git a/src/spikeinterface/sortingcomponents/waveforms/hanning_filter.py b/src/spikeinterface/sortingcomponents/waveforms/hanning_filter.py new file mode 100644 index 0000000000..cf60419855 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/waveforms/hanning_filter.py @@ -0,0 +1,47 @@ +from __future__ import annotations + + +from typing import List, Optional +import numpy as np +from spikeinterface.core import BaseRecording +from spikeinterface.core.node_pipeline import PipelineNode, WaveformsNode, find_parent_of_type + + +class HanningFilter(WaveformsNode): + """ + Hanning Filtering to remove border effects while extracting waveforms + + Parameters + ---------- + recording: BaseRecording + The recording extractor object + return_output: bool, default: True + Whether to return output from this node + parents: list of PipelineNodes, default: None + The parent nodes of this node + """ + + def __init__( + self, + recording: BaseRecording, + return_output: bool = True, + parents: Optional[List[PipelineNode]] = None, + ): + waveform_extractor = find_parent_of_type(parents, WaveformsNode) + if waveform_extractor is None: + raise TypeError(f"HanningFilter should have a single {WaveformsNode.__name__} in its parents") + + super().__init__( + recording, + waveform_extractor.ms_before, + waveform_extractor.ms_after, + return_output=return_output, + parents=parents, + ) + + self.hanning = np.hanning(self.nbefore + self.nafter)[:, None] + self._kwargs.update(dict()) + + def compute(self, traces, peaks, waveforms): + denoised_waveforms = waveforms * self.hanning + return denoised_waveforms