From d27cd31924de014fb5e71d1a00a6cfd99b928271 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 13 Nov 2024 09:10:16 +0100 Subject: [PATCH 01/64] Cleaning API --- .../sortingcomponents/clustering/circus.py | 5 +---- .../sortingcomponents/clustering/dummy.py | 2 +- .../sortingcomponents/clustering/main.py | 2 +- .../sortingcomponents/clustering/position.py | 5 ++--- .../clustering/position_and_features.py | 19 +++++-------------- .../clustering/position_and_pca.py | 9 ++++----- .../clustering/position_ptp_scaled.py | 5 ++--- .../clustering/random_projections.py | 6 +----- .../clustering/sliding_hdbscan.py | 7 +++---- .../clustering/sliding_nn.py | 13 ++++++------- .../sortingcomponents/clustering/tdc.py | 5 +---- .../sortingcomponents/clustering/tools.py | 4 ---- 12 files changed, 27 insertions(+), 55 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 99c59f493e..78227a65f3 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -63,16 +63,13 @@ class CircusClustering: "rank": 5, "noise_levels": None, "tmp_folder": None, - "job_kwargs": {}, "verbose": True, } @classmethod - def main_function(cls, recording, peaks, params): + def main_function(cls, recording, peaks, params, job_kwargs=dict()): assert HAVE_HDBSCAN, "random projections clustering needs hdbscan to be installed" - job_kwargs = fix_job_kwargs(params["job_kwargs"]) - d = params verbose = d["verbose"] diff --git a/src/spikeinterface/sortingcomponents/clustering/dummy.py b/src/spikeinterface/sortingcomponents/clustering/dummy.py index c1032ee6c6..b5761ad5cf 100644 --- a/src/spikeinterface/sortingcomponents/clustering/dummy.py +++ b/src/spikeinterface/sortingcomponents/clustering/dummy.py @@ -13,7 +13,7 @@ class DummyClustering: _default_params = {} @classmethod - def main_function(cls, recording, peaks, params): + def main_function(cls, recording, peaks, params, job_kwargs=dict()): labels = np.arange(recording.get_num_channels(), dtype="int64") peak_labels = peaks["channel_index"] return labels, peak_labels diff --git a/src/spikeinterface/sortingcomponents/clustering/main.py b/src/spikeinterface/sortingcomponents/clustering/main.py index 99881f2f34..ba0fe6f9ac 100644 --- a/src/spikeinterface/sortingcomponents/clustering/main.py +++ b/src/spikeinterface/sortingcomponents/clustering/main.py @@ -41,7 +41,7 @@ def find_cluster_from_peaks(recording, peaks, method="stupid", method_kwargs={}, params = method_class._default_params.copy() params.update(**method_kwargs) - outputs = method_class.main_function(recording, peaks, params) + outputs = method_class.main_function(recording, peaks, params, job_kwargs=job_kwargs) if extra_outputs: return outputs diff --git a/src/spikeinterface/sortingcomponents/clustering/position.py b/src/spikeinterface/sortingcomponents/clustering/position.py index ae772206bb..dc76d787f6 100644 --- a/src/spikeinterface/sortingcomponents/clustering/position.py +++ b/src/spikeinterface/sortingcomponents/clustering/position.py @@ -25,18 +25,17 @@ class PositionClustering: "hdbscan_kwargs": {"min_cluster_size": 20, "allow_single_cluster": True, "core_dist_n_jobs": -1}, "debug": False, "tmp_folder": None, - "job_kwargs": {"n_jobs": -1, "chunk_memory": "10M"}, } @classmethod - def main_function(cls, recording, peaks, params): + def main_function(cls, recording, peaks, params, job_kwargs=dict()): assert HAVE_HDBSCAN, "position clustering need hdbscan to be installed" d = params if d["peak_locations"] is None: from spikeinterface.sortingcomponents.peak_localization import localize_peaks - peak_locations = localize_peaks(recording, peaks, **d["peak_localization_kwargs"], **d["job_kwargs"]) + peak_locations = localize_peaks(recording, peaks, **d["peak_localization_kwargs"], **job_kwargs) else: peak_locations = d["peak_locations"] diff --git a/src/spikeinterface/sortingcomponents/clustering/position_and_features.py b/src/spikeinterface/sortingcomponents/clustering/position_and_features.py index d23eb26239..513e8085ed 100644 --- a/src/spikeinterface/sortingcomponents/clustering/position_and_features.py +++ b/src/spikeinterface/sortingcomponents/clustering/position_and_features.py @@ -42,23 +42,14 @@ class PositionAndFeaturesClustering: "ms_before": 1.5, "ms_after": 1.5, "cleaning_method": "dip", - "job_kwargs": {"n_jobs": -1, "chunk_memory": "10M", "progress_bar": True}, } @classmethod - def main_function(cls, recording, peaks, params): + def main_function(cls, recording, peaks, params, job_kwargs=dict()): from sklearn.preprocessing import QuantileTransformer assert HAVE_HDBSCAN, "twisted clustering needs hdbscan to be installed" - if "n_jobs" in params["job_kwargs"]: - if params["job_kwargs"]["n_jobs"] == -1: - params["job_kwargs"]["n_jobs"] = os.cpu_count() - - if "core_dist_n_jobs" in params["hdbscan_kwargs"]: - if params["hdbscan_kwargs"]["core_dist_n_jobs"] == -1: - params["hdbscan_kwargs"]["core_dist_n_jobs"] = os.cpu_count() - d = params peak_dtype = [("sample_index", "int64"), ("unit_index", "int64"), ("segment_index", "int64")] @@ -80,7 +71,7 @@ def main_function(cls, recording, peaks, params): } features_data = compute_features_from_peaks( - recording, peaks, features_list, features_params, ms_before=1, ms_after=1, **params["job_kwargs"] + recording, peaks, features_list, features_params, ms_before=1, ms_after=1, **job_kwargs ) hdbscan_data = np.zeros((len(peaks), 3), dtype=np.float32) @@ -150,7 +141,7 @@ def main_function(cls, recording, peaks, params): dtype=recording.get_dtype(), sparsity_mask=None, copy=True, - **params["job_kwargs"], + **job_kwargs, ) noise_levels = get_noise_levels(recording, return_scaled=False) @@ -181,7 +172,7 @@ def main_function(cls, recording, peaks, params): nbefore, nafter, return_scaled=False, - **params["job_kwargs"], + **job_kwargs, ) templates = Templates( templates_array=templates_array, @@ -193,7 +184,7 @@ def main_function(cls, recording, peaks, params): ) labels, peak_labels = remove_duplicates_via_matching( - templates, peak_labels, job_kwargs=params["job_kwargs"], **params["cleaning_kwargs"] + templates, peak_labels, job_kwargs=job_kwargs, **params["cleaning_kwargs"] ) shutil.rmtree(tmp_folder) diff --git a/src/spikeinterface/sortingcomponents/clustering/position_and_pca.py b/src/spikeinterface/sortingcomponents/clustering/position_and_pca.py index 4dfe3c960c..3b730752c1 100644 --- a/src/spikeinterface/sortingcomponents/clustering/position_and_pca.py +++ b/src/spikeinterface/sortingcomponents/clustering/position_and_pca.py @@ -38,7 +38,6 @@ class PositionAndPCAClustering: "ms_after": 2.5, "n_components_by_channel": 3, "n_components": 5, - "job_kwargs": {"n_jobs": -1, "chunk_memory": "10M", "progress_bar": True}, "hdbscan_global_kwargs": {"min_cluster_size": 20, "allow_single_cluster": True, "core_dist_n_jobs": -1}, "hdbscan_local_kwargs": {"min_cluster_size": 20, "allow_single_cluster": True, "core_dist_n_jobs": -1}, "waveform_mode": "shared_memory", @@ -73,7 +72,7 @@ def _check_params(cls, recording, peaks, params): return params2 @classmethod - def main_function(cls, recording, peaks, params): + def main_function(cls, recording, peaks, params, job_kwargs=dict()): # res = PositionClustering(recording, peaks, params) assert HAVE_HDBSCAN, "position_and_pca clustering need hdbscan to be installed" @@ -86,7 +85,7 @@ def main_function(cls, recording, peaks, params): from spikeinterface.sortingcomponents.peak_localization import localize_peaks peak_locations = localize_peaks( - recording, peaks, **params["peak_localization_kwargs"], **params["job_kwargs"] + recording, peaks, **params["peak_localization_kwargs"], **job_kwargs ) else: peak_locations = params["peak_locations"] @@ -155,7 +154,7 @@ def main_function(cls, recording, peaks, params): dtype=recording.get_dtype(), sparsity_mask=sparsity_mask, copy=(params["waveform_mode"] == "shared_memory"), - **params["job_kwargs"], + **job_kwargs, ) noise = get_random_data_chunks( @@ -222,7 +221,7 @@ def main_function(cls, recording, peaks, params): dtype=recording.get_dtype(), sparsity_mask=sparsity_mask3, copy=(params["waveform_mode"] == "shared_memory"), - **params["job_kwargs"], + **job_kwargs, ) clean_peak_labels, peak_sample_shifts = auto_clean_clustering( diff --git a/src/spikeinterface/sortingcomponents/clustering/position_ptp_scaled.py b/src/spikeinterface/sortingcomponents/clustering/position_ptp_scaled.py index 788addf1e6..0f7390d7ac 100644 --- a/src/spikeinterface/sortingcomponents/clustering/position_ptp_scaled.py +++ b/src/spikeinterface/sortingcomponents/clustering/position_ptp_scaled.py @@ -26,7 +26,6 @@ class PositionPTPScaledClustering: "ptps": None, "scales": (1, 1, 10), "peak_localization_kwargs": {"method": "center_of_mass"}, - "job_kwargs": {"n_jobs": -1, "chunk_memory": "10M", "progress_bar": True}, "hdbscan_kwargs": { "min_cluster_size": 20, "min_samples": 20, @@ -38,7 +37,7 @@ class PositionPTPScaledClustering: } @classmethod - def main_function(cls, recording, peaks, params): + def main_function(cls, recording, peaks, params, job_kwargs=dict()): assert HAVE_HDBSCAN, "position clustering need hdbscan to be installed" d = params @@ -60,7 +59,7 @@ def main_function(cls, recording, peaks, params): if d["ptps"] is None: (ptps,) = compute_features_from_peaks( - recording, peaks, ["ptp"], feature_params={"ptp": {"all_channels": True}}, **d["job_kwargs"] + recording, peaks, ["ptp"], feature_params={"ptp": {"all_channels": True}}, **job_kwargs ) else: ptps = d["ptps"] diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index f7ca999d53..36033c61e1 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -17,7 +17,6 @@ from spikeinterface.core.waveform_tools import estimate_templates from .clustering_tools import remove_duplicates_via_matching from spikeinterface.core.recording_tools import get_noise_levels, get_channel_distances -from spikeinterface.core.job_tools import fix_job_kwargs from spikeinterface.sortingcomponents.waveforms.savgol_denoiser import SavGolDenoiser from spikeinterface.sortingcomponents.features_from_peaks import RandomProjectionsFeature from spikeinterface.core.template import Templates @@ -55,16 +54,13 @@ class RandomProjectionClustering: "noise_levels": None, "smoothing_kwargs": {"window_length_ms": 0.25}, "tmp_folder": None, - "job_kwargs": {}, "verbose": True, } @classmethod - def main_function(cls, recording, peaks, params): + def main_function(cls, recording, peaks, params, job_kwargs=dict()): assert HAVE_HDBSCAN, "random projections clustering need hdbscan to be installed" - job_kwargs = fix_job_kwargs(params["job_kwargs"]) - d = params verbose = d["verbose"] diff --git a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py index 8b9acbc92d..ee56894b13 100644 --- a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py +++ b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py @@ -55,11 +55,10 @@ class SlidingHdbscanClustering: "auto_merge_quantile_limit": 0.8, "ratio_num_channel_intersect": 0.5, # ~ 'auto_trash_misalignment_shift' : 4, - "job_kwargs": {"n_jobs": -1, "chunk_memory": "10M", "progress_bar": True}, } @classmethod - def main_function(cls, recording, peaks, params): + def main_function(cls, recording, peaks, params, job_kwargs=dict()): assert HAVE_HDBSCAN, "sliding_hdbscan clustering need hdbscan to be installed" params = cls._check_params(recording, peaks, params) wfs_arrays, sparsity_mask, noise = cls._initialize_folder(recording, peaks, params) @@ -145,7 +144,7 @@ def _initialize_folder(cls, recording, peaks, params): dtype=dtype, sparsity_mask=sparsity_mask, copy=(d["waveform_mode"] == "shared_memory"), - **d["job_kwargs"], + **job_kwargs, ) # noise @@ -465,7 +464,7 @@ def _prepare_clean(cls, recording, peaks, wfs_arrays, sparsity_mask, peak_labels dtype=recording.get_dtype(), sparsity_mask=sparsity_mask2, copy=(d["waveform_mode"] == "shared_memory"), - **d["job_kwargs"], + **job_kwargs, ) return wfs_arrays2, sparsity_mask2 diff --git a/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py b/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py index a6ffa5fdc2..40cedacdc5 100644 --- a/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py +++ b/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py @@ -71,11 +71,10 @@ class SlidingNNClustering: "tmp_folder": None, "verbose": False, "tmp_folder": None, - "job_kwargs": {"n_jobs": -1}, } @classmethod - def _initialize_folder(cls, recording, peaks, params): + def _initialize_folder(cls, recording, peaks, params, job_kwargs=dict()): assert HAVE_NUMBA, "SlidingNN needs numba to work" assert HAVE_TORCH, "SlidingNN needs torch to work" assert HAVE_NNDESCENT, "SlidingNN needs pynndescent to work" @@ -126,16 +125,16 @@ def _initialize_folder(cls, recording, peaks, params): dtype=dtype, sparsity_mask=sparsity_mask, copy=(d["waveform_mode"] == "shared_memory"), - **d["job_kwargs"], + **job_kwargs, ) return wfs_arrays, sparsity_mask @classmethod - def main_function(cls, recording, peaks, params): + def main_function(cls, recording, peaks, params, job_kwargs=dict()): d = params - # wfs_arrays, sparsity_mask, noise = cls._initialize_folder(recording, peaks, params) + # wfs_arrays, sparsity_mask, noise = cls._initialize_folder(recording, peaks, params, job_kwargs) # prepare neighborhood parameters fs = recording.get_sampling_frequency() @@ -228,7 +227,7 @@ def main_function(cls, recording, peaks, params): n_channel_neighbors=d["n_channel_neighbors"], low_memory=d["low_memory"], knn_verbose=d["verbose"], - n_jobs=d["job_kwargs"]["n_jobs"], + n_jobs=job_kwargs["n_jobs"], ) # remove the first nearest neighbor (which should be self) knn_distances = knn_distances[:, 1:] @@ -297,7 +296,7 @@ def main_function(cls, recording, peaks, params): # TODO HDBSCAN can be done on GPU with NVIDIA RAPIDS for speed clusterer = hdbscan.HDBSCAN( prediction_data=True, - core_dist_n_jobs=d["job_kwargs"]["n_jobs"], + core_dist_n_jobs=job_kwargs["n_jobs"], **d["hdbscan_kwargs"], ).fit(embeddings_chunk) diff --git a/src/spikeinterface/sortingcomponents/clustering/tdc.py b/src/spikeinterface/sortingcomponents/clustering/tdc.py index 13af5b0fab..c6b94eaa48 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tdc.py +++ b/src/spikeinterface/sortingcomponents/clustering/tdc.py @@ -50,15 +50,12 @@ class TdcClustering: "merge_radius_um": 40.0, "threshold_diff": 1.5, }, - "job_kwargs": {}, } @classmethod - def main_function(cls, recording, peaks, params): + def main_function(cls, recording, peaks, params, job_kwargs=dict()): import hdbscan - job_kwargs = params["job_kwargs"] - if params["folder"] is None: randname = "".join(random.choices(string.ascii_uppercase + string.digits, k=6)) clustering_folder = get_global_tmp_folder() / f"tdcclustering_{randname}" diff --git a/src/spikeinterface/sortingcomponents/clustering/tools.py b/src/spikeinterface/sortingcomponents/clustering/tools.py index e2a0d273d6..64cc0f39c4 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/tools.py @@ -172,8 +172,6 @@ def apply_waveforms_shift(waveforms, peak_shifts, inplace=False): """ - print("apply_waveforms_shift") - if inplace: aligned_waveforms = waveforms else: @@ -193,6 +191,4 @@ def apply_waveforms_shift(waveforms, peak_shifts, inplace=False): else: aligned_waveforms[mask, -shift:, :] = wfs[:, :-shift, :] - print("apply_waveforms_shift DONE") - return aligned_waveforms From 32f2a38360e501381e71855496f7c405eb7098be Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 13 Nov 2024 09:14:49 +0100 Subject: [PATCH 02/64] WIP --- src/spikeinterface/sortingcomponents/clustering/circus.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 78227a65f3..3eae272fbe 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -257,13 +257,10 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): if verbose: print("We found %d raw clusters, starting to clean with matching..." % (len(templates.unit_ids))) - cleaning_matching_params = params["job_kwargs"].copy() - cleaning_matching_params["n_jobs"] = 1 - cleaning_matching_params["progress_bar"] = False cleaning_params = params["cleaning_kwargs"].copy() labels, peak_labels = remove_duplicates_via_matching( - templates, peak_labels, job_kwargs=cleaning_matching_params, **cleaning_params + templates, peak_labels, job_kwargs=job_kwargs, **cleaning_params ) if verbose: From d8d5b7052cb20c7a5cc085031817e158bbac1550 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 13 Nov 2024 09:16:25 +0100 Subject: [PATCH 03/64] WIP --- src/spikeinterface/sortingcomponents/clustering/circus.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 3eae272fbe..5982c270cb 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -257,10 +257,12 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): if verbose: print("We found %d raw clusters, starting to clean with matching..." % (len(templates.unit_ids))) + cleaning_job_kwargs = job_kwargs.copy() + cleaning_job_kwargs["progress_bar"] = False cleaning_params = params["cleaning_kwargs"].copy() labels, peak_labels = remove_duplicates_via_matching( - templates, peak_labels, job_kwargs=job_kwargs, **cleaning_params + templates, peak_labels, job_kwargs=cleaning_job_kwargs, **cleaning_params ) if verbose: From bd82f45382eb89b9acb4aaa2f18bc0f61286dcec Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 13 Nov 2024 09:24:15 +0100 Subject: [PATCH 04/64] Cleaning clustering --- .../sortingcomponents/clustering/clean.py | 2 -- .../clustering/random_projections.py | 12 +++--------- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/clean.py b/src/spikeinterface/sortingcomponents/clustering/clean.py index c7d57b14e4..e8bc5a1d49 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clean.py +++ b/src/spikeinterface/sortingcomponents/clustering/clean.py @@ -32,7 +32,6 @@ def clean_clusters( count = np.zeros(n, dtype="int64") for i, label in enumerate(labels_set): count[i] = np.sum(peak_labels == label) - print(count) templates = compute_template_from_sparse(peaks, peak_labels, labels_set, sparse_wfs, sparse_mask, total_channels) @@ -42,6 +41,5 @@ def clean_clusters( max_values = -np.min(templates, axis=(1, 2)) elif peak_sign == "pos": max_values = np.max(templates, axis=(1, 2)) - print(max_values) return clean_labels diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 36033c61e1..40bb4ac987 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -152,18 +152,12 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): if verbose: print("We found %d raw clusters, starting to clean with matching..." % (len(templates.unit_ids))) - cleaning_matching_params = job_kwargs.copy() - for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: - if value in cleaning_matching_params: - cleaning_matching_params[value] = None - cleaning_matching_params["chunk_duration"] = "100ms" - cleaning_matching_params["n_jobs"] = 1 - cleaning_matching_params["progress_bar"] = False - + cleaning_job_kwargs = job_kwargs.copy() + cleaning_job_kwargs["progress_bar"] = False cleaning_params = params["cleaning_kwargs"].copy() labels, peak_labels = remove_duplicates_via_matching( - templates, peak_labels, job_kwargs=cleaning_matching_params, **cleaning_params + templates, peak_labels, job_kwargs=cleaning_job_kwargs, **cleaning_params ) if verbose: From dbe5fa3f914838fe6f6bd10e23a6342396d44448 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 13 Nov 2024 09:27:44 +0100 Subject: [PATCH 05/64] WIP --- src/spikeinterface/sorters/internal/spyking_circus2.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index eed693b343..5cce8b54f5 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -219,7 +219,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_params["radius_um"] = radius_um clustering_params["waveforms"]["ms_before"] = ms_before clustering_params["waveforms"]["ms_after"] = ms_after - clustering_params["job_kwargs"] = job_kwargs clustering_params["noise_levels"] = noise_levels clustering_params["ms_before"] = exclude_sweep_ms clustering_params["ms_after"] = exclude_sweep_ms @@ -233,7 +232,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_method = "random_projections" labels, peak_labels = find_cluster_from_peaks( - recording_w, selected_peaks, method=clustering_method, method_kwargs=clustering_params + recording_w, selected_peaks, method=clustering_method, method_kwargs=clustering_params, **job_kwargs ) ## We get the labels for our peaks From 3a78e7ad216ff8e9b927204555c94e4d42aaad17 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 13 Nov 2024 08:30:34 +0000 Subject: [PATCH 06/64] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sortingcomponents/clustering/position_and_pca.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/position_and_pca.py b/src/spikeinterface/sortingcomponents/clustering/position_and_pca.py index 3b730752c1..c4f372fc21 100644 --- a/src/spikeinterface/sortingcomponents/clustering/position_and_pca.py +++ b/src/spikeinterface/sortingcomponents/clustering/position_and_pca.py @@ -84,9 +84,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): if params["peak_locations"] is None: from spikeinterface.sortingcomponents.peak_localization import localize_peaks - peak_locations = localize_peaks( - recording, peaks, **params["peak_localization_kwargs"], **job_kwargs - ) + peak_locations = localize_peaks(recording, peaks, **params["peak_localization_kwargs"], **job_kwargs) else: peak_locations = params["peak_locations"] From 035d8d2a4f24f27a5cb9f314e05acb2b3448fb98 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 13 Nov 2024 15:36:44 +0100 Subject: [PATCH 07/64] Fix --- .../sortingcomponents/clustering/sliding_hdbscan.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py index ee56894b13..56f7e35096 100644 --- a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py +++ b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py @@ -65,7 +65,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): peak_labels = cls._find_clusters(recording, peaks, wfs_arrays, sparsity_mask, noise, params) wfs_arrays2, sparsity_mask2 = cls._prepare_clean( - recording, peaks, wfs_arrays, sparsity_mask, peak_labels, params + recording, peaks, wfs_arrays, sparsity_mask, peak_labels, params, job_kwargs ) clean_peak_labels, peak_sample_shifts = cls._clean_cluster( @@ -400,7 +400,7 @@ def _find_clusters(cls, recording, peaks, wfs_arrays, sparsity_mask, noise, d): return peak_labels @classmethod - def _prepare_clean(cls, recording, peaks, wfs_arrays, sparsity_mask, peak_labels, d): + def _prepare_clean(cls, recording, peaks, wfs_arrays, sparsity_mask, peak_labels, d, job_kwargs): tmp_folder = d["tmp_folder"] if tmp_folder is None: wf_folder = None From 389c86cb07322107f927225868e61c6bf20c263b Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 13 Nov 2024 16:01:34 +0100 Subject: [PATCH 08/64] Fixes --- .../sortingcomponents/clustering/sliding_hdbscan.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py index 56f7e35096..2ae810ae20 100644 --- a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py +++ b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py @@ -99,7 +99,7 @@ def _check_params(cls, recording, peaks, params): return params2 @classmethod - def _initialize_folder(cls, recording, peaks, params): + def _initialize_folder(cls, recording, peaks, params, job_kwargs=dict()): d = params tmp_folder = params["tmp_folder"] @@ -400,7 +400,7 @@ def _find_clusters(cls, recording, peaks, wfs_arrays, sparsity_mask, noise, d): return peak_labels @classmethod - def _prepare_clean(cls, recording, peaks, wfs_arrays, sparsity_mask, peak_labels, d, job_kwargs): + def _prepare_clean(cls, recording, peaks, wfs_arrays, sparsity_mask, peak_labels, d, job_kwargs=dict()): tmp_folder = d["tmp_folder"] if tmp_folder is None: wf_folder = None From 530864bceda4b436f5cf16fd5258efd19ccaa76d Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 14 Nov 2024 09:33:03 +0000 Subject: [PATCH 09/64] Replace qm_params with metric_params --- .../qualitymetrics/pca_metrics.py | 36 ++++++++++++------- .../quality_metric_calculator.py | 30 ++++++++++------ .../tests/test_metrics_functions.py | 10 +++--- 3 files changed, 47 insertions(+), 29 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index 4c68dfea59..b4952bfe6d 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -6,6 +6,7 @@ from copy import deepcopy import platform from tqdm.auto import tqdm +from warnings import warn import numpy as np @@ -52,6 +53,7 @@ def get_quality_pca_metric_list(): def compute_pc_metrics( sorting_analyzer, metric_names=None, + metric_params=None, qm_params=None, unit_ids=None, seed=None, @@ -70,7 +72,7 @@ def compute_pc_metrics( metric_names : list of str, default: None The list of PC metrics to compute. If not provided, defaults to all PC metrics. - qm_params : dict or None + metric_params : dict or None Dictionary with parameters for each PC metric function. unit_ids : list of int or None List of unit ids to compute metrics for. @@ -86,6 +88,14 @@ def compute_pc_metrics( pc_metrics : dict The computed PC metrics. """ + + if qm_params is not None and metric_params is None: + deprecation_msg = ( + "`qm_params` is deprecated and will be removed in version 0.104.0 Please use metric_params instead" + ) + metric_params = qm_params + warn(deprecation_msg, category=DeprecationWarning, stacklevel=2) + pca_ext = sorting_analyzer.get_extension("principal_components") assert pca_ext is not None, "calculate_pc_metrics() need extension 'principal_components'" @@ -93,8 +103,8 @@ def compute_pc_metrics( if metric_names is None: metric_names = _possible_pc_metric_names.copy() - if qm_params is None: - qm_params = _default_params + if metric_params is None: + metric_params = _default_params extremum_channels = get_template_extremum_channel(sorting_analyzer) @@ -147,7 +157,7 @@ def compute_pc_metrics( pcs = dense_projections[np.isin(all_labels, neighbor_unit_ids)][:, :, neighbor_channel_indices] pcs_flat = pcs.reshape(pcs.shape[0], -1) - func_args = (pcs_flat, labels, non_nn_metrics, unit_id, unit_ids, qm_params, max_threads_per_process) + func_args = (pcs_flat, labels, non_nn_metrics, unit_id, unit_ids, metric_params, max_threads_per_process) items.append(func_args) if not run_in_parallel and non_nn_metrics: @@ -184,7 +194,7 @@ def compute_pc_metrics( units_loop = tqdm(units_loop, desc=f"calculate {metric_name} metric", total=len(unit_ids)) func = _nn_metric_name_to_func[metric_name] - metric_params = qm_params[metric_name] if metric_name in qm_params else {} + metric_params = metric_params[metric_name] if metric_name in metric_params else {} for _, unit_id in units_loop: try: @@ -213,7 +223,7 @@ def compute_pc_metrics( def calculate_pc_metrics( - sorting_analyzer, metric_names=None, qm_params=None, unit_ids=None, seed=None, n_jobs=1, progress_bar=False + sorting_analyzer, metric_names=None, metric_params=None, unit_ids=None, seed=None, n_jobs=1, progress_bar=False ): warnings.warn( "The `calculate_pc_metrics` function is deprecated and will be removed in 0.103.0. Please use compute_pc_metrics instead", @@ -224,7 +234,7 @@ def calculate_pc_metrics( pc_metrics = compute_pc_metrics( sorting_analyzer, metric_names=metric_names, - qm_params=qm_params, + metric_params=metric_params, unit_ids=unit_ids, seed=seed, n_jobs=n_jobs, @@ -977,16 +987,16 @@ def _compute_isolation(pcs_target_unit, pcs_other_unit, n_neighbors: int): def pca_metrics_one_unit(args): - (pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params, max_threads_per_process) = args + (pcs_flat, labels, metric_names, unit_id, unit_ids, metric_params, max_threads_per_process) = args if max_threads_per_process is None: - return _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params) + return _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, metric_params) else: with threadpool_limits(limits=int(max_threads_per_process)): - return _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params) + return _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, metric_params) -def _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params): +def _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, metric_params): pc_metrics = {} # metrics if "isolation_distance" in metric_names or "l_ratio" in metric_names: @@ -1015,7 +1025,7 @@ def _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_ if "nearest_neighbor" in metric_names: try: nn_hit_rate, nn_miss_rate = nearest_neighbors_metrics( - pcs_flat, labels, unit_id, **qm_params["nearest_neighbor"] + pcs_flat, labels, unit_id, **metric_params["nearest_neighbor"] ) except: nn_hit_rate = np.nan @@ -1024,7 +1034,7 @@ def _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_ pc_metrics["nn_miss_rate"] = nn_miss_rate if "silhouette" in metric_names: - silhouette_method = qm_params["silhouette"]["method"] + silhouette_method = metric_params["silhouette"]["method"] if "simplified" in silhouette_method: try: unit_silhouette_score = simplified_silhouette_score(pcs_flat, labels, unit_id) diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index b6a50d60f5..eb380304b6 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -6,6 +6,7 @@ from copy import deepcopy import numpy as np +from warnings import warn from spikeinterface.core.job_tools import fix_job_kwargs from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension @@ -31,7 +32,7 @@ class ComputeQualityMetrics(AnalyzerExtension): A SortingAnalyzer object. metric_names : list or None List of quality metrics to compute. - qm_params : dict or None + metric_params : dict or None Dictionary with parameters for quality metrics calculation. Default parameters can be obtained with: `si.qualitymetrics.get_default_qm_params()` skip_pc_metrics : bool, default: False @@ -58,6 +59,7 @@ class ComputeQualityMetrics(AnalyzerExtension): def _set_params( self, metric_names=None, + metric_params=None, qm_params=None, peak_sign=None, seed=None, @@ -65,6 +67,12 @@ def _set_params( delete_existing_metrics=False, metrics_to_compute=None, ): + if qm_params is not None and metric_params is None: + deprecation_msg = ( + "`qm_params` is deprecated and will be removed in version 0.104.0 Please use metric_params instead" + ) + metric_params = qm_params + warn(deprecation_msg, category=DeprecationWarning, stacklevel=2) if metric_names is None: metric_names = list(_misc_metric_name_to_func.keys()) @@ -80,12 +88,12 @@ def _set_params( if "drift" in metric_names: metric_names.remove("drift") - qm_params_ = get_default_qm_params() - for k in qm_params_: - if qm_params is not None and k in qm_params: - qm_params_[k].update(qm_params[k]) - if "peak_sign" in qm_params_[k] and peak_sign is not None: - qm_params_[k]["peak_sign"] = peak_sign + metric_params_ = get_default_qm_params() + for k in metric_params_: + if metric_params is not None and k in metric_params: + metric_params_[k].update(metric_params[k]) + if "peak_sign" in metric_params_[k] and peak_sign is not None: + metric_params_[k]["peak_sign"] = peak_sign metrics_to_compute = metric_names qm_extension = self.sorting_analyzer.get_extension("quality_metrics") @@ -101,7 +109,7 @@ def _set_params( metric_names=metric_names, peak_sign=peak_sign, seed=seed, - qm_params=qm_params_, + metric_params=metric_params_, skip_pc_metrics=skip_pc_metrics, delete_existing_metrics=delete_existing_metrics, metrics_to_compute=metrics_to_compute, @@ -141,7 +149,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri """ import pandas as pd - qm_params = self.params["qm_params"] + metric_params = self.params["metric_params"] # sparsity = self.params["sparsity"] seed = self.params["seed"] @@ -177,7 +185,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri func = _misc_metric_name_to_func[metric_name] - params = qm_params[metric_name] if metric_name in qm_params else {} + params = metric_params[metric_name] if metric_name in metric_params else {} res = func(sorting_analyzer, unit_ids=non_empty_unit_ids, **params) # QM with uninstall dependencies might return None if res is not None: @@ -205,7 +213,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri # sparsity=sparsity, progress_bar=progress_bar, n_jobs=n_jobs, - qm_params=qm_params, + metric_params=metric_params, seed=seed, ) for col, values in pc_metrics.items(): diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index 4c0890b62b..20869aa44a 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -69,7 +69,7 @@ def test_compute_new_quality_metrics(small_sorting_analyzer): assert calculated_metrics == ["snr"] small_sorting_analyzer.compute( - {"quality_metrics": {"metric_names": list(qm_params.keys()), "qm_params": qm_params}} + {"quality_metrics": {"metric_names": list(qm_params.keys()), "metric_params": qm_params}} ) small_sorting_analyzer.compute({"quality_metrics": {"metric_names": ["snr"]}}) @@ -96,13 +96,13 @@ def test_compute_new_quality_metrics(small_sorting_analyzer): # check that, when parameters are changed, the data and metadata are updated old_snr_data = deepcopy(quality_metric_extension.get_data()["snr"].values) small_sorting_analyzer.compute( - {"quality_metrics": {"metric_names": ["snr"], "qm_params": {"snr": {"peak_mode": "peak_to_peak"}}}} + {"quality_metrics": {"metric_names": ["snr"], "metric_params": {"snr": {"peak_mode": "peak_to_peak"}}}} ) new_quality_metric_extension = small_sorting_analyzer.get_extension("quality_metrics") new_snr_data = new_quality_metric_extension.get_data()["snr"].values assert np.all(old_snr_data != new_snr_data) - assert new_quality_metric_extension.params["qm_params"]["snr"]["peak_mode"] == "peak_to_peak" + assert new_quality_metric_extension.params["metric_params"]["snr"]["peak_mode"] == "peak_to_peak" # check that all quality metrics are deleted when parents are recomputed, even after # recomputation @@ -280,10 +280,10 @@ def test_unit_id_order_independence(small_sorting_analyzer): } quality_metrics_1 = compute_quality_metrics( - small_sorting_analyzer, metric_names=get_quality_metric_list(), qm_params=qm_params + small_sorting_analyzer, metric_names=get_quality_metric_list(), metric_params=qm_params ) quality_metrics_2 = compute_quality_metrics( - small_sorting_analyzer_2, metric_names=get_quality_metric_list(), qm_params=qm_params + small_sorting_analyzer_2, metric_names=get_quality_metric_list(), metric_params=qm_params ) for metric, metric_2_data in quality_metrics_2.items(): From 908318a04731f6e8723a9cf3b34914d8e782e900 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 14 Nov 2024 09:38:28 +0000 Subject: [PATCH 10/64] fix tests --- .../tests/test_quality_metric_calculator.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index a6415c58e8..60f0490f51 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -24,14 +24,14 @@ def test_compute_quality_metrics(sorting_analyzer_simple): metrics = compute_quality_metrics( sorting_analyzer, metric_names=["snr"], - qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), + metric_params=dict(isi_violation=dict(isi_threshold_ms=2)), skip_pc_metrics=True, seed=2205, ) # print(metrics) qm = sorting_analyzer.get_extension("quality_metrics") - assert qm.params["qm_params"]["isi_violation"]["isi_threshold_ms"] == 2 + assert qm.params["metric_params"]["isi_violation"]["isi_threshold_ms"] == 2 assert "snr" in metrics.columns assert "isolation_distance" not in metrics.columns @@ -40,7 +40,7 @@ def test_compute_quality_metrics(sorting_analyzer_simple): metrics = compute_quality_metrics( sorting_analyzer, metric_names=None, - qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), + metric_params=dict(isi_violation=dict(isi_threshold_ms=2)), skip_pc_metrics=False, seed=2205, ) @@ -54,7 +54,7 @@ def test_compute_quality_metrics_recordingless(sorting_analyzer_simple): metrics = compute_quality_metrics( sorting_analyzer, metric_names=None, - qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), + metric_params=dict(isi_violation=dict(isi_threshold_ms=2)), skip_pc_metrics=False, seed=2205, ) @@ -68,7 +68,7 @@ def test_compute_quality_metrics_recordingless(sorting_analyzer_simple): metrics_norec = compute_quality_metrics( sorting_analyzer_norec, metric_names=None, - qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), + metric_params=dict(isi_violation=dict(isi_threshold_ms=2)), skip_pc_metrics=False, seed=2205, ) @@ -101,7 +101,7 @@ def test_empty_units(sorting_analyzer_simple): metrics_empty = compute_quality_metrics( sorting_analyzer_empty, metric_names=None, - qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), + metric_params=dict(isi_violation=dict(isi_threshold_ms=2)), skip_pc_metrics=True, seed=2205, ) From 2706a0bee9fef9dd3b3a4af99715366aae1c1625 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 14 Nov 2024 09:45:28 +0000 Subject: [PATCH 11/64] Change metrics_kwargs to metric_params and add depreciation message --- .../postprocessing/template_metrics.py | 36 +++++++++++-------- .../tests/test_template_metrics.py | 2 +- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 6e7bcf21b8..ef6abfe51f 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -63,8 +63,8 @@ class ComputeTemplateMetrics(AnalyzerExtension): include_multi_channel_metrics : bool, default: False Whether to compute multi-channel metrics delete_existing_metrics : bool, default: False - If True, any template metrics attached to the `sorting_analyzer` are deleted. If False, any metrics which were previously calculated but are not included in `metric_names` are kept, provided the `metrics_kwargs` are unchanged. - metrics_kwargs : dict + If True, any template metrics attached to the `sorting_analyzer` are deleted. If False, any metrics which were previously calculated but are not included in `metric_names` are kept, provided the `metric_params` are unchanged. + metric_params : dict Additional arguments to pass to the metric functions. Including: * recovery_window_ms: the window in ms after the peak to compute the recovery_slope, default: 0.7 * peak_relative_threshold: the relative threshold to detect positive and negative peaks, default: 0.2 @@ -109,12 +109,20 @@ def _set_params( peak_sign="neg", upsampling_factor=10, sparsity=None, + metric_params=None, metrics_kwargs=None, include_multi_channel_metrics=False, delete_existing_metrics=False, **other_kwargs, ): + if metrics_kwargs is not None and metric_params is None: + deprecation_msg = ( + "`qm_params` is deprecated and will be removed in version 0.104.0 Please use metric_params instead" + ) + metric_params = metrics_kwargs + warnings.warn(deprecation_msg, category=DeprecationWarning, stacklevel=2) + import pandas as pd # TODO alessio can you check this : this used to be in the function but now we have ComputeTemplateMetrics.function_factory() @@ -134,27 +142,27 @@ def _set_params( if include_multi_channel_metrics: metric_names += get_multi_channel_template_metric_names() - if metrics_kwargs is None: - metrics_kwargs_ = _default_function_kwargs.copy() + if metric_params is None: + metric_params_ = _default_function_kwargs.copy() if len(other_kwargs) > 0: for m in other_kwargs: - if m in metrics_kwargs_: - metrics_kwargs_[m] = other_kwargs[m] + if m in metric_params_: + metric_params_[m] = other_kwargs[m] else: - metrics_kwargs_ = _default_function_kwargs.copy() - metrics_kwargs_.update(metrics_kwargs) + metric_params_ = _default_function_kwargs.copy() + metric_params_.update(metric_params) metrics_to_compute = metric_names tm_extension = self.sorting_analyzer.get_extension("template_metrics") if delete_existing_metrics is False and tm_extension is not None: - existing_params = tm_extension.params["metrics_kwargs"] + existing_params = tm_extension.params["metric_params"] # checks that existing metrics were calculated using the same params - if existing_params != metrics_kwargs_: + if existing_params != metric_params_: warnings.warn( f"The parameters used to calculate the previous template metrics are different" f"than those used now.\nPrevious parameters: {existing_params}\nCurrent " - f"parameters: {metrics_kwargs_}\nDeleting previous template metrics..." + f"parameters: {metric_params_}\nDeleting previous template metrics..." ) tm_extension.params["metric_names"] = [] existing_metric_names = [] @@ -171,7 +179,7 @@ def _set_params( sparsity=sparsity, peak_sign=peak_sign, upsampling_factor=int(upsampling_factor), - metrics_kwargs=metrics_kwargs_, + metric_params=metric_params_, delete_existing_metrics=delete_existing_metrics, metrics_to_compute=metrics_to_compute, ) @@ -273,7 +281,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri sampling_frequency=sampling_frequency_up, trough_idx=trough_idx, peak_idx=peak_idx, - **self.params["metrics_kwargs"], + **self.params["metric_params"], ) except Exception as e: warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}") @@ -312,7 +320,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri template_upsampled, channel_locations=channel_locations_sparse, sampling_frequency=sampling_frequency_up, - **self.params["metrics_kwargs"], + **self.params["metric_params"], ) except Exception as e: warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}") diff --git a/src/spikeinterface/postprocessing/tests/test_template_metrics.py b/src/spikeinterface/postprocessing/tests/test_template_metrics.py index 5056d4ff2a..1df723bfe3 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_metrics.py +++ b/src/spikeinterface/postprocessing/tests/test_template_metrics.py @@ -47,7 +47,7 @@ def test_compute_new_template_metrics(small_sorting_analyzer): # check that, when parameters are changed, the old metrics are deleted small_sorting_analyzer.compute( - {"template_metrics": {"metric_names": ["exp_decay"], "metrics_kwargs": {"recovery_window_ms": 0.6}}} + {"template_metrics": {"metric_names": ["exp_decay"], "metric_params": {"recovery_window_ms": 0.6}}} ) From 22460710bca1114e5f11f5ba4cbdad1b82941d70 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 14 Nov 2024 09:46:16 +0000 Subject: [PATCH 12/64] Update warning message (oups) --- src/spikeinterface/postprocessing/template_metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index ef6abfe51f..9b85f99c0d 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -118,7 +118,7 @@ def _set_params( if metrics_kwargs is not None and metric_params is None: deprecation_msg = ( - "`qm_params` is deprecated and will be removed in version 0.104.0 Please use metric_params instead" + "`metrics_kwargs` is deprecated and will be removed in version 0.104.0 Please use metric_params instead" ) metric_params = metrics_kwargs warnings.warn(deprecation_msg, category=DeprecationWarning, stacklevel=2) From 3579934baf43e25759ca1e8ee7f1e3288180be71 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 14 Nov 2024 10:04:02 +0000 Subject: [PATCH 13/64] Make compute work and add `get_default_tm_params` --- .../postprocessing/template_metrics.py | 45 +++++++++---------- 1 file changed, 20 insertions(+), 25 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 9b85f99c0d..25e0d0d490 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -64,22 +64,10 @@ class ComputeTemplateMetrics(AnalyzerExtension): Whether to compute multi-channel metrics delete_existing_metrics : bool, default: False If True, any template metrics attached to the `sorting_analyzer` are deleted. If False, any metrics which were previously calculated but are not included in `metric_names` are kept, provided the `metric_params` are unchanged. - metric_params : dict - Additional arguments to pass to the metric functions. Including: - * recovery_window_ms: the window in ms after the peak to compute the recovery_slope, default: 0.7 - * peak_relative_threshold: the relative threshold to detect positive and negative peaks, default: 0.2 - * peak_width_ms: the width in samples to detect peaks, default: 0.2 - * depth_direction: the direction to compute velocity above and below, default: "y" (see notes) - * min_channels_for_velocity: the minimum number of channels above or below to compute velocity, default: 5 - * min_r2_velocity: the minimum r2 to accept the velocity fit, default: 0.7 - * exp_peak_function: the function to use to compute the peak amplitude for the exp decay, default: "ptp" - * min_r2_exp_decay: the minimum r2 to accept the exp decay fit, default: 0.5 - * spread_threshold: the threshold to compute the spread, default: 0.2 - * spread_smooth_um: the smoothing in um to compute the spread, default: 20 - * column_range: the range in um in the horizontal direction to consider channels for velocity, default: None - - If None, all channels all channels are considered - - If 0 or 1, only the "column" that includes the max channel is considered - - If > 1, only channels within range (+/-) um from the max channel horizontal position are used + metric_params : dict of dicts + metric_params : dict of dicts or None + Dictionary with parameters for quality metrics calculation. + Default parameters can be obtained with: `si.qualitymetrics.get_default_tm_params()` Returns ------- @@ -116,13 +104,6 @@ def _set_params( **other_kwargs, ): - if metrics_kwargs is not None and metric_params is None: - deprecation_msg = ( - "`metrics_kwargs` is deprecated and will be removed in version 0.104.0 Please use metric_params instead" - ) - metric_params = metrics_kwargs - warnings.warn(deprecation_msg, category=DeprecationWarning, stacklevel=2) - import pandas as pd # TODO alessio can you check this : this used to be in the function but now we have ComputeTemplateMetrics.function_factory() @@ -142,6 +123,13 @@ def _set_params( if include_multi_channel_metrics: metric_names += get_multi_channel_template_metric_names() + if metrics_kwargs is not None and metric_params is None: + deprecation_msg = ( + "`metrics_kwargs` is deprecated and will be removed in version 0.104.0 Please use metric_params instead" + ) + metric_params = dict(zip(metric_names, [metrics_kwargs] * len(metric_names))) + warnings.warn(deprecation_msg, category=DeprecationWarning, stacklevel=2) + if metric_params is None: metric_params_ = _default_function_kwargs.copy() if len(other_kwargs) > 0: @@ -281,7 +269,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri sampling_frequency=sampling_frequency_up, trough_idx=trough_idx, peak_idx=peak_idx, - **self.params["metric_params"], + **self.params["metric_params"][metric_name], ) except Exception as e: warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}") @@ -320,7 +308,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri template_upsampled, channel_locations=channel_locations_sparse, sampling_frequency=sampling_frequency_up, - **self.params["metric_params"], + **self.params["metric_params"][metric_name], ) except Exception as e: warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}") @@ -380,6 +368,13 @@ def _get_data(self): ) +def get_default_tm_params(): + metric_names = get_single_channel_template_metric_names() + get_multi_channel_template_metric_names() + base_tm_params = _default_function_kwargs + metric_params = dict(zip(metric_names, [base_tm_params] * len(metric_names))) + return metric_params + + def get_trough_and_peak_idx(template): """ Return the indices into the input template of the detected trough From 66190c3857bcabf30ca64af994693a3a029c41e1 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 14 Nov 2024 10:18:46 +0000 Subject: [PATCH 14/64] Update compute_name_to_column_names to qm_compute_name_to_column_names --- .../qualitymetrics/quality_metric_calculator.py | 6 +++--- src/spikeinterface/qualitymetrics/quality_metric_list.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index eb380304b6..365d7bcc09 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -16,7 +16,7 @@ compute_pc_metrics, _misc_metric_name_to_func, _possible_pc_metric_names, - compute_name_to_column_names, + qm_compute_name_to_column_names, ) from .misc_metrics import _default_params as misc_metrics_params from .pca_metrics import _default_params as pca_metrics_params @@ -32,7 +32,7 @@ class ComputeQualityMetrics(AnalyzerExtension): A SortingAnalyzer object. metric_names : list or None List of quality metrics to compute. - metric_params : dict or None + metric_params : dict of dicts or None Dictionary with parameters for quality metrics calculation. Default parameters can be obtained with: `si.qualitymetrics.get_default_qm_params()` skip_pc_metrics : bool, default: False @@ -254,7 +254,7 @@ def _run(self, verbose=False, **job_kwargs): # append the metrics which were previously computed for metric_name in set(existing_metrics).difference(metrics_to_compute): # some metrics names produce data columns with other names. This deals with that. - for column_name in compute_name_to_column_names[metric_name]: + for column_name in qm_compute_name_to_column_names[metric_name]: computed_metrics[column_name] = qm_extension.data["metrics"][column_name] self.data["metrics"] = computed_metrics diff --git a/src/spikeinterface/qualitymetrics/quality_metric_list.py b/src/spikeinterface/qualitymetrics/quality_metric_list.py index 375dd320ae..fc7e92b50d 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_list.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_list.py @@ -55,7 +55,7 @@ } # a dict converting the name of the metric for computation to the output of that computation -compute_name_to_column_names = { +qm_compute_name_to_column_names = { "num_spikes": ["num_spikes"], "firing_rate": ["firing_rate"], "presence_ratio": ["presence_ratio"], From 2de25b47013c8954ece03af1f47250e5db1f7ffb Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 14 Nov 2024 10:19:27 +0000 Subject: [PATCH 15/64] Unify template param checks with quality param checks --- .../postprocessing/template_metrics.py | 63 +++++++++++-------- 1 file changed, 38 insertions(+), 25 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 25e0d0d490..cfdbd122b3 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -130,33 +130,18 @@ def _set_params( metric_params = dict(zip(metric_names, [metrics_kwargs] * len(metric_names))) warnings.warn(deprecation_msg, category=DeprecationWarning, stacklevel=2) - if metric_params is None: - metric_params_ = _default_function_kwargs.copy() - if len(other_kwargs) > 0: - for m in other_kwargs: - if m in metric_params_: - metric_params_[m] = other_kwargs[m] - else: - metric_params_ = _default_function_kwargs.copy() - metric_params_.update(metric_params) + metric_params_ = get_default_tm_params() + for k in metric_params_: + if metric_params is not None and k in metric_params: + metric_params_[k].update(metric_params[k]) + if "peak_sign" in metric_params_[k] and peak_sign is not None: + metric_params_[k]["peak_sign"] = peak_sign metrics_to_compute = metric_names tm_extension = self.sorting_analyzer.get_extension("template_metrics") if delete_existing_metrics is False and tm_extension is not None: - existing_params = tm_extension.params["metric_params"] - # checks that existing metrics were calculated using the same params - if existing_params != metric_params_: - warnings.warn( - f"The parameters used to calculate the previous template metrics are different" - f"than those used now.\nPrevious parameters: {existing_params}\nCurrent " - f"parameters: {metric_params_}\nDeleting previous template metrics..." - ) - tm_extension.params["metric_names"] = [] - existing_metric_names = [] - else: - existing_metric_names = tm_extension.params["metric_names"] - + existing_metric_names = tm_extension.params["metric_names"] existing_metric_names_propogated = [ metric_name for metric_name in existing_metric_names if metric_name not in metrics_to_compute ] @@ -322,8 +307,8 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri def _run(self, verbose=False): - delete_existing_metrics = self.params["delete_existing_metrics"] metrics_to_compute = self.params["metrics_to_compute"] + delete_existing_metrics = self.params["delete_existing_metrics"] # compute the metrics which have been specified by the user computed_metrics = self._compute_metrics( @@ -339,9 +324,21 @@ def _run(self, verbose=False): ): existing_metrics = tm_extension.params["metric_names"] + existing_metrics = [] + # here we get in the loaded via the dict only (to avoid full loading from disk after params reset) + tm_extension = self.sorting_analyzer.extensions.get("template_metrics", None) + if ( + delete_existing_metrics is False + and tm_extension is not None + and tm_extension.data.get("metrics") is not None + ): + existing_metrics = tm_extension.params["metric_names"] + # append the metrics which were previously computed for metric_name in set(existing_metrics).difference(metrics_to_compute): - computed_metrics[metric_name] = tm_extension.data["metrics"][metric_name] + # some metrics names produce data columns with other names. This deals with that. + for column_name in tm_compute_name_to_column_names[metric_name]: + computed_metrics[column_name] = tm_extension.data["metrics"][column_name] self.data["metrics"] = computed_metrics @@ -369,12 +366,28 @@ def _get_data(self): def get_default_tm_params(): - metric_names = get_single_channel_template_metric_names() + get_multi_channel_template_metric_names() + metric_names = get_template_metric_names() base_tm_params = _default_function_kwargs metric_params = dict(zip(metric_names, [base_tm_params] * len(metric_names))) return metric_params +# a dict converting the name of the metric for computation to the output of that computation +tm_compute_name_to_column_names = { + "peak_to_valley": ["peak_to_valley"], + "peak_trough_ratio": ["peak_trough_ratio"], + "half_width": ["half_width"], + "repolarization_slope": ["repolarization_slope"], + "recovery_slope": ["recovery_slope"], + "num_positive_peaks": ["num_positive_peaks"], + "num_negative_peaks": ["num_negative_peaks"], + "velocity_above": ["velocity_above"], + "velocity_below": ["velocity_below"], + "exp_decay": ["exp_decay"], + "spread": ["spread"], +} + + def get_trough_and_peak_idx(template): """ Return the indices into the input template of the detected trough From 8f6602423d53dc625689da2183a24fe43cbb8629 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 14 Nov 2024 11:07:27 +0000 Subject: [PATCH 16/64] add some tests --- .../tests/test_template_metrics.py | 47 ++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/tests/test_template_metrics.py b/src/spikeinterface/postprocessing/tests/test_template_metrics.py index 1df723bfe3..1bf49f64c1 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_metrics.py +++ b/src/spikeinterface/postprocessing/tests/test_template_metrics.py @@ -1,5 +1,5 @@ from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite -from spikeinterface.postprocessing import ComputeTemplateMetrics +from spikeinterface.postprocessing import ComputeTemplateMetrics, compute_template_metrics import pytest import csv @@ -8,6 +8,49 @@ template_metrics = list(_single_channel_metric_name_to_func.keys()) +def test_different_params_template_metrics(small_sorting_analyzer): + """ + Computes template metrics using different params, and check that they are + actually calculated using the different params. + """ + compute_template_metrics( + sorting_analyzer=small_sorting_analyzer, + metric_names=["exp_decay", "spread", "half_width"], + metric_params={"exp_decay": {"recovery_window_ms": 0.8}, "spread": {"spread_smooth_um": 15}}, + ) + + tm_extension = small_sorting_analyzer.get_extension("template_metrics") + tm_params = tm_extension.params["metric_params"] + + assert tm_params["exp_decay"]["recovery_window_ms"] == 0.8 + assert tm_params["spread"]["recovery_window_ms"] == 0.7 + assert tm_params["half_width"]["recovery_window_ms"] == 0.7 + + assert tm_params["spread"]["spread_smooth_um"] == 15 + assert tm_params["exp_decay"]["spread_smooth_um"] == 20 + assert tm_params["half_width"]["spread_smooth_um"] == 20 + + +def test_backwards_compat_params_template_metrics(small_sorting_analyzer): + """ + Computes template metrics using the metrics_kwargs keyword + """ + compute_template_metrics( + sorting_analyzer=small_sorting_analyzer, + metric_names=["exp_decay", "spread"], + metrics_kwargs={"recovery_window_ms": 0.8}, + ) + + tm_extension = small_sorting_analyzer.get_extension("template_metrics") + tm_params = tm_extension.params["metric_params"] + + assert tm_params["exp_decay"]["recovery_window_ms"] == 0.8 + assert tm_params["spread"]["recovery_window_ms"] == 0.8 + + assert tm_params["spread"]["spread_smooth_um"] == 20 + assert tm_params["exp_decay"]["spread_smooth_um"] == 20 + + def test_compute_new_template_metrics(small_sorting_analyzer): """ Computes template metrics then computes a subset of template metrics, and checks @@ -17,6 +60,8 @@ def test_compute_new_template_metrics(small_sorting_analyzer): are deleted. """ + small_sorting_analyzer.delete_extension("template_metrics") + # calculate just exp_decay small_sorting_analyzer.compute({"template_metrics": {"metric_names": ["exp_decay"]}}) template_metric_extension = small_sorting_analyzer.get_extension("template_metrics") From fdc01f5adb81f55c5787166fa469100d7bc06239 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 14 Nov 2024 15:00:15 +0000 Subject: [PATCH 17/64] little fixes --- .../postprocessing/template_metrics.py | 31 +++++++++++-------- .../qualitymetrics/pca_metrics.py | 4 +-- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index cfdbd122b3..cbcf38d19d 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -66,8 +66,8 @@ class ComputeTemplateMetrics(AnalyzerExtension): If True, any template metrics attached to the `sorting_analyzer` are deleted. If False, any metrics which were previously calculated but are not included in `metric_names` are kept, provided the `metric_params` are unchanged. metric_params : dict of dicts metric_params : dict of dicts or None - Dictionary with parameters for quality metrics calculation. - Default parameters can be obtained with: `si.qualitymetrics.get_default_tm_params()` + Dictionary with parameters for template metrics calculation. + Default parameters can be obtained with: `si.postprocessing.template_metrics.get_default_tm_params()` Returns ------- @@ -124,18 +124,17 @@ def _set_params( metric_names += get_multi_channel_template_metric_names() if metrics_kwargs is not None and metric_params is None: - deprecation_msg = ( - "`metrics_kwargs` is deprecated and will be removed in version 0.104.0 Please use metric_params instead" - ) - metric_params = dict(zip(metric_names, [metrics_kwargs] * len(metric_names))) - warnings.warn(deprecation_msg, category=DeprecationWarning, stacklevel=2) + deprecation_msg = "`metrics_kwargs` is deprecated and will be removed in version 0.104.0. Please use metric_params instead" + warnings.warn(deprecation_msg, category=DeprecationWarning) + + metric_params = {} + for metric_name in metric_names: + metric_params[metric_name] = deepcopy(metrics_kwargs) - metric_params_ = get_default_tm_params() + metric_params_ = get_default_tm_params(metric_names) for k in metric_params_: if metric_params is not None and k in metric_params: metric_params_[k].update(metric_params[k]) - if "peak_sign" in metric_params_[k] and peak_sign is not None: - metric_params_[k]["peak_sign"] = peak_sign metrics_to_compute = metric_names tm_extension = self.sorting_analyzer.get_extension("template_metrics") @@ -365,10 +364,16 @@ def _get_data(self): ) -def get_default_tm_params(): - metric_names = get_template_metric_names() +def get_default_tm_params(metric_names): + if metric_names is None: + metric_names = get_template_metric_names() + base_tm_params = _default_function_kwargs - metric_params = dict(zip(metric_names, [base_tm_params] * len(metric_names))) + + metric_params = {} + for metric_name in metric_names: + metric_params[metric_name] = deepcopy(base_tm_params) + return metric_params diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index b4952bfe6d..ca21f1e45f 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -91,10 +91,10 @@ def compute_pc_metrics( if qm_params is not None and metric_params is None: deprecation_msg = ( - "`qm_params` is deprecated and will be removed in version 0.104.0 Please use metric_params instead" + "`qm_params` is deprecated and will be removed in version 0.104.0. Please use metric_params instead" ) - metric_params = qm_params warn(deprecation_msg, category=DeprecationWarning, stacklevel=2) + metric_params = qm_params pca_ext = sorting_analyzer.get_extension("principal_components") assert pca_ext is not None, "calculate_pc_metrics() need extension 'principal_components'" From 9db0b83f7c7ad2c773c8673fa3dd09e5c3cdecb6 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 14 Nov 2024 15:11:01 +0000 Subject: [PATCH 18/64] backwards compatible loading --- .../postprocessing/template_metrics.py | 12 ++++++++++++ .../qualitymetrics/quality_metric_calculator.py | 7 +++++++ 2 files changed, 19 insertions(+) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index cbcf38d19d..477ad04440 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -344,6 +344,18 @@ def _run(self, verbose=False): def _get_data(self): return self.data["metrics"] + def load_params(self): + AnalyzerExtension.load_params(self) + # For backwards compatibility - this reformats metrics_kwargs as metric_params + if (metrics_kwargs := self.params.get("metrics_kwargs")) is not None: + + metric_params = {} + for metric_name in self.params["metric_names"]: + metric_params[metric_name] = deepcopy(metrics_kwargs) + self.params["metric_params"] = metric_params + + del self.params["metrics_kwargs"] + register_result_extension(ComputeTemplateMetrics) compute_template_metrics = ComputeTemplateMetrics.function_factory() diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 365d7bcc09..e7e7c244ea 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -262,6 +262,13 @@ def _run(self, verbose=False, **job_kwargs): def _get_data(self): return self.data["metrics"] + def load_params(self): + AnalyzerExtension.load_params(self) + # For backwards compatibility - this renames qm_params as metric_params + if (qm_params := self.params.get("qm_params")) is not None: + self.params["metric_params"] = qm_params + del self.params["qm_params"] + register_result_extension(ComputeQualityMetrics) compute_quality_metrics = ComputeQualityMetrics.function_factory() From bdeb30041880e881b10a76cc03642757a020ae87 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Thu, 14 Nov 2024 16:38:30 +0100 Subject: [PATCH 19/64] WIP --- src/spikeinterface/sortingcomponents/clustering/circus.py | 4 +++- .../sortingcomponents/clustering/position_and_features.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 5982c270cb..993bd7fee0 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -245,8 +245,10 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): probe=recording.get_probe(), is_scaled=False, ) + if params["noise_levels"] is None: - params["noise_levels"] = get_noise_levels(recording, return_scaled=False) + params["noise_levels"] = get_noise_levels(recording, return_scaled=False, **job_kwargs) + sparsity = compute_sparsity(templates, noise_levels=params["noise_levels"], **params["sparsity"]) templates = templates.to_sparse(sparsity) empty_templates = templates.sparsity_mask.sum(axis=1) == 0 diff --git a/src/spikeinterface/sortingcomponents/clustering/position_and_features.py b/src/spikeinterface/sortingcomponents/clustering/position_and_features.py index 513e8085ed..20067a2eec 100644 --- a/src/spikeinterface/sortingcomponents/clustering/position_and_features.py +++ b/src/spikeinterface/sortingcomponents/clustering/position_and_features.py @@ -144,7 +144,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): **job_kwargs, ) - noise_levels = get_noise_levels(recording, return_scaled=False) + noise_levels = get_noise_levels(recording, return_scaled=False, **job_kwargs) labels, peak_labels = remove_duplicates( wfs_arrays, noise_levels, peak_labels, num_samples, num_chans, **params["cleaning_kwargs"] ) From d08df426ac5abfd756aa78d78712570517fdf341 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 14 Nov 2024 15:40:29 +0000 Subject: [PATCH 20/64] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sortingcomponents/clustering/circus.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 993bd7fee0..32fe69ee38 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -245,10 +245,10 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): probe=recording.get_probe(), is_scaled=False, ) - + if params["noise_levels"] is None: params["noise_levels"] = get_noise_levels(recording, return_scaled=False, **job_kwargs) - + sparsity = compute_sparsity(templates, noise_levels=params["noise_levels"], **params["sparsity"]) templates = templates.to_sparse(sparsity) empty_templates = templates.sparsity_mask.sum(axis=1) == 0 From 9b2875fd06782bd6da722fed1f524dfda4203f0a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 21 Nov 2024 11:53:01 +0100 Subject: [PATCH 21/64] Add stream_mode as extra_requirements for NWB wghen streaming --- src/spikeinterface/extractors/nwbextractors.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index d797e64910..171992f6b1 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -599,6 +599,8 @@ def __init__( else: gains, offsets, locations, groups = self._fetch_main_properties_backend() self.extra_requirements.append("h5py") + if stream_mode is not None: + self.extra_requirements.append(stream_mode) self.set_channel_gains(gains) self.set_channel_offsets(offsets) if locations is not None: @@ -1100,6 +1102,8 @@ def __init__( for property_name, property_values in properties.items(): values = [x.decode("utf-8") if isinstance(x, bytes) else x for x in property_values] self.set_property(property_name, values) + if stream_mode is not None: + self.extra_requirements.append(stream_mode) if stream_mode is None and file_path is not None: file_path = str(Path(file_path).resolve()) From c9d02d884372e89109842a28f47e8dbf8b5a8f11 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 22 Nov 2024 09:44:33 +0100 Subject: [PATCH 22/64] Tests --- src/spikeinterface/sorters/internal/spyking_circus2.py | 3 ++- src/spikeinterface/sortingcomponents/clustering/main.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 5cce8b54f5..f74570806c 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -115,10 +115,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): from spikeinterface.sortingcomponents.matching import find_spikes_from_templates from spikeinterface.sortingcomponents.tools import remove_empty_templates from spikeinterface.sortingcomponents.tools import get_prototype_spike, check_probe_for_drift_correction - from spikeinterface.sortingcomponents.tools import get_prototype_spike job_kwargs = fix_job_kwargs(params["job_kwargs"]) job_kwargs.update({"progress_bar": verbose}) + print(job_kwargs) recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) @@ -231,6 +231,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): else: clustering_method = "random_projections" + print('test') labels, peak_labels = find_cluster_from_peaks( recording_w, selected_peaks, method=clustering_method, method_kwargs=clustering_params, **job_kwargs ) diff --git a/src/spikeinterface/sortingcomponents/clustering/main.py b/src/spikeinterface/sortingcomponents/clustering/main.py index ba0fe6f9ac..fadcb07527 100644 --- a/src/spikeinterface/sortingcomponents/clustering/main.py +++ b/src/spikeinterface/sortingcomponents/clustering/main.py @@ -32,6 +32,7 @@ def find_cluster_from_peaks(recording, peaks, method="stupid", method_kwargs={}, peak_labels.shape[0] == peaks.shape[0] """ job_kwargs = fix_job_kwargs(job_kwargs) + print("toto", job_kwargs) assert ( method in clustering_methods From a1e97a54249b076768af8f790506bcc48e3b7fb8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 22 Nov 2024 08:44:58 +0000 Subject: [PATCH 23/64] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sorters/internal/spyking_circus2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index f74570806c..bb6306fe15 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -231,7 +231,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): else: clustering_method = "random_projections" - print('test') + print("test") labels, peak_labels = find_cluster_from_peaks( recording_w, selected_peaks, method=clustering_method, method_kwargs=clustering_params, **job_kwargs ) From d414c2d6561966ca1979d11ce0a16c9029011fab Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 22 Nov 2024 09:56:24 +0100 Subject: [PATCH 24/64] Remove prints --- src/spikeinterface/sorters/internal/spyking_circus2.py | 2 -- src/spikeinterface/sortingcomponents/clustering/circus.py | 2 +- src/spikeinterface/sortingcomponents/clustering/main.py | 1 - 3 files changed, 1 insertion(+), 4 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index f74570806c..208d9f5bc6 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -118,7 +118,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): job_kwargs = fix_job_kwargs(params["job_kwargs"]) job_kwargs.update({"progress_bar": verbose}) - print(job_kwargs) recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) @@ -231,7 +230,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): else: clustering_method = "random_projections" - print('test') labels, peak_labels = find_cluster_from_peaks( recording_w, selected_peaks, method=clustering_method, method_kwargs=clustering_params, **job_kwargs ) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 32fe69ee38..6a341047f4 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -184,7 +184,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): sparse_mask = node1.neighbours_mask neighbours_mask = get_channel_distances(recording) <= radius_um - # np.save(features_folder / "sparse_mask.npy", sparse_mask) + # np.save(features_folder / "sparse_mask.npy", sparse_mask) np.save(features_folder / "peaks.npy", peaks) original_labels = peaks["channel_index"] diff --git a/src/spikeinterface/sortingcomponents/clustering/main.py b/src/spikeinterface/sortingcomponents/clustering/main.py index fadcb07527..ba0fe6f9ac 100644 --- a/src/spikeinterface/sortingcomponents/clustering/main.py +++ b/src/spikeinterface/sortingcomponents/clustering/main.py @@ -32,7 +32,6 @@ def find_cluster_from_peaks(recording, peaks, method="stupid", method_kwargs={}, peak_labels.shape[0] == peaks.shape[0] """ job_kwargs = fix_job_kwargs(job_kwargs) - print("toto", job_kwargs) assert ( method in clustering_methods From 4c20e0c588ade80025ed61ae67963581ab0e7ada Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 22 Nov 2024 08:58:32 +0000 Subject: [PATCH 25/64] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sortingcomponents/clustering/circus.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 6a341047f4..32fe69ee38 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -184,7 +184,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): sparse_mask = node1.neighbours_mask neighbours_mask = get_channel_distances(recording) <= radius_um - # np.save(features_folder / "sparse_mask.npy", sparse_mask) + # np.save(features_folder / "sparse_mask.npy", sparse_mask) np.save(features_folder / "peaks.npy", peaks) original_labels = peaks["channel_index"] From 9cc1673f5c1c520636ee7654a067de2b0a68ef96 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 22 Nov 2024 10:14:49 +0100 Subject: [PATCH 26/64] Cleaning imports. Need to test with mac --- .../sortingcomponents/clustering/circus.py | 1 - .../sortingcomponents/clustering/random_projections.py | 4 ++-- .../sortingcomponents/clustering/sliding_hdbscan.py | 2 +- src/spikeinterface/sortingcomponents/clustering/tdc.py | 9 ++------- 4 files changed, 5 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 6a341047f4..c1ec3f1aab 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -18,7 +18,6 @@ from spikeinterface.core.waveform_tools import estimate_templates from .clustering_tools import remove_duplicates_via_matching from spikeinterface.core.recording_tools import get_noise_levels, get_channel_distances -from spikeinterface.core.job_tools import fix_job_kwargs from spikeinterface.sortingcomponents.peak_selection import select_peaks from spikeinterface.sortingcomponents.waveforms.temporal_pca import TemporalPCAProjection from spikeinterface.core.template import Templates diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 40bb4ac987..484a7376c1 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -16,7 +16,7 @@ from spikeinterface.core.basesorting import minimum_spike_dtype from spikeinterface.core.waveform_tools import estimate_templates from .clustering_tools import remove_duplicates_via_matching -from spikeinterface.core.recording_tools import get_noise_levels, get_channel_distances +from spikeinterface.core.recording_tools import get_noise_levels from spikeinterface.sortingcomponents.waveforms.savgol_denoiser import SavGolDenoiser from spikeinterface.sortingcomponents.features_from_peaks import RandomProjectionsFeature from spikeinterface.core.template import Templates @@ -144,7 +144,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): is_scaled=False, ) if params["noise_levels"] is None: - params["noise_levels"] = get_noise_levels(recording, return_scaled=False) + params["noise_levels"] = get_noise_levels(recording, return_scaled=False, **job_kwargs) sparsity = compute_sparsity(templates, params["noise_levels"], **params["sparsity"]) templates = templates.to_sparse(sparsity) templates = remove_empty_templates(templates) diff --git a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py index 2ae810ae20..5f8ac99848 100644 --- a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py +++ b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py @@ -23,7 +23,7 @@ get_random_data_chunks, extract_waveforms_to_buffers, ) -from .clustering_tools import auto_clean_clustering, auto_split_clustering +from .clustering_tools import auto_clean_clustering class SlidingHdbscanClustering: diff --git a/src/spikeinterface/sortingcomponents/clustering/tdc.py b/src/spikeinterface/sortingcomponents/clustering/tdc.py index c6b94eaa48..6c361b0562 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tdc.py +++ b/src/spikeinterface/sortingcomponents/clustering/tdc.py @@ -9,27 +9,22 @@ from spikeinterface.core import ( get_channel_distances, - Templates, - compute_sparsity, get_global_tmp_folder, ) from spikeinterface.core.node_pipeline import ( run_node_pipeline, - ExtractDenseWaveforms, ExtractSparseWaveforms, PeakRetriever, ) -from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel, cache_preprocessing -from spikeinterface.sortingcomponents.peak_detection import detect_peaks, DetectPeakLocallyExclusive +from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel from spikeinterface.sortingcomponents.peak_selection import select_peaks -from spikeinterface.sortingcomponents.peak_localization import LocalizeCenterOfMass, LocalizeGridConvolution from spikeinterface.sortingcomponents.waveforms.temporal_pca import TemporalPCAProjection from spikeinterface.sortingcomponents.clustering.split import split_clusters from spikeinterface.sortingcomponents.clustering.merge import merge_clusters -from spikeinterface.sortingcomponents.clustering.tools import compute_template_from_sparse + class TdcClustering: From 61351e7a60b7b2215cd3eea3374363e354c7fb1e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 22 Nov 2024 09:18:47 +0000 Subject: [PATCH 27/64] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sortingcomponents/clustering/tdc.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/tdc.py b/src/spikeinterface/sortingcomponents/clustering/tdc.py index 6c361b0562..59472d1374 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tdc.py +++ b/src/spikeinterface/sortingcomponents/clustering/tdc.py @@ -26,7 +26,6 @@ from spikeinterface.sortingcomponents.clustering.merge import merge_clusters - class TdcClustering: """ Here the implementation of clustering used by tridesclous2 From de5ee687781eddfd8ebdde3b83e9f91d3ac81163 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 22 Nov 2024 15:07:40 +0100 Subject: [PATCH 28/64] Less cores for mac ? --- src/spikeinterface/sorters/internal/spyking_circus2.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 208d9f5bc6..6e84cc996f 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -52,7 +52,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "matched_filtering": True, "cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True}, "multi_units_only": False, - "job_kwargs": {"n_jobs": 0.8}, + "job_kwargs": {"n_jobs": 0.5}, "debug": False, } @@ -282,11 +282,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): matching_method = params["matching"].pop("method") matching_params = params["matching"].copy() matching_params["templates"] = templates - matching_job_params = job_kwargs.copy() if matching_method is not None: spikes = find_spikes_from_templates( - recording_w, matching_method, method_kwargs=matching_params, **matching_job_params + recording_w, matching_method, method_kwargs=matching_params, **job_kwargs ) if params["debug"]: From 2cff62babd7816c14fef9246152b53a2bb59d991 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 27 Nov 2024 14:36:53 +0100 Subject: [PATCH 29/64] plot drift with the scatter plot --- .../benchmark/benchmark_motion_estimation.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/benchmark/benchmark_motion_estimation.py b/src/spikeinterface/benchmark/benchmark_motion_estimation.py index abb2a51bae..3a7d11fc35 100644 --- a/src/spikeinterface/benchmark/benchmark_motion_estimation.py +++ b/src/spikeinterface/benchmark/benchmark_motion_estimation.py @@ -109,6 +109,9 @@ def run(self, **job_kwargs): estimate_motion=t4 - t3, ) + + self.result["peaks"] = peaks + self.result["peak_locations"] = peak_locations self.result["step_run_times"] = step_run_times self.result["raw_motion"] = motion @@ -131,6 +134,8 @@ def compute_result(self, **result_params): self.result["motion"] = motion _run_key_saved = [ + ("peaks", "npy"), + ("peak_locations", "npy"), ("raw_motion", "Motion"), ("step_run_times", "pickle"), ] @@ -161,7 +166,7 @@ def create_benchmark(self, key): def plot_true_drift(self, case_keys=None, scaling_probe=1.5, figsize=(8, 6)): self.plot_drift(case_keys=case_keys, tested_drift=False, scaling_probe=scaling_probe, figsize=figsize) - def plot_drift(self, case_keys=None, gt_drift=True, tested_drift=True, scaling_probe=1.0, figsize=(8, 6)): + def plot_drift(self, case_keys=None, gt_drift=True, tested_drift=True, raster=False, scaling_probe=1.0, figsize=(8, 6)): import matplotlib.pyplot as plt if case_keys is None: @@ -195,6 +200,13 @@ def plot_drift(self, case_keys=None, gt_drift=True, tested_drift=True, scaling_p # for i in range(self.gt_unit_positions.shape[1]): # ax.plot(temporal_bins_s, self.gt_unit_positions[:, i], alpha=0.5, ls="--", c="0.5") + if raster: + peaks = bench.result["peaks"] + peak_locations = bench.result["peak_locations"] + rec = bench.recording + x = peaks["sample_index"] / rec.sampling_frequency + y = peak_locations[bench.direction] + ax.scatter(x, y, alpha=.2, s=2, c=np.abs(peaks["amplitude"]), cmap="inferno") for i in range(gt_motion.displacement[0].shape[1]): depth = motion.spatial_bins_um[i] From 22882ef66a8389fdfd7aac30ea4633151f1cdd16 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 27 Nov 2024 13:38:09 +0000 Subject: [PATCH 30/64] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../benchmark/benchmark_motion_estimation.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/benchmark/benchmark_motion_estimation.py b/src/spikeinterface/benchmark/benchmark_motion_estimation.py index 3a7d11fc35..5a3c490d38 100644 --- a/src/spikeinterface/benchmark/benchmark_motion_estimation.py +++ b/src/spikeinterface/benchmark/benchmark_motion_estimation.py @@ -109,7 +109,6 @@ def run(self, **job_kwargs): estimate_motion=t4 - t3, ) - self.result["peaks"] = peaks self.result["peak_locations"] = peak_locations self.result["step_run_times"] = step_run_times @@ -166,7 +165,9 @@ def create_benchmark(self, key): def plot_true_drift(self, case_keys=None, scaling_probe=1.5, figsize=(8, 6)): self.plot_drift(case_keys=case_keys, tested_drift=False, scaling_probe=scaling_probe, figsize=figsize) - def plot_drift(self, case_keys=None, gt_drift=True, tested_drift=True, raster=False, scaling_probe=1.0, figsize=(8, 6)): + def plot_drift( + self, case_keys=None, gt_drift=True, tested_drift=True, raster=False, scaling_probe=1.0, figsize=(8, 6) + ): import matplotlib.pyplot as plt if case_keys is None: @@ -206,7 +207,7 @@ def plot_drift(self, case_keys=None, gt_drift=True, tested_drift=True, raster=Fa rec = bench.recording x = peaks["sample_index"] / rec.sampling_frequency y = peak_locations[bench.direction] - ax.scatter(x, y, alpha=.2, s=2, c=np.abs(peaks["amplitude"]), cmap="inferno") + ax.scatter(x, y, alpha=0.2, s=2, c=np.abs(peaks["amplitude"]), cmap="inferno") for i in range(gt_motion.displacement[0].shape[1]): depth = motion.spatial_bins_um[i] From 0c88b39a875e5068b0bfd4f63db7ff45f025e202 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 29 Nov 2024 05:46:02 +0100 Subject: [PATCH 31/64] Patch to force remove sorters --- src/spikeinterface/benchmark/benchmark_base.py | 5 +++-- src/spikeinterface/benchmark/benchmark_sorter.py | 9 +++++++++ src/spikeinterface/curation/auto_merge.py | 2 +- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/benchmark/benchmark_base.py b/src/spikeinterface/benchmark/benchmark_base.py index b9cbf269c8..ddcf25f2ab 100644 --- a/src/spikeinterface/benchmark/benchmark_base.py +++ b/src/spikeinterface/benchmark/benchmark_base.py @@ -208,10 +208,11 @@ def run(self, case_keys=None, keep=True, verbose=False, **job_kwargs): for key in case_keys: result_folder = self.folder / "results" / self.key_to_str(key) - + sorter_folder = self.folder / "sorters" / self.key_to_str(key) + if keep and result_folder.exists(): continue - elif not keep and result_folder.exists(): + elif not keep and (result_folder.exists() or sorter_folder.exists()): self.remove_benchmark(key) job_keys.append(key) diff --git a/src/spikeinterface/benchmark/benchmark_sorter.py b/src/spikeinterface/benchmark/benchmark_sorter.py index f9267c785a..8180c943be 100644 --- a/src/spikeinterface/benchmark/benchmark_sorter.py +++ b/src/spikeinterface/benchmark/benchmark_sorter.py @@ -56,6 +56,15 @@ def create_benchmark(self, key): benchmark = SorterBenchmark(recording, gt_sorting, params, sorter_folder) return benchmark + def remove_benchmark(self, key): + BenchmarkStudy.remove_benchmark(self, key) + + sorter_folder = self.folder / "sorters" / self.key_to_str(key) + import shutil + if sorter_folder.exists(): + shutil.rmtree(sorter_folder) + + def get_performance_by_unit(self, case_keys=None): import pandas as pd diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 4f4cff144e..89c24565c2 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -231,7 +231,7 @@ def compute_merge_unit_groups( params = _default_step_params.get(step).copy() if steps_params is not None and step in steps_params: params.update(steps_params[step]) - + # STEP : remove units with too few spikes if step == "num_spikes": From 32c74d43085848008e0a497342f4b3eb2d917020 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 29 Nov 2024 05:46:58 +0100 Subject: [PATCH 32/64] Spaces --- src/spikeinterface/curation/auto_merge.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 89c24565c2..4f4cff144e 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -231,7 +231,7 @@ def compute_merge_unit_groups( params = _default_step_params.get(step).copy() if steps_params is not None and step in steps_params: params.update(steps_params[step]) - + # STEP : remove units with too few spikes if step == "num_spikes": From c35a706632065af280300330ba25f76326906590 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 29 Nov 2024 04:49:39 +0000 Subject: [PATCH 33/64] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/benchmark/benchmark_base.py | 2 +- src/spikeinterface/benchmark/benchmark_sorter.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/benchmark/benchmark_base.py b/src/spikeinterface/benchmark/benchmark_base.py index ddcf25f2ab..fc1b136d2d 100644 --- a/src/spikeinterface/benchmark/benchmark_base.py +++ b/src/spikeinterface/benchmark/benchmark_base.py @@ -209,7 +209,7 @@ def run(self, case_keys=None, keep=True, verbose=False, **job_kwargs): result_folder = self.folder / "results" / self.key_to_str(key) sorter_folder = self.folder / "sorters" / self.key_to_str(key) - + if keep and result_folder.exists(): continue elif not keep and (result_folder.exists() or sorter_folder.exists()): diff --git a/src/spikeinterface/benchmark/benchmark_sorter.py b/src/spikeinterface/benchmark/benchmark_sorter.py index 8180c943be..3cf6dca04f 100644 --- a/src/spikeinterface/benchmark/benchmark_sorter.py +++ b/src/spikeinterface/benchmark/benchmark_sorter.py @@ -61,9 +61,9 @@ def remove_benchmark(self, key): sorter_folder = self.folder / "sorters" / self.key_to_str(key) import shutil + if sorter_folder.exists(): shutil.rmtree(sorter_folder) - def get_performance_by_unit(self, case_keys=None): import pandas as pd From b0e7b1c60086fd818759b8df0f617d5441f4ff40 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 29 Nov 2024 06:04:48 +0100 Subject: [PATCH 34/64] Fix kwargs in silence periods --- src/spikeinterface/preprocessing/silence_periods.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 85169011d8..a188f5d8db 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -97,7 +97,12 @@ def __init__(self, recording, list_periods, mode="zeros", noise_levels=None, see rec_segment = SilencedPeriodsRecordingSegment(parent_segment, periods, mode, noise_generator, seg_index) self.add_recording_segment(rec_segment) - self._kwargs = dict(recording=recording, list_periods=list_periods, mode=mode, noise_generator=noise_generator) + self._kwargs = dict(recording=recording, + list_periods=list_periods, + mode=mode, + noise_levels=noise_levels, + seed=seed, + random_chunk_kwargs=random_chunk_kwargs) class SilencedPeriodsRecordingSegment(BasePreprocessorSegment): From 60d7ad53b59fac5b47bd976905d64efc62b3daeb Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 29 Nov 2024 06:10:02 +0100 Subject: [PATCH 35/64] Fix --- src/spikeinterface/preprocessing/silence_periods.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index a188f5d8db..5e410d51d5 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -100,7 +100,6 @@ def __init__(self, recording, list_periods, mode="zeros", noise_levels=None, see self._kwargs = dict(recording=recording, list_periods=list_periods, mode=mode, - noise_levels=noise_levels, seed=seed, random_chunk_kwargs=random_chunk_kwargs) From 01ae85cbb3ebcfd6f9e20eb191d92a278bbcb5e4 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 29 Nov 2024 06:14:16 +0100 Subject: [PATCH 36/64] WIP --- src/spikeinterface/preprocessing/silence_periods.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 5e410d51d5..7c518d02a0 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -100,8 +100,8 @@ def __init__(self, recording, list_periods, mode="zeros", noise_levels=None, see self._kwargs = dict(recording=recording, list_periods=list_periods, mode=mode, - seed=seed, - random_chunk_kwargs=random_chunk_kwargs) + seed=seed) + self._kwargs.update(random_chunk_kwargs) class SilencedPeriodsRecordingSegment(BasePreprocessorSegment): From db2b4d5130a095500637c06f9d74aa2f03d41b73 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 29 Nov 2024 05:15:54 +0000 Subject: [PATCH 37/64] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/preprocessing/silence_periods.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 7c518d02a0..00d9a1a407 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -97,10 +97,7 @@ def __init__(self, recording, list_periods, mode="zeros", noise_levels=None, see rec_segment = SilencedPeriodsRecordingSegment(parent_segment, periods, mode, noise_generator, seg_index) self.add_recording_segment(rec_segment) - self._kwargs = dict(recording=recording, - list_periods=list_periods, - mode=mode, - seed=seed) + self._kwargs = dict(recording=recording, list_periods=list_periods, mode=mode, seed=seed) self._kwargs.update(random_chunk_kwargs) From 5af4f858268c18421c8d1e7a3cbaca3c9957491e Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Mon, 2 Dec 2024 09:19:15 +0000 Subject: [PATCH 38/64] Hard code synchony_size for users, but leave flexible code underneathe --- doc/get_started/quickstart.rst | 2 +- doc/modules/qualitymetrics/synchrony.rst | 4 +- .../qualitymetrics/misc_metrics.py | 27 ++++++------- .../tests/test_metrics_functions.py | 39 ++++++++----------- 4 files changed, 30 insertions(+), 42 deletions(-) diff --git a/doc/get_started/quickstart.rst b/doc/get_started/quickstart.rst index 3d45606a78..1349802ce5 100644 --- a/doc/get_started/quickstart.rst +++ b/doc/get_started/quickstart.rst @@ -673,7 +673,7 @@ compute quality metrics (some quality metrics require certain extensions 'min_spikes': 0, 'window_size_s': 1}, 'snr': {'peak_mode': 'extremum', 'peak_sign': 'neg'}, - 'synchrony': {'synchrony_sizes': (2, 4, 8)}} + 'synchrony': {} Since the recording is very short, let’s change some parameters to diff --git a/doc/modules/qualitymetrics/synchrony.rst b/doc/modules/qualitymetrics/synchrony.rst index d244fd0c0f..696dacbd3c 100644 --- a/doc/modules/qualitymetrics/synchrony.rst +++ b/doc/modules/qualitymetrics/synchrony.rst @@ -12,7 +12,7 @@ trains. This way synchronous events can be found both in multi-unit and single-u Complexity is calculated by counting the number of spikes (i.e. non-empty bins) that occur at the same sample index, within and across spike trains. -Synchrony metrics can be computed for different synchrony sizes (>1), defining the number of simultaneous spikes to count. +Synchrony metrics are computed for 2, 4 and 8 synchronous spikes. @@ -29,7 +29,7 @@ Example code import spikeinterface.qualitymetrics as sqm # Combine a sorting and recording into a sorting_analyzer - synchrony = sqm.compute_synchrony_metrics(sorting_analyzer=sorting_analyzer synchrony_sizes=(2, 4, 8)) + synchrony = sqm.compute_synchrony_metrics(sorting_analyzer=sorting_analyzer) # synchrony is a tuple of dicts with the synchrony metrics for each unit diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index 8dfd41cf88..b0e0a0ad19 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -520,7 +520,7 @@ def compute_sliding_rp_violations( ) -def get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids): +def _get_synchrony_counts(spikes, all_unit_ids, synchrony_sizes=np.array([2, 4, 8])): """ Compute synchrony counts, the number of simultaneous spikes with sizes `synchrony_sizes`. @@ -528,10 +528,10 @@ def get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids): ---------- spikes : np.array Structured numpy array with fields ("sample_index", "unit_index", "segment_index"). - synchrony_sizes : numpy array - The synchrony sizes to compute. Should be pre-sorted. all_unit_ids : list or None, default: None List of unit ids to compute the synchrony metrics. Expecting all units. + synchrony_sizes : numpy array + The synchrony sizes to compute. Should be pre-sorted. Returns ------- @@ -565,17 +565,15 @@ def get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids): return synchrony_counts -def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_ids=None): +def compute_synchrony_metrics(sorting_analyzer, unit_ids=None): """ Compute synchrony metrics. Synchrony metrics represent the rate of occurrences of - "synchrony_size" spikes at the exact same sample index. + spikes at the exact same sample index, with synchrony sizes 2, 4 and 8. Parameters ---------- sorting_analyzer : SortingAnalyzer A SortingAnalyzer object. - synchrony_sizes : list or tuple, default: (2, 4, 8) - The synchrony sizes to compute. unit_ids : list or None, default: None List of unit ids to compute the synchrony metrics. If None, all units are used. @@ -583,19 +581,16 @@ def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_ ------- sync_spike_{X} : dict The synchrony metric for synchrony size X. - Returns are as many as synchrony_sizes. References ---------- Based on concepts described in [GrĂ¼n]_ This code was adapted from `Elephant - Electrophysiology Analysis Toolkit `_ """ - assert min(synchrony_sizes) > 1, "Synchrony sizes must be greater than 1" - # Sort the synchrony times so we can slice numpy arrays, instead of using dicts - synchrony_sizes_np = np.array(synchrony_sizes, dtype=np.int16) - synchrony_sizes_np.sort() - res = namedtuple("synchrony_metrics", [f"sync_spike_{size}" for size in synchrony_sizes_np]) + synchrony_sizes = np.array([2, 4, 8]) + + res = namedtuple("synchrony_metrics", [f"sync_spike_{size}" for size in synchrony_sizes]) sorting = sorting_analyzer.sorting @@ -606,10 +601,10 @@ def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_ spikes = sorting.to_spike_vector() all_unit_ids = sorting.unit_ids - synchrony_counts = get_synchrony_counts(spikes, synchrony_sizes_np, all_unit_ids) + synchrony_counts = _get_synchrony_counts(spikes, all_unit_ids, synchrony_sizes=synchrony_sizes) synchrony_metrics_dict = {} - for sync_idx, synchrony_size in enumerate(synchrony_sizes_np): + for sync_idx, synchrony_size in enumerate(synchrony_sizes): sync_id_metrics_dict = {} for i, unit_id in enumerate(all_unit_ids): if unit_id not in unit_ids: @@ -623,7 +618,7 @@ def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_ return res(**synchrony_metrics_dict) -_default_params["synchrony"] = dict(synchrony_sizes=(2, 4, 8)) +_default_params["synchrony"] = dict() def compute_firing_ranges(sorting_analyzer, bin_size_s=5, percentiles=(5, 95), unit_ids=None): diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index 4c0890b62b..f51dc3e884 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -39,7 +39,7 @@ compute_firing_ranges, compute_amplitude_cv_metrics, compute_sd_ratio, - get_synchrony_counts, + _get_synchrony_counts, compute_quality_metrics, ) @@ -352,7 +352,7 @@ def test_synchrony_counts_no_sync(): one_spike["sample_index"] = spike_times one_spike["unit_index"] = spike_units - sync_count = get_synchrony_counts(one_spike, np.array((2)), [0]) + sync_count = _get_synchrony_counts(one_spike, [0]) assert np.all(sync_count[0] == np.array([0])) @@ -372,7 +372,7 @@ def test_synchrony_counts_one_sync(): two_spikes["sample_index"] = np.concatenate((spike_indices, added_spikes_indices)) two_spikes["unit_index"] = np.concatenate((spike_labels, added_spikes_labels)) - sync_count = get_synchrony_counts(two_spikes, np.array((2)), [0, 1]) + sync_count = _get_synchrony_counts(two_spikes, [0, 1]) assert np.all(sync_count[0] == np.array([1, 1])) @@ -392,7 +392,7 @@ def test_synchrony_counts_one_quad_sync(): four_spikes["sample_index"] = np.concatenate((spike_indices, added_spikes_indices)) four_spikes["unit_index"] = np.concatenate((spike_labels, added_spikes_labels)) - sync_count = get_synchrony_counts(four_spikes, np.array((2, 4)), [0, 1, 2, 3]) + sync_count = _get_synchrony_counts(four_spikes, [0, 1, 2, 3]) assert np.all(sync_count[0] == np.array([1, 1, 1, 1])) assert np.all(sync_count[1] == np.array([1, 1, 1, 1])) @@ -409,7 +409,7 @@ def test_synchrony_counts_not_all_units(): three_spikes["sample_index"] = np.concatenate((spike_indices, added_spikes_indices)) three_spikes["unit_index"] = np.concatenate((spike_labels, added_spikes_labels)) - sync_count = get_synchrony_counts(three_spikes, np.array((2)), [0, 1, 2]) + sync_count = _get_synchrony_counts(three_spikes, [0, 1, 2]) assert np.all(sync_count[0] == np.array([0, 1, 1])) @@ -610,9 +610,9 @@ def test_calculate_rp_violations(sorting_analyzer_violations): def test_synchrony_metrics(sorting_analyzer_simple): sorting_analyzer = sorting_analyzer_simple sorting = sorting_analyzer.sorting - synchrony_sizes = (2, 3, 4) - synchrony_metrics = compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=synchrony_sizes) - print(synchrony_metrics) + synchrony_metrics = compute_synchrony_metrics(sorting_analyzer) + + synchrony_sizes = np.array([2, 4, 8]) # check returns for size in synchrony_sizes: @@ -625,10 +625,8 @@ def test_synchrony_metrics(sorting_analyzer_simple): sorting_sync = add_synchrony_to_sorting(sorting, sync_event_ratio=sync_level) sorting_analyzer_sync = create_sorting_analyzer(sorting_sync, sorting_analyzer.recording, format="memory") - previous_synchrony_metrics = compute_synchrony_metrics( - previous_sorting_analyzer, synchrony_sizes=synchrony_sizes - ) - current_synchrony_metrics = compute_synchrony_metrics(sorting_analyzer_sync, synchrony_sizes=synchrony_sizes) + previous_synchrony_metrics = compute_synchrony_metrics(previous_sorting_analyzer) + current_synchrony_metrics = compute_synchrony_metrics(sorting_analyzer_sync) print(current_synchrony_metrics) # check that all values increased for i, col in enumerate(previous_synchrony_metrics._fields): @@ -647,22 +645,17 @@ def test_synchrony_metrics_unit_id_subset(sorting_analyzer_simple): unit_ids_subset = [3, 7] - synchrony_sizes = (2,) - (synchrony_metrics,) = compute_synchrony_metrics( - sorting_analyzer_simple, synchrony_sizes=synchrony_sizes, unit_ids=unit_ids_subset - ) + synchrony_metrics = compute_synchrony_metrics(sorting_analyzer_simple, unit_ids=unit_ids_subset) - assert list(synchrony_metrics.keys()) == [3, 7] + assert list(synchrony_metrics.sync_spike_2.keys()) == [3, 7] + assert list(synchrony_metrics.sync_spike_4.keys()) == [3, 7] + assert list(synchrony_metrics.sync_spike_8.keys()) == [3, 7] def test_synchrony_metrics_no_unit_ids(sorting_analyzer_simple): - # all_unit_ids = sorting_analyzer_simple.sorting.unit_ids - - synchrony_sizes = (2,) - (synchrony_metrics,) = compute_synchrony_metrics(sorting_analyzer_simple, synchrony_sizes=synchrony_sizes) - - assert np.all(list(synchrony_metrics.keys()) == sorting_analyzer_simple.unit_ids) + synchrony_metrics = compute_synchrony_metrics(sorting_analyzer_simple) + assert np.all(list(synchrony_metrics.sync_spike_2.keys()) == sorting_analyzer_simple.unit_ids) @pytest.mark.sortingcomponents From 45eb5b74e58061ee04dcb2a4bba10dbcf2a2c892 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Mon, 2 Dec 2024 09:31:23 +0000 Subject: [PATCH 39/64] Add warning and ability to pass synchrony_sizes --- src/spikeinterface/qualitymetrics/__init__.py | 2 +- src/spikeinterface/qualitymetrics/misc_metrics.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/__init__.py b/src/spikeinterface/qualitymetrics/__init__.py index 9d604f6ae2..754c82d8e3 100644 --- a/src/spikeinterface/qualitymetrics/__init__.py +++ b/src/spikeinterface/qualitymetrics/__init__.py @@ -6,4 +6,4 @@ get_default_qm_params, ) from .pca_metrics import get_quality_pca_metric_list -from .misc_metrics import get_synchrony_counts +from .misc_metrics import _get_synchrony_counts diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index b0e0a0ad19..2f178c46f3 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -565,7 +565,7 @@ def _get_synchrony_counts(spikes, all_unit_ids, synchrony_sizes=np.array([2, 4, return synchrony_counts -def compute_synchrony_metrics(sorting_analyzer, unit_ids=None): +def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=None): """ Compute synchrony metrics. Synchrony metrics represent the rate of occurrences of spikes at the exact same sample index, with synchrony sizes 2, 4 and 8. @@ -588,6 +588,10 @@ def compute_synchrony_metrics(sorting_analyzer, unit_ids=None): This code was adapted from `Elephant - Electrophysiology Analysis Toolkit `_ """ + if synchrony_sizes is not None: + warning_message = "Custom `synchrony_sizes` is deprecated; the `synchrony_metrics` will be computed using `synchrony_sizes = [2,4,8]`" + warnings.warn(warning_message) + synchrony_sizes = np.array([2, 4, 8]) res = namedtuple("synchrony_metrics", [f"sync_spike_{size}" for size in synchrony_sizes]) From 039b408a59ce965b91908de82d0bc55114f8655e Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Mon, 2 Dec 2024 13:43:39 +0000 Subject: [PATCH 40/64] move backwards compat to `_handle_backward_compatibility_on_load` --- .../postprocessing/template_metrics.py | 25 ++++++++++--------- .../quality_metric_calculator.py | 14 +++++------ 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 477ad04440..7de6e8766a 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -88,9 +88,22 @@ class ComputeTemplateMetrics(AnalyzerExtension): need_recording = False use_nodepipeline = False need_job_kwargs = False + need_backward_compatibility_on_load = True min_channels_for_multi_channel_warning = 10 + def _handle_backward_compatibility_on_load(self): + + # For backwards compatibility - this reformats metrics_kwargs as metric_params + if (metrics_kwargs := self.params.get("metrics_kwargs")) is not None: + + metric_params = {} + for metric_name in self.params["metric_names"]: + metric_params[metric_name] = deepcopy(metrics_kwargs) + self.params["metric_params"] = metric_params + + del self.params["metrics_kwargs"] + def _set_params( self, metric_names=None, @@ -344,18 +357,6 @@ def _run(self, verbose=False): def _get_data(self): return self.data["metrics"] - def load_params(self): - AnalyzerExtension.load_params(self) - # For backwards compatibility - this reformats metrics_kwargs as metric_params - if (metrics_kwargs := self.params.get("metrics_kwargs")) is not None: - - metric_params = {} - for metric_name in self.params["metric_names"]: - metric_params[metric_name] = deepcopy(metrics_kwargs) - self.params["metric_params"] = metric_params - - del self.params["metrics_kwargs"] - register_result_extension(ComputeTemplateMetrics) compute_template_metrics = ComputeTemplateMetrics.function_factory() diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index e7e7c244ea..d71450853f 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -55,6 +55,13 @@ class ComputeQualityMetrics(AnalyzerExtension): need_recording = False use_nodepipeline = False need_job_kwargs = True + need_backward_compatibility_on_load = True + + def _handle_backward_compatibility_on_load(self): + # For backwards compatibility - this renames qm_params as metric_params + if (qm_params := self.params.get("qm_params")) is not None: + self.params["metric_params"] = qm_params + del self.params["qm_params"] def _set_params( self, @@ -262,13 +269,6 @@ def _run(self, verbose=False, **job_kwargs): def _get_data(self): return self.data["metrics"] - def load_params(self): - AnalyzerExtension.load_params(self) - # For backwards compatibility - this renames qm_params as metric_params - if (qm_params := self.params.get("qm_params")) is not None: - self.params["metric_params"] = qm_params - del self.params["qm_params"] - register_result_extension(ComputeQualityMetrics) compute_quality_metrics = ComputeQualityMetrics.function_factory() From 771de98c6ddd07e064cb28d6e2450e599d54d2a6 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Tue, 3 Dec 2024 09:51:28 +0000 Subject: [PATCH 41/64] Respond to z-man --- src/spikeinterface/postprocessing/template_metrics.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 7de6e8766a..da917e673c 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -64,8 +64,7 @@ class ComputeTemplateMetrics(AnalyzerExtension): Whether to compute multi-channel metrics delete_existing_metrics : bool, default: False If True, any template metrics attached to the `sorting_analyzer` are deleted. If False, any metrics which were previously calculated but are not included in `metric_names` are kept, provided the `metric_params` are unchanged. - metric_params : dict of dicts - metric_params : dict of dicts or None + metric_params : dict of dicts or None, default: None Dictionary with parameters for template metrics calculation. Default parameters can be obtained with: `si.postprocessing.template_metrics.get_default_tm_params()` @@ -138,7 +137,7 @@ def _set_params( if metrics_kwargs is not None and metric_params is None: deprecation_msg = "`metrics_kwargs` is deprecated and will be removed in version 0.104.0. Please use metric_params instead" - warnings.warn(deprecation_msg, category=DeprecationWarning) + deprecation_msg = "`metrics_kwargs` is deprecated and will be removed in version 0.104.0. Please use `metric_params` instead" metric_params = {} for metric_name in metric_names: From de7210a43135c1164ee2f214e117543441935375 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Tue, 3 Dec 2024 09:52:16 +0000 Subject: [PATCH 42/64] oups --- src/spikeinterface/postprocessing/template_metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index da917e673c..1969480503 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -136,7 +136,7 @@ def _set_params( metric_names += get_multi_channel_template_metric_names() if metrics_kwargs is not None and metric_params is None: - deprecation_msg = "`metrics_kwargs` is deprecated and will be removed in version 0.104.0. Please use metric_params instead" + deprecation_msg = "`metrics_kwargs` is deprecated and will be removed in version 0.104.0. Please use `metric_params` instead" deprecation_msg = "`metrics_kwargs` is deprecated and will be removed in version 0.104.0. Please use `metric_params` instead" metric_params = {} From 2081916e33d467145223a7c3099aca556f6e3864 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Tue, 3 Dec 2024 14:18:15 +0000 Subject: [PATCH 43/64] respond to review --- src/spikeinterface/qualitymetrics/misc_metrics.py | 10 ++++++---- .../qualitymetrics/tests/test_metrics_functions.py | 8 ++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index 2f178c46f3..6007de379c 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -520,7 +520,7 @@ def compute_sliding_rp_violations( ) -def _get_synchrony_counts(spikes, all_unit_ids, synchrony_sizes=np.array([2, 4, 8])): +def _get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids): """ Compute synchrony counts, the number of simultaneous spikes with sizes `synchrony_sizes`. @@ -530,7 +530,7 @@ def _get_synchrony_counts(spikes, all_unit_ids, synchrony_sizes=np.array([2, 4, Structured numpy array with fields ("sample_index", "unit_index", "segment_index"). all_unit_ids : list or None, default: None List of unit ids to compute the synchrony metrics. Expecting all units. - synchrony_sizes : numpy array + synchrony_sizes : None or np.array, default: None The synchrony sizes to compute. Should be pre-sorted. Returns @@ -576,6 +576,8 @@ def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=N A SortingAnalyzer object. unit_ids : list or None, default: None List of unit ids to compute the synchrony metrics. If None, all units are used. + synchrony_sizes: None, default: None + Deprecated argument. Please use private `_get_synchrony_counts` if you need finer control over number of synchronous spikes. Returns ------- @@ -590,7 +592,7 @@ def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=N if synchrony_sizes is not None: warning_message = "Custom `synchrony_sizes` is deprecated; the `synchrony_metrics` will be computed using `synchrony_sizes = [2,4,8]`" - warnings.warn(warning_message) + warnings.warn(warning_message, DeprecationWarning, stacklevel=2) synchrony_sizes = np.array([2, 4, 8]) @@ -605,7 +607,7 @@ def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=N spikes = sorting.to_spike_vector() all_unit_ids = sorting.unit_ids - synchrony_counts = _get_synchrony_counts(spikes, all_unit_ids, synchrony_sizes=synchrony_sizes) + synchrony_counts = _get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids) synchrony_metrics_dict = {} for sync_idx, synchrony_size in enumerate(synchrony_sizes): diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index f51dc3e884..ae4c7ab62d 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -352,7 +352,7 @@ def test_synchrony_counts_no_sync(): one_spike["sample_index"] = spike_times one_spike["unit_index"] = spike_units - sync_count = _get_synchrony_counts(one_spike, [0]) + sync_count = _get_synchrony_counts(one_spike, np.array([2, 4, 8]), [0]) assert np.all(sync_count[0] == np.array([0])) @@ -372,7 +372,7 @@ def test_synchrony_counts_one_sync(): two_spikes["sample_index"] = np.concatenate((spike_indices, added_spikes_indices)) two_spikes["unit_index"] = np.concatenate((spike_labels, added_spikes_labels)) - sync_count = _get_synchrony_counts(two_spikes, [0, 1]) + sync_count = _get_synchrony_counts(two_spikes, np.array([2, 4, 8]), [0, 1]) assert np.all(sync_count[0] == np.array([1, 1])) @@ -392,7 +392,7 @@ def test_synchrony_counts_one_quad_sync(): four_spikes["sample_index"] = np.concatenate((spike_indices, added_spikes_indices)) four_spikes["unit_index"] = np.concatenate((spike_labels, added_spikes_labels)) - sync_count = _get_synchrony_counts(four_spikes, [0, 1, 2, 3]) + sync_count = _get_synchrony_counts(four_spikes, np.array([2, 4, 8]), [0, 1, 2, 3]) assert np.all(sync_count[0] == np.array([1, 1, 1, 1])) assert np.all(sync_count[1] == np.array([1, 1, 1, 1])) @@ -409,7 +409,7 @@ def test_synchrony_counts_not_all_units(): three_spikes["sample_index"] = np.concatenate((spike_indices, added_spikes_indices)) three_spikes["unit_index"] = np.concatenate((spike_labels, added_spikes_labels)) - sync_count = _get_synchrony_counts(three_spikes, [0, 1, 2]) + sync_count = _get_synchrony_counts(three_spikes, np.array([2, 4, 8]), [0, 1, 2]) assert np.all(sync_count[0] == np.array([0, 1, 1])) From 09ff624817b53d30d35f1e4f9060edabab45a308 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 4 Dec 2024 10:33:40 +0100 Subject: [PATCH 44/64] Remove venv in full-tests-with-codecov --- .../actions/build-test-environment/action.yml | 36 +++++++------------ .github/workflows/all-tests.yml | 2 +- 2 files changed, 14 insertions(+), 24 deletions(-) diff --git a/.github/actions/build-test-environment/action.yml b/.github/actions/build-test-environment/action.yml index 723e8a702f..a212bd64d5 100644 --- a/.github/actions/build-test-environment/action.yml +++ b/.github/actions/build-test-environment/action.yml @@ -1,41 +1,20 @@ name: Install packages description: This action installs the package and its dependencies for testing -inputs: - python-version: - description: 'Python version to set up' - required: false - os: - description: 'Operating system to set up' - required: false - runs: using: "composite" steps: - name: Install dependencies run: | - sudo apt install git git config --global user.email "CI@example.com" git config --global user.name "CI Almighty" - python -m venv ${{ github.workspace }}/test_env # Environment used in the caching step - python -m pip install -U pip # Official recommended way - source ${{ github.workspace }}/test_env/bin/activate pip install tabulate # This produces summaries at the end pip install -e .[test,extractors,streaming_extractors,test_extractors,full] shell: bash - - name: Force installation of latest dev from key-packages when running dev (not release) - run: | - source ${{ github.workspace }}/test_env/bin/activate - spikeinterface_is_dev_version=$(python -c "import spikeinterface; print(spikeinterface.DEV_MODE)") - if [ $spikeinterface_is_dev_version = "True" ]; then - echo "Running spikeinterface dev version" - pip install --no-cache-dir git+https://github.com/NeuralEnsemble/python-neo - pip install --no-cache-dir git+https://github.com/SpikeInterface/probeinterface - fi - echo "Running tests for release, using pyproject.toml versions of neo and probeinterface" + - name: Install git-annex shell: bash - - name: git-annex install run: | + pip install datalad-installer wget https://downloads.kitenet.net/git-annex/linux/current/git-annex-standalone-amd64.tar.gz mkdir /home/runner/work/installation mv git-annex-standalone-amd64.tar.gz /home/runner/work/installation/ @@ -44,4 +23,15 @@ runs: tar xvzf git-annex-standalone-amd64.tar.gz echo "$(pwd)/git-annex.linux" >> $GITHUB_PATH cd $workdir + git config --global filter.annex.process "git-annex filter-process" # recommended for efficiency + - name: Force installation of latest dev from key-packages when running dev (not release) + run: | + source ${{ github.workspace }}/test_env/bin/activate + spikeinterface_is_dev_version=$(python -c "import spikeinterface; print(spikeinterface.DEV_MODE)") + if [ $spikeinterface_is_dev_version = "True" ]; then + echo "Running spikeinterface dev version" + pip install --no-cache-dir git+https://github.com/NeuralEnsemble/python-neo + pip install --no-cache-dir git+https://github.com/SpikeInterface/probeinterface + fi + echo "Running tests for release, using pyproject.toml versions of neo and probeinterface" shell: bash diff --git a/.github/workflows/all-tests.yml b/.github/workflows/all-tests.yml index dcaec8b272..a9c840d5d5 100644 --- a/.github/workflows/all-tests.yml +++ b/.github/workflows/all-tests.yml @@ -47,7 +47,7 @@ jobs: echo "$file was changed" done - - name: Set testing environment # This decides which tests are run and whether to install especial dependencies + - name: Set testing environment # This decides which tests are run and whether to install special dependencies shell: bash run: | changed_files="${{ steps.changed-files.outputs.all_changed_files }}" From 8500b9d0f4488794dcc6d6b71afec2ebf4697b1d Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 4 Dec 2024 10:48:11 +0100 Subject: [PATCH 45/64] Oups --- .github/actions/build-test-environment/action.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/actions/build-test-environment/action.yml b/.github/actions/build-test-environment/action.yml index a212bd64d5..c2524d2c16 100644 --- a/.github/actions/build-test-environment/action.yml +++ b/.github/actions/build-test-environment/action.yml @@ -26,7 +26,6 @@ runs: git config --global filter.annex.process "git-annex filter-process" # recommended for efficiency - name: Force installation of latest dev from key-packages when running dev (not release) run: | - source ${{ github.workspace }}/test_env/bin/activate spikeinterface_is_dev_version=$(python -c "import spikeinterface; print(spikeinterface.DEV_MODE)") if [ $spikeinterface_is_dev_version = "True" ]; then echo "Running spikeinterface dev version" From 922606b6d4d279da103b7e7edde3ecb79a76e3c8 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 4 Dec 2024 11:16:01 +0100 Subject: [PATCH 46/64] Oups 2 --- .github/workflows/full-test-with-codecov.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/full-test-with-codecov.yml b/.github/workflows/full-test-with-codecov.yml index 407c614ebf..f8ed2aa7a9 100644 --- a/.github/workflows/full-test-with-codecov.yml +++ b/.github/workflows/full-test-with-codecov.yml @@ -45,7 +45,6 @@ jobs: env: HDF5_PLUGIN_PATH: ${{ github.workspace }}/hdf5_plugin_path_maxwell run: | - source ${{ github.workspace }}/test_env/bin/activate pytest -m "not sorters_external" --cov=./ --cov-report xml:./coverage.xml -vv -ra --durations=0 | tee report_full.txt; test ${PIPESTATUS[0]} -eq 0 || exit 1 echo "# Timing profile of full tests" >> $GITHUB_STEP_SUMMARY python ./.github/scripts/build_job_summary.py report_full.txt >> $GITHUB_STEP_SUMMARY From 986a74a30c94a49ed2a2dd6183e8ddc078105b85 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 5 Dec 2024 09:02:31 +0100 Subject: [PATCH 47/64] Pin ONE-API version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index fc09ad9198..22fbdc7f22 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,7 +73,7 @@ extractors = [ ] streaming_extractors = [ - "ONE-api>=2.7.0", # alf sorter and streaming IBL + "ONE-api>=2.7.0,<2.10.0", # alf sorter and streaming IBL "ibllib>=2.36.0", # streaming IBL # Following dependencies are for streaming with nwb files "pynwb>=2.6.0", From 96da22f7ac509bfc83a2a90eed06d58e3f71f990 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 5 Dec 2024 10:00:30 +0000 Subject: [PATCH 48/64] Correct method default in docstring --- src/spikeinterface/postprocessing/unit_locations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/unit_locations.py b/src/spikeinterface/postprocessing/unit_locations.py index 3f6dd47eec..bea06fd8f5 100644 --- a/src/spikeinterface/postprocessing/unit_locations.py +++ b/src/spikeinterface/postprocessing/unit_locations.py @@ -26,7 +26,7 @@ class ComputeUnitLocations(AnalyzerExtension): ---------- sorting_analyzer : SortingAnalyzer A SortingAnalyzer object - method : "center_of_mass" | "monopolar_triangulation" | "grid_convolution", default: "center_of_mass" + method : "monopolar_triangulation" or "center_of_mass" or "grid_convolution", default: "monopolar_triangulation" The method to use for localization **method_kwargs : dict, default: {} Kwargs which are passed to the method function. These can be found in the docstrings of `compute_center_of_mass`, `compute_grid_convolution` and `compute_monopolar_triangulation`. From 10d459f3d45315cee3079e3b14428222487ef9c6 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 5 Dec 2024 15:18:59 +0000 Subject: [PATCH 49/64] change or to | in docstring --- src/spikeinterface/postprocessing/unit_locations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/unit_locations.py b/src/spikeinterface/postprocessing/unit_locations.py index bea06fd8f5..df19458316 100644 --- a/src/spikeinterface/postprocessing/unit_locations.py +++ b/src/spikeinterface/postprocessing/unit_locations.py @@ -26,7 +26,7 @@ class ComputeUnitLocations(AnalyzerExtension): ---------- sorting_analyzer : SortingAnalyzer A SortingAnalyzer object - method : "monopolar_triangulation" or "center_of_mass" or "grid_convolution", default: "monopolar_triangulation" + method : "monopolar_triangulation" | "center_of_mass" | "grid_convolution", default: "monopolar_triangulation" The method to use for localization **method_kwargs : dict, default: {} Kwargs which are passed to the method function. These can be found in the docstrings of `compute_center_of_mass`, `compute_grid_convolution` and `compute_monopolar_triangulation`. From 4c7b6a5be65af4aa4ce6461e84956455f970942f Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 11 Dec 2024 09:53:59 +0100 Subject: [PATCH 50/64] Patch --- src/spikeinterface/widgets/unit_waveforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/unit_waveforms.py b/src/spikeinterface/widgets/unit_waveforms.py index c593836061..3b31eacee5 100644 --- a/src/spikeinterface/widgets/unit_waveforms.py +++ b/src/spikeinterface/widgets/unit_waveforms.py @@ -565,7 +565,7 @@ def _update_plot(self, change): channel_locations = self.sorting_analyzer.get_channel_locations() else: unit_indices = [list(self.templates.unit_ids).index(unit_id) for unit_id in unit_ids] - templates = self.templates.templates_array[unit_indices] + templates = self.templates.get_dense_templates()[unit_indices] templates_shadings = None channel_locations = self.templates.get_channel_locations() From 0bf2b08248b836c6323524c1f54cf3690cd6c5f8 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 17 Dec 2024 11:14:08 -0600 Subject: [PATCH 51/64] use strings as ids in generators --- src/spikeinterface/core/generate.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 0316b3bab1..d03c08b480 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -134,7 +134,7 @@ def generate_sorting( seed = _ensure_seed(seed) rng = np.random.default_rng(seed) num_segments = len(durations) - unit_ids = np.arange(num_units) + unit_ids = [str(id) for id in np.arange(num_units)] spikes = [] for segment_index in range(num_segments): @@ -1111,7 +1111,7 @@ def __init__( """ - unit_ids = np.arange(num_units) + unit_ids = [str(id) for id in np.arange(num_units)] super().__init__(sampling_frequency, unit_ids) self.num_units = num_units @@ -1280,7 +1280,7 @@ def __init__( noise_block_size: int = 30000, ): - channel_ids = np.arange(num_channels) + channel_ids = [str(id) for id in np.arange(num_channels)] dtype = np.dtype(dtype).name # Cast to string for serialization if dtype not in ("float32", "float64"): raise ValueError(f"'dtype' must be 'float32' or 'float64' but is {dtype}") From 212a974ea7fa17eacb91f37495c764dc2eb8f828 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 17 Dec 2024 12:16:16 -0600 Subject: [PATCH 52/64] change to strings --- src/spikeinterface/core/basesorting.py | 2 +- src/spikeinterface/core/generate.py | 7 ++++-- .../core/tests/test_basesnippets.py | 10 ++++----- .../test_channelsaggregationrecording.py | 6 +++-- .../core/tests/test_sortinganalyzer.py | 14 ++++++------ .../core/tests/test_unitsselectionsorting.py | 22 +++++++++++-------- 6 files changed, 35 insertions(+), 26 deletions(-) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 2af48407a3..9a0e242d62 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -135,7 +135,7 @@ def get_total_duration(self) -> float: def get_unit_spike_train( self, - unit_id, + unit_id: str | int, segment_index: Union[int, None] = None, start_frame: Union[int, None] = None, end_frame: Union[int, None] = None, diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index d03c08b480..5824a75ab8 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -2,7 +2,7 @@ import math import warnings import numpy as np -from typing import Literal +from typing import Literal, Optional from math import ceil from .basesorting import SpikeVectorSortingSegment @@ -1138,6 +1138,7 @@ def __init__( firing_rates=firing_rates, refractory_period_seconds=self.refractory_period_seconds, seed=segment_seed, + unit_ids=unit_ids, t_start=None, ) self.add_sorting_segment(segment) @@ -1161,6 +1162,7 @@ def __init__( firing_rates: float | np.ndarray, refractory_period_seconds: float | np.ndarray, seed: int, + unit_ids: list[str], t_start: Optional[float] = None, ): self.num_units = num_units @@ -1177,7 +1179,8 @@ def __init__( self.refractory_period_seconds = np.full(num_units, self.refractory_period_seconds, dtype="float64") self.segment_seed = seed - self.units_seed = {unit_id: self.segment_seed + hash(unit_id) for unit_id in range(num_units)} + self.units_seed = {unit_id: abs(self.segment_seed + hash(unit_id)) for unit_id in unit_ids} + self.num_samples = math.ceil(sampling_frequency * duration) super().__init__(t_start) diff --git a/src/spikeinterface/core/tests/test_basesnippets.py b/src/spikeinterface/core/tests/test_basesnippets.py index 64f7f76819..f243dd9d9f 100644 --- a/src/spikeinterface/core/tests/test_basesnippets.py +++ b/src/spikeinterface/core/tests/test_basesnippets.py @@ -41,8 +41,8 @@ def test_BaseSnippets(create_cache_folder): assert snippets.get_num_segments() == len(duration) assert snippets.get_num_channels() == num_channels - assert np.all(snippets.ids_to_indices([0, 1, 2]) == [0, 1, 2]) - assert np.all(snippets.ids_to_indices([0, 1, 2], prefer_slice=True) == slice(0, 3, None)) + assert np.all(snippets.ids_to_indices(["0", "1", "2"]) == [0, 1, 2]) + assert np.all(snippets.ids_to_indices(["0", "1", "2"], prefer_slice=True) == slice(0, 3, None)) # annotations / properties snippets.annotate(gre="ta") @@ -60,7 +60,7 @@ def test_BaseSnippets(create_cache_folder): ) # missing property - snippets.set_property("string_property", ["ciao", "bello"], ids=[0, 1]) + snippets.set_property("string_property", ["ciao", "bello"], ids=["0", "1"]) values = snippets.get_property("string_property") assert values[2] == "" @@ -70,14 +70,14 @@ def test_BaseSnippets(create_cache_folder): snippets.set_property, key="string_property_nan", values=["hola", "chabon"], - ids=[0, 1], + ids=["0", "1"], missing_value=np.nan, ) # int properties without missing values raise an error assert_raises(Exception, snippets.set_property, key="int_property", values=[5, 6], ids=[1, 2]) - snippets.set_property("int_property", [5, 6], ids=[1, 2], missing_value=200) + snippets.set_property("int_property", [5, 6], ids=["1", "2"], missing_value=200) values = snippets.get_property("int_property") assert values.dtype.kind == "i" diff --git a/src/spikeinterface/core/tests/test_channelsaggregationrecording.py b/src/spikeinterface/core/tests/test_channelsaggregationrecording.py index 118b6092a9..99d6890dfd 100644 --- a/src/spikeinterface/core/tests/test_channelsaggregationrecording.py +++ b/src/spikeinterface/core/tests/test_channelsaggregationrecording.py @@ -38,10 +38,12 @@ def test_channelsaggregationrecording(): assert np.allclose(traces1_1, recording_agg.get_traces(channel_ids=[str(channel_ids[1])], segment_index=seg)) assert np.allclose( - traces2_0, recording_agg.get_traces(channel_ids=[str(num_channels + channel_ids[0])], segment_index=seg) + traces2_0, + recording_agg.get_traces(channel_ids=[str(num_channels + int(channel_ids[0]))], segment_index=seg), ) assert np.allclose( - traces3_2, recording_agg.get_traces(channel_ids=[str(2 * num_channels + channel_ids[2])], segment_index=seg) + traces3_2, + recording_agg.get_traces(channel_ids=[str(2 * num_channels + int(channel_ids[2]))], segment_index=seg), ) # all traces traces1 = recording1.get_traces(segment_index=seg) diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index 35ab18b5f2..899993d840 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -76,8 +76,8 @@ def test_SortingAnalyzer_binary_folder(tmp_path, dataset): # test select_units see https://github.com/SpikeInterface/spikeinterface/issues/3041 # this bug requires that we have an info.json file so we calculate templates above - select_units_sorting_analyer = sorting_analyzer.select_units(unit_ids=[1]) - assert len(select_units_sorting_analyer.unit_ids) == 1 + select_units_sorting_analyer = sorting_analyzer.select_units(unit_ids=["1"]) + assert len(select_units_sorting_analyer.unit_ids) == "1" folder = tmp_path / "test_SortingAnalyzer_binary_folder" if folder.exists(): @@ -121,11 +121,11 @@ def test_SortingAnalyzer_zarr(tmp_path, dataset): # test select_units see https://github.com/SpikeInterface/spikeinterface/issues/3041 # this bug requires that we have an info.json file so we calculate templates above - select_units_sorting_analyer = sorting_analyzer.select_units(unit_ids=[1]) - assert len(select_units_sorting_analyer.unit_ids) == 1 - remove_units_sorting_analyer = sorting_analyzer.remove_units(remove_unit_ids=[1]) + select_units_sorting_analyer = sorting_analyzer.select_units(unit_ids=["1"]) + assert len(select_units_sorting_analyer.unit_ids) == "1" + remove_units_sorting_analyer = sorting_analyzer.remove_units(remove_unit_ids=["1"]) assert len(remove_units_sorting_analyer.unit_ids) == len(sorting_analyzer.unit_ids) - 1 - assert 1 not in remove_units_sorting_analyer.unit_ids + assert "1" not in remove_units_sorting_analyer.unit_ids # test no compression sorting_analyzer_no_compression = create_sorting_analyzer( @@ -358,7 +358,7 @@ def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder): shutil.rmtree(folder) else: folder = None - sorting_analyzer4 = sorting_analyzer.merge_units(merge_unit_groups=[[0, 1]], format=format, folder=folder) + sorting_analyzer4 = sorting_analyzer.merge_units(merge_unit_groups=[["0", "1"]], format=format, folder=folder) if format != "memory": if format == "zarr": diff --git a/src/spikeinterface/core/tests/test_unitsselectionsorting.py b/src/spikeinterface/core/tests/test_unitsselectionsorting.py index 1e72b0ab28..3ecb702aa2 100644 --- a/src/spikeinterface/core/tests/test_unitsselectionsorting.py +++ b/src/spikeinterface/core/tests/test_unitsselectionsorting.py @@ -10,25 +10,29 @@ def test_basic_functions(): sorting = generate_sorting(num_units=3, durations=[0.100, 0.100], sampling_frequency=30000.0) - sorting2 = UnitsSelectionSorting(sorting, unit_ids=[0, 2]) - assert np.array_equal(sorting2.unit_ids, [0, 2]) + sorting2 = UnitsSelectionSorting(sorting, unit_ids=["0", "2"]) + assert np.array_equal(sorting2.unit_ids, ["0", "2"]) assert sorting2.get_parent() == sorting - sorting3 = UnitsSelectionSorting(sorting, unit_ids=[0, 2], renamed_unit_ids=["a", "b"]) + sorting3 = UnitsSelectionSorting(sorting, unit_ids=["0", "2"], renamed_unit_ids=["a", "b"]) assert np.array_equal(sorting3.unit_ids, ["a", "b"]) assert np.array_equal( - sorting.get_unit_spike_train(0, segment_index=0), sorting2.get_unit_spike_train(0, segment_index=0) + sorting.get_unit_spike_train(unit_id="0", segment_index=0), + sorting2.get_unit_spike_train(unit_id="0", segment_index=0), ) assert np.array_equal( - sorting.get_unit_spike_train(0, segment_index=0), sorting3.get_unit_spike_train("a", segment_index=0) + sorting.get_unit_spike_train(unit_id="0", segment_index=0), + sorting3.get_unit_spike_train(unit_id="a", segment_index=0), ) assert np.array_equal( - sorting.get_unit_spike_train(2, segment_index=0), sorting2.get_unit_spike_train(2, segment_index=0) + sorting.get_unit_spike_train(unit_id="2", segment_index=0), + sorting2.get_unit_spike_train(unit_id="2", segment_index=0), ) assert np.array_equal( - sorting.get_unit_spike_train(2, segment_index=0), sorting3.get_unit_spike_train("b", segment_index=0) + sorting.get_unit_spike_train(unit_id="2", segment_index=0), + sorting3.get_unit_spike_train(unit_id="b", segment_index=0), ) @@ -36,13 +40,13 @@ def test_failure_with_non_unique_unit_ids(): seed = 10 sorting = generate_sorting(num_units=3, durations=[0.100], sampling_frequency=30000.0, seed=seed) with pytest.raises(AssertionError): - sorting2 = UnitsSelectionSorting(sorting, unit_ids=[0, 2], renamed_unit_ids=["a", "a"]) + sorting2 = UnitsSelectionSorting(sorting, unit_ids=["0", "2"], renamed_unit_ids=["a", "a"]) def test_custom_cache_spike_vector(): sorting = generate_sorting(num_units=3, durations=[0.100, 0.100], sampling_frequency=30000.0) - sub_sorting = UnitsSelectionSorting(sorting, unit_ids=[2, 0], renamed_unit_ids=["b", "a"]) + sub_sorting = UnitsSelectionSorting(sorting, unit_ids=["2", "0"], renamed_unit_ids=["b", "a"]) cached_spike_vector = sub_sorting.to_spike_vector(use_cache=True) computed_spike_vector = sub_sorting.to_spike_vector(use_cache=False) assert np.all(cached_spike_vector == computed_spike_vector) From c36e49e6c5057496cd59ab74f23c805732bf708c Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 17 Dec 2024 12:24:41 -0600 Subject: [PATCH 53/64] keep sorting analyzer tests as they were --- src/spikeinterface/core/tests/test_sortinganalyzer.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index 899993d840..8d8beaa491 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -31,6 +31,14 @@ def get_dataset(): noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), seed=2205, ) + + # TODO: the tests or the sorting analyzer make assumptions about the ids being integers + # So keeping this the way it was + integer_channel_ids = [int(id) for id in recording.get_channel_ids()] + integer_unit_ids = [int(id) for id in sorting.get_unit_ids()] + + recording = recording.rename_channels(new_channel_ids=integer_channel_ids) + sorting = sorting.rename_units(new_unit_ids=integer_unit_ids) return recording, sorting @@ -358,7 +366,7 @@ def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder): shutil.rmtree(folder) else: folder = None - sorting_analyzer4 = sorting_analyzer.merge_units(merge_unit_groups=[["0", "1"]], format=format, folder=folder) + sorting_analyzer4 = sorting_analyzer.merge_units(merge_unit_groups=[[0, 1]], format=format, folder=folder) if format != "memory": if format == "zarr": From 3dd6b359daa86252878e480dddbe9dc719e98c2b Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 17 Dec 2024 12:27:04 -0600 Subject: [PATCH 54/64] fully restore sorting anlayzer --- .../core/tests/test_sortinganalyzer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index 8d8beaa491..15f089f784 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -84,8 +84,8 @@ def test_SortingAnalyzer_binary_folder(tmp_path, dataset): # test select_units see https://github.com/SpikeInterface/spikeinterface/issues/3041 # this bug requires that we have an info.json file so we calculate templates above - select_units_sorting_analyer = sorting_analyzer.select_units(unit_ids=["1"]) - assert len(select_units_sorting_analyer.unit_ids) == "1" + select_units_sorting_analyer = sorting_analyzer.select_units(unit_ids=[1]) + assert len(select_units_sorting_analyer.unit_ids) == 1 folder = tmp_path / "test_SortingAnalyzer_binary_folder" if folder.exists(): @@ -129,11 +129,11 @@ def test_SortingAnalyzer_zarr(tmp_path, dataset): # test select_units see https://github.com/SpikeInterface/spikeinterface/issues/3041 # this bug requires that we have an info.json file so we calculate templates above - select_units_sorting_analyer = sorting_analyzer.select_units(unit_ids=["1"]) - assert len(select_units_sorting_analyer.unit_ids) == "1" - remove_units_sorting_analyer = sorting_analyzer.remove_units(remove_unit_ids=["1"]) + select_units_sorting_analyer = sorting_analyzer.select_units(unit_ids=[1]) + assert len(select_units_sorting_analyer.unit_ids) == 1 + remove_units_sorting_analyer = sorting_analyzer.remove_units(remove_unit_ids=[1]) assert len(remove_units_sorting_analyer.unit_ids) == len(sorting_analyzer.unit_ids) - 1 - assert "1" not in remove_units_sorting_analyer.unit_ids + assert 1 not in remove_units_sorting_analyer.unit_ids # test no compression sorting_analyzer_no_compression = create_sorting_analyzer( From 61f40187cb372f64a2752136c6df461bbad89705 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 17 Dec 2024 12:41:22 -0600 Subject: [PATCH 55/64] fix mda extractor --- src/spikeinterface/extractors/tests/test_mdaextractors.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/spikeinterface/extractors/tests/test_mdaextractors.py b/src/spikeinterface/extractors/tests/test_mdaextractors.py index 0ef6697c6c..78e6afb65e 100644 --- a/src/spikeinterface/extractors/tests/test_mdaextractors.py +++ b/src/spikeinterface/extractors/tests/test_mdaextractors.py @@ -9,6 +9,12 @@ def test_mda_extractors(create_cache_folder): cache_folder = create_cache_folder rec, sort = generate_ground_truth_recording(durations=[10.0], num_units=10) + ids_as_integers = [id for id in range(rec.get_num_channels())] + rec = rec.rename_channels(new_channel_ids=ids_as_integers) + + ids_as_integers = [id for id in range(sort.get_num_units())] + sort = sort.rename_units(new_unit_ids=ids_as_integers) + MdaRecordingExtractor.write_recording(rec, cache_folder / "mdatest") rec_mda = MdaRecordingExtractor(cache_folder / "mdatest") probe = rec_mda.get_probe() From 0429152bfa7141b4c1428fc960833d9141da2168 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 17 Dec 2024 13:11:51 -0600 Subject: [PATCH 56/64] fix preprocessing --- src/spikeinterface/preprocessing/tests/test_clip.py | 8 ++++---- .../preprocessing/tests/test_interpolate_bad_channels.py | 4 +++- .../preprocessing/tests/test_normalize_scale.py | 2 +- src/spikeinterface/preprocessing/tests/test_rectify.py | 2 +- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/preprocessing/tests/test_clip.py b/src/spikeinterface/preprocessing/tests/test_clip.py index 724ba2c963..c18c7d37af 100644 --- a/src/spikeinterface/preprocessing/tests/test_clip.py +++ b/src/spikeinterface/preprocessing/tests/test_clip.py @@ -14,12 +14,12 @@ def test_clip(): rec1 = clip(rec, a_min=-1.5) rec1.save(verbose=False) - traces0 = rec0.get_traces(segment_index=0, channel_ids=[1]) + traces0 = rec0.get_traces(segment_index=0, channel_ids=["1"]) assert traces0.shape[1] == 1 assert np.all(-2 <= traces0[0] <= 3) - traces1 = rec1.get_traces(segment_index=0, channel_ids=[0, 1]) + traces1 = rec1.get_traces(segment_index=0, channel_ids=["0", "1"]) assert traces1.shape[1] == 2 assert np.all(-1.5 <= traces1[1]) @@ -34,11 +34,11 @@ def test_blank_staturation(): rec1 = blank_staturation(rec, quantile_threshold=0.01, direction="both", chunk_size=10000) rec1.save(verbose=False) - traces0 = rec0.get_traces(segment_index=0, channel_ids=[1]) + traces0 = rec0.get_traces(segment_index=0, channel_ids=["1"]) assert traces0.shape[1] == 1 assert np.all(traces0 < 3.0) - traces1 = rec1.get_traces(segment_index=0, channel_ids=[0]) + traces1 = rec1.get_traces(segment_index=0, channel_ids=["0"]) assert traces1.shape[1] == 1 # use a smaller value to be sure a_min = rec1._recording_segments[0].a_min diff --git a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py index 1189f04f7d..06bde4e3d1 100644 --- a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py @@ -163,7 +163,9 @@ def test_output_values(): expected_weights = np.r_[np.tile(np.exp(-2), 3), np.exp(-4)] expected_weights /= np.sum(expected_weights) - si_interpolated_recording = spre.interpolate_bad_channels(recording, bad_channel_indexes, sigma_um=1, p=1) + si_interpolated_recording = spre.interpolate_bad_channels( + recording, bad_channel_ids=bad_channel_ids, sigma_um=1, p=1 + ) si_interpolated = si_interpolated_recording.get_traces() expected_ts = si_interpolated[:, 1:] @ expected_weights diff --git a/src/spikeinterface/preprocessing/tests/test_normalize_scale.py b/src/spikeinterface/preprocessing/tests/test_normalize_scale.py index 576b570832..151752e0e6 100644 --- a/src/spikeinterface/preprocessing/tests/test_normalize_scale.py +++ b/src/spikeinterface/preprocessing/tests/test_normalize_scale.py @@ -15,7 +15,7 @@ def test_normalize_by_quantile(): rec2 = normalize_by_quantile(rec, mode="by_channel") rec2.save(verbose=False) - traces = rec2.get_traces(segment_index=0, channel_ids=[1]) + traces = rec2.get_traces(segment_index=0, channel_ids=["1"]) assert traces.shape[1] == 1 rec2 = normalize_by_quantile(rec, mode="pool_channel") diff --git a/src/spikeinterface/preprocessing/tests/test_rectify.py b/src/spikeinterface/preprocessing/tests/test_rectify.py index b8bb31015e..a2a06e7a1f 100644 --- a/src/spikeinterface/preprocessing/tests/test_rectify.py +++ b/src/spikeinterface/preprocessing/tests/test_rectify.py @@ -15,7 +15,7 @@ def test_rectify(): rec2 = rectify(rec) rec2.save(verbose=False) - traces = rec2.get_traces(segment_index=0, channel_ids=[1]) + traces = rec2.get_traces(segment_index=0, channel_ids=["1"]) assert traces.shape[1] == 1 # import matplotlib.pyplot as plt From 6d70a154426a2391e6712fb38b27dcf0fbd95a05 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 17 Dec 2024 13:18:20 -0600 Subject: [PATCH 57/64] fix quality metrics --- src/spikeinterface/qualitymetrics/tests/conftest.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/spikeinterface/qualitymetrics/tests/conftest.py b/src/spikeinterface/qualitymetrics/tests/conftest.py index 01fa16c8d7..ac1789a375 100644 --- a/src/spikeinterface/qualitymetrics/tests/conftest.py +++ b/src/spikeinterface/qualitymetrics/tests/conftest.py @@ -16,6 +16,11 @@ def small_sorting_analyzer(): seed=1205, ) + channel_ids_as_integers = [id for id in range(recording.get_num_channels())] + unit_ids_as_integers = [id for id in range(sorting.get_num_units())] + recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers) + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers) + sorting = sorting.select_units([2, 7, 0], ["#3", "#9", "#4"]) sorting_analyzer = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory") @@ -60,6 +65,11 @@ def sorting_analyzer_simple(): seed=1205, ) + channel_ids_as_integers = [id for id in range(recording.get_num_channels())] + unit_ids_as_integers = [id for id in range(sorting.get_num_units())] + recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers) + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers) + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True) sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=1205) From d94ccf1e56bf1c8cd0d89b71bbba42817c59106f Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 17 Dec 2024 13:45:55 -0600 Subject: [PATCH 58/64] fix post processing --- .../postprocessing/tests/test_multi_extensions.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/spikeinterface/postprocessing/tests/test_multi_extensions.py b/src/spikeinterface/postprocessing/tests/test_multi_extensions.py index bf0000135c..be0070d94a 100644 --- a/src/spikeinterface/postprocessing/tests/test_multi_extensions.py +++ b/src/spikeinterface/postprocessing/tests/test_multi_extensions.py @@ -23,6 +23,11 @@ def get_dataset(): seed=2205, ) + channel_ids_as_integers = [id for id in range(recording.get_num_channels())] + unit_ids_as_integers = [id for id in range(sorting.get_num_units())] + recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers) + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers) + # since templates are going to be averaged and this might be a problem for amplitude scaling # we select the 3 units with the largest templates to split analyzer_raw = create_sorting_analyzer(sorting, recording, format="memory", sparse=False) From 7f461db713a377725ea51888859930401d224a98 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 17 Dec 2024 13:48:07 -0600 Subject: [PATCH 59/64] fix motion --- src/spikeinterface/sortingcomponents/tests/common.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/spikeinterface/sortingcomponents/tests/common.py b/src/spikeinterface/sortingcomponents/tests/common.py index 01e4445a13..d5e5b6be1b 100644 --- a/src/spikeinterface/sortingcomponents/tests/common.py +++ b/src/spikeinterface/sortingcomponents/tests/common.py @@ -21,4 +21,10 @@ def make_dataset(): noise_kwargs=dict(noise_levels=5.0, strategy="on_the_fly"), seed=2205, ) + + channel_ids_as_integers = [id for id in range(recording.get_num_channels())] + unit_ids_as_integers = [id for id in range(sorting.get_num_units())] + recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers) + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers) + return recording, sorting From 33359cc1f1646545d7beb2a16bad8528e019d428 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 17 Dec 2024 14:21:07 -0600 Subject: [PATCH 60/64] fix curation --- src/spikeinterface/curation/tests/common.py | 5 +++++ .../curation/tests/test_sortingview_curation.py | 9 +++++++++ 2 files changed, 14 insertions(+) diff --git a/src/spikeinterface/curation/tests/common.py b/src/spikeinterface/curation/tests/common.py index 9cd20f4bfc..e9c4c4a463 100644 --- a/src/spikeinterface/curation/tests/common.py +++ b/src/spikeinterface/curation/tests/common.py @@ -19,6 +19,11 @@ def make_sorting_analyzer(sparse=True): seed=2205, ) + channel_ids_as_integers = [id for id in range(recording.get_num_channels())] + unit_ids_as_integers = [id for id in range(sorting.get_num_units())] + recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers) + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers) + sorting_analyzer = create_sorting_analyzer(sorting=sorting, recording=recording, format="memory", sparse=sparse) sorting_analyzer.compute("random_spikes") sorting_analyzer.compute("waveforms", **job_kwargs) diff --git a/src/spikeinterface/curation/tests/test_sortingview_curation.py b/src/spikeinterface/curation/tests/test_sortingview_curation.py index 945aca7937..ff80be365d 100644 --- a/src/spikeinterface/curation/tests/test_sortingview_curation.py +++ b/src/spikeinterface/curation/tests/test_sortingview_curation.py @@ -49,6 +49,9 @@ def test_gh_curation(): Test curation using GitHub URI. """ sorting = generate_sorting(num_units=10) + unit_ids_as_int = [id for id in range(sorting.get_num_units())] + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_int) + # curated link: # https://figurl.org/f?v=npm://@fi-sci/figurl-sortingview@12/dist&d=sha1://058ab901610aa9d29df565595a3cc2a81a1b08e5 gh_uri = "gh://SpikeInterface/spikeinterface/main/src/spikeinterface/curation/tests/sv-sorting-curation.json" @@ -76,6 +79,8 @@ def test_sha1_curation(): Test curation using SHA1 URI. """ sorting = generate_sorting(num_units=10) + unit_ids_as_int = [id for id in range(sorting.get_num_units())] + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_int) # from SHA1 # curated link: @@ -105,6 +110,8 @@ def test_json_curation(): Test curation using a JSON file. """ sorting = generate_sorting(num_units=10) + unit_ids_as_int = [id for id in range(sorting.get_num_units())] + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_int) # from curation.json json_file = parent_folder / "sv-sorting-curation.json" @@ -248,6 +255,8 @@ def test_json_no_merge_curation(): Test curation with no merges using a JSON file. """ sorting = generate_sorting(num_units=10) + unit_ids_as_int = [id for id in range(sorting.get_num_units())] + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_int) json_file = parent_folder / "sv-sorting-curation-no-merge.json" sorting_curated = apply_sortingview_curation(sorting, uri_or_json=json_file) From bb48b63a05e6e933bce31e380f97f83cfb3cddb9 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 19 Dec 2024 11:03:57 -0600 Subject: [PATCH 61/64] Update src/spikeinterface/core/generate.py Co-authored-by: Alessio Buccino --- src/spikeinterface/core/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 5824a75ab8..118ce384f3 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1111,7 +1111,7 @@ def __init__( """ - unit_ids = [str(id) for id in np.arange(num_units)] + unit_ids = [str(idx) for idx in np.arange(num_units)] super().__init__(sampling_frequency, unit_ids) self.num_units = num_units From 2f26983798026145f9455cacb5e312f11efc4bf3 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 19 Dec 2024 11:04:03 -0600 Subject: [PATCH 62/64] Update src/spikeinterface/core/generate.py Co-authored-by: Alessio Buccino --- src/spikeinterface/core/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 118ce384f3..fb10f26a2e 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -134,7 +134,7 @@ def generate_sorting( seed = _ensure_seed(seed) rng = np.random.default_rng(seed) num_segments = len(durations) - unit_ids = [str(id) for id in np.arange(num_units)] + unit_ids = [str(idx) for idx in np.arange(num_units)] spikes = [] for segment_index in range(num_segments): From 7dea3b2b39568c1a9fb8dd95e1a3fa1de8ed01a4 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 7 Jan 2025 11:35:15 +0100 Subject: [PATCH 63/64] Update src/spikeinterface/core/generate.py --- src/spikeinterface/core/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index fb10f26a2e..aa69fe585b 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1283,7 +1283,7 @@ def __init__( noise_block_size: int = 30000, ): - channel_ids = [str(id) for id in np.arange(num_channels)] + channel_ids = [str(idx) for idx in np.arange(num_channels)] dtype = np.dtype(dtype).name # Cast to string for serialization if dtype not in ("float32", "float64"): raise ValueError(f"'dtype' must be 'float32' or 'float64' but is {dtype}") From e74aa00e2c8ee5d6e94f79da491a565fcef322c8 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 8 Jan 2025 15:21:52 +0000 Subject: [PATCH 64/64] string-ify unit_ids in plot_2_sort_gallery --- examples/tutorials/widgets/plot_2_sort_gallery.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/tutorials/widgets/plot_2_sort_gallery.py b/examples/tutorials/widgets/plot_2_sort_gallery.py index da5c611ce4..056b5e3a8d 100644 --- a/examples/tutorials/widgets/plot_2_sort_gallery.py +++ b/examples/tutorials/widgets/plot_2_sort_gallery.py @@ -31,14 +31,14 @@ # plot_autocorrelograms() # ~~~~~~~~~~~~~~~~~~~~~~~~ -w_ach = sw.plot_autocorrelograms(sorting, window_ms=150.0, bin_ms=5.0, unit_ids=[1, 2, 5]) +w_ach = sw.plot_autocorrelograms(sorting, window_ms=150.0, bin_ms=5.0, unit_ids=['1', '2', '5']) ############################################################################## # plot_crosscorrelograms() # ~~~~~~~~~~~~~~~~~~~~~~~~ -w_cch = sw.plot_crosscorrelograms(sorting, window_ms=150.0, bin_ms=5.0, unit_ids=[1, 2, 5]) +w_cch = sw.plot_crosscorrelograms(sorting, window_ms=150.0, bin_ms=5.0, unit_ids=['1', '2', '5']) plt.show()