From b2e28f77317e4d3c100486d474a3cd5b85349be9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 28 May 2024 10:54:19 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../preprocessing/tests/test_whiten.py | 4 +++- src/spikeinterface/preprocessing/whiten.py | 22 ++++++++++++------- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/preprocessing/tests/test_whiten.py b/src/spikeinterface/preprocessing/tests/test_whiten.py index f8093dd25f..f3e9a8221f 100644 --- a/src/spikeinterface/preprocessing/tests/test_whiten.py +++ b/src/spikeinterface/preprocessing/tests/test_whiten.py @@ -51,7 +51,9 @@ def test_whiten(): # test regularization with pytest.raises(AssertionError): - W, M = compute_whitening_matrix(rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=None, regularize=True) + W, M = compute_whitening_matrix( + rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=None, regularize=True + ) # W must be sparse np.sum(W == 0) == 6 diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index 5c5d167ba8..f3f0a1368b 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -77,15 +77,21 @@ def __init__( if dtype_.kind == "i": assert int_scale is not None, "For recording with dtype=int you must set dtype=float32 OR set a int_scale" - + if W is not None: W = np.asarray(W) if M is not None: M = np.asarray(M) else: W, M = compute_whitening_matrix( - recording, mode, random_chunk_kwargs, apply_mean, radius_um=radius_um, eps=eps, regularize=regularize, - regularize_kwargs=regularize_kwargs + recording, + mode, + random_chunk_kwargs, + apply_mean, + radius_um=radius_um, + eps=eps, + regularize=regularize, + regularize_kwargs=regularize_kwargs, ) BasePreprocessor.__init__(self, recording, dtype=dtype_) @@ -142,8 +148,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): def compute_whitening_matrix( - recording, mode, random_chunk_kwargs, apply_mean, radius_um=None, eps=None, regularize=False, - regularize_kwargs=None + recording, mode, random_chunk_kwargs, apply_mean, radius_um=None, eps=None, regularize=False, regularize_kwargs=None ): """ Compute whitening matrix @@ -197,12 +202,13 @@ def compute_whitening_matrix( cov = cov / data.shape[0] else: import sklearn.covariance + if regularize_kwargs is None: regularize_kwargs = {} - regularize_kwargs['assume_centered'] = True + regularize_kwargs["assume_centered"] = True job_kwargs = get_global_job_kwargs() - if 'n_jobs' in job_kwargs and 'n_jobs' not in regularize_kwargs: - regularize_kwargs['n_jobs'] = job_kwargs['n_jobs'] + if "n_jobs" in job_kwargs and "n_jobs" not in regularize_kwargs: + regularize_kwargs["n_jobs"] = job_kwargs["n_jobs"] estimator = sklearn.covariance.GraphicalLassoCV(**regularize_kwargs) estimator.fit(data) cov = estimator.covariance_