From 5781d9e63a6ea26a20fee770496af265dd25ed70 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Wed, 15 Jan 2025 10:09:46 +0100 Subject: [PATCH] Trying to change tqdm behavior --- .../sortingcomponents/clustering/circus.py | 6 +- .../sortingcomponents/clustering/split.py | 116 ++---------------- 2 files changed, 13 insertions(+), 109 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 8fce498fd9..24b5c7d7df 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -52,7 +52,7 @@ class CircusClustering: "sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25}, "recursive_kwargs": { "recursive": True, - # "recursive_depth": 3, + "recursive_depth": 3, "returns_split_count": True, }, "radius_um": 100, @@ -187,11 +187,11 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): np.save(features_folder / "peaks.npy", peaks) original_labels = peaks["channel_index"] - from spikeinterface.sortingcomponents.clustering.split import split_clusters_alternative + from spikeinterface.sortingcomponents.clustering.split import split_clusters min_size = params["hdbscan_kwargs"].get("min_cluster_size", 50) - peak_labels, _ = split_clusters_alternative( + peak_labels, _ = split_clusters( original_labels, recording, features_folder, diff --git a/src/spikeinterface/sortingcomponents/clustering/split.py b/src/spikeinterface/sortingcomponents/clustering/split.py index 4127460f36..5e613f9cec 100644 --- a/src/spikeinterface/sortingcomponents/clustering/split.py +++ b/src/spikeinterface/sortingcomponents/clustering/split.py @@ -89,15 +89,16 @@ def split_clusters( jobs.append(pool.submit(split_function_wrapper, peak_indices, 1)) if progress_bar: - iterator = tqdm(jobs, desc=f"split_clusters with {method}", total=len(labels_set)) - else: - iterator = jobs + pbar = tqdm(desc=f"split_clusters with {method}", total=len(labels_set)) - for res in iterator: + for res in jobs: is_split, local_labels, peak_indices = res.result() # print(is_split, local_labels, peak_indices) if not is_split: continue + + if progress_bar: + pbar.update(1) mask = local_labels >= 0 peak_labels[peak_indices[mask]] = local_labels[mask] + current_max_label @@ -122,108 +123,11 @@ def split_clusters( # print('Relaunched', label, len(peak_indices), recursion_level) jobs.append(pool.submit(split_function_wrapper, peak_indices, recursion_level)) if progress_bar: - iterator.total += 1 - - if returns_split_count: - return peak_labels, split_count - else: - return peak_labels - - -def split_clusters_alternative( - peak_labels, - recording, - features_dict_or_folder, - method="local_feature_clustering", - method_kwargs={}, - recursive=True, - returns_split_count=False, - **job_kwargs, -): - """ - Run recusrsively (or not) in a multi process pool a local split method. - - Parameters - ---------- - peak_labels: numpy.array - Peak label before split - recording: Recording - Recording object - features_dict_or_folder: dict or folder - A dictionary of features precomputed with peak_pipeline or a folder containing npz file for features - method: str, default: "local_feature_clustering" - The method name - method_kwargs: dict, default: dict() - The method option - recursive: bool, default: True - Recursive or not. If True, splits are done until no more splits are possible. - returns_split_count: bool, default: False - Optionally return the split count vector. Same size as labels - - Returns - ------- - new_labels: numpy.ndarray - The labels of peaks after split. - split_count: numpy.ndarray - Optionally returned - """ - - job_kwargs = fix_job_kwargs(job_kwargs) - n_jobs = job_kwargs["n_jobs"] - mp_context = job_kwargs.get("mp_context", None) - progress_bar = job_kwargs["progress_bar"] - max_threads_per_process = job_kwargs.get("max_threads_per_process", 1) - original_labels = peak_labels - peak_labels = peak_labels.copy() - split_count = np.zeros(peak_labels.size, dtype=int) - Executor = get_poolexecutor(n_jobs) - - with Executor( - max_workers=n_jobs, - initializer=split_worker_init, - mp_context=get_context(method=mp_context), - initargs=(recording, features_dict_or_folder, original_labels, method, method_kwargs, max_threads_per_process), - ) as pool: - - has_been_splitted = np.ones(len(peak_labels), dtype=bool) - recursion_level = 1 - - while np.any(has_been_splitted): - - labels_set = np.setdiff1d(peak_labels[has_been_splitted], [-1]) - if labels_set.size == 0: - break - - current_max_label = np.max(labels_set) + 1 - jobs = [] - - for label in labels_set: - peak_indices = np.flatnonzero((peak_labels == label) * has_been_splitted) - if peak_indices.size > 0: - jobs.append(pool.submit(split_function_wrapper, peak_indices, recursion_level)) - - if progress_bar and recursion_level == 1: - iterator = tqdm(jobs, desc=f"split_clusters with {method}", total=len(labels_set)) - else: - iterator = jobs - - for res in iterator: - is_split, local_labels, peak_indices = res.result() - - if not is_split: - has_been_splitted[peak_indices] = False - continue - - mask = local_labels >= 0 - peak_labels[peak_indices[mask]] = local_labels[mask] + current_max_label - peak_labels[peak_indices[~mask]] = local_labels[~mask] - split_count[peak_indices] += 1 - current_max_label += np.max(local_labels[mask]) + 1 - - if not recursive: - break - - recursion_level += 1 + pbar.total += 1 + + if progress_bar: + pbar.close() + del pbar if returns_split_count: return peak_labels, split_count