Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 9, 2025
1 parent fb9747d commit 6b3aba1
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 8 deletions.
8 changes: 6 additions & 2 deletions src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,13 +222,17 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
np.save(clustering_folder / "prototype.npy", prototype)
if skip_peaks:
detection_params["skip_after_n_peaks"] = n_peaks
detection_params["recording_slices"] = get_shuffled_recording_slices(recording_w, seed=params["seed"], **job_kwargs)
detection_params["recording_slices"] = get_shuffled_recording_slices(
recording_w, seed=params["seed"], **job_kwargs
)
peaks = detect_peaks(recording_w, "matched_filtering", **detection_params, **job_kwargs)
else:
waveforms = None
if skip_peaks:
detection_params["skip_after_n_peaks"] = n_peaks
detection_params["recording_slices"] = get_shuffled_recording_slices(recording_w, seed=params["seed"], **job_kwargs)
detection_params["recording_slices"] = get_shuffled_recording_slices(
recording_w, seed=params["seed"], **job_kwargs
)
peaks = detect_peaks(recording_w, "locally_exclusive", **detection_params, **job_kwargs)

if verbose:
Expand Down
23 changes: 17 additions & 6 deletions src/spikeinterface/sortingcomponents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ def extract_waveform_at_max_channel(rec, peaks, ms_before=0.5, ms_after=1.5, **j
return all_wfs


def get_prototype_and_waveforms_from_peaks(recording, peaks, n_peaks=5000, ms_before=0.5, ms_after=0.5, seed=None, **all_kwargs):
def get_prototype_and_waveforms_from_peaks(
recording, peaks, n_peaks=5000, ms_before=0.5, ms_after=0.5, seed=None, **all_kwargs
):
"""
Function to extract a prototype waveform from peaks.
Expand Down Expand Up @@ -100,6 +102,7 @@ def get_prototype_and_waveforms_from_peaks(recording, peaks, n_peaks=5000, ms_be
The selected peaks used to extract waveforms.
"""
from spikeinterface.sortingcomponents.peak_selection import select_peaks

_, job_kwargs = split_job_kwargs(all_kwargs)

nbefore = int(ms_before * recording.sampling_frequency / 1000.0)
Expand All @@ -118,7 +121,9 @@ def get_prototype_and_waveforms_from_peaks(recording, peaks, n_peaks=5000, ms_be
return prototype, waveforms[:, :, 0], few_peaks


def get_prototype_and_waveforms_from_recording(recording, n_peaks=5000, ms_before=0.5, ms_after=0.5, seed=None, **all_kwargs):
def get_prototype_and_waveforms_from_recording(
recording, n_peaks=5000, ms_before=0.5, ms_after=0.5, seed=None, **all_kwargs
):
"""
Function to extract a prototype waveform from peaks detected on the fly.
Expand Down Expand Up @@ -148,6 +153,7 @@ def get_prototype_and_waveforms_from_recording(recording, n_peaks=5000, ms_befor
"""
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.core.node_pipeline import ExtractSparseWaveforms

detection_kwargs, job_kwargs = split_job_kwargs(all_kwargs)

nbefore = int(ms_before * recording.sampling_frequency / 1000.0)
Expand All @@ -159,7 +165,7 @@ def get_prototype_and_waveforms_from_recording(recording, n_peaks=5000, ms_befor
ms_after=ms_after,
radius_um=0,
)

pipeline_nodes = [node]

recording_slices = get_shuffled_recording_slices(recording, seed=seed, **job_kwargs)
Expand All @@ -172,7 +178,7 @@ def get_prototype_and_waveforms_from_recording(recording, n_peaks=5000, ms_befor
**detection_kwargs,
**job_kwargs,
)

rng = np.random.RandomState(seed)
indices = rng.permutation(np.arange(len(res[0])))

Expand Down Expand Up @@ -219,9 +225,14 @@ def get_prototype_and_waveforms(
The extracted waveforms, returned if return_waveforms is True.
"""
if peaks is None:
return get_prototype_and_waveforms_from_peaks(recording, n_peaks, ms_before=ms_before, ms_after=ms_after, seed=seed, **all_kwargs)
return get_prototype_and_waveforms_from_peaks(
recording, n_peaks, ms_before=ms_before, ms_after=ms_after, seed=seed, **all_kwargs
)
else:
return get_prototype_and_waveforms_from_recording(recording, peaks, n_peaks, ms_before=ms_before, ms_after=ms_after, seed=seed, **all_kwargs)
return get_prototype_and_waveforms_from_recording(
recording, peaks, n_peaks, ms_before=ms_before, ms_after=ms_after, seed=seed, **all_kwargs
)


def check_probe_for_drift_correction(recording, dist_x_max=60):
num_channels = recording.get_num_channels()
Expand Down

0 comments on commit 6b3aba1

Please sign in to comment.