diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index e07924b196..2f965b0483 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -7,7 +7,6 @@ from spikeinterface.core import ( get_noise_levels, NumpySorting, - get_channel_distances, estimate_templates_with_accumulator, Templates, compute_sparsity, @@ -18,15 +17,11 @@ from spikeinterface.preprocessing import bandpass_filter, common_reference, zscore, whiten from spikeinterface.core.basesorting import minimum_spike_dtype -from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel, cache_preprocessing +from spikeinterface.sortingcomponents.tools import cache_preprocessing -# from spikeinterface.qualitymetrics import compute_snrs import numpy as np -import pickle -import json - class Tridesclous2Sorter(ComponentsBasedSorter): sorter_name = "tridesclous2" @@ -34,13 +29,14 @@ class Tridesclous2Sorter(ComponentsBasedSorter): _default_params = { "apply_preprocessing": True, "apply_motion_correction": False, + "motion_correction": {"preset": "nonrigid_fast_and_accurate"}, "cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True}, "waveforms": { "ms_before": 0.5, "ms_after": 1.5, "radius_um": 120.0, }, - "filtering": {"freq_min": 300.0, "freq_max": 12000.0}, + "filtering": {"freq_min": 300.0, "freq_max": 8000.0}, "detection": {"peak_sign": "neg", "detect_threshold": 5, "exclude_sweep_ms": 1.5, "radius_um": 150.0}, "selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000}, "svd": {"n_components": 6}, @@ -53,7 +49,7 @@ class Tridesclous2Sorter(ComponentsBasedSorter): "ms_before": 2.0, "ms_after": 3.0, "max_spikes_per_unit": 400, - "sparsity_threshold": 2.0, + "sparsity_threshold": 1.5, # "peak_shift_ms": 0.2, }, # "matching": {"method": "tridesclous", "method_kwargs": {"peak_shift_ms": 0.2, "radius_um": 100.0}}, @@ -86,31 +82,18 @@ def get_sorter_version(cls): @classmethod def _run_from_folder(cls, sorter_output_folder, params, verbose): - job_kwargs = params["job_kwargs"].copy() - job_kwargs = fix_job_kwargs(job_kwargs) - job_kwargs["progress_bar"] = verbose from spikeinterface.sortingcomponents.matching import find_spikes_from_templates - from spikeinterface.core.node_pipeline import ( - run_node_pipeline, - ExtractDenseWaveforms, - ExtractSparseWaveforms, - PeakRetriever, - ) - from spikeinterface.sortingcomponents.peak_detection import detect_peaks, DetectPeakLocallyExclusive + from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.sortingcomponents.peak_selection import select_peaks - from spikeinterface.sortingcomponents.peak_localization import LocalizeCenterOfMass, LocalizeGridConvolution - from spikeinterface.sortingcomponents.waveforms.temporal_pca import TemporalPCAProjection - - from spikeinterface.sortingcomponents.clustering.split import split_clusters - from spikeinterface.sortingcomponents.clustering.merge import merge_clusters - from spikeinterface.sortingcomponents.clustering.tools import compute_template_from_sparse from spikeinterface.sortingcomponents.clustering.main import find_cluster_from_peaks from spikeinterface.sortingcomponents.tools import remove_empty_templates + from spikeinterface.preprocessing import correct_motion + from spikeinterface.sortingcomponents.motion_interpolation import InterpolateMotionRecording - from sklearn.decomposition import TruncatedSVD - - import hdbscan + job_kwargs = params["job_kwargs"].copy() + job_kwargs = fix_job_kwargs(job_kwargs) + job_kwargs["progress_bar"] = verbose recording_raw = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) @@ -119,10 +102,44 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # preprocessing if params["apply_preprocessing"]: - recording = bandpass_filter(recording_raw, **params["filtering"]) + if params["apply_motion_correction"]: + rec_for_motion = recording_raw + if params["apply_preprocessing"]: + rec_for_motion = bandpass_filter(rec_for_motion, freq_min=300.0, freq_max=6000.0, dtype="float32") + rec_for_motion = common_reference(rec_for_motion) + if verbose: + print("Start correct_motion()") + _, motion_info = correct_motion( + rec_for_motion, + folder=sorter_output_folder / "motion", + output_motion_info=True, + **params["motion_correction"], + ) + if verbose: + print("Done correct_motion()") + + recording = bandpass_filter(recording_raw, **params["filtering"], dtype="float32") recording = common_reference(recording) + + if params["apply_motion_correction"]: + interpolate_motion_kwargs = dict( + direction=1, + border_mode="force_extrapolate", + spatial_interpolation_method="kriging", + sigma_um=20.0, + p=2, + ) + + recording = InterpolateMotionRecording( + recording, + motion_info["motion"], + motion_info["temporal_bins"], + motion_info["spatial_bins"], + **interpolate_motion_kwargs, + ) + recording = zscore(recording, dtype="float32") - recording = whiten(recording, dtype="float32") + recording = whiten(recording, dtype="float32", mode="local", radius_um=100.0) # used only if "folder" or "zarr" cache_folder = sorter_output_folder / "cache_preprocessing" @@ -141,7 +158,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): all_peaks = detect_peaks(recording, method="locally_exclusive", **detection_params, **job_kwargs) if verbose: - print("We found %d peaks in total" % len(all_peaks)) + print(f"detect_peaks(): {len(all_peaks)} peaks found") # selection selection_params = params["selection"].copy() @@ -150,36 +167,38 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): peaks = select_peaks(all_peaks, method="uniform", n_peaks=n_peaks) if verbose: - print("We kept %d peaks for clustering" % len(peaks)) + print(f"select_peaks(): {len(peaks)} peaks kept for clustering") clustering_kwargs = {} clustering_kwargs["folder"] = sorter_output_folder clustering_kwargs["waveforms"] = params["waveforms"].copy() clustering_kwargs["clustering"] = params["clustering"].copy() - labels_set, post_clean_label, extra_out = find_cluster_from_peaks( + labels_set, clustering_label, extra_out = find_cluster_from_peaks( recording, peaks, method="tdc_clustering", method_kwargs=clustering_kwargs, extra_outputs=True, **job_kwargs ) peak_shifts = extra_out["peak_shifts"] new_peaks = peaks.copy() new_peaks["sample_index"] -= peak_shifts - mask = post_clean_label >= 0 + mask = clustering_label >= 0 sorting_pre_peeler = NumpySorting.from_times_labels( new_peaks["sample_index"][mask], - post_clean_label[mask], + clustering_label[mask], sampling_frequency, unit_ids=labels_set, ) - # sorting_pre_peeler = sorting_pre_peeler.save(folder=sorter_output_folder / "sorting_pre_peeler") - recording_w = whiten(recording, mode="local", radius_um=100.0) + if verbose: + print(f"find_cluster_from_peaks(): {sorting_pre_peeler.unit_ids.size} cluster found") + + recording_for_peeler = recording nbefore = int(params["templates"]["ms_before"] * sampling_frequency / 1000.0) nafter = int(params["templates"]["ms_after"] * sampling_frequency / 1000.0) - sparsity_threshold = params["templates"]["sparsity_threshold"] + templates_array = estimate_templates_with_accumulator( - recording_w, + recording_for_peeler, sorting_pre_peeler.to_spike_vector(), sorting_pre_peeler.unit_ids, nbefore, @@ -192,51 +211,24 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): sampling_frequency=sampling_frequency, nbefore=nbefore, sparsity_mask=None, - probe=recording_w.get_probe(), + probe=recording_for_peeler.get_probe(), is_scaled=False, ) + # TODO : try other methods for sparsity + sparsity_threshold = params["templates"]["sparsity_threshold"] # sparsity = compute_sparsity(templates_dense, method="radius", radius_um=120.) sparsity = compute_sparsity(templates_dense, noise_levels=noise_levels, threshold=sparsity_threshold) templates = templates_dense.to_sparse(sparsity) templates = remove_empty_templates(templates) - # snrs = compute_snrs(we, peak_sign=params["detection"]["peak_sign"], peak_mode="extremum") - # print(snrs) - - # matching_params = params["matching"].copy() - # matching_params["noise_levels"] = noise_levels - # matching_params["peak_sign"] = params["detection"]["peak_sign"] - # matching_params["detect_threshold"] = params["detection"]["detect_threshold"] - # matching_params["radius_um"] = params["detection"]["radius_um"] - - # spikes = find_spikes_from_templates( - # recording, method="tridesclous", method_kwargs=matching_params, **job_kwargs - # ) - + ## peeler matching_method = params["matching"]["method"] matching_params = params["matching"]["method_kwargs"].copy() - matching_params["templates"] = templates matching_params["noise_levels"] = noise_levels - # matching_params["peak_sign"] = params["detection"]["peak_sign"] - # matching_params["detect_threshold"] = params["detection"]["detect_threshold"] - # matching_params["radius_um"] = params["detection"]["radius_um"] - - # spikes = find_spikes_from_templates( - # recording, method="tridesclous", method_kwargs=matching_params, **job_kwargs - # ) - # ) - - # if matching_method == "circus-omp-svd": - # job_kwargs = job_kwargs.copy() - # for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: - # if value in job_kwargs: - # job_kwargs.pop(value) - # job_kwargs["chunk_duration"] = "100ms" - spikes = find_spikes_from_templates( - recording_w, method=matching_method, method_kwargs=matching_params, **job_kwargs + recording_for_peeler, method=matching_method, method_kwargs=matching_params, **job_kwargs ) if params["save_array"]: @@ -244,9 +236,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): np.save(sorter_output_folder / "noise_levels.npy", noise_levels) np.save(sorter_output_folder / "all_peaks.npy", all_peaks) - # np.save(sorter_output_folder / "post_split_label.npy", post_split_label) - # np.save(sorter_output_folder / "split_count.npy", split_count) - # np.save(sorter_output_folder / "post_merge_label.npy", post_merge_label) + np.save(sorter_output_folder / "peaks.npy", peaks) + np.save(sorter_output_folder / "clustering_label.npy", clustering_label) np.save(sorter_output_folder / "spikes.npy", spikes) final_spikes = np.zeros(spikes.size, dtype=minimum_spike_dtype) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py index ebddd2bd58..7a0f9ba253 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py @@ -188,6 +188,8 @@ def plot_agreements(self, case_keys=None, figsize=(15, 15)): ax.set_title(self.cases[key]["label"]) plot_agreement_matrix(self.get_result(key)["gt_comparison"], ax=ax) + return fig + def plot_performances_vs_snr(self, case_keys=None, figsize=(15, 15)): if case_keys is None: case_keys = list(self.cases.keys()) @@ -210,6 +212,8 @@ def plot_performances_vs_snr(self, case_keys=None, figsize=(15, 15)): if count == 2: ax.legend() + return fig + def plot_error_metrics(self, metric="cosine", case_keys=None, figsize=(15, 5)): if case_keys is None: @@ -244,6 +248,8 @@ def plot_error_metrics(self, metric="cosine", case_keys=None, figsize=(15, 5)): label = self.cases[key]["label"] axs[0, count].set_title(label) + return fig + def plot_metrics_vs_snr(self, metric="agreement", case_keys=None, figsize=(15, 5)): if case_keys is None: @@ -296,6 +302,8 @@ def plot_metrics_vs_snr(self, metric="agreement", case_keys=None, figsize=(15, 5 axs[0, count].set_title(label) axs[0, count].legend() + return fig + def plot_metrics_vs_depth_and_snr(self, metric="agreement", case_keys=None, figsize=(15, 5)): if case_keys is None: @@ -354,6 +362,8 @@ def plot_metrics_vs_depth_and_snr(self, metric="agreement", case_keys=None, figs axs[0, count].set_title(label) # axs[0, count].legend() + return fig + def plot_unit_losses(self, case_before, case_after, metric="agreement", figsize=None): fig, axs = plt.subplots(ncols=1, nrows=3, figsize=figsize) @@ -384,6 +394,7 @@ def plot_unit_losses(self, case_before, case_after, metric="agreement", figsize= fig.colorbar(im, ax=ax) ax.set_title(k) ax.set_ylabel("snr") + return fig def plot_comparison_clustering( self, @@ -444,10 +455,13 @@ def plot_comparison_clustering( plt.tight_layout(h_pad=0, w_pad=0) + return fig + def plot_some_over_merged(self, case_keys=None, overmerged_score=0.05, max_units=5, figsize=None): if case_keys is None: case_keys = list(self.cases.keys()) + figs = [] for count, key in enumerate(case_keys): label = self.cases[key]["label"] comp = self.get_result(key)["gt_comparison"] @@ -475,13 +489,17 @@ def plot_some_over_merged(self, case_keys=None, overmerged_score=0.05, max_units ax.set_xticks([]) fig.suptitle(label) + figs.append(fig) else: print(key, "no overmerged") + return figs + def plot_some_over_splited(self, case_keys=None, oversplit_score=0.05, max_units=5, figsize=None): if case_keys is None: case_keys = list(self.cases.keys()) + figs = [] for count, key in enumerate(case_keys): label = self.cases[key]["label"] comp = self.get_result(key)["gt_comparison"] @@ -509,5 +527,8 @@ def plot_some_over_splited(self, case_keys=None, oversplit_score=0.05, max_units ax.set_xticks([]) fig.suptitle(label) + figs.append(fig) else: print(key, "no over splited") + + return figs diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py index 5dd0778f76..c003c71d70 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py @@ -38,11 +38,12 @@ def run(self, **job_kwargs): self.result = {"sorting": sorting} self.result["templates"] = self.templates - def compute_result(self, **result_params): + def compute_result(self, with_collision=False, **result_params): sorting = self.result["sorting"] comp = compare_sorter_to_ground_truth(self.gt_sorting, sorting, exhaustive_gt=True) self.result["gt_comparison"] = comp - self.result["gt_collision"] = CollisionGTComparison(self.gt_sorting, sorting, exhaustive_gt=True) + if with_collision: + self.result["gt_collision"] = CollisionGTComparison(self.gt_sorting, sorting, exhaustive_gt=True) _run_key_saved = [ ("sorting", "sorting"), @@ -73,6 +74,8 @@ def plot_agreements(self, case_keys=None, figsize=None): ax.set_title(self.cases[key]["label"]) plot_agreement_matrix(self.get_result(key)["gt_comparison"], ax=ax) + return fig + def plot_performances_vs_snr(self, case_keys=None, figsize=None): if case_keys is None: case_keys = list(self.cases.keys()) @@ -95,6 +98,8 @@ def plot_performances_vs_snr(self, case_keys=None, figsize=None): if count == 2: ax.legend() + return fig + def plot_collisions(self, case_keys=None, figsize=None): if case_keys is None: case_keys = list(self.cases.keys()) @@ -112,6 +117,8 @@ def plot_collisions(self, case_keys=None, figsize=None): good_only=False, ) + return fig + def plot_comparison_matching( self, case_keys=None, @@ -170,6 +177,8 @@ def plot_comparison_matching( ax.set_yticks([]) plt.tight_layout(h_pad=0, w_pad=0) + return fig + def get_count_units(self, case_keys=None, well_detected_score=None, redundant_score=None, overmerged_score=None): import pandas as pd @@ -240,3 +249,4 @@ def plot_unit_losses(self, before, after, figsize=None): # if count == 2: # ax.legend() + return fig diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py index 3c5623f202..5d3c9c207a 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py @@ -289,13 +289,15 @@ def plot_errors(self, case_keys=None, figsize=None, lim=None): if lim is not None: ax.set_ylim(0, lim) - def plot_summary_errors(self, case_keys=None, show_legend=True, colors=None, figsize=(15, 5)): + def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)): if case_keys is None: case_keys = list(self.cases.keys()) fig, axes = plt.subplots(1, 3, figsize=figsize) + colors = self.get_colors() + for count, key in enumerate(case_keys): bench = self.benchmarks[key] @@ -306,7 +308,9 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, colors=None, fig temporal_bins = bench.result["temporal_bins"] spatial_bins = bench.result["spatial_bins"] - c = colors[count] if colors is not None else None + # c = colors[count] if colors is not None else None + c = colors[key] + errors = gt_motion - motion mean_error = np.sqrt(np.mean((errors) ** 2, axis=1)) depth_error = np.sqrt(np.mean((errors) ** 2, axis=0)) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py index 5e2ded5ecc..6afac8d13c 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py @@ -6,6 +6,8 @@ import numpy as np import pandas as pd +import matplotlib.pyplot as plt + import time import os @@ -13,6 +15,7 @@ from spikeinterface.core import SortingAnalyzer from spikeinterface.core.core_tools import check_json from spikeinterface import load_extractor, split_job_kwargs, create_sorting_analyzer, load_sorting_analyzer +from spikeinterface.widgets import get_some_colors import pickle @@ -43,6 +46,7 @@ def __init__(self, study_folder): self.cases = {} self.benchmarks = {} self.scan_folder() + self.colors = None @classmethod def create(cls, study_folder, datasets={}, cases={}, levels=None): @@ -225,6 +229,20 @@ def run(self, case_keys=None, keep=True, verbose=False, **job_kwargs): benchmark.result["run_time"] = float(t1 - t0) benchmark.save_main(bench_folder) + def set_colors(self, colors=None, map_name="tab20"): + if colors is None: + case_keys = list(self.cases.keys()) + self.colors = get_some_colors( + case_keys, map_name=map_name, color_engine="matplotlib", shuffle=False, margin=0 + ) + else: + self.colors = colors + + def get_colors(self): + if self.colors is None: + self.set_colors() + return self.colors + def get_run_times(self, case_keys=None): if case_keys is None: case_keys = list(self.cases.keys()) @@ -245,7 +263,19 @@ def plot_run_times(self, case_keys=None): case_keys = list(self.cases.keys()) run_times = self.get_run_times(case_keys=case_keys) - run_times.plot(kind="bar") + colors = self.get_colors() + fig, ax = plt.subplots() + labels = [] + for i, key in enumerate(case_keys): + labels.append(self.cases[key]["label"]) + rt = run_times.at[key, "run_times"] + ax.bar(i, rt, width=0.8, color=colors[key]) + ax.set_xticks(np.arange(len(case_keys))) + ax.set_xticklabels(labels, rotation=45.0) + return fig + + # ax = run_times.plot(kind="bar") + # return ax.figure def compute_results(self, case_keys=None, verbose=False, **result_params): if case_keys is None: @@ -368,6 +398,8 @@ def __init__(self): def _save_keys(self, saved_keys, folder): for k, format in saved_keys: + if k not in self.result or self.result[k] is None: + continue if format == "npy": np.save(folder / f"{k}.npy", self.result[k]) elif format == "pickle": diff --git a/src/spikeinterface/sortingcomponents/clustering/tdc.py b/src/spikeinterface/sortingcomponents/clustering/tdc.py index 90938f208d..a6c39c05e5 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tdc.py +++ b/src/spikeinterface/sortingcomponents/clustering/tdc.py @@ -158,13 +158,18 @@ def main_function(cls, recording, peaks, params): method="local_feature_clustering", method_kwargs=dict( clusterer="hdbscan", + clusterer_kwargs={ + "min_cluster_size": min_cluster_size, + "allow_single_cluster": True, + "cluster_selection_method": "eom", + }, # clusterer="isocut5", + # clusterer_kwargs={"min_cluster_size": min_cluster_size}, feature_name="sparse_tsvd", # feature_name="sparse_wfs", neighbours_mask=neighbours_mask, waveforms_sparse_mask=sparse_mask, min_size_split=min_cluster_size, - clusterer_kwargs={"min_cluster_size": min_cluster_size}, n_pca_features=3, scale_n_pca_by_depth=True, ), diff --git a/src/spikeinterface/widgets/utils.py b/src/spikeinterface/widgets/utils.py index 05cd1a7024..337e253cfa 100644 --- a/src/spikeinterface/widgets/utils.py +++ b/src/spikeinterface/widgets/utils.py @@ -20,7 +20,9 @@ HAVE_MPL = False -def get_some_colors(keys, color_engine="auto", map_name="gist_ncar", format="RGBA", shuffle=None, seed=None): +def get_some_colors( + keys, color_engine="auto", map_name="gist_ncar", format="RGBA", shuffle=None, seed=None, margin=None +): """ Return a dict of colors for given keys @@ -39,6 +41,8 @@ def get_some_colors(keys, color_engine="auto", map_name="gist_ncar", format="RGB * set to False for distinctipy seed: int or None, default: None Set the seed + margin: None or int + If None, put a margin to remove colors on borders of some colomap of matplotlib. Returns ------- @@ -75,9 +79,10 @@ def get_some_colors(keys, color_engine="auto", map_name="gist_ncar", format="RGB elif color_engine == "matplotlib": # some map have black or white at border so +10 - margin = max(4, int(N * 0.08)) - cmap = plt.colormaps[map_name].resampled(N + 2 * margin) + if margin is None: + margin = max(4, int(N * 0.08)) + cmap = plt.colormaps[map_name].resampled(N + 2 * margin) colors = [cmap(i + margin) for i, key in enumerate(keys)] elif color_engine == "colorsys":