diff --git a/src/spikeinterface/preprocessing/tests/test_whiten.py b/src/spikeinterface/preprocessing/tests/test_whiten.py index 8ddd31e75d..c7dd1b226c 100644 --- a/src/spikeinterface/preprocessing/tests/test_whiten.py +++ b/src/spikeinterface/preprocessing/tests/test_whiten.py @@ -4,15 +4,14 @@ from spikeinterface.core import generate_recording from spikeinterface.core import BaseRecording, BaseRecordingSegment from spikeinterface.preprocessing import whiten, scale, compute_whitening_matrix -from spikeinterface.preprocessing.whiten import compute_covariance_matrix +from spikeinterface.preprocessing.whiten import compute_sklearn_covariance_matrix import spikeinterface.full as si # TOOD: is this a bad idea? remove! class CustomRecording(BaseRecording): - """ + """ """ - """ def __init__(self, durations, num_channels, channel_ids, sampling_frequency, dtype): BaseRecording.__init__(self, sampling_frequency=sampling_frequency, channel_ids=channel_ids, dtype=dtype) @@ -32,14 +31,10 @@ def __init__(self, durations, num_channels, channel_ids, sampling_frequency, dty "sampling_frequency": sampling_frequency, } -# TODO -# 1) save covariance matrix -# 2) - class CustomRecordingSegment(BaseRecordingSegment): - """ - """ + """ """ + def __init__(self, num_samples, num_channels, sampling_frequency): self.num_samples = num_samples self.num_channels = num_channels @@ -67,94 +62,208 @@ def get_num_samples(self): # TODO: return random cghuns scaled vs unscled -@pytest.mark.parametrize("eps", [1e-8, 1]) -@pytest.mark.parametrize("num_segments", [1, 2]) -@pytest.mark.parametrize("dtype", [np.float32]) # np.int16 -def test_compute_whitening_matrix(eps, num_segments, dtype): - """ - """ - num_channels = 3 - - recording = CustomRecording( - durations=[10, 10] if num_segments == 2 else [10], # will auto-fill zeros - num_channels=num_channels, - channel_ids=np.arange(num_channels), - sampling_frequency=30000, - dtype=dtype - ) - num_samples = recording.get_num_samples(segment_index=0) - - # 1) setup the data with known mean and covariance. - mean_1 = mean_2 = np.zeros(num_channels) # TODO: diferent tests - # mean_2 = np.arange(num_channels) - - # Covariances for simulated data. Limit off-diagonals larger than variances - # for realism + stability / PSD. Actually, a better way is to just get - # some random data and compute x.T@x# - cov_1 = np.array( - [[1, 0.5, 0], - [0.5, 1, -0.25], - [0, -0.25, 1]] - ) - seg_1_data = np.random.multivariate_normal(mean_1, cov_1, recording.get_num_samples(segment_index=0)) - seg_1_data = seg_1_data.astype(dtype) - - recording._recording_segments[0].set_data(seg_1_data) - assert np.array_equal(recording.get_traces(segment_index=0), seg_1_data), "segment 1 test setup did not work." - - if num_segments == 2: + +class TestWhiten: + + def get_float_test_data(self, num_segments, dtype, mean=None, covar=None): + """ + mention the segment thing + """ + num_channels = 3 + dtype = np.float32 + + if mean is None: + mean = np.zeros(num_channels) + + if covar is None: + covar = np.array([[1, 0.5, 0], [0.5, 1, -0.25], [0, -0.25, 1]]) + + recording = self.get_empty_custom_recording(num_segments, num_channels, dtype) + + seg_1_data = np.random.multivariate_normal( + mean, covar, recording.get_num_samples(segment_index=0) # TODO: RENAME! + ) + if dtype == np.float32: + seg_1_data = seg_1_data.astype(dtype) + elif dtype == np.int16: + seg_1_data = np.round(seg_1_data * 32767).astype(np.int16) + else: + raise ValueError("dtype must be float32 or int16") + + recording._recording_segments[0].set_data(seg_1_data) + assert np.array_equal(recording.get_traces(segment_index=0), seg_1_data), "segment 1 test setup did not work." + + return mean, covar, recording + + def get_empty_custom_recording(self, num_segments, num_channels, dtype): + """ """ + return CustomRecording( + durations=[10, 10] if num_segments == 2 else [10], # will auto-fill zeros + num_channels=num_channels, + channel_ids=np.arange(num_channels), + sampling_frequency=30000, + dtype=dtype, + ) + + def covar_from_whitening_mat(self, whitened_recording, eps): + """ + The whitening matrix is computed as the + inverse square root of the covariance matrix + (Sigma, 'S' below + some eps for regularising. + + Here the inverse process is performed to compute + the covariance matrix from the whitening matrix + for testing purposes. This allows the entire + workflow to be tested rather than subfunction only. + """ + W = whitened_recording._kwargs["W"] + U, D, Vt = np.linalg.svd(W) + D_new = (1 / D) ** 2 - eps + S = U @ np.diag(D_new) @ Vt + + return S + + @pytest.mark.parametrize("eps", [1e-8, 1]) + @pytest.mark.parametrize("dtype", [np.float32, np.int16]) + def test_compute_covariance_matrix(self, dtype, eps): + """ """ + eps = 1e-8 + mean, covar, recording = self.get_float_test_data(num_segments=1, dtype=dtype) + + whitened_recording = si.whiten( + recording, + apply_mean=True, + regularize=False, + num_chunks_per_segment=1, + chunk_size=recording.get_num_samples(segment_index=0) - 1, + eps=eps, + ) + + test_covar = self.covar_from_whitening_mat(whitened_recording, eps) + + assert np.allclose(test_covar, covar, rtol=0, atol=0.01) + + if eps != 1: + X = whitened_recording.get_traces() + X = X - np.mean(X, axis=0) + S = X.T @ X / X.shape[0] + + assert np.allclose(S, np.eye(recording.get_num_channels()), rtol=0, atol=1e-4) + + def test_compute_covariance_matrix_float_2_segments(self): + """ """ + eps = 1e-8 + mean, covar, recording = self.get_float_test_data(num_segments=2, dtype=np.float32) + recording._recording_segments[1].set_data( - np.zeros((num_samples, num_channels)) + np.zeros((recording.get_num_samples(segment_index=0), recording.get_num_channels())) ) - _, test_cov, test_mean = compute_covariance_matrix( - recording, - apply_mean=True, - regularize=False, - regularize_kwargs={}, - random_chunk_kwargs=dict( + whitened_recording = si.whiten( + recording, + apply_mean=True, + regularize=False, + regularize_kwargs={}, num_chunks_per_segment=1, - chunk_size=recording.get_num_samples(segment_index=0)-1, + chunk_size=recording.get_num_samples(segment_index=0) - 1, + eps=eps, ) - ) - - if num_segments == 1: - assert np.allclose(test_cov, cov_1, rtol=0, atol=0.01) - else: - assert np.allclose(test_cov, cov_1 / 2, rtol=0, atol=0.01) - - # test_cov - # TOOD: own test for mean - - whitened_recording = si.whiten( - recording, - apply_mean=True, - regularize=False, - regularize_kwargs={}, - num_chunks_per_segment=1, - chunk_size=recording.get_num_samples(segment_index=0) - 1, - eps=eps - ) - - W = whitened_recording._kwargs["W"] - U, S, Vt = np.linalg.svd(W) - S_ = (1 / S) ** 2 - eps - P = U @ np.diag(S_) @ Vt - - if num_segments == 1: - assert np.allclose(P, cov_1, rtol=0, atol=0.01) - else: - assert np.allclose(P, cov_1 / 2, rtol=0, atol=0.01) - - # TODO: - # 1) test int16, MVN is not going to work. Completely new test that just tests against X.T@T/n - # 2) test apply mean on / off and means - # 3) make clear eps is tested above - # 4) test regularisation (use existing approach). Maybe test directly against sklearn function - # 5 )test local vs. global - # 6) monkeypatch regularisation and random kwargs to check they are passed correctly. - # 7) test radius and int scale in the simplest way - # 8) test W, M are saved correctly + + test_covar = self.covar_from_whitening_mat(whitened_recording, eps) + + assert np.allclose(test_covar, covar / 2, rtol=0, atol=0.01) + + @pytest.mark.parametrize("apply_mean", [True, False]) + def test_apply_mean(self, apply_mean): + + means = np.array([10, 20, 30]) + + eps = 1e-8 + mean, covar, recording = self.get_float_test_data(num_segments=1, dtype=np.float32, mean=means) + + whitened_recording = si.whiten( + recording, + apply_mean=apply_mean, + regularize=False, + regularize_kwargs={}, + num_chunks_per_segment=1, + chunk_size=recording.get_num_samples(segment_index=0) - 1, + eps=eps, + ) + + test_covar = self.covar_from_whitening_mat(whitened_recording, eps) + + if apply_mean: + assert np.allclose(test_covar, covar, rtol=0, atol=0.01) + else: + assert np.allclose(np.diag(test_covar), means**2, rtol=0, atol=5) + + breakpoint() # TODO: check whitened data is cov identity even when apply_mean=False... + + def test_compute_sklearn_covariance_matrix(self): + """ + TODO: assume centered is fixed to True + Test some random stuff + + # TODO: this is not appropraite for all covariance functions. only one with the fit method! e.g. does not work with leodit_wolf + """ + from sklearn import covariance + + X = np.random.randn(100, 100) + + test_cov = compute_sklearn_covariance_matrix( + X, {"method": "GraphicalLasso", "alpha": 1, "enet_tol": 0.04} + ) # RENAME test_cov + cov = covariance.GraphicalLasso(alpha=1, enet_tol=0.04, assume_centered=True).fit(X).covariance_ + assert np.allclose(test_cov, cov) + + test_cov = compute_sklearn_covariance_matrix( + X, {"method": "ShrunkCovariance", "shrinkage": 0.3} + ) # RENAME test_cov + cov = covariance.ShrunkCovariance(shrinkage=0.3, assume_centered=True).fit(X).covariance_ + assert np.allclose(test_cov, cov) + + def test_whiten_regularisation_norm(self): + """ """ + from sklearn import covariance + + _, _, recording = self.get_float_test_data(num_segments=1, dtype=np.float32) + + whitened_recording = si.whiten( + recording, + regularize=True, + regularize_kwargs={"method": "ShrunkCovariance", "shrinkage": 0.3}, + apply_mean=True, + num_chunks_per_segment=1, + chunk_size=recording.get_num_samples(segment_index=0) - 1, + eps=1e-8, + ) + + test_covar = self.covar_from_whitening_mat(whitened_recording, eps=1e-8) + + X = recording.get_traces()[:-1, :] + X = X - np.mean(X, axis=0) + + covar = covariance.ShrunkCovariance(shrinkage=0.3, assume_centered=True).fit(X).covariance_ + + assert np.allclose(test_covar, covar, rtol=0, atol=1e-4) + + def test_local_vs_global_whiten(self): + # Make test data with 4 channels, known covar between all + # do with radius = 2. compute manually. Will need to change channel locations + # check matches well. + pass + + def test_passed_W_and_M(self): + pass + + def test_all_kwargs(self): + pass + + def test_saved_to_disk(self): + # check attributes are saved properly + pass + def test_whiten(create_cache_folder): cache_folder = create_cache_folder @@ -196,5 +305,5 @@ def test_whiten(create_cache_folder): assert np.linalg.norm(W1) > np.linalg.norm(W2) -#if __name__ == "__main__": - # test_compute_whitening_matrix() +# if __name__ == "__main__": +# test_compute_whitening_matrix() diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index 1db5b3afbe..f430608543 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -7,7 +7,6 @@ from ..core import get_random_data_chunks, get_channel_distances from .filter import fix_dtype -from ..core.globals import get_global_job_kwargs class WhitenRecording(BasePreprocessor): @@ -76,8 +75,12 @@ def __init__( dtype_ = fix_dtype(recording, dtype) if dtype_.kind == "i": - assert int_scale is not None, ("For recording with dtype=int you must set the output dtype to float " - " OR set a int_scale") + assert int_scale is not None, ( + "For recording with dtype=int you must set the output dtype to float " " OR set a int_scale" + ) + + if not apply_mean and regularize: + raise ValueError("`apply_mean` must be `True` if regularising. `assume_centered` is fixed to `True`.") if W is not None: W = np.asarray(W) @@ -144,10 +147,6 @@ def get_traces(self, start_frame, end_frame, channel_indices): return whiten_traces.astype(self.dtype) -# function for API -whiten = define_function_from_class(source_class=WhitenRecording, name="whiten") - - def compute_whitening_matrix( recording, mode, random_chunk_kwargs, apply_mean, radius_um=None, eps=None, regularize=False, regularize_kwargs=None ): @@ -187,9 +186,7 @@ def compute_whitening_matrix( The "mean" matrix """ - data, cov, M = compute_covariance_matrix( - recording, apply_mean, regularize, regularize_kwargs, random_chunk_kwargs - ) + data, cov, M = compute_covariance_matrix(recording, apply_mean, regularize, regularize_kwargs, random_chunk_kwargs) # Here we determine eps used below to avoid division by zero. # Typically we can assume that data is either unscaled integers or in units of @@ -229,9 +226,10 @@ def compute_whitening_matrix( return W, M -def compute_covariance_matrix(recording, apply_mean, regularize, regularize_kwargs, random_chunk_kwargs): # TODO: check order - """ - """ +def compute_covariance_matrix( + recording, apply_mean, regularize, regularize_kwargs, random_chunk_kwargs +): # TODO: check order + """ """ random_data = get_random_data_chunks(recording, concatenated=True, return_scaled=False, **random_chunk_kwargs) random_data = random_data.astype(np.float64) @@ -249,16 +247,43 @@ def compute_covariance_matrix(recording, apply_mean, regularize, regularize_kwar cov = data.T @ data cov = cov / data.shape[0] else: - import sklearn.covariance - - method = regularize_kwargs.pop("method") - regularize_kwargs["assume_centered"] = True - estimator_class = getattr(sklearn.covariance, method) - estimator = estimator_class(**regularize_kwargs) - estimator.fit(data) - cov = estimator.covariance_ + cov = compute_sklearn_covariance_matrix(data, regularize_kwargs) + breakpoint() + # import sklearn.covariance + # method = regularize_kwargs.pop("method") + # regularize_kwargs["assume_centered"] = True + # estimator_class = getattr(sklearn.covariance, method) + # estimator = estimator_class(**regularize_kwargs) + # estimator.fit(data) + # cov = estimator.covariance_ return data, cov, M # TODO: rename data +# TODO: do we want to fix assume centered here or directly use `apply_mean`? + +def compute_sklearn_covariance_matrix(data, regularize_kwargs): + + import sklearn.covariance + + if "assume_centered" in regularize_kwargs and not regularize_kwargs["assume_centered"]: + raise ValueError("Cannot use `assume_centered=False` for `regularize_kwargs`. " "Fixing to `True`.") + + method = regularize_kwargs.pop("method") + regularize_kwargs["assume_centered"] = True + estimator_class = getattr(sklearn.covariance, method) + estimator = estimator_class(**regularize_kwargs) + estimator.fit(data) + cov = estimator.covariance_ + + return cov + + +# 1) factor out to own function +# 2) test function is simple way +# 3) monkeypatch compute covariance to check function is called and returned +# 4) check norm is smaller + +# function for API +whiten = define_function_from_class(source_class=WhitenRecording, name="whiten")