From 747fc0b5eb5d0f6d02c30a67de95f6ab2b57299a Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 26 Apr 2024 16:21:59 +0200 Subject: [PATCH] Benchmarks for sorting components Improve benchmark sorting components --- .../benchmark/benchmark_clustering.py | 544 ++++++------------ .../benchmark/benchmark_matching.py | 77 ++- .../benchmark/benchmark_peak_detection.py | 424 ++++---------- .../benchmark/benchmark_peak_localization.py | 67 +-- .../benchmark/benchmark_peak_selection.py | 86 +-- .../tests/test_benchmark_clustering.py | 23 +- .../tests/test_benchmark_peak_detection.py | 76 +++ .../tests/test_benchmark_peak_localization.py | 17 +- 8 files changed, 496 insertions(+), 818 deletions(-) create mode 100644 src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_detection.py diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py index fbdf939fae..9d7d202098 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py @@ -30,18 +30,10 @@ class ClusteringBenchmark(Benchmark): - def __init__(self, recording, gt_sorting, params, indices, exhaustive_gt=True): + def __init__(self, recording, gt_sorting, params, indices, peaks, exhaustive_gt=True): self.recording = recording self.gt_sorting = gt_sorting self.indices = indices - - sorting_analyzer = create_sorting_analyzer(self.gt_sorting, self.recording, format="memory", sparse=False) - sorting_analyzer.compute(["random_spikes", "templates"]) - extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, outputs="index") - - peaks = self.gt_sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) - if self.indices is None: - self.indices = np.arange(len(peaks)) self.peaks = peaks[self.indices] self.params = params self.exhaustive_gt = exhaustive_gt @@ -64,7 +56,9 @@ def compute_result(self, **result_params): ) data = spikes[self.indices][~self.noise] - data["unit_index"] = self.result["peak_labels"][~self.noise] + # data["unit_index"] = self.result["peak_labels"][~self.noise] + positions = self.gt_sorting.get_property("gt_unit_locations") + self.result["sliced_gt_sorting"].set_property("gt_unit_locations", positions) self.result["clustering"] = NumpySorting.from_times_labels( data["sample_index"], self.result["peak_labels"][~self.noise], self.recording.sampling_frequency @@ -135,6 +129,47 @@ def homogeneity_score(self, ignore_noise=True, case_keys=None): np.mean(noise), ) + def get_count_units(self, case_keys=None, well_detected_score=None, redundant_score=None, overmerged_score=None): + import pandas as pd + + if case_keys is None: + case_keys = list(self.cases.keys()) + + if isinstance(case_keys[0], str): + index = pd.Index(case_keys, name=self.levels) + else: + index = pd.MultiIndex.from_tuples(case_keys, names=self.levels) + + columns = ["num_gt", "num_sorter", "num_well_detected"] + comp = self.get_result(case_keys[0])["gt_comparison"] + if comp.exhaustive_gt: + columns.extend(["num_false_positive", "num_redundant", "num_overmerged", "num_bad"]) + count_units = pd.DataFrame(index=index, columns=columns, dtype=int) + + for key in case_keys: + comp = self.get_result(key)["gt_comparison"] + assert comp is not None, "You need to do study.run_comparisons() first" + + gt_sorting = comp.sorting1 + sorting = comp.sorting2 + + count_units.loc[key, "num_gt"] = len(gt_sorting.get_unit_ids()) + count_units.loc[key, "num_sorter"] = len(sorting.get_unit_ids()) + count_units.loc[key, "num_well_detected"] = comp.count_well_detected_units(well_detected_score) + + if comp.exhaustive_gt: + count_units.loc[key, "num_redundant"] = comp.count_redundant_units(redundant_score) + count_units.loc[key, "num_overmerged"] = comp.count_overmerged_units(overmerged_score) + count_units.loc[key, "num_false_positive"] = comp.count_false_positive_units(redundant_score) + count_units.loc[key, "num_bad"] = comp.count_bad_units() + + return count_units + + def plot_unit_counts(self, case_keys=None, figsize=None, **extra_kwargs): + from spikeinterface.widgets.widget_list import plot_study_unit_counts + + plot_study_unit_counts(self, case_keys, figsize=figsize, **extra_kwargs) + def plot_agreements(self, case_keys=None, figsize=(15, 15)): if case_keys is None: case_keys = list(self.cases.keys()) @@ -184,8 +219,8 @@ def plot_error_metrics(self, metric="cosine", case_keys=None, figsize=(15, 5)): unit_ids2 = scores.columns.values inds_1 = result["gt_comparison"].sorting1.ids_to_indices(unit_ids1) inds_2 = result["gt_comparison"].sorting2.ids_to_indices(unit_ids2) - t1 = result["sliced_gt_templates"].templates_array - t2 = result["clustering_templates"].templates_array + t1 = result["sliced_gt_templates"].templates_array[:] + t2 = result["clustering_templates"].templates_array[:] a = t1.reshape(len(t1), -1)[inds_1] b = t2.reshape(len(t2), -1)[inds_2] @@ -196,13 +231,13 @@ def plot_error_metrics(self, metric="cosine", case_keys=None, figsize=(15, 5)): else: distances = sklearn.metrics.pairwise_distances(a, b, metric) - im = axs[count].imshow(distances, aspect="auto") - axs[count].set_title(metric) - fig.colorbar(im, ax=axs[count]) + im = axs[0, count].imshow(distances, aspect="auto") + axs[0, count].set_title(metric) + fig.colorbar(im, ax=axs[0, count]) label = self.cases[key]["label"] axs[0, count].set_title(label) - def plot_metrics_vs_snr(self, metric="cosine", case_keys=None, figsize=(15, 5)): + def plot_metrics_vs_snr(self, metric="agreement", case_keys=None, figsize=(15, 5)): if case_keys is None: case_keys = list(self.cases.keys()) @@ -212,17 +247,20 @@ def plot_metrics_vs_snr(self, metric="cosine", case_keys=None, figsize=(15, 5)): for count, key in enumerate(case_keys): result = self.get_result(key) - scores = result["gt_comparison"].get_ordered_agreement_scores() + scores = result["gt_comparison"].agreement_scores analyzer = self.get_sorting_analyzer(key) metrics = analyzer.get_extension("quality_metrics").get_data() - unit_ids1 = scores.index.values - unit_ids2 = scores.columns.values - inds_1 = result["gt_comparison"].sorting1.ids_to_indices(unit_ids1) - inds_2 = result["gt_comparison"].sorting2.ids_to_indices(unit_ids2) - t1 = result["sliced_gt_templates"].templates_array - t2 = result["clustering_templates"].templates_array + unit_ids1 = result["gt_comparison"].unit1_ids + matched_ids2 = result["gt_comparison"].hungarian_match_12.values + mask = matched_ids2 > -1 + + inds_1 = result["gt_comparison"].sorting1.ids_to_indices(unit_ids1[mask]) + inds_2 = result["gt_comparison"].sorting2.ids_to_indices(matched_ids2[mask]) + + t1 = result["sliced_gt_templates"].templates_array[:] + t2 = result["clustering_templates"].templates_array[:] a = t1.reshape(len(t1), -1) b = t2.reshape(len(t2), -1) @@ -230,18 +268,111 @@ def plot_metrics_vs_snr(self, metric="cosine", case_keys=None, figsize=(15, 5)): if metric == "cosine": distances = sklearn.metrics.pairwise.cosine_similarity(a, b) - else: + elif metric == "l2": + distances = sklearn.metrics.pairwise_distances(a, b, metric) + + snr_matched = metrics["snr"][unit_ids1[mask]] + snr_missed = metrics["snr"][unit_ids1[~mask]] + + to_plot = [] + if metric in ["cosine", "l2"]: + for found, real in zip(inds_2, inds_1): + to_plot += [distances[real, found]] + elif metric == "agreement": + for found, real in zip(matched_ids2[mask], unit_ids1[mask]): + to_plot += [scores.at[real, found]] + axs[0, count].plot(snr_matched, to_plot, ".", label="matched") + axs[0, count].plot(snr_missed, np.zeros(len(snr_missed)), ".", c="r", label="missed") + axs[0, count].set_xlabel("snr") + axs[0, count].set_ylabel(metric) + label = self.cases[key]["label"] + axs[0, count].set_title(label) + axs[0, count].legend() + + def plot_metrics_vs_depth_and_snr(self, metric="agreement", case_keys=None, figsize=(15, 5)): + + if case_keys is None: + case_keys = list(self.cases.keys()) + + fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False) + + for count, key in enumerate(case_keys): + + result = self.get_result(key) + scores = result["gt_comparison"].agreement_scores + + # positions = result["gt_comparison"].sorting1.get_property('gt_unit_locations') + positions = self.datasets[key[1]][1].get_property("gt_unit_locations") + depth = positions[:, 1] + + analyzer = self.get_sorting_analyzer(key) + metrics = analyzer.get_extension("quality_metrics").get_data() + + unit_ids1 = result["gt_comparison"].unit1_ids + matched_ids2 = result["gt_comparison"].hungarian_match_12.values + mask = matched_ids2 > -1 + + inds_1 = result["gt_comparison"].sorting1.ids_to_indices(unit_ids1[mask]) + inds_2 = result["gt_comparison"].sorting2.ids_to_indices(matched_ids2[mask]) + + t1 = result["sliced_gt_templates"].templates_array[:] + t2 = result["clustering_templates"].templates_array[:] + a = t1.reshape(len(t1), -1) + b = t2.reshape(len(t2), -1) + + import sklearn + + if metric == "cosine": + distances = sklearn.metrics.pairwise.cosine_similarity(a, b) + elif metric == "l2": distances = sklearn.metrics.pairwise_distances(a, b, metric) - snr = metrics["snr"][unit_ids1][inds_1[: len(inds_2)]] + snr_matched = metrics["snr"][unit_ids1[mask]] + snr_missed = metrics["snr"][unit_ids1[~mask]] + depth_matched = depth[mask] + depth_missed = depth[~mask] + to_plot = [] - for found, real in zip(inds_2, inds_1): - to_plot += [distances[real, found]] - axs[0, count].plot(snr, to_plot, ".") + if metric in ["cosine", "l2"]: + for found, real in zip(inds_2, inds_1): + to_plot += [distances[real, found]] + elif metric == "agreement": + for found, real in zip(matched_ids2[mask], unit_ids1[mask]): + to_plot += [scores.at[real, found]] + axs[0, count].scatter(depth_matched, snr_matched, c=to_plot, label="matched") + axs[0, count].scatter(depth_missed, snr_missed, c=np.zeros(len(snr_missed)), label="missed") axs[0, count].set_xlabel("snr") axs[0, count].set_ylabel(metric) label = self.cases[key]["label"] axs[0, count].set_title(label) + axs[0, count].legend() + + def plot_unit_losses(self, before, after, metric="agreement", figsize=None): + + fig, axs = plt.subplots(ncols=1, nrows=3, figsize=figsize) + + for count, k in enumerate(("accuracy", "recall", "precision")): + + ax = axs[count] + + label = self.cases[after]["label"] + + positions = self.get_result(before)["gt_comparison"].sorting1.get_property("gt_unit_locations") + + analyzer = self.get_sorting_analyzer(before) + metrics_before = analyzer.get_extension("quality_metrics").get_data() + x = metrics_before["snr"].values + + y_before = self.get_result(before)["gt_comparison"].get_performance()[k].values + y_after = self.get_result(after)["gt_comparison"].get_performance()[k].values + if count < 2: + ax.set_xticks([], []) + elif count == 2: + ax.set_xlabel("depth (um)") + im = ax.scatter(positions[:, 1], x, c=(y_after - y_before), marker=".", s=50, cmap="copper") + fig.colorbar(im, ax=ax) + ax.set_title(k) + ax.set_ylabel("snr") def plot_comparison_clustering( self, @@ -299,360 +430,5 @@ def plot_comparison_clustering( ax.spines["right"].set_visible(False) ax.set_xticks([]) ax.set_yticks([]) - plt.tight_layout(h_pad=0, w_pad=0) - -# def _scatter_clusters( -# self, -# xs, -# ys, -# sorting, -# colors=None, -# labels=None, -# ax=None, -# n_std=2.0, -# force_black_for=[], -# s=1, -# alpha=0.5, -# show_ellipses=True, -# ): -# if colors is None: -# from spikeinterface.widgets import get_unit_colors - -# colors = get_unit_colors(sorting) - -# from matplotlib.patches import Ellipse -# import matplotlib.transforms as transforms - -# ax = ax or plt.gca() -# # scatter and collect gaussian info -# means = {} -# covs = {} -# labels = sorting.to_spike_vector(concatenated=False)[0]["unit_index"] - -# for unit_ind, unit_id in enumerate(sorting.unit_ids): -# where = np.flatnonzero(labels == unit_ind) - -# xk = xs[where] -# yk = ys[where] - -# if unit_id not in force_black_for: -# ax.scatter(xk, yk, s=s, color=colors[unit_id], alpha=alpha, marker=".") -# x_mean, y_mean = xk.mean(), yk.mean() -# xycov = np.cov(xk, yk) -# means[unit_id] = x_mean, y_mean -# covs[unit_id] = xycov -# ax.annotate(unit_id, (x_mean, y_mean)) -# ax.scatter([x_mean], [y_mean], s=50, c="k") -# else: -# ax.scatter(xk, yk, s=s, color="k", alpha=alpha, marker=".") - -# for unit_id in means.keys(): -# mean_x, mean_y = means[unit_id] -# cov = covs[unit_id] - -# with np.errstate(invalid="ignore"): -# vx, vy = cov[0, 0], cov[1, 1] -# rho = cov[0, 1] / np.sqrt(vx * vy) -# if not np.isfinite([vx, vy, rho]).all(): -# continue - -# if show_ellipses: -# ell = Ellipse( -# (0, 0), -# width=2 * np.sqrt(1 + rho), -# height=2 * np.sqrt(1 - rho), -# facecolor=(0, 0, 0, 0), -# edgecolor=colors[unit_id], -# linewidth=1, -# ) -# transform = ( -# transforms.Affine2D() -# .rotate_deg(45) -# .scale(n_std * np.sqrt(vx), n_std * np.sqrt(vy)) -# .translate(mean_x, mean_y) -# ) -# ell.set_transform(transform + ax.transData) -# ax.add_patch(ell) - -# def plot_clusters(self, show_probe=True, show_ellipses=True): -# fig, axs = plt.subplots(ncols=3, nrows=1, figsize=(15, 10)) -# fig.suptitle(f"Clustering results with {self.method}") -# ax = axs[0] -# ax.set_title("Full gt clusters") -# if show_probe: -# plot_probe_map(self.recording_f, ax=ax) - -# from spikeinterface.widgets import get_unit_colors - -# colors = get_unit_colors(self.gt_sorting) -# self._scatter_clusters( -# self.gt_positions["x"], -# self.gt_positions["y"], -# self.gt_sorting, -# colors, -# s=1, -# alpha=0.5, -# ax=ax, -# show_ellipses=show_ellipses, -# ) -# xlim = ax.get_xlim() -# ylim = ax.get_ylim() -# ax.set_xlabel("x") -# ax.set_ylabel("y") - -# ax = axs[1] -# ax.set_title("Sliced gt clusters") -# if show_probe: -# plot_probe_map(self.recording_f, ax=ax) - -# self._scatter_clusters( -# self.sliced_gt_positions["x"], -# self.sliced_gt_positions["y"], -# self.sliced_gt_sorting, -# colors, -# s=1, -# alpha=0.5, -# ax=ax, -# show_ellipses=show_ellipses, -# ) -# if self.exhaustive_gt: -# ax.set_xlim(xlim) -# ax.set_ylim(ylim) -# ax.set_xlabel("x") -# ax.set_yticks([], []) - -# ax = axs[2] -# ax.set_title("Found clusters") -# if show_probe: -# plot_probe_map(self.recording_f, ax=ax) -# ax.scatter(self.positions["x"][self.noise], self.positions["y"][self.noise], c="k", s=1, alpha=0.1) -# self._scatter_clusters( -# self.positions["x"][~self.noise], -# self.positions["y"][~self.noise], -# self.clustering, -# s=1, -# alpha=0.5, -# ax=ax, -# show_ellipses=show_ellipses, -# ) - -# ax.set_xlabel("x") -# if self.exhaustive_gt: -# ax.set_xlim(xlim) -# ax.set_ylim(ylim) -# ax.set_yticks([], []) - -# def plot_found_clusters(self, show_probe=True, show_ellipses=True): -# fig, ax = plt.subplots(ncols=1, nrows=1, figsize=(10, 10)) -# fig.suptitle(f"Clustering results with {self.method}") -# ax.set_title("Found clusters") -# if show_probe: -# plot_probe_map(self.recording_f, ax=ax) -# ax.scatter(self.positions["x"][self.noise], self.positions["y"][self.noise], c="k", s=1, alpha=0.1) -# self._scatter_clusters( -# self.positions["x"][~self.noise], -# self.positions["y"][~self.noise], -# self.clustering, -# s=1, -# alpha=0.5, -# ax=ax, -# show_ellipses=show_ellipses, -# ) - -# ax.set_xlabel("x") -# if self.exhaustive_gt: -# ax.set_yticks([], []) - -# def plot_statistics(self, metric="cosine", annotations=True, detect_threshold=5): -# fig, axs = plt.subplots(ncols=3, nrows=2, figsize=(15, 10)) - -# fig.suptitle(f"Clustering results with {self.method}") -# metrics = compute_quality_metrics(self.waveforms["gt"], metric_names=["snr"], load_if_exists=False) - -# ax = axs[0, 0] -# plot_agreement_matrix(self.comp, ax=ax) -# scores = self.comp.get_ordered_agreement_scores() -# ymin, ymax = ax.get_ylim() -# xmin, xmax = ax.get_xlim() -# unit_ids1 = scores.index.values -# unit_ids2 = scores.columns.values -# inds_1 = self.comp.sorting1.ids_to_indices(unit_ids1) -# snrs = metrics["snr"][inds_1] - -# nb_detectable = len(unit_ids1) - -# if detect_threshold is not None: -# for count, snr in enumerate(snrs): -# if snr < detect_threshold: -# ax.plot([xmin, xmax], [count, count], "k") -# nb_detectable -= 1 - -# ax.plot([nb_detectable + 0.5, nb_detectable + 0.5], [ymin, ymax], "r") - -# # import MEArec as mr -# # mearec_recording = mr.load_recordings(self.mearec_file) -# # positions = mearec_recording.template_locations[:] - -# # self.found_positions = np.zeros((len(self.labels), 2)) -# # for i in range(len(self.labels)): -# # data = self.positions[self.selected_peaks_labels == self.labels[i]] -# # self.found_positions[i] = np.median(data['x']), np.median(data['y']) - -# unit_ids1 = scores.index.values -# unit_ids2 = scores.columns.values -# inds_1 = self.comp.sorting1.ids_to_indices(unit_ids1) -# inds_2 = self.comp.sorting2.ids_to_indices(unit_ids2) - -# a = self.templates["gt"].reshape(len(self.templates["gt"]), -1)[inds_1] -# b = self.templates["clustering"].reshape(len(self.templates["clustering"]), -1)[inds_2] - -# import sklearn - -# if metric == "cosine": -# distances = sklearn.metrics.pairwise.cosine_similarity(a, b) -# else: -# distances = sklearn.metrics.pairwise_distances(a, b, metric) - -# ax = axs[0, 1] -# nb_peaks = np.array( -# [len(self.sliced_gt_sorting.get_unit_spike_train(i)) for i in self.sliced_gt_sorting.unit_ids] -# ) - -# nb_potentials = np.sum(scores.max(1).values > 0.1) - -# ax.plot( -# metrics["snr"][unit_ids1][inds_1[:nb_potentials]], -# nb_peaks[inds_1[:nb_potentials]], -# markersize=10, -# marker=".", -# ls="", -# c="k", -# label="Cluster potentially found", -# ) -# ax.plot( -# metrics["snr"][unit_ids1][inds_1[nb_potentials:]], -# nb_peaks[inds_1[nb_potentials:]], -# markersize=10, -# marker=".", -# ls="", -# c="r", -# label="Cluster clearly missed", -# ) - -# if annotations: -# for l, x, y in zip( -# unit_ids1[: len(inds_2)], -# metrics["snr"][unit_ids1][inds_1[: len(inds_2)]], -# nb_peaks[inds_1[: len(inds_2)]], -# ): -# ax.annotate(l, (x, y)) - -# for l, x, y in zip( -# unit_ids1[len(inds_2) :], -# metrics["snr"][unit_ids1][inds_1[len(inds_2) :]], -# nb_peaks[inds_1[len(inds_2) :]], -# ): -# ax.annotate(l, (x, y), c="r") - -# if detect_threshold is not None: -# ymin, ymax = ax.get_ylim() -# ax.plot([detect_threshold, detect_threshold], [ymin, ymax], "k--") - -# ax.legend() -# ax.set_xlabel("template snr") -# ax.set_ylabel("nb spikes") -# ax.spines["top"].set_visible(False) -# ax.spines["right"].set_visible(False) - -# ax = axs[0, 2] -# im = ax.imshow(distances, aspect="auto") -# ax.set_title(metric) -# fig.colorbar(im, ax=ax) - -# if detect_threshold is not None: -# for count, snr in enumerate(snrs): -# if snr < detect_threshold: -# ax.plot([xmin, xmax], [count, count], "w") - -# ymin, ymax = ax.get_ylim() -# ax.plot([nb_detectable + 0.5, nb_detectable + 0.5], [ymin, ymax], "r") - -# ax.set_yticks(np.arange(0, len(scores.index))) -# ax.set_yticklabels(scores.index, fontsize=8) - -# res = [] -# nb_spikes = [] -# energy = [] -# nb_channels = [] - -# noise_levels = get_noise_levels(self.recording_f, return_scaled=False) - -# for found, real in zip(unit_ids2, unit_ids1): -# wfs = self.waveforms["clustering"].get_waveforms(found) -# wfs_real = self.waveforms["gt"].get_waveforms(real) -# template = self.waveforms["clustering"].get_template(found) -# template_real = self.waveforms["gt"].get_template(real) -# nb_channels += [np.sum(np.std(template_real, 0) < noise_levels)] - -# wfs = wfs.reshape(len(wfs), -1) -# template = template.reshape(template.size, 1).T -# template_real = template_real.reshape(template_real.size, 1).T - -# if metric == "cosine": -# dist = sklearn.metrics.pairwise.cosine_similarity(template, template_real).flatten().tolist() -# else: -# dist = sklearn.metrics.pairwise_distances(template, template_real, metric).flatten().tolist() -# res += dist -# nb_spikes += [self.sliced_gt_sorting.get_unit_spike_train(real).size] -# energy += [np.linalg.norm(template_real)] - -# ax = axs[1, 0] -# res = np.array(res) -# nb_spikes = np.array(nb_spikes) -# nb_channels = np.array(nb_channels) -# energy = np.array(energy) - -# snrs = metrics["snr"][unit_ids1][inds_1[: len(inds_2)]] -# cm = ax.scatter(snrs, nb_spikes, c=res) -# ax.set_xlabel("template snr") -# ax.set_ylabel("nb spikes") -# ax.spines["top"].set_visible(False) -# ax.spines["right"].set_visible(False) -# cb = fig.colorbar(cm, ax=ax) -# cb.set_label(metric) -# if detect_threshold is not None: -# ymin, ymax = ax.get_ylim() -# ax.plot([detect_threshold, detect_threshold], [ymin, ymax], "k--") - -# if annotations: -# for l, x, y in zip(unit_ids1[: len(inds_2)], snrs, nb_spikes): -# ax.annotate(l, (x, y)) - -# ax = axs[1, 1] -# cm = ax.scatter(energy, nb_channels, c=res) -# ax.set_xlabel("template energy") -# ax.set_ylabel("nb channels") -# ax.spines["top"].set_visible(False) -# ax.spines["right"].set_visible(False) -# cb = fig.colorbar(cm, ax=ax) -# cb.set_label(metric) - -# if annotations: -# for l, x, y in zip(unit_ids1[: len(inds_2)], energy, nb_channels): -# ax.annotate(l, (x, y)) - -# ax = axs[1, 2] -# for performance_name in ["accuracy", "recall", "precision"]: -# perf = self.comp.get_performance()[performance_name] -# ax.plot(metrics["snr"], perf, markersize=10, marker=".", ls="", label=performance_name) -# ax.set_xlabel("template snr") -# ax.set_ylabel("performance") -# ax.spines["top"].set_visible(False) -# ax.spines["right"].set_visible(False) -# ax.legend() -# if detect_threshold is not None: -# ymin, ymax = ax.get_ylim() -# ax.plot([detect_threshold, detect_threshold], [ymin, ymax], "k--") - -# plt.tight_layout() + plt.tight_layout(h_pad=0, w_pad=0) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py index bb6d0f7683..5dd0778f76 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py @@ -1,8 +1,6 @@ from __future__ import annotations -from spikeinterface.postprocessing import compute_template_similarity from spikeinterface.sortingcomponents.matching import find_spikes_from_templates -from spikeinterface.core.template import Templates from spikeinterface.core import NumpySorting from spikeinterface.comparison import CollisionGTComparison, compare_sorter_to_ground_truth from spikeinterface.widgets import ( @@ -10,12 +8,10 @@ plot_comparison_collision_by_similarity, ) -from pathlib import Path import pylab as plt import matplotlib.patches as mpatches import numpy as np -import pandas as pd -from .benchmark_tools import BenchmarkStudy, Benchmark +from spikeinterface.sortingcomponents.benchmark.benchmark_tools import Benchmark, BenchmarkStudy from spikeinterface.core.basesorting import minimum_spike_dtype @@ -173,3 +169,74 @@ def plot_comparison_matching( ax.set_xticks([]) ax.set_yticks([]) plt.tight_layout(h_pad=0, w_pad=0) + + def get_count_units(self, case_keys=None, well_detected_score=None, redundant_score=None, overmerged_score=None): + import pandas as pd + + if case_keys is None: + case_keys = list(self.cases.keys()) + + if isinstance(case_keys[0], str): + index = pd.Index(case_keys, name=self.levels) + else: + index = pd.MultiIndex.from_tuples(case_keys, names=self.levels) + + columns = ["num_gt", "num_sorter", "num_well_detected"] + comp = self.get_result(case_keys[0])["gt_comparison"] + if comp.exhaustive_gt: + columns.extend(["num_false_positive", "num_redundant", "num_overmerged", "num_bad"]) + count_units = pd.DataFrame(index=index, columns=columns, dtype=int) + + for key in case_keys: + comp = self.get_result(key)["gt_comparison"] + assert comp is not None, "You need to do study.run_comparisons() first" + + gt_sorting = comp.sorting1 + sorting = comp.sorting2 + + count_units.loc[key, "num_gt"] = len(gt_sorting.get_unit_ids()) + count_units.loc[key, "num_sorter"] = len(sorting.get_unit_ids()) + count_units.loc[key, "num_well_detected"] = comp.count_well_detected_units(well_detected_score) + + if comp.exhaustive_gt: + count_units.loc[key, "num_redundant"] = comp.count_redundant_units(redundant_score) + count_units.loc[key, "num_overmerged"] = comp.count_overmerged_units(overmerged_score) + count_units.loc[key, "num_false_positive"] = comp.count_false_positive_units(redundant_score) + count_units.loc[key, "num_bad"] = comp.count_bad_units() + + return count_units + + def plot_unit_counts(self, case_keys=None, figsize=None): + from spikeinterface.widgets.widget_list import plot_study_unit_counts + + plot_study_unit_counts(self, case_keys, figsize=figsize) + + def plot_unit_losses(self, before, after, figsize=None): + + fig, axs = plt.subplots(ncols=1, nrows=3, figsize=figsize) + + for count, k in enumerate(("accuracy", "recall", "precision")): + + ax = axs[count] + + label = self.cases[after]["label"] + + positions = self.get_result(before)["gt_comparison"].sorting1.get_property("gt_unit_locations") + + analyzer = self.get_sorting_analyzer(before) + metrics_before = analyzer.get_extension("quality_metrics").get_data() + x = metrics_before["snr"].values + + y_before = self.get_result(before)["gt_comparison"].get_performance()[k].values + y_after = self.get_result(after)["gt_comparison"].get_performance()[k].values + if count < 2: + ax.set_xticks([], []) + elif count == 2: + ax.set_xlabel("depth (um)") + im = ax.scatter(positions[:, 1], x, c=(y_after - y_before), marker=".", s=50, cmap="copper") + fig.colorbar(im, ax=ax) + ax.set_title(k) + ax.set_ylabel("snr") + + # if count == 2: + # ax.legend() diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_detection.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_detection.py index 18e736e3aa..09220d162a 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_detection.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_detection.py @@ -21,7 +21,7 @@ import os import numpy as np -from .benchmark_tools import BenchmarkStudy, Benchmark +from spikeinterface.sortingcomponents.benchmark.benchmark_tools import Benchmark, BenchmarkStudy from spikeinterface.core.basesorting import minimum_spike_dtype from spikeinterface.core.sortinganalyzer import create_sorting_analyzer from spikeinterface.core.template_tools import get_template_extremum_channel @@ -29,31 +29,36 @@ class PeakDetectionBenchmark(Benchmark): - def __init__(self, recording, gt_sorting, params, exhaustive_gt=True): + def __init__(self, recording, gt_sorting, params, gt_peaks, exhaustive_gt=True, delta_t_ms=0.2): self.recording = recording self.gt_sorting = gt_sorting - - sorting_analyzer = create_sorting_analyzer(self.gt_sorting, self.recording, format="memory", sparse=False) - sorting_analyzer.compute(["random_spikes", "templates", "spike_amplitudes"]) - extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, outputs="index") - self.gt_peaks = self.gt_sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) + self.gt_peaks = gt_peaks self.params = params self.exhaustive_gt = exhaustive_gt - self.method = params["method"] - self.method_kwargs = params["method_kwargs"] - self.result = {"gt_peaks": self.gt_peaks} - self.result["gt_amplitudes"] = sorting_analyzer.get_extension("spike_amplitudes").get_data() + assert "method" in self.params, "Method should be specified in the params!" + self.method = self.params.get("method") + self.delta_frames = int(delta_t_ms * self.recording.sampling_frequency / 1000) + self.params = self.params["method_kwargs"] + self.result = {} def run(self, **job_kwargs): - peaks = detect_peaks(self.recording, method=self.method, **self.method_kwargs, **job_kwargs) + peaks = detect_peaks(self.recording, self.method, **self.params, **job_kwargs) self.result["peaks"] = peaks def compute_result(self, **result_params): + + sorting_analyzer = create_sorting_analyzer(self.gt_sorting, self.recording, format="memory", sparse=False) + sorting_analyzer.compute("random_spikes") + sorting_analyzer.compute("templates") + sorting_analyzer.compute("spike_amplitudes") + self.result["gt_amplitudes"] = sorting_analyzer.get_extension("spike_amplitudes").get_data() + self.result["gt_templates"] = sorting_analyzer.get_extension("templates").get_data() + spikes = self.result["peaks"] self.result["peak_on_channels"] = NumpySorting.from_peaks( spikes, self.recording.sampling_frequency, unit_ids=self.recording.channel_ids ) - spikes = self.result["gt_peaks"] + spikes = self.gt_peaks self.result["gt_on_channels"] = NumpySorting.from_peaks( spikes, self.recording.sampling_frequency, unit_ids=self.recording.channel_ids ) @@ -63,46 +68,56 @@ def compute_result(self, **result_params): ) gt_peaks = self.gt_sorting.to_spike_vector() - times1 = self.result["gt_peaks"]["sample_index"] - times2 = self.result["peaks"]["sample_index"] + peaks = self.result["peaks"] + times1 = peaks["sample_index"] + times2 = spikes["sample_index"] print("The gt recording has {} peaks and {} have been detected".format(len(times1), len(times2))) - matches = make_matching_events(times1, times2, int(0.4 * self.recording.sampling_frequency / 1000)) - self.matches = matches - self.gt_matches = matches["index1"] + matches = make_matching_events(times1, times2, self.delta_frames) + gt_matches = matches["index2"] + detected_matches = matches["index1"] - self.deltas = {"labels": [], "channels": [], "delta": matches["delta_frame"]} - self.deltas["labels"] = gt_peaks["unit_index"][self.gt_matches] - self.deltas["channels"] = self.result["gt_peaks"]["unit_index"][self.gt_matches] + self.result["matches"] = {"deltas": matches["delta_frame"]} + self.result["matches"]["labels"] = gt_peaks["unit_index"][gt_matches] + self.result["matches"]["channels"] = spikes["unit_index"][gt_matches] + sorting = np.zeros(gt_matches.size, dtype=minimum_spike_dtype) + sorting["sample_index"] = peaks[detected_matches]["sample_index"] + sorting["unit_index"] = gt_peaks["unit_index"][gt_matches] + sorting["segment_index"] = peaks[detected_matches]["segment_index"] + order = np.lexsort((sorting["sample_index"], sorting["segment_index"])) + sorting = sorting[order] self.result["sliced_gt_sorting"] = NumpySorting( - gt_peaks[self.gt_matches], self.recording.sampling_frequency, self.gt_sorting.unit_ids + sorting, self.recording.sampling_frequency, self.gt_sorting.unit_ids + ) + self.result["sliced_gt_comparison"] = GroundTruthComparison( + self.gt_sorting, self.result["sliced_gt_sorting"], exhaustive_gt=self.exhaustive_gt ) - ratio = 100 * len(self.gt_matches) / len(times1) + ratio = 100 * len(gt_matches) / len(times2) print("Only {0:.2f}% of gt peaks are matched to detected peaks".format(ratio)) - # matches = make_matching_events(times2, times1, int(delta * self.sampling_rate / 1000)) - # self.good_matches = matches["index1"] - - # garbage_matches = ~np.isin(np.arange(len(times2)), self.good_matches) - # garbage_channels = self.peaks["channel_index"][garbage_matches] - # garbage_peaks = times2[garbage_matches] - # nb_garbage = len(garbage_peaks) - - # ratio = 100 * len(garbage_peaks) / len(times2) - # self.garbage_sorting = NumpySorting.from_times_labels(garbage_peaks, garbage_channels, self.sampling_rate) + sorting_analyzer = create_sorting_analyzer( + self.result["sliced_gt_sorting"], self.recording, format="memory", sparse=False + ) + sorting_analyzer.compute({"random_spikes": {}, "templates": {}}) - # print("The peaks have {0:.2f}% of garbage (without gt around)".format(ratio)) + self.result["templates"] = sorting_analyzer.get_extension("templates").get_data() - _run_key_saved = [("peaks", "npy"), ("gt_peaks", "npy"), ("gt_amplitudes", "npy")] + _run_key_saved = [("peaks", "npy")] _result_key_saved = [ ("gt_comparison", "pickle"), ("sliced_gt_sorting", "sorting"), + ("sliced_gt_comparison", "pickle"), + ("sliced_gt_sorting", "sorting"), ("peak_on_channels", "sorting"), ("gt_on_channels", "sorting"), + ("matches", "pickle"), + ("templates", "npy"), + ("gt_amplitudes", "npy"), + ("gt_templates", "npy"), ] @@ -118,7 +133,7 @@ def create_benchmark(self, key): benchmark = PeakDetectionBenchmark(recording, gt_sorting, params, **init_kwargs) return benchmark - def plot_agreements(self, case_keys=None, figsize=(15, 15)): + def plot_agreements_by_channels(self, case_keys=None, figsize=(15, 15)): if case_keys is None: case_keys = list(self.cases.keys()) @@ -129,7 +144,18 @@ def plot_agreements(self, case_keys=None, figsize=(15, 15)): ax.set_title(self.cases[key]["label"]) plot_agreement_matrix(self.get_result(key)["gt_comparison"], ax=ax) - def plot_performances_vs_snr(self, case_keys=None, figsize=(15, 15)): + def plot_agreements_by_units(self, case_keys=None, figsize=(15, 15)): + if case_keys is None: + case_keys = list(self.cases.keys()) + + fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False) + + for count, key in enumerate(case_keys): + ax = axs[0, count] + ax.set_title(self.cases[key]["label"]) + plot_agreement_matrix(self.get_result(key)["sliced_gt_comparison"], ax=ax) + + def plot_performances_vs_snr(self, case_keys=None, figsize=(15, 15), detect_threshold=None): if case_keys is None: case_keys = list(self.cases.keys()) @@ -144,14 +170,17 @@ def plot_performances_vs_snr(self, case_keys=None, figsize=(15, 15)): analyzer = self.get_sorting_analyzer(key) metrics = analyzer.get_extension("quality_metrics").get_data() x = metrics["snr"].values - y = self.get_result(key)["gt_comparison"].get_performance()[k].values + y = self.get_result(key)["sliced_gt_comparison"].get_performance()[k].values ax.scatter(x, y, marker=".", label=label) ax.set_title(k) + if detect_threshold is not None: + ymin, ymax = ax.get_ylim() + ax.plot([detect_threshold, detect_threshold], [ymin, ymax], "k--") if count == 2: ax.legend() - def plot_detected_amplitudes(self, case_keys=None, figsize=(15, 5)): + def plot_detected_amplitudes(self, case_keys=None, figsize=(15, 5), detect_threshold=None): if case_keys is None: case_keys = list(self.cases.keys()) @@ -167,283 +196,66 @@ def plot_detected_amplitudes(self, case_keys=None, figsize=(15, 5)): ax.hist(data2, bins=bins, alpha=0.5, label="gt") ax.set_title(self.cases[key]["label"]) ax.legend() + if detect_threshold is not None: + noise_levels = get_noise_levels(self.benchmarks[key].recording, return_scaled=False).mean() + ymin, ymax = ax.get_ylim() + abs_threshold = -detect_threshold * noise_levels + ax.plot([abs_threshold, abs_threshold], [ymin, ymax], "k--") + def plot_deltas_per_cells(self, case_keys=None, figsize=(15, 5)): -# def run(self, peaks=None, positions=None, delta=0.2): -# t_start = time.time() - -# if peaks is not None: -# self._peaks = peaks - -# nb_peaks = len(self.peaks) - -# if positions is not None: -# self._positions = positions - -# spikes1 = self.gt_sorting.to_spike_vector(concatenated=False)[0]["sample_index"] -# times2 = self.peaks["sample_index"] - -# print("The gt recording has {} peaks and {} have been detected".format(len(times1[0]), len(times2))) - -# matches = make_matching_events(spikes1["sample_index"], times2, int(delta * self.sampling_rate / 1000)) -# self.matches = matches - -# self.deltas = {"labels": [], "delta": matches["delta_frame"]} -# self.deltas["labels"] = spikes1["unit_index"][matches["index1"]] + if case_keys is None: + case_keys = list(self.cases.keys()) -# gt_matches = matches["index1"] -# self.sliced_gt_sorting = NumpySorting(spikes1[gt_matches], self.sampling_rate, self.gt_sorting.unit_ids) + fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False) + for count, key in enumerate(case_keys): + ax = axs[0, count] + gt_sorting = self.benchmarks[key].gt_sorting + data = self.get_result(key)["matches"] + for unit_ind, unit_id in enumerate(gt_sorting.unit_ids): + mask = data["labels"] == unit_id + ax.violinplot( + data["deltas"][mask], [unit_ind], widths=2, showmeans=True, showmedians=False, showextrema=False + ) + ax.set_title(self.cases[key]["label"]) + ax.set_xticks(np.arange(len(gt_sorting.unit_ids)), gt_sorting.unit_ids) + ax.set_ylabel("# frames") + ax.set_xlabel("unit id") -# ratio = 100 * len(gt_matches) / len(spikes1) -# print("Only {0:.2f}% of gt peaks are matched to detected peaks".format(ratio)) + def plot_template_similarities(self, case_keys=None, metric="l2", figsize=(15, 5), detect_threshold=None): -# matches = make_matching_events(times2, spikes1["sample_index"], int(delta * self.sampling_rate / 1000)) -# self.good_matches = matches["index1"] + if case_keys is None: + case_keys = list(self.cases.keys()) -# garbage_matches = ~np.isin(np.arange(len(times2)), self.good_matches) -# garbage_channels = self.peaks["channel_index"][garbage_matches] -# garbage_peaks = times2[garbage_matches] -# nb_garbage = len(garbage_peaks) + fig, ax = plt.subplots(ncols=1, nrows=1, figsize=figsize, squeeze=True) + for key in case_keys: -# ratio = 100 * len(garbage_peaks) / len(times2) -# self.garbage_sorting = NumpySorting.from_times_labels(garbage_peaks, garbage_channels, self.sampling_rate) + import sklearn -# print("The peaks have {0:.2f}% of garbage (without gt around)".format(ratio)) + gt_templates = self.get_result(key)["gt_templates"] + found_templates = self.get_result(key)["templates"] + num_templates = len(gt_templates) + distances = np.zeros(num_templates) -# self.comp = GroundTruthComparison(self.gt_sorting, self.sliced_gt_sorting, exhaustive_gt=self.exhaustive_gt) + for i in range(num_templates): -# for label, sorting in zip( -# ["gt", "full_gt", "garbage"], [self.sliced_gt_sorting, self.gt_sorting, self.garbage_sorting] -# ): -# tmp_folder = os.path.join(self.tmp_folder, label) -# if os.path.exists(tmp_folder): -# import shutil + a = gt_templates[i].flatten() + b = found_templates[i].flatten() -# shutil.rmtree(tmp_folder) + if metric == "cosine": + distances[i] = sklearn.metrics.pairwise.cosine_similarity(a[None, :], b[None, :])[0, 0] + else: + distances[i] = sklearn.metrics.pairwise_distances(a[None, :], b[None, :], metric)[0, 0] -# if not (label == "full_gt" and label in self.waveforms): -# if self.verbose: -# print(f"Extracting waveforms for {label}") + label = self.cases[key]["label"] + analyzer = self.get_sorting_analyzer(key) + metrics = analyzer.get_extension("quality_metrics").get_data() + x = metrics["snr"].values + ax.scatter(x, distances, marker=".", label=label) + if detect_threshold is not None: + ymin, ymax = ax.get_ylim() + ax.plot([detect_threshold, detect_threshold], [ymin, ymax], "k--") -# self.waveforms[label] = extract_waveforms( -# self.recording, -# sorting, -# tmp_folder, -# load_if_exists=True, -# ms_before=2.5, -# ms_after=3.5, -# max_spikes_per_unit=500, -# return_scaled=False, -# **self.job_kwargs, -# ) - -# self.templates[label] = self.waveforms[label].get_all_templates(mode="median") - -# if self.gt_peaks is None: -# if self.verbose: -# print("Computing gt peaks") -# gt_peaks_ = self.gt_sorting.to_spike_vector() -# self.gt_peaks = np.zeros( -# gt_peaks_.size, -# dtype=[ -# ("sample_index", "