Skip to content

Commit

Permalink
Remove joblib in favor of ParallelProcessExecutor (#2218)
Browse files Browse the repository at this point in the history
* Remove joblib in favor of ParallelProcessExecutor

* Use global tmp folder in PCA

* Fix indentation!

* Remove print and Ramon's suggestions

* Add futures in waveforms extractor
  • Loading branch information
alejoe91 authored Nov 23, 2023
1 parent 91cfa2a commit 8cba6d9
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 55 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ classifiers = [
dependencies = [
"numpy",
"neo>=0.12.0",
"joblib",
"threadpoolctl",
"tqdm",
"probeinterface>=0.2.19",
Expand Down
22 changes: 1 addition & 21 deletions src/spikeinterface/core/job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import os
import warnings

import joblib
import sys
import contextlib
from tqdm.auto import tqdm
Expand Down Expand Up @@ -95,25 +94,6 @@ def split_job_kwargs(mixed_kwargs):
return specific_kwargs, job_kwargs


# from https://stackoverflow.com/questions/24983493/tracking-progress-of-joblib-parallel-execution
@contextlib.contextmanager
def tqdm_joblib(tqdm_object):
"""Context manager to patch joblib to report into tqdm progress bar given as argument"""

class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack):
def __call__(self, *args, **kwargs):
tqdm_object.update(n=self.batch_size)
return super().__call__(*args, **kwargs)

old_batch_callback = joblib.parallel.BatchCompletionCallBack
joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback
try:
yield tqdm_object
finally:
joblib.parallel.BatchCompletionCallBack = old_batch_callback
tqdm_object.close()


def divide_segment_into_chunks(num_frames, chunk_size):
if chunk_size is None:
chunks = [(0, num_frames)]
Expand Down Expand Up @@ -156,7 +136,7 @@ def _mem_to_int(mem):

def ensure_n_jobs(recording, n_jobs=1):
if n_jobs == -1:
n_jobs = joblib.cpu_count()
n_jobs = os.cpu_count()
elif n_jobs == 0:
n_jobs = 1
elif n_jobs is None:
Expand Down
2 changes: 2 additions & 0 deletions src/spikeinterface/core/waveform_extractor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import math
import pickle
from pathlib import Path
Expand Down
24 changes: 16 additions & 8 deletions src/spikeinterface/postprocessing/principal_component.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import shutil
import pickle
import warnings
import tempfile
from pathlib import Path
from tqdm.auto import tqdm

Expand All @@ -9,6 +10,7 @@

from spikeinterface.core.job_tools import ChunkRecordingExecutor, _shared_job_kwargs_doc, fix_job_kwargs
from spikeinterface.core.waveform_extractor import WaveformExtractor, BaseWaveformExtractorExtension
from spikeinterface.core.globals import get_global_tmp_folder

_possible_modes = ["by_channel_local", "by_channel_global", "concatenated"]

Expand Down Expand Up @@ -370,7 +372,7 @@ def run_for_all_spikes(self, file_path=None, **job_kwargs):

def _fit_by_channel_local(self, n_jobs, progress_bar):
from sklearn.decomposition import IncrementalPCA
from joblib import delayed, Parallel
from concurrent.futures import ProcessPoolExecutor

we = self.waveform_extractor
p = self._params
Expand All @@ -385,12 +387,13 @@ def _fit_by_channel_local(self, n_jobs, progress_bar):

tmp_folder = p["tmp_folder"]
if tmp_folder is None:
tmp_folder = "tmp"
tmp_folder = Path(tmp_folder)
if n_jobs > 1:
tmp_folder = tempfile.mkdtemp(prefix="pca", dir=get_global_tmp_folder())

for chan_ind, chan_id in enumerate(channel_ids):
pca_model = pca_models[chan_ind]
if n_jobs > 1:
tmp_folder = Path(tmp_folder)
tmp_folder.mkdir(exist_ok=True)
pca_model_file = tmp_folder / f"tmp_pca_model_{mode}_{chan_id}.pkl"
with pca_model_file.open("wb") as f:
Expand All @@ -411,10 +414,14 @@ def _fit_by_channel_local(self, n_jobs, progress_bar):
pca = pca_models[chan_ind]
pca.partial_fit(wfs[:, :, wf_ind])
else:
Parallel(n_jobs=n_jobs)(
delayed(partial_fit_one_channel)(pca_model_files[chan_ind], wfs[:, :, wf_ind])
for wf_ind, chan_ind in enumerate(channel_inds)
)
# parallel
items = [(pca_model_files[chan_ind], wfs[:, :, wf_ind]) for wf_ind, chan_ind in enumerate(channel_inds)]
n_jobs = min(n_jobs, len(items))

with ProcessPoolExecutor(max_workers=n_jobs) as executor:
results = executor.map(partial_fit_one_channel, items)
for res in results:
pass

# reload the models (if n_jobs > 1)
if n_jobs not in (0, 1):
Expand Down Expand Up @@ -762,7 +769,8 @@ def compute_principal_components(
return pc


def partial_fit_one_channel(pca_file, wf_chan):
def partial_fit_one_channel(args):
pca_file, wf_chan = args
with open(pca_file, "rb") as fid:
pca_model = pickle.load(fid)
pca_model.partial_fit(wf_chan)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,5 +204,5 @@ def test_project_new(self):
# test.test_extension()
# test.test_shapes()
# test.test_compute_for_all_spikes()
test.test_sparse()
# test.test_project_new()
# test.test_sparse()
test.test_project_new()
54 changes: 31 additions & 23 deletions src/spikeinterface/qualitymetrics/pca_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,18 @@

import numpy as np
from tqdm.auto import tqdm
from concurrent.futures import ProcessPoolExecutor

try:
import scipy.stats
import scipy.spatial.distance
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.neighbors import NearestNeighbors
from sklearn.decomposition import IncrementalPCA
from joblib import delayed, Parallel
except:
pass

from ..core import get_random_data_chunks, compute_sparsity, WaveformExtractor
from ..core.job_tools import tqdm_joblib
from ..core.template_tools import get_template_extremum_channel

from ..postprocessing import WaveformPrincipalComponent
Expand All @@ -25,7 +24,6 @@
from .misc_metrics import compute_num_spikes, compute_firing_rates

from ..core import get_random_data_chunks, load_waveforms, compute_sparsity, WaveformExtractor
from ..core.job_tools import tqdm_joblib
from ..core.template_tools import get_template_extremum_channel
from ..postprocessing import WaveformPrincipalComponent

Expand Down Expand Up @@ -134,7 +132,9 @@ def calculate_pc_metrics(
parallel_functions = []

all_labels, all_pcs = pca.get_all_projections()
for unit_ind, unit_id in units_loop:

items = []
for unit_id in unit_ids:
if we.is_sparse():
neighbor_channel_ids = we.sparsity.unit_id_to_channel_ids[unit_id]
neighbor_unit_ids = [
Expand Down Expand Up @@ -166,26 +166,23 @@ def calculate_pc_metrics(
n_spikes_all_units,
fr_all_units,
)
items.append(func_args)

if not run_in_parallel:
pca_metrics_unit = pca_metrics_one_unit(*func_args)
if not run_in_parallel:
for unit_ind, unit_id in units_loop:
pca_metrics_unit = pca_metrics_one_unit(items[unit_ind])
for metric_name, metric in pca_metrics_unit.items():
pc_metrics[metric_name][unit_id] = metric
else:
parallel_functions.append(delayed(pca_metrics_one_unit)(*func_args))

if run_in_parallel:
if progress_bar:
units_loop = tqdm(units_loop, desc="Computing PCA metrics", total=len(unit_ids))
with tqdm_joblib(units_loop) as pb:
pc_metrics_units = Parallel(n_jobs=n_jobs)(parallel_functions)
else:
pc_metrics_units = Parallel(n_jobs=n_jobs)(parallel_functions)
else:
with ProcessPoolExecutor(n_jobs) as executor:
results = executor.map(pca_metrics_one_unit, items)
if progress_bar:
results = tqdm(results, total=len(unit_ids))

for ui, pca_metrics_unit in enumerate(pc_metrics_units):
unit_id = unit_ids[ui]
for metric_name, metric in pca_metrics_unit.items():
pc_metrics[metric_name][unit_id] = metric
for ui, pca_metrics_unit in enumerate(results):
unit_id = unit_ids[ui]
for metric_name, metric in pca_metrics_unit.items():
pc_metrics[metric_name][unit_id] = metric

return pc_metrics

Expand Down Expand Up @@ -888,9 +885,20 @@ def _compute_isolation(pcs_target_unit, pcs_other_unit, n_neighbors: int):
return isolation


def pca_metrics_one_unit(
pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params, seed, we_folder, n_spikes_all_units, fr_all_units
):
def pca_metrics_one_unit(args):
(
pcs_flat,
labels,
metric_names,
unit_id,
unit_ids,
qm_params,
seed,
we_folder,
n_spikes_all_units,
fr_all_units,
) = args

if "nn_isolation" in metric_names or "nn_noise_overlap" in metric_names:
we = load_waveforms(we_folder)

Expand Down

0 comments on commit 8cba6d9

Please sign in to comment.