Skip to content

Commit

Permalink
New iterative and recursive splits to avoid leaking objects
Browse files Browse the repository at this point in the history
  • Loading branch information
yger committed Jan 14, 2025
1 parent 0e3dc72 commit 9a19832
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 13 deletions.
8 changes: 4 additions & 4 deletions src/spikeinterface/sortingcomponents/clustering/circus.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class CircusClustering:
"sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25},
"recursive_kwargs": {
"recursive": True,
"recursive_depth": 3,
#"recursive_depth": 3,
"returns_split_count": True,
},
"radius_um": 100,
Expand Down Expand Up @@ -187,11 +187,11 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
np.save(features_folder / "peaks.npy", peaks)

original_labels = peaks["channel_index"]
from spikeinterface.sortingcomponents.clustering.split import split_clusters
from spikeinterface.sortingcomponents.clustering.split import split_clusters_alternative

min_size = params["hdbscan_kwargs"].get("min_cluster_size", 50)

peak_labels, _ = split_clusters(
peak_labels, _ = split_clusters_alternative(
original_labels,
recording,
features_folder,
Expand All @@ -203,7 +203,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
waveforms_sparse_mask=sparse_mask,
min_size_split=min_size,
clusterer_kwargs=d["hdbscan_kwargs"],
n_pca_features=[2, 4, 6, 8, 10],
n_pca_features=[2, 4, 8, 16],
),
**params["recursive_kwargs"],
**job_kwargs,
Expand Down
119 changes: 110 additions & 9 deletions src/spikeinterface/sortingcomponents/clustering/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def split_clusters(
Recording object
features_dict_or_folder: dict or folder
A dictionary of features precomputed with peak_pipeline or a folder containing npz file for features
method: str, default: "hdbscan_on_local_pca"
method: str, default: "local_feature_clustering"
The method name
method_kwargs: dict, default: dict()
The method option
Expand Down Expand Up @@ -82,6 +82,7 @@ def split_clusters(
labels_set = np.setdiff1d(peak_labels, [-1])
current_max_label = np.max(labels_set) + 1
jobs = []

for label in labels_set:
peak_indices = np.flatnonzero(peak_labels == label)
if peak_indices.size > 0:
Expand Down Expand Up @@ -129,6 +130,108 @@ def split_clusters(
return peak_labels


def split_clusters_alternative(
peak_labels,
recording,
features_dict_or_folder,
method="local_feature_clustering",
method_kwargs={},
recursive=True,
returns_split_count=False,
**job_kwargs,
):
"""
Run recusrsively (or not) in a multi process pool a local split method.
Parameters
----------
peak_labels: numpy.array
Peak label before split
recording: Recording
Recording object
features_dict_or_folder: dict or folder
A dictionary of features precomputed with peak_pipeline or a folder containing npz file for features
method: str, default: "local_feature_clustering"
The method name
method_kwargs: dict, default: dict()
The method option
recursive: bool, default: True
Recursive or not. If True, splits are done until no more splits are possible.
returns_split_count: bool, default: False
Optionally return the split count vector. Same size as labels
Returns
-------
new_labels: numpy.ndarray
The labels of peaks after split.
split_count: numpy.ndarray
Optionally returned
"""

job_kwargs = fix_job_kwargs(job_kwargs)
n_jobs = job_kwargs["n_jobs"]
mp_context = job_kwargs.get("mp_context", None)
progress_bar = job_kwargs["progress_bar"]
max_threads_per_process = job_kwargs.get("max_threads_per_process", 1)
original_labels = peak_labels
peak_labels = peak_labels.copy()
split_count = np.zeros(peak_labels.size, dtype=int)
Executor = get_poolexecutor(n_jobs)

with Executor(
max_workers=n_jobs,
initializer=split_worker_init,
mp_context=get_context(method=mp_context),
initargs=(recording, features_dict_or_folder, original_labels, method, method_kwargs, max_threads_per_process),
) as pool:

has_been_splitted = np.ones(len(peak_labels), dtype=bool)
recursion_level = 1

while np.any(has_been_splitted):

labels_set = np.setdiff1d(peak_labels[has_been_splitted], [-1])
if labels_set.size == 0:
break

current_max_label = np.max(labels_set) + 1
jobs = []

for label in labels_set:
peak_indices = np.flatnonzero((peak_labels == label) * has_been_splitted)
if peak_indices.size > 0:
jobs.append(pool.submit(split_function_wrapper, peak_indices, recursion_level))

if progress_bar and recursion_level == 1:
iterator = tqdm(jobs, desc=f"split_clusters with {method}", total=len(labels_set))
else:
iterator = jobs

for res in iterator:
is_split, local_labels, peak_indices = res.result()

if not is_split:
has_been_splitted[peak_indices] = False
continue

mask = local_labels >= 0
peak_labels[peak_indices[mask]] = local_labels[mask] + current_max_label
peak_labels[peak_indices[~mask]] = local_labels[~mask]
split_count[peak_indices] += 1
current_max_label += np.max(local_labels[mask]) + 1

if not recursive:
break

recursion_level += 1

if returns_split_count:
return peak_labels, split_count
else:
return peak_labels



global _ctx


Expand Down Expand Up @@ -186,6 +289,7 @@ def split(
min_size_split=25,
n_pca_features=2,
minimum_overlap_ratio=0.25,
debug=False
):
local_labels = np.zeros(peak_indices.size, dtype=np.int64)

Expand Down Expand Up @@ -259,9 +363,7 @@ def split(
else:
raise ValueError(f"wrong clusterer {clusterer}. Possible options are 'hdbscan' or 'isocut5'.")

# DEBUG = True
DEBUG = False
if DEBUG:
if debug:
import matplotlib.pyplot as plt

labels_set = np.setdiff1d(possible_labels, [-1])
Expand All @@ -283,11 +385,10 @@ def split(
ax.set_xlabel("PCA features")

axs[0].set_title(f"{clusterer} {is_split} {peak_indices[0]} {n_pca}, recursion_level={recursion_level}")
import time

plt.savefig(f"split_{recursion_level}_{time.time()}.png")
plt.close()
# plt.show()
#import time
#plt.savefig(f"split_{recursion_level}_{time.time()}.png")
#plt.close()
#plt.show()

if is_split:
break
Expand Down

0 comments on commit 9a19832

Please sign in to comment.