Skip to content

Commit

Permalink
Trying to change tqdm behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
yger committed Jan 15, 2025
1 parent 169e83d commit 5781d9e
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 109 deletions.
6 changes: 3 additions & 3 deletions src/spikeinterface/sortingcomponents/clustering/circus.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class CircusClustering:
"sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25},
"recursive_kwargs": {
"recursive": True,
# "recursive_depth": 3,
"recursive_depth": 3,
"returns_split_count": True,
},
"radius_um": 100,
Expand Down Expand Up @@ -187,11 +187,11 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
np.save(features_folder / "peaks.npy", peaks)

original_labels = peaks["channel_index"]
from spikeinterface.sortingcomponents.clustering.split import split_clusters_alternative
from spikeinterface.sortingcomponents.clustering.split import split_clusters

min_size = params["hdbscan_kwargs"].get("min_cluster_size", 50)

peak_labels, _ = split_clusters_alternative(
peak_labels, _ = split_clusters(
original_labels,
recording,
features_folder,
Expand Down
116 changes: 10 additions & 106 deletions src/spikeinterface/sortingcomponents/clustering/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,16 @@ def split_clusters(
jobs.append(pool.submit(split_function_wrapper, peak_indices, 1))

if progress_bar:
iterator = tqdm(jobs, desc=f"split_clusters with {method}", total=len(labels_set))
else:
iterator = jobs
pbar = tqdm(desc=f"split_clusters with {method}", total=len(labels_set))

for res in iterator:
for res in jobs:
is_split, local_labels, peak_indices = res.result()
# print(is_split, local_labels, peak_indices)
if not is_split:
continue

if progress_bar:
pbar.update(1)

mask = local_labels >= 0
peak_labels[peak_indices[mask]] = local_labels[mask] + current_max_label
Expand All @@ -122,108 +123,11 @@ def split_clusters(
# print('Relaunched', label, len(peak_indices), recursion_level)
jobs.append(pool.submit(split_function_wrapper, peak_indices, recursion_level))
if progress_bar:
iterator.total += 1

if returns_split_count:
return peak_labels, split_count
else:
return peak_labels


def split_clusters_alternative(
peak_labels,
recording,
features_dict_or_folder,
method="local_feature_clustering",
method_kwargs={},
recursive=True,
returns_split_count=False,
**job_kwargs,
):
"""
Run recusrsively (or not) in a multi process pool a local split method.
Parameters
----------
peak_labels: numpy.array
Peak label before split
recording: Recording
Recording object
features_dict_or_folder: dict or folder
A dictionary of features precomputed with peak_pipeline or a folder containing npz file for features
method: str, default: "local_feature_clustering"
The method name
method_kwargs: dict, default: dict()
The method option
recursive: bool, default: True
Recursive or not. If True, splits are done until no more splits are possible.
returns_split_count: bool, default: False
Optionally return the split count vector. Same size as labels
Returns
-------
new_labels: numpy.ndarray
The labels of peaks after split.
split_count: numpy.ndarray
Optionally returned
"""

job_kwargs = fix_job_kwargs(job_kwargs)
n_jobs = job_kwargs["n_jobs"]
mp_context = job_kwargs.get("mp_context", None)
progress_bar = job_kwargs["progress_bar"]
max_threads_per_process = job_kwargs.get("max_threads_per_process", 1)
original_labels = peak_labels
peak_labels = peak_labels.copy()
split_count = np.zeros(peak_labels.size, dtype=int)
Executor = get_poolexecutor(n_jobs)

with Executor(
max_workers=n_jobs,
initializer=split_worker_init,
mp_context=get_context(method=mp_context),
initargs=(recording, features_dict_or_folder, original_labels, method, method_kwargs, max_threads_per_process),
) as pool:

has_been_splitted = np.ones(len(peak_labels), dtype=bool)
recursion_level = 1

while np.any(has_been_splitted):

labels_set = np.setdiff1d(peak_labels[has_been_splitted], [-1])
if labels_set.size == 0:
break

current_max_label = np.max(labels_set) + 1
jobs = []

for label in labels_set:
peak_indices = np.flatnonzero((peak_labels == label) * has_been_splitted)
if peak_indices.size > 0:
jobs.append(pool.submit(split_function_wrapper, peak_indices, recursion_level))

if progress_bar and recursion_level == 1:
iterator = tqdm(jobs, desc=f"split_clusters with {method}", total=len(labels_set))
else:
iterator = jobs

for res in iterator:
is_split, local_labels, peak_indices = res.result()

if not is_split:
has_been_splitted[peak_indices] = False
continue

mask = local_labels >= 0
peak_labels[peak_indices[mask]] = local_labels[mask] + current_max_label
peak_labels[peak_indices[~mask]] = local_labels[~mask]
split_count[peak_indices] += 1
current_max_label += np.max(local_labels[mask]) + 1

if not recursive:
break

recursion_level += 1
pbar.total += 1

if progress_bar:
pbar.close()
del pbar

if returns_split_count:
return peak_labels, split_count
Expand Down

0 comments on commit 5781d9e

Please sign in to comment.