diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 729830f3a5..65640f05a5 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -222,13 +222,17 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): np.save(clustering_folder / "prototype.npy", prototype) if skip_peaks: detection_params["skip_after_n_peaks"] = n_peaks - detection_params["recording_slices"] = get_shuffled_recording_slices(recording_w, seed=params["seed"], **job_kwargs) + detection_params["recording_slices"] = get_shuffled_recording_slices( + recording_w, seed=params["seed"], **job_kwargs + ) peaks = detect_peaks(recording_w, "matched_filtering", **detection_params, **job_kwargs) else: waveforms = None if skip_peaks: detection_params["skip_after_n_peaks"] = n_peaks - detection_params["recording_slices"] = get_shuffled_recording_slices(recording_w, seed=params["seed"], **job_kwargs) + detection_params["recording_slices"] = get_shuffled_recording_slices( + recording_w, seed=params["seed"], **job_kwargs + ) peaks = detect_peaks(recording_w, "locally_exclusive", **detection_params, **job_kwargs) if verbose: diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index fb3dae669d..d26bc00eb8 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -69,7 +69,9 @@ def extract_waveform_at_max_channel(rec, peaks, ms_before=0.5, ms_after=1.5, **j return all_wfs -def get_prototype_and_waveforms_from_peaks(recording, peaks, n_peaks=5000, ms_before=0.5, ms_after=0.5, seed=None, **all_kwargs): +def get_prototype_and_waveforms_from_peaks( + recording, peaks, n_peaks=5000, ms_before=0.5, ms_after=0.5, seed=None, **all_kwargs +): """ Function to extract a prototype waveform from peaks. @@ -100,6 +102,7 @@ def get_prototype_and_waveforms_from_peaks(recording, peaks, n_peaks=5000, ms_be The selected peaks used to extract waveforms. """ from spikeinterface.sortingcomponents.peak_selection import select_peaks + _, job_kwargs = split_job_kwargs(all_kwargs) nbefore = int(ms_before * recording.sampling_frequency / 1000.0) @@ -118,7 +121,9 @@ def get_prototype_and_waveforms_from_peaks(recording, peaks, n_peaks=5000, ms_be return prototype, waveforms[:, :, 0], few_peaks -def get_prototype_and_waveforms_from_recording(recording, n_peaks=5000, ms_before=0.5, ms_after=0.5, seed=None, **all_kwargs): +def get_prototype_and_waveforms_from_recording( + recording, n_peaks=5000, ms_before=0.5, ms_after=0.5, seed=None, **all_kwargs +): """ Function to extract a prototype waveform from peaks detected on the fly. @@ -148,6 +153,7 @@ def get_prototype_and_waveforms_from_recording(recording, n_peaks=5000, ms_befor """ from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.core.node_pipeline import ExtractSparseWaveforms + detection_kwargs, job_kwargs = split_job_kwargs(all_kwargs) nbefore = int(ms_before * recording.sampling_frequency / 1000.0) @@ -159,7 +165,7 @@ def get_prototype_and_waveforms_from_recording(recording, n_peaks=5000, ms_befor ms_after=ms_after, radius_um=0, ) - + pipeline_nodes = [node] recording_slices = get_shuffled_recording_slices(recording, seed=seed, **job_kwargs) @@ -172,7 +178,7 @@ def get_prototype_and_waveforms_from_recording(recording, n_peaks=5000, ms_befor **detection_kwargs, **job_kwargs, ) - + rng = np.random.RandomState(seed) indices = rng.permutation(np.arange(len(res[0]))) @@ -219,9 +225,14 @@ def get_prototype_and_waveforms( The extracted waveforms, returned if return_waveforms is True. """ if peaks is None: - return get_prototype_and_waveforms_from_peaks(recording, n_peaks, ms_before=ms_before, ms_after=ms_after, seed=seed, **all_kwargs) + return get_prototype_and_waveforms_from_peaks( + recording, n_peaks, ms_before=ms_before, ms_after=ms_after, seed=seed, **all_kwargs + ) else: - return get_prototype_and_waveforms_from_recording(recording, peaks, n_peaks, ms_before=ms_before, ms_after=ms_after, seed=seed, **all_kwargs) + return get_prototype_and_waveforms_from_recording( + recording, peaks, n_peaks, ms_before=ms_before, ms_after=ms_after, seed=seed, **all_kwargs + ) + def check_probe_for_drift_correction(recording, dist_x_max=60): num_channels = recording.get_num_channels()