From 63a58c1900627c78b977e117cc3f2900bcb7fc29 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 24 Jan 2025 22:29:14 +0100 Subject: [PATCH] Adding job_kwargs to all benchmarks --- .../benchmark/tests/common_benchmark_testing.py | 9 ++++++--- .../benchmark/tests/test_benchmark_clustering.py | 2 +- .../benchmark/tests/test_benchmark_matching.py | 4 ++-- .../benchmark/tests/test_benchmark_merging.py | 2 +- .../tests/test_benchmark_motion_interpolation.py | 2 +- .../benchmark/tests/test_benchmark_peak_detection.py | 7 ++++--- .../benchmark/tests/test_benchmark_peak_localization.py | 2 +- 7 files changed, 16 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/benchmark/tests/common_benchmark_testing.py b/src/spikeinterface/benchmark/tests/common_benchmark_testing.py index 1e9f8abae9..5c473f1740 100644 --- a/src/spikeinterface/benchmark/tests/common_benchmark_testing.py +++ b/src/spikeinterface/benchmark/tests/common_benchmark_testing.py @@ -7,6 +7,7 @@ import pytest from pathlib import Path import os +from spikeinterface.core.job_tools import fix_job_kwargs import numpy as np @@ -21,7 +22,9 @@ ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS")) -def make_dataset(): +def make_dataset(job_kwargs={}): + + job_kwargs = fix_job_kwargs(job_kwargs) recording, gt_sorting = generate_ground_truth_recording( durations=[60.0], sampling_frequency=30000.0, @@ -39,10 +42,10 @@ def make_dataset(): seed=2205, ) - gt_analyzer = create_sorting_analyzer(gt_sorting, recording, sparse=True, format="memory") + gt_analyzer = create_sorting_analyzer(gt_sorting, recording, sparse=True, format="memory", **job_kwargs) gt_analyzer.compute("random_spikes", method="uniform", max_spikes_per_unit=500) # analyzer.compute("waveforms") - gt_analyzer.compute("templates") + gt_analyzer.compute("templates", **job_kwargs) gt_analyzer.compute("noise_levels") return recording, gt_sorting, gt_analyzer diff --git a/src/spikeinterface/benchmark/tests/test_benchmark_clustering.py b/src/spikeinterface/benchmark/tests/test_benchmark_clustering.py index 3f574fd058..1dda5dc269 100644 --- a/src/spikeinterface/benchmark/tests/test_benchmark_clustering.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_clustering.py @@ -16,7 +16,7 @@ def test_benchmark_clustering(create_cache_folder): cache_folder = create_cache_folder job_kwargs = dict(n_jobs=0.8, chunk_duration="1s") - recording, gt_sorting, gt_analyzer = make_dataset() + recording, gt_sorting, gt_analyzer = make_dataset(job_kwargs) num_spikes = gt_sorting.to_spike_vector().size spike_indices = np.arange(0, num_spikes, 5) diff --git a/src/spikeinterface/benchmark/tests/test_benchmark_matching.py b/src/spikeinterface/benchmark/tests/test_benchmark_matching.py index 000a00faf5..44cf7efe8d 100644 --- a/src/spikeinterface/benchmark/tests/test_benchmark_matching.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_matching.py @@ -21,13 +21,13 @@ def test_benchmark_matching(create_cache_folder): cache_folder = create_cache_folder job_kwargs = dict(n_jobs=0.8, chunk_duration="100ms") - recording, gt_sorting, gt_analyzer = make_dataset() + recording, gt_sorting, gt_analyzer = make_dataset(job_kwargs) # templates sparse gt_templates = compute_gt_templates( recording, gt_sorting, ms_before=2.0, ms_after=3.0, return_scaled=False, **job_kwargs ) - noise_levels = get_noise_levels(recording) + noise_levels = get_noise_levels(recording, **job_kwargs) sparsity = compute_sparsity(gt_templates, noise_levels, method="snr", amplitude_mode="peak_to_peak", threshold=0.25) gt_templates = gt_templates.to_sparse(sparsity) diff --git a/src/spikeinterface/benchmark/tests/test_benchmark_merging.py b/src/spikeinterface/benchmark/tests/test_benchmark_merging.py index a61610b65c..022ba18ebd 100644 --- a/src/spikeinterface/benchmark/tests/test_benchmark_merging.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_merging.py @@ -15,7 +15,7 @@ def test_benchmark_merging(create_cache_folder): cache_folder = create_cache_folder job_kwargs = dict(n_jobs=0.8, chunk_duration="1s") - recording, gt_sorting, gt_analyzer = make_dataset() + recording, gt_sorting, gt_analyzer = make_dataset(job_kwargs) # create study study_folder = cache_folder / "study_clustering" diff --git a/src/spikeinterface/benchmark/tests/test_benchmark_motion_interpolation.py b/src/spikeinterface/benchmark/tests/test_benchmark_motion_interpolation.py index f7afd7a8bc..a9b64d19ea 100644 --- a/src/spikeinterface/benchmark/tests/test_benchmark_motion_interpolation.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_motion_interpolation.py @@ -22,7 +22,7 @@ def test_benchmark_motion_interpolation(create_cache_folder): cache_folder = create_cache_folder job_kwargs = dict(n_jobs=0.8, chunk_duration="1s") - data = make_drifting_dataset() + data = make_drifting_dataset(job_kwargs) datasets = { "data_static": (data["static_rec"], data["sorting"]), diff --git a/src/spikeinterface/benchmark/tests/test_benchmark_peak_detection.py b/src/spikeinterface/benchmark/tests/test_benchmark_peak_detection.py index d45ac0b4ce..6a60b1a646 100644 --- a/src/spikeinterface/benchmark/tests/test_benchmark_peak_detection.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_peak_detection.py @@ -15,7 +15,7 @@ def test_benchmark_peak_detection(create_cache_folder): job_kwargs = dict(n_jobs=0.8, chunk_duration="100ms") # recording, gt_sorting = make_dataset() - recording, gt_sorting, gt_analyzer = make_dataset() + recording, gt_sorting, gt_analyzer = make_dataset(job_kwargs) # create study study_folder = cache_folder / "study_peak_detection" @@ -27,8 +27,9 @@ def test_benchmark_peak_detection(create_cache_folder): recording, gt_sorting = datasets[dataset] - sorting_analyzer = create_sorting_analyzer(gt_sorting, recording, format="memory", sparse=False) - sorting_analyzer.compute(["random_spikes", "templates"]) + sorting_analyzer = create_sorting_analyzer(gt_sorting, recording, format="memory", sparse=False, **job_kwargs) + sorting_analyzer.compute("random_spikes") + sorting_analyzer.compute("templates", **job_kwargs) extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, outputs="index") spikes = gt_sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) peaks[dataset] = spikes diff --git a/src/spikeinterface/benchmark/tests/test_benchmark_peak_localization.py b/src/spikeinterface/benchmark/tests/test_benchmark_peak_localization.py index 3b6240cb10..dc4527b761 100644 --- a/src/spikeinterface/benchmark/tests/test_benchmark_peak_localization.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_peak_localization.py @@ -15,7 +15,7 @@ def test_benchmark_peak_localization(create_cache_folder): job_kwargs = dict(n_jobs=0.8, chunk_duration="100ms") # recording, gt_sorting = make_dataset() - recording, gt_sorting, gt_analyzer = make_dataset() + recording, gt_sorting, gt_analyzer = make_dataset(job_kwargs) # create study study_folder = cache_folder / "study_peak_localization"