From 9a198325e7ee084c1aa66f629de72efcce576943 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 14 Jan 2025 21:13:59 +0100 Subject: [PATCH] New iterative and recursive splits to avoid leaking objects --- .../sortingcomponents/clustering/circus.py | 8 +- .../sortingcomponents/clustering/split.py | 119 ++++++++++++++++-- 2 files changed, 114 insertions(+), 13 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 60937a2c54..44c0b01e48 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -52,7 +52,7 @@ class CircusClustering: "sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25}, "recursive_kwargs": { "recursive": True, - "recursive_depth": 3, + #"recursive_depth": 3, "returns_split_count": True, }, "radius_um": 100, @@ -187,11 +187,11 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): np.save(features_folder / "peaks.npy", peaks) original_labels = peaks["channel_index"] - from spikeinterface.sortingcomponents.clustering.split import split_clusters + from spikeinterface.sortingcomponents.clustering.split import split_clusters_alternative min_size = params["hdbscan_kwargs"].get("min_cluster_size", 50) - peak_labels, _ = split_clusters( + peak_labels, _ = split_clusters_alternative( original_labels, recording, features_folder, @@ -203,7 +203,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): waveforms_sparse_mask=sparse_mask, min_size_split=min_size, clusterer_kwargs=d["hdbscan_kwargs"], - n_pca_features=[2, 4, 6, 8, 10], + n_pca_features=[2, 4, 8, 16], ), **params["recursive_kwargs"], **job_kwargs, diff --git a/src/spikeinterface/sortingcomponents/clustering/split.py b/src/spikeinterface/sortingcomponents/clustering/split.py index 7a74d75b10..700bcaf3e7 100644 --- a/src/spikeinterface/sortingcomponents/clustering/split.py +++ b/src/spikeinterface/sortingcomponents/clustering/split.py @@ -42,7 +42,7 @@ def split_clusters( Recording object features_dict_or_folder: dict or folder A dictionary of features precomputed with peak_pipeline or a folder containing npz file for features - method: str, default: "hdbscan_on_local_pca" + method: str, default: "local_feature_clustering" The method name method_kwargs: dict, default: dict() The method option @@ -82,6 +82,7 @@ def split_clusters( labels_set = np.setdiff1d(peak_labels, [-1]) current_max_label = np.max(labels_set) + 1 jobs = [] + for label in labels_set: peak_indices = np.flatnonzero(peak_labels == label) if peak_indices.size > 0: @@ -129,6 +130,108 @@ def split_clusters( return peak_labels +def split_clusters_alternative( + peak_labels, + recording, + features_dict_or_folder, + method="local_feature_clustering", + method_kwargs={}, + recursive=True, + returns_split_count=False, + **job_kwargs, +): + """ + Run recusrsively (or not) in a multi process pool a local split method. + + Parameters + ---------- + peak_labels: numpy.array + Peak label before split + recording: Recording + Recording object + features_dict_or_folder: dict or folder + A dictionary of features precomputed with peak_pipeline or a folder containing npz file for features + method: str, default: "local_feature_clustering" + The method name + method_kwargs: dict, default: dict() + The method option + recursive: bool, default: True + Recursive or not. If True, splits are done until no more splits are possible. + returns_split_count: bool, default: False + Optionally return the split count vector. Same size as labels + + Returns + ------- + new_labels: numpy.ndarray + The labels of peaks after split. + split_count: numpy.ndarray + Optionally returned + """ + + job_kwargs = fix_job_kwargs(job_kwargs) + n_jobs = job_kwargs["n_jobs"] + mp_context = job_kwargs.get("mp_context", None) + progress_bar = job_kwargs["progress_bar"] + max_threads_per_process = job_kwargs.get("max_threads_per_process", 1) + original_labels = peak_labels + peak_labels = peak_labels.copy() + split_count = np.zeros(peak_labels.size, dtype=int) + Executor = get_poolexecutor(n_jobs) + + with Executor( + max_workers=n_jobs, + initializer=split_worker_init, + mp_context=get_context(method=mp_context), + initargs=(recording, features_dict_or_folder, original_labels, method, method_kwargs, max_threads_per_process), + ) as pool: + + has_been_splitted = np.ones(len(peak_labels), dtype=bool) + recursion_level = 1 + + while np.any(has_been_splitted): + + labels_set = np.setdiff1d(peak_labels[has_been_splitted], [-1]) + if labels_set.size == 0: + break + + current_max_label = np.max(labels_set) + 1 + jobs = [] + + for label in labels_set: + peak_indices = np.flatnonzero((peak_labels == label) * has_been_splitted) + if peak_indices.size > 0: + jobs.append(pool.submit(split_function_wrapper, peak_indices, recursion_level)) + + if progress_bar and recursion_level == 1: + iterator = tqdm(jobs, desc=f"split_clusters with {method}", total=len(labels_set)) + else: + iterator = jobs + + for res in iterator: + is_split, local_labels, peak_indices = res.result() + + if not is_split: + has_been_splitted[peak_indices] = False + continue + + mask = local_labels >= 0 + peak_labels[peak_indices[mask]] = local_labels[mask] + current_max_label + peak_labels[peak_indices[~mask]] = local_labels[~mask] + split_count[peak_indices] += 1 + current_max_label += np.max(local_labels[mask]) + 1 + + if not recursive: + break + + recursion_level += 1 + + if returns_split_count: + return peak_labels, split_count + else: + return peak_labels + + + global _ctx @@ -186,6 +289,7 @@ def split( min_size_split=25, n_pca_features=2, minimum_overlap_ratio=0.25, + debug=False ): local_labels = np.zeros(peak_indices.size, dtype=np.int64) @@ -259,9 +363,7 @@ def split( else: raise ValueError(f"wrong clusterer {clusterer}. Possible options are 'hdbscan' or 'isocut5'.") - # DEBUG = True - DEBUG = False - if DEBUG: + if debug: import matplotlib.pyplot as plt labels_set = np.setdiff1d(possible_labels, [-1]) @@ -283,11 +385,10 @@ def split( ax.set_xlabel("PCA features") axs[0].set_title(f"{clusterer} {is_split} {peak_indices[0]} {n_pca}, recursion_level={recursion_level}") - import time - - plt.savefig(f"split_{recursion_level}_{time.time()}.png") - plt.close() - # plt.show() + #import time + #plt.savefig(f"split_{recursion_level}_{time.time()}.png") + #plt.close() + #plt.show() if is_split: break