diff --git a/doc/conf.py b/doc/conf.py index 41659d2e84..d229dc18ee 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -119,7 +119,9 @@ # for sphinx gallery plugin sphinx_gallery_conf = { - 'only_warn_on_example_error': True, + # This is the default but including here explicitly. Should build all docs and fail on gallery failures only. + # other option would be abort_on_example_error, but this fails on first failure. So we decided against this. + 'only_warn_on_example_error': False, 'examples_dirs': ['../examples/tutorials'], 'gallery_dirs': ['tutorials' ], # path where to save gallery generated examples 'subsection_order': ExplicitOrder([ diff --git a/readthedocs.yml b/readthedocs.yml index 512fcbc709..c6c44d83a0 100644 --- a/readthedocs.yml +++ b/readthedocs.yml @@ -1,5 +1,9 @@ version: 2 +sphinx: + # Path to your Sphinx configuration file. + configuration: doc/conf.py + build: os: ubuntu-22.04 tools: diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 55cbe6070a..fdad87287e 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -2092,6 +2092,13 @@ def load_data(self): import pandas as pd ext_data = pd.read_csv(ext_data_file, index_col=0) + # we need to cast the index to the unit id dtype (int or str) + unit_ids = self.sorting_analyzer.unit_ids + if ext_data.shape[0] == unit_ids.size: + # we force dtype to be the same as unit_ids + if ext_data.index.dtype != unit_ids.dtype: + ext_data.index = ext_data.index.astype(unit_ids.dtype) + elif ext_data_file.suffix == ".pkl": with ext_data_file.open("rb") as f: ext_data = pickle.load(f) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index b64f0610ea..3e3fcc7384 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -205,6 +205,7 @@ def to_sparse(self, sparsity): unit_ids=self.unit_ids, probe=self.probe, check_for_consistent_sparsity=self.check_for_consistent_sparsity, + is_scaled=self.is_scaled, ) def get_one_template_dense(self, unit_index): diff --git a/src/spikeinterface/curation/curation_format.py b/src/spikeinterface/curation/curation_format.py index 5f85538b08..80f251ca43 100644 --- a/src/spikeinterface/curation/curation_format.py +++ b/src/spikeinterface/curation/curation_format.py @@ -45,12 +45,16 @@ def validate_curation_dict(curation_dict): if not removed_units_set.issubset(unit_set): raise ValueError("Curation format: some removed units are not in the unit list") + for group in curation_dict["merge_unit_groups"]: + if len(group) < 2: + raise ValueError("Curation format: 'merge_unit_groups' must be list of list with at least 2 elements") + all_merging_groups = [set(group) for group in curation_dict["merge_unit_groups"]] for gp_1, gp_2 in combinations(all_merging_groups, 2): if len(gp_1.intersection(gp_2)) != 0: - raise ValueError("Some units belong to multiple merge groups") + raise ValueError("Curation format: some units belong to multiple merge groups") if len(removed_units_set.intersection(merged_units_set)) != 0: - raise ValueError("Some units were merged and deleted") + raise ValueError("Curation format: some units were merged and deleted") # Check the labels exclusivity for lbl in curation_dict["manual_labels"]: @@ -238,7 +242,7 @@ def apply_curation_labels(sorting, new_unit_ids, curation_dict): all_values = np.zeros(sorting.unit_ids.size, dtype=values.dtype) for unit_ind, unit_id in enumerate(sorting.unit_ids): if unit_id not in new_unit_ids: - ind = curation_dict["unit_ids"].index(unit_id) + ind = list(curation_dict["unit_ids"]).index(unit_id) all_values[unit_ind] = values[ind] sorting.set_property(key, all_values) @@ -253,7 +257,7 @@ def apply_curation_labels(sorting, new_unit_ids, curation_dict): group_values.append(value) if len(set(group_values)) == 1: # all group has the same label or empty - sorting.set_property(key, values=group_values, ids=[new_unit_id]) + sorting.set_property(key, values=group_values[:1], ids=[new_unit_id]) else: for key in label_def["label_options"]: @@ -339,18 +343,22 @@ def apply_curation( elif isinstance(sorting_or_analyzer, SortingAnalyzer): analyzer = sorting_or_analyzer - analyzer = analyzer.remove_units(curation_dict["removed_units"]) - analyzer, new_unit_ids = analyzer.merge_units( - curation_dict["merge_unit_groups"], - censor_ms=censor_ms, - merging_mode=merging_mode, - sparsity_overlap=sparsity_overlap, - new_id_strategy=new_id_strategy, - return_new_unit_ids=True, - format="memory", - verbose=verbose, - **job_kwargs, - ) + if len(curation_dict["removed_units"]) > 0: + analyzer = analyzer.remove_units(curation_dict["removed_units"]) + if len(curation_dict["merge_unit_groups"]) > 0: + analyzer, new_unit_ids = analyzer.merge_units( + curation_dict["merge_unit_groups"], + censor_ms=censor_ms, + merging_mode=merging_mode, + sparsity_overlap=sparsity_overlap, + new_id_strategy=new_id_strategy, + return_new_unit_ids=True, + format="memory", + verbose=verbose, + **job_kwargs, + ) + else: + new_unit_ids = [] apply_curation_labels(analyzer.sorting, new_unit_ids, curation_dict) return analyzer else: diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 2a9fb34267..ec15506006 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -66,7 +66,7 @@ class Kilosort4Sorter(BaseSorter): "do_correction": True, "keep_good_only": False, "skip_kilosort_preprocessing": False, - "use_binary_file": None, + "use_binary_file": True, "delete_recording_dat": True, } @@ -116,7 +116,7 @@ class Kilosort4Sorter(BaseSorter): "keep_good_only": "If True, only the units labeled as 'good' by Kilosort are returned in the output. (spikeinterface parameter)", "use_binary_file": "If True then Kilosort is run using a binary file. In this case, if the input recording is not binary compatible, it is written to a binary file in the output folder. " "If False then Kilosort is run on the recording object directly using the RecordingExtractorAsArray object. If None, then if the recording is binary compatible, the sorter will use the binary file, otherwise the RecordingExtractorAsArray. " - "Default is None. (spikeinterface parameter)", + "Default is True. (spikeinterface parameter)", "delete_recording_dat": "If True, if a temporary binary file is created, it is deleted after the sorting is done. Default is True. (spikeinterface parameter)", } diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 6301b664ba..eeaef4d4bc 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -6,12 +6,16 @@ import numpy as np from spikeinterface.core import NumpySorting -from spikeinterface.core.job_tools import fix_job_kwargs +from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs from spikeinterface.core.recording_tools import get_noise_levels from spikeinterface.core.template import Templates from spikeinterface.core.waveform_tools import estimate_templates from spikeinterface.preprocessing import common_reference, whiten, bandpass_filter, correct_motion -from spikeinterface.sortingcomponents.tools import cache_preprocessing +from spikeinterface.sortingcomponents.tools import ( + cache_preprocessing, + get_prototype_and_waveforms_from_recording, + get_shuffled_recording_slices, +) from spikeinterface.core.basesorting import minimum_spike_dtype from spikeinterface.core.sparsity import compute_sparsity @@ -22,7 +26,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): _default_params = { "general": {"ms_before": 2, "ms_after": 2, "radius_um": 75}, "sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25}, - "filtering": {"freq_min": 150, "freq_max": 7000, "ftype": "bessel", "filter_order": 2}, + "filtering": {"freq_min": 150, "freq_max": 7000, "ftype": "bessel", "filter_order": 2, "margin_ms": 10}, "whitening": {"mode": "local", "regularize": False}, "detection": {"peak_sign": "neg", "detect_threshold": 4}, "selection": { @@ -42,6 +46,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True}, "multi_units_only": False, "job_kwargs": {"n_jobs": 0.5}, + "seed": 42, "debug": False, } @@ -63,18 +68,21 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "merging": "A dictionary to specify the final merging param to group cells after template matching (auto_merge_units)", "motion_correction": "A dictionary to be provided if motion correction has to be performed (dense probe only)", "apply_preprocessing": "Boolean to specify whether circus 2 should preprocess the recording or not. If yes, then high_pass filtering + common\ - median reference + zscore", + median reference + whitening", + "apply_motion_correction": "Boolean to specify whether circus 2 should apply motion correction to the recording or not", + "matched_filtering": "Boolean to specify whether circus 2 should detect peaks via matched filtering (slightly slower)", "cache_preprocessing": "How to cache the preprocessed recording. Mode can be memory, file, zarr, with extra arguments. In case of memory (default), \ memory_limit will control how much RAM can be used. In case of folder or zarr, delete_cache controls if cache is cleaned after sorting", "multi_units_only": "Boolean to get only multi units activity (i.e. one template per electrode)", "job_kwargs": "A dictionary to specify how many jobs and which parameters they should used", + "seed": "An int to control how chunks are shuffled while detecting peaks", "debug": "Boolean to specify if internal data structures made during the sorting should be kept for debugging", } sorter_description = """Spyking Circus 2 is a rewriting of Spyking Circus, within the SpikeInterface framework It uses a more conservative clustering algorithm (compared to Spyking Circus), which is less prone to hallucinate units and/or find noise. In addition, it also uses a full Orthogonal Matching Pursuit engine to reconstruct the traces, leading to more spikes - being discovered.""" + being discovered. The code is much faster and memory efficient, inheriting from all the preprocessing possibilities of spikeinterface""" @classmethod def get_sorter_version(cls): @@ -103,7 +111,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): from spikeinterface.sortingcomponents.clustering import find_cluster_from_peaks from spikeinterface.sortingcomponents.matching import find_spikes_from_templates from spikeinterface.sortingcomponents.tools import remove_empty_templates - from spikeinterface.sortingcomponents.tools import get_prototype_spike, check_probe_for_drift_correction + from spikeinterface.sortingcomponents.tools import check_probe_for_drift_correction job_kwargs = fix_job_kwargs(params["job_kwargs"]) job_kwargs.update({"progress_bar": verbose}) @@ -120,10 +128,14 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## First, we are filtering the data filtering_params = params["filtering"].copy() if params["apply_preprocessing"]: + if verbose: + print("Preprocessing the recording (bandpass filtering + CMR + whitening)") recording_f = bandpass_filter(recording, **filtering_params, dtype="float32") if num_channels > 1: recording_f = common_reference(recording_f) else: + if verbose: + print("Skipping preprocessing (whitening only)") recording_f = recording recording_f.annotate(is_filtered=True) @@ -146,12 +158,14 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # TODO add , regularize=True chen ready whitening_kwargs = params["whitening"].copy() whitening_kwargs["dtype"] = "float32" - whitening_kwargs["radius_um"] = radius_um + whitening_kwargs["regularize"] = whitening_kwargs.get("regularize", False) if num_channels == 1: whitening_kwargs["regularize"] = False + if whitening_kwargs["regularize"]: + whitening_kwargs["regularize_kwargs"] = {"method": "LedoitWolf"} recording_w = whiten(recording_f, **whitening_kwargs) - noise_levels = get_noise_levels(recording_w, return_scaled=False) + noise_levels = get_noise_levels(recording_w, return_scaled=False, **job_kwargs) if recording_w.check_serializability("json"): recording_w.dump(sorter_output_folder / "preprocessed_recording.json", relative_to=None) @@ -162,9 +176,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## Then, we are detecting peaks with a locally_exclusive method detection_params = params["detection"].copy() - detection_params.update(job_kwargs) - - detection_params["radius_um"] = detection_params.get("radius_um", 50) + selection_params = params["selection"].copy() + detection_params["radius_um"] = radius_um detection_params["exclude_sweep_ms"] = exclude_sweep_ms detection_params["noise_levels"] = noise_levels @@ -172,17 +185,47 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): nbefore = int(ms_before * fs / 1000.0) nafter = int(ms_after * fs / 1000.0) + skip_peaks = not params["multi_units_only"] and selection_params.get("method", "uniform") == "uniform" + max_n_peaks = selection_params["n_peaks_per_channel"] * num_channels + n_peaks = max(selection_params["min_n_peaks"], max_n_peaks) + + if params["debug"]: + clustering_folder = sorter_output_folder / "clustering" + clustering_folder.mkdir(parents=True, exist_ok=True) + np.save(clustering_folder / "noise_levels.npy", noise_levels) + if params["matched_filtering"]: - peaks = detect_peaks(recording_w, "locally_exclusive", **detection_params, skip_after_n_peaks=5000) - prototype = get_prototype_spike(recording_w, peaks, ms_before, ms_after, **job_kwargs) + prototype, waveforms, _ = get_prototype_and_waveforms_from_recording( + recording_w, + n_peaks=10000, + ms_before=ms_before, + ms_after=ms_after, + seed=params["seed"], + **detection_params, + **job_kwargs, + ) detection_params["prototype"] = prototype detection_params["ms_before"] = ms_before - peaks = detect_peaks(recording_w, "matched_filtering", **detection_params) + if params["debug"]: + np.save(clustering_folder / "waveforms.npy", waveforms) + np.save(clustering_folder / "prototype.npy", prototype) + if skip_peaks: + detection_params["skip_after_n_peaks"] = n_peaks + detection_params["recording_slices"] = get_shuffled_recording_slices( + recording_w, seed=params["seed"], **job_kwargs + ) + peaks = detect_peaks(recording_w, "matched_filtering", **detection_params, **job_kwargs) else: - peaks = detect_peaks(recording_w, "locally_exclusive", **detection_params) + waveforms = None + if skip_peaks: + detection_params["skip_after_n_peaks"] = n_peaks + detection_params["recording_slices"] = get_shuffled_recording_slices( + recording_w, seed=params["seed"], **job_kwargs + ) + peaks = detect_peaks(recording_w, "locally_exclusive", **detection_params, **job_kwargs) - if verbose: - print("We found %d peaks in total" % len(peaks)) + if not skip_peaks and verbose: + print("Found %d peaks in total" % len(peaks)) if params["multi_units_only"]: sorting = NumpySorting.from_peaks(peaks, sampling_frequency, unit_ids=recording_w.unit_ids) @@ -190,14 +233,12 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## We subselect a subset of all the peaks, by making the distributions os SNRs over all ## channels as flat as possible selection_params = params["selection"] - selection_params["n_peaks"] = min(len(peaks), selection_params["n_peaks_per_channel"] * num_channels) - selection_params["n_peaks"] = max(selection_params["min_n_peaks"], selection_params["n_peaks"]) - + selection_params["n_peaks"] = n_peaks selection_params.update({"noise_levels": noise_levels}) selected_peaks = select_peaks(peaks, **selection_params) if verbose: - print("We kept %d peaks for clustering" % len(selected_peaks)) + print("Kept %d peaks for clustering" % len(selected_peaks)) ## We launch a clustering (using hdbscan) relying on positions and features extracted on ## the fly from the snippets @@ -207,10 +248,13 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_params["radius_um"] = radius_um clustering_params["waveforms"]["ms_before"] = ms_before clustering_params["waveforms"]["ms_after"] = ms_after + clustering_params["few_waveforms"] = waveforms clustering_params["noise_levels"] = noise_levels - clustering_params["ms_before"] = exclude_sweep_ms - clustering_params["ms_after"] = exclude_sweep_ms + clustering_params["ms_before"] = ms_before + clustering_params["ms_after"] = ms_after + clustering_params["verbose"] = verbose clustering_params["tmp_folder"] = sorter_output_folder / "clustering" + clustering_params["noise_threshold"] = detection_params.get("detect_threshold", 4) legacy = clustering_params.get("legacy", True) @@ -235,12 +279,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): unit_ids = np.arange(len(np.unique(labeled_peaks["unit_index"]))) sorting = NumpySorting(labeled_peaks, sampling_frequency, unit_ids=unit_ids) - clustering_folder = sorter_output_folder / "clustering" - clustering_folder.mkdir(parents=True, exist_ok=True) - - if not params["debug"]: - shutil.rmtree(clustering_folder) - else: + if params["debug"]: + np.save(clustering_folder / "peak_labels", peak_labels) np.save(clustering_folder / "labels", labels) np.save(clustering_folder / "peaks", selected_peaks) @@ -283,7 +323,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): np.save(fitting_folder / "spikes", spikes) if verbose: - print("We found %d spikes" % len(spikes)) + print("Found %d spikes" % len(spikes)) ## And this is it! We have a spyking circus sorting = np.zeros(spikes.size, dtype=minimum_spike_dtype) @@ -320,7 +360,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): sorting = final_cleaning_circus(recording_w, sorting, templates, **merging_params, **job_kwargs) if verbose: - print(f"Final merging, keeping {len(sorting.unit_ids)} units") + print(f"Kept {len(sorting.unit_ids)} units after final merging") folder_to_delete = None cache_mode = params["cache_preprocessing"].get("mode", "memory") diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 243c854bba..bc173a6ff0 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -40,13 +40,7 @@ class CircusClustering: """ _default_params = { - "hdbscan_kwargs": { - "min_cluster_size": 25, - "allow_single_cluster": True, - "core_dist_n_jobs": -1, - "cluster_selection_method": "eom", - # "cluster_selection_epsilon" : 5 ## To be optimized - }, + "hdbscan_kwargs": {"min_cluster_size": 10, "allow_single_cluster": True, "min_samples": 5}, "cleaning_kwargs": {}, "waveforms": {"ms_before": 2, "ms_after": 2}, "sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25}, @@ -57,8 +51,10 @@ class CircusClustering: }, "radius_um": 100, "n_svd": [5, 2], + "few_waveforms": None, "ms_before": 0.5, "ms_after": 0.5, + "noise_threshold": 4, "rank": 5, "noise_levels": None, "tmp_folder": None, @@ -86,12 +82,25 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): tmp_folder.mkdir(parents=True, exist_ok=True) # SVD for time compression - few_peaks = select_peaks(peaks, recording=recording, method="uniform", n_peaks=10000, margin=(nbefore, nafter)) - few_wfs = extract_waveform_at_max_channel( - recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs - ) + if params["few_waveforms"] is None: + few_peaks = select_peaks( + peaks, recording=recording, method="uniform", n_peaks=10000, margin=(nbefore, nafter) + ) + few_wfs = extract_waveform_at_max_channel( + recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs + ) + wfs = few_wfs[:, :, 0] + else: + offset = int(params["waveforms"]["ms_before"] * fs / 1000) + wfs = params["few_waveforms"][:, offset - nbefore : offset + nafter] + + # Ensure all waveforms have a positive max + wfs *= np.sign(wfs[:, nbefore])[:, np.newaxis] + + # Remove outliers + valid = np.argmax(np.abs(wfs), axis=1) == nbefore + wfs = wfs[valid] - wfs = few_wfs[:, :, 0] from sklearn.decomposition import TruncatedSVD tsvd = TruncatedSVD(params["n_svd"][0]) @@ -189,7 +198,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): original_labels = peaks["channel_index"] from spikeinterface.sortingcomponents.clustering.split import split_clusters - min_size = params["hdbscan_kwargs"].get("min_cluster_size", 50) + min_size = 2 * params["hdbscan_kwargs"].get("min_cluster_size", 10) peak_labels, _ = split_clusters( original_labels, @@ -225,38 +234,54 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): nbefore = int(params["waveforms"]["ms_before"] * fs / 1000.0) nafter = int(params["waveforms"]["ms_after"] * fs / 1000.0) + if params["noise_levels"] is None: + params["noise_levels"] = get_noise_levels(recording, return_scaled=False, **job_kwargs) + templates_array = estimate_templates( - recording, spikes, unit_ids, nbefore, nafter, return_scaled=False, job_name=None, **job_kwargs + recording, + spikes, + unit_ids, + nbefore, + nafter, + return_scaled=False, + job_name=None, + **job_kwargs, ) + best_channels = np.argmax(np.abs(templates_array[:, nbefore, :]), axis=1) + peak_snrs = np.abs(templates_array[:, nbefore, :]) + best_snrs_ratio = (peak_snrs / params["noise_levels"])[np.arange(len(peak_snrs)), best_channels] + valid_templates = best_snrs_ratio > params["noise_threshold"] + if d["rank"] is not None: from spikeinterface.sortingcomponents.matching.circus import compress_templates _, _, _, templates_array = compress_templates(templates_array, d["rank"]) templates = Templates( - templates_array=templates_array, + templates_array=templates_array[valid_templates], sampling_frequency=fs, nbefore=nbefore, sparsity_mask=None, channel_ids=recording.channel_ids, - unit_ids=unit_ids, + unit_ids=unit_ids[valid_templates], probe=recording.get_probe(), is_scaled=False, ) - if params["noise_levels"] is None: - params["noise_levels"] = get_noise_levels(recording, return_scaled=False, **job_kwargs) - sparsity = compute_sparsity(templates, noise_levels=params["noise_levels"], **params["sparsity"]) templates = templates.to_sparse(sparsity) empty_templates = templates.sparsity_mask.sum(axis=1) == 0 templates = remove_empty_templates(templates) + mask = np.isin(peak_labels, np.where(empty_templates)[0]) peak_labels[mask] = -1 + mask = np.isin(peak_labels, np.where(~valid_templates)[0]) + peak_labels[mask] = -1 + if verbose: - print("We found %d raw clusters, starting to clean with matching..." % (len(templates.unit_ids))) + print("Found %d raw clusters, starting to clean with matching" % (len(templates.unit_ids))) cleaning_job_kwargs = job_kwargs.copy() cleaning_job_kwargs["progress_bar"] = False @@ -267,6 +292,6 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): ) if verbose: - print("We kept %d non-duplicated clusters..." % len(labels)) + print("Kept %d non-duplicated clusters" % len(labels)) return labels, peak_labels diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 08a1384333..93db9a268f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -570,7 +570,7 @@ def detect_mixtures(templates, method_kwargs={}, job_kwargs={}, tmp_folder=None, ) else: recording = NumpyRecording(zdata, sampling_frequency=fs) - recording = SharedMemoryRecording.from_recording(recording) + recording = SharedMemoryRecording.from_recording(recording, **job_kwargs) recording = recording.set_probe(templates.probe) recording.annotate(is_filtered=True) @@ -587,6 +587,8 @@ def detect_mixtures(templates, method_kwargs={}, job_kwargs={}, tmp_folder=None, keep_searching = True + local_job_kargs = {"n_jobs": 1, "progress_bar": False} + DEBUG = False while keep_searching: @@ -604,7 +606,11 @@ def detect_mixtures(templates, method_kwargs={}, job_kwargs={}, tmp_folder=None, local_params.update({"ignore_inds": ignore_inds + [i]}) spikes, more_outputs = find_spikes_from_templates( - sub_recording, method="circus-omp-svd", method_kwargs=local_params, extra_outputs=True, **job_kwargs + sub_recording, + method="circus-omp-svd", + method_kwargs=local_params, + extra_outputs=True, + **local_job_kargs, ) local_params["precomputed"] = more_outputs valid = (spikes["sample_index"] >= 0) * (spikes["sample_index"] < duration + 2 * margin) diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 484a7376c1..1d4d8881ad 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -53,6 +53,7 @@ class RandomProjectionClustering: "random_seed": 42, "noise_levels": None, "smoothing_kwargs": {"window_length_ms": 0.25}, + "noise_threshold": 4, "tmp_folder": None, "verbose": True, } @@ -129,28 +130,49 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): nbefore = int(params["waveforms"]["ms_before"] * fs / 1000.0) nafter = int(params["waveforms"]["ms_after"] * fs / 1000.0) + if params["noise_levels"] is None: + params["noise_levels"] = get_noise_levels(recording, return_scaled=False, **job_kwargs) + templates_array = estimate_templates( - recording, spikes, unit_ids, nbefore, nafter, return_scaled=False, job_name=None, **job_kwargs + recording, + spikes, + unit_ids, + nbefore, + nafter, + return_scaled=False, + job_name=None, + **job_kwargs, ) + best_channels = np.argmax(np.abs(templates_array[:, nbefore, :]), axis=1) + peak_snrs = np.abs(templates_array[:, nbefore, :]) + best_snrs_ratio = (peak_snrs / params["noise_levels"])[np.arange(len(peak_snrs)), best_channels] + valid_templates = best_snrs_ratio > params["noise_threshold"] + templates = Templates( - templates_array=templates_array, + templates_array=templates_array[valid_templates], sampling_frequency=fs, nbefore=nbefore, sparsity_mask=None, channel_ids=recording.channel_ids, - unit_ids=unit_ids, + unit_ids=unit_ids[valid_templates], probe=recording.get_probe(), is_scaled=False, ) - if params["noise_levels"] is None: - params["noise_levels"] = get_noise_levels(recording, return_scaled=False, **job_kwargs) - sparsity = compute_sparsity(templates, params["noise_levels"], **params["sparsity"]) + + sparsity = compute_sparsity(templates, noise_levels=params["noise_levels"], **params["sparsity"]) templates = templates.to_sparse(sparsity) + empty_templates = templates.sparsity_mask.sum(axis=1) == 0 templates = remove_empty_templates(templates) + mask = np.isin(peak_labels, np.where(empty_templates)[0]) + peak_labels[mask] = -1 + + mask = np.isin(peak_labels, np.where(~valid_templates)[0]) + peak_labels[mask] = -1 + if verbose: - print("We found %d raw clusters, starting to clean with matching..." % (len(templates.unit_ids))) + print("Found %d raw clusters, starting to clean with matching" % (len(templates.unit_ids))) cleaning_job_kwargs = job_kwargs.copy() cleaning_job_kwargs["progress_bar"] = False @@ -161,6 +183,6 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): ) if verbose: - print("We kept %d non-duplicated clusters..." % len(labels)) + print("Kept %d non-duplicated clusters" % len(labels)) return labels, peak_labels diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index 2240357d27..12955e2c40 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -118,7 +118,11 @@ def detect_peaks( squeeze_output = True else: squeeze_output = False - job_name += f" + {len(pipeline_nodes)} nodes" + if len(pipeline_nodes) == 1: + plural = "" + else: + plural = "s" + job_name += f" + {len(pipeline_nodes)} node{plural}" # because node are modified inplace (insert parent) they need to copy incase # the same pipeline is run several times @@ -677,7 +681,6 @@ def __init__( medians = medians[:, None] noise_levels = np.median(np.abs(conv_random_data - medians), axis=1) / 0.6744897501960817 self.abs_thresholds = noise_levels * detect_threshold - self._dtype = np.dtype(base_peak_dtype + [("z", "float32")]) def get_dtype(self): @@ -727,8 +730,8 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): return (np.zeros(0, dtype=self._dtype),) peak_sample_ind += self.exclude_sweep_size + self.conv_margin + self.nbefore - peak_amplitude = traces[peak_sample_ind, peak_chan_ind] + local_peaks = np.zeros(peak_sample_ind.size, dtype=self._dtype) local_peaks["sample_index"] = peak_sample_ind local_peaks["channel_index"] = peak_chan_ind diff --git a/src/spikeinterface/sortingcomponents/peak_localization.py b/src/spikeinterface/sortingcomponents/peak_localization.py index 08bcabf5e5..1e4e0edded 100644 --- a/src/spikeinterface/sortingcomponents/peak_localization.py +++ b/src/spikeinterface/sortingcomponents/peak_localization.py @@ -33,7 +33,7 @@ get_grid_convolution_templates_and_weights, ) -from .tools import get_prototype_spike +from .tools import get_prototype_and_waveforms_from_peaks def get_localization_pipeline_nodes( @@ -73,8 +73,8 @@ def get_localization_pipeline_nodes( assert isinstance(peak_source, (PeakRetriever, SpikeRetriever)) # extract prototypes silently job_kwargs["progress_bar"] = False - method_kwargs["prototype"] = get_prototype_spike( - recording, peak_source.peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs + method_kwargs["prototype"], _, _ = get_prototype_and_waveforms_from_peaks( + recording, peaks=peak_source.peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs ) extract_dense_waveforms = ExtractDenseWaveforms( recording, parents=[peak_source], ms_before=ms_before, ms_after=ms_after, return_output=False diff --git a/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py b/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py index 7c34f5948d..341ed3426d 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py +++ b/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py @@ -22,7 +22,7 @@ ) from spikeinterface.core.node_pipeline import run_node_pipeline -from spikeinterface.sortingcomponents.tools import get_prototype_spike +from spikeinterface.sortingcomponents.tools import get_prototype_and_waveforms_from_peaks from spikeinterface.sortingcomponents.tests.common import make_dataset @@ -314,7 +314,9 @@ def test_detect_peaks_locally_exclusive_matched_filtering(recording, job_kwargs) ms_before = 1.0 ms_after = 1.0 - prototype = get_prototype_spike(recording, peaks_by_channel_np, ms_before, ms_after, **job_kwargs) + prototype, _, _ = get_prototype_and_waveforms_from_peaks( + recording, peaks=peaks_by_channel_np, ms_before=ms_before, ms_after=ms_after, **job_kwargs + ) peaks_local_mf_filtering = detect_peaks( recording, diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 439aee6db8..8171a330b5 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -70,25 +70,174 @@ def extract_waveform_at_max_channel(rec, peaks, ms_before=0.5, ms_after=1.5, **j return all_wfs -def get_prototype_spike(recording, peaks, ms_before=0.5, ms_after=0.5, nb_peaks=1000, **job_kwargs): +def get_prototype_and_waveforms_from_peaks( + recording, peaks, n_peaks=5000, ms_before=0.5, ms_after=0.5, seed=None, **all_kwargs +): + """ + Function to extract a prototype waveform from peaks. + + Parameters + ---------- + recording : Recording + The recording object containing the data. + peaks : numpy.array, optional + Array of peaks, if None, peaks will be detected, by default None. + n_peaks : int, optional + Number of peaks to consider, by default 5000. + ms_before : float, optional + Time in milliseconds before the peak to extract the waveform, by default 0.5. + ms_after : float, optional + Time in milliseconds after the peak to extract the waveform, by default 0.5. + seed : int or None, optional + Seed for random number generator, by default None. + **all_kwargs : dict + Additional keyword arguments for peak detection and job kwargs. + + Returns + ------- + prototype : numpy.array + The prototype waveform. + waveforms : numpy.array + The extracted waveforms for the selected peaks. + peaks : numpy.array + The selected peaks used to extract waveforms. + """ from spikeinterface.sortingcomponents.peak_selection import select_peaks + _, job_kwargs = split_job_kwargs(all_kwargs) + nbefore = int(ms_before * recording.sampling_frequency / 1000.0) nafter = int(ms_after * recording.sampling_frequency / 1000.0) - few_peaks = select_peaks(peaks, recording=recording, method="uniform", n_peaks=nb_peaks, margin=(nbefore, nafter)) - + few_peaks = select_peaks( + peaks, recording=recording, method="uniform", n_peaks=n_peaks, margin=(nbefore, nafter), seed=seed + ) waveforms = extract_waveform_at_max_channel( recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs ) + with np.errstate(divide="ignore", invalid="ignore"): prototype = np.nanmedian(waveforms[:, :, 0] / (np.abs(waveforms[:, nbefore, 0][:, np.newaxis])), axis=0) - return prototype + + return prototype, waveforms[:, :, 0], few_peaks + + +def get_prototype_and_waveforms_from_recording( + recording, n_peaks=5000, ms_before=0.5, ms_after=0.5, seed=None, **all_kwargs +): + """ + Function to extract a prototype waveform from peaks detected on the fly. + + Parameters + ---------- + recording : Recording + The recording object containing the data. + n_peaks : int, optional + Number of peaks to consider, by default 5000. + ms_before : float, optional + Time in milliseconds before the peak to extract the waveform, by default 0.5. + ms_after : float, optional + Time in milliseconds after the peak to extract the waveform, by default 0.5. + seed : int or None, optional + Seed for random number generator, by default None. + **all_kwargs : dict + Additional keyword arguments for peak detection and job kwargs. + + Returns + ------- + prototype : numpy.array + The prototype waveform. + waveforms : numpy.array + The extracted waveforms for the selected peaks. + peaks : numpy.array + The selected peaks used to extract waveforms. + """ + from spikeinterface.sortingcomponents.peak_detection import detect_peaks + from spikeinterface.core.node_pipeline import ExtractSparseWaveforms + + detection_kwargs, job_kwargs = split_job_kwargs(all_kwargs) + + nbefore = int(ms_before * recording.sampling_frequency / 1000.0) + node = ExtractSparseWaveforms( + recording, + parents=None, + return_output=True, + ms_before=ms_before, + ms_after=ms_after, + radius_um=0, + ) + + pipeline_nodes = [node] + + recording_slices = get_shuffled_recording_slices(recording, seed=seed, **job_kwargs) + + res = detect_peaks( + recording, + pipeline_nodes=pipeline_nodes, + skip_after_n_peaks=n_peaks, + recording_slices=recording_slices, + **detection_kwargs, + **job_kwargs, + ) + + rng = np.random.RandomState(seed) + indices = rng.permutation(np.arange(len(res[0]))) + + few_peaks = res[0][indices[:n_peaks]] + waveforms = res[1][indices[:n_peaks]] + + with np.errstate(divide="ignore", invalid="ignore"): + prototype = np.nanmedian(waveforms[:, :, 0] / (np.abs(waveforms[:, nbefore, 0][:, np.newaxis])), axis=0) + + return prototype, waveforms[:, :, 0], few_peaks + + +def get_prototype_and_waveforms( + recording, n_peaks=5000, peaks=None, ms_before=0.5, ms_after=0.5, seed=None, **all_kwargs +): + """ + Function to extract a prototype waveform either from peaks or from a peak detection. Note that in case + of a peak detection, the detection stops as soon as n_peaks are detected. + + Parameters + ---------- + recording : Recording + The recording object containing the data. + n_peaks : int, optional + Number of peaks to consider, by default 5000. + peaks : numpy.array, optional + Array of peaks, if None, peaks will be detected, by default None. + ms_before : float, optional + Time in milliseconds before the peak to extract the waveform, by default 0.5. + ms_after : float, optional + Time in milliseconds after the peak to extract the waveform, by default 0.5. + seed : int or None, optional + Seed for random number generator, by default None. + **all_kwargs : dict + Additional keyword arguments for peak detection and job kwargs. + + Returns + ------- + prototype : numpy.array + The prototype waveform. + waveforms : numpy.array + The extracted waveforms for the selected peaks. + peaks : numpy.array + The selected peaks used to extract waveforms. + """ + if peaks is None: + return get_prototype_and_waveforms_from_recording( + recording, n_peaks, ms_before=ms_before, ms_after=ms_after, seed=seed, **all_kwargs + ) + else: + return get_prototype_and_waveforms_from_peaks( + recording, peaks, n_peaks, ms_before=ms_before, ms_after=ms_after, seed=seed, **all_kwargs + ) def check_probe_for_drift_correction(recording, dist_x_max=60): num_channels = recording.get_num_channels() - if num_channels < 32: + if num_channels <= 32: return False else: locations = recording.get_channel_locations() @@ -175,3 +324,19 @@ def create_sorting_analyzer_with_existing_templates(sorting, recording, template sa.extensions["templates"].data["average"] = templates_array sa.extensions["templates"].data["std"] = np.zeros(templates_array.shape, dtype=np.float32) return sa + +def get_shuffled_recording_slices(recording, seed=None, **job_kwargs): + from spikeinterface.core.job_tools import ensure_chunk_size + from spikeinterface.core.job_tools import divide_segment_into_chunks + + chunk_size = ensure_chunk_size(recording, **job_kwargs) + recording_slices = [] + for segment_index in range(recording.get_num_segments()): + num_frames = recording.get_num_samples(segment_index) + chunks = divide_segment_into_chunks(num_frames, chunk_size) + recording_slices.extend([(segment_index, frame_start, frame_stop) for frame_start, frame_stop in chunks]) + + rng = np.random.default_rng(seed) + recording_slices = rng.permutation(recording_slices) + + return recording_slices