Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add limit for n_jobs with a max of os.cpu_count() #2940

Merged
merged 5 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/spikeinterface/core/job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ def fix_job_kwargs(runtime_job_kwargs):
else:
n_jobs = int(n_jobs)

job_kwargs["n_jobs"] = max(n_jobs, 1)
n_jobs = max(n_jobs, 1)
job_kwargs["n_jobs"] = min(n_jobs, os.cpu_count())

if "n_jobs" not in runtime_job_kwargs and job_kwargs["n_jobs"] == 1 and not is_set_global_job_kwargs_set():
warnings.warn(
Expand Down
7 changes: 6 additions & 1 deletion src/spikeinterface/core/tests/test_globals.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import warnings
from pathlib import Path
from os import cpu_count

from spikeinterface import (
set_global_dataset_folder,
Expand Down Expand Up @@ -64,7 +65,7 @@ def test_global_job_kwargs():
n_jobs=2, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1
)
# test that fix_job_kwargs grabs global kwargs
new_job_kwargs = dict(n_jobs=10)
new_job_kwargs = dict(n_jobs=cpu_count())
job_kwargs_split = fix_job_kwargs(new_job_kwargs)
assert job_kwargs_split["n_jobs"] == new_job_kwargs["n_jobs"]
assert job_kwargs_split["chunk_duration"] == job_kwargs["chunk_duration"]
Expand All @@ -74,6 +75,10 @@ def test_global_job_kwargs():
job_kwargs_split = fix_job_kwargs(none_job_kwargs)
assert job_kwargs_split["chunk_duration"] == job_kwargs["chunk_duration"]
assert job_kwargs_split["progress_bar"] == job_kwargs["progress_bar"]
# test that n_jobs are clipped if using more than virtual cores
excessive_n_jobs = dict(n_jobs=cpu_count() + 2)
job_kwargs_split = fix_job_kwargs(excessive_n_jobs)
assert job_kwargs_split["n_jobs"] == cpu_count()
reset_global_job_kwargs()


Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/core/tests/test_job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,9 @@ def test_fix_job_kwargs():
assert fixed_job_kwargs["n_jobs"] == 1

# test float value > 1 is cast to correct int
job_kwargs = dict(n_jobs=4.0, progress_bar=False, chunk_duration="1s")
job_kwargs = dict(n_jobs=float(os.cpu_count()), progress_bar=False, chunk_duration="1s")
fixed_job_kwargs = fix_job_kwargs(job_kwargs)
assert fixed_job_kwargs["n_jobs"] == 4
assert fixed_job_kwargs["n_jobs"] == os.cpu_count()

# test wrong keys
with pytest.raises(AssertionError):
Expand Down
Loading