diff --git a/src/spikeinterface/preprocessing/tests/test_whiten.py b/src/spikeinterface/preprocessing/tests/test_whiten.py index c036e5020c..4cb1c9a8d8 100644 --- a/src/spikeinterface/preprocessing/tests/test_whiten.py +++ b/src/spikeinterface/preprocessing/tests/test_whiten.py @@ -6,14 +6,36 @@ from spikeinterface.preprocessing import whiten, scale, compute_whitening_matrix from spikeinterface.preprocessing.whiten import compute_sklearn_covariance_matrix -import spikeinterface.full as si # TOOD: is this a bad idea? remove! +try: + from sklearn import covariance as sklearn_covariance + + HAS_SKLEARN = True +except ImportError: + HAS_SKLEARN = False + ################################################# # Custom Recording - TODO: get feedback and move ################################################# + +# TODO: do we want to fix assume centered here or directly use `apply_mean`? +# TODO: ask about custom recording +# TODO: test `detect_bad_channels` should use `CustomRecording` +# TODO: +# 1) apply mean by default +# 2) apply mean argument in sklearn +# 3) some sklearn functions have a different signature +# 4) does it make sense to estimate covariance matrix from a multi-segment recording? +# # TODO: this is not appropraite for all covariance functions. +# only one with the fit method! e.g. does not work with leodit_wolf + + class CustomRecording(BaseRecording): - """ """ + """ + A convenience class for adding custom data to + a recording for test purposes. + """ def __init__(self, durations, num_channels, channel_ids, sampling_frequency, dtype): BaseRecording.__init__(self, sampling_frequency=sampling_frequency, channel_ids=channel_ids, dtype=dtype) @@ -35,14 +57,16 @@ def __init__(self, durations, num_channels, channel_ids, sampling_frequency, dty } def set_data(self, data, segment_index=0): - + """ + Set the `data` on on the segment of index `segment_index`. + `data` must be the same size (num_samples, num_channels) + and dtype as the reocrding. + """ if data.shape[0] != self.get_num_samples(segment_index=segment_index): - raise ValueError("The first dimension must be the same size as" - "the number of samples.") + raise ValueError("The first dimension must be the same size as" "the number of samples.") if data.shape[1] != self.get_num_channels(): - raise ValueError("The second dimension of the data be the same" - "size as the number of channels.") + raise ValueError("The second dimension of the data be the same" "size as the number of channels.") if data.dtype != self.dtype: raise ValueError("The dtype of the data must match the recording dtype.") @@ -52,6 +76,9 @@ def set_data(self, data, segment_index=0): class CustomRecordingSegment(BaseRecordingSegment): """ + Segment for the `CustomRecording` which simply returns + the set data when `get_traces()` is called. See + `CustomRecording.set_data()` for details on the set data. """ def __init__(self, num_samples, num_channels, sampling_frequency): @@ -74,56 +101,87 @@ def get_traces( def get_num_samples(self): return self.num_samples + ################################################# # Test Class ################################################# + class TestWhiten: """ + Test the whitening preprocessing step. - + The strategy is to generate a recording that has data + with a known covariance matrix, then testing that the + covariance matrix is computed properly and that the + returned data is indeed white. """ - def get_float_test_data(self, num_segments, dtype, mean=None, covar=None): + def get_float_test_data(self, num_segments, dtype, means=None): """ - mention the segment thing + Generate a set of test data with known covariance matrix and mean. + Test data is drawn from a multivariate Gaussian distribute with + means `mean` and covariance matrix `cov_mat`. + + A fixture is not used because we often want to change the options, + and it is very quick to generate this test data. + + The number of channels (3) and covariance matrix is fixed + and directly tested against in below tests. + + Parameters + ---------- + + num_segments : int + Number of segments for the recording. Note that only the first + segment is filled with data. Data for other segments must be + set manually. + + dtype : np.float32 | np.int16 + Datatype of the generated recording. + + means : None | np.ndarray + The `means` should be an array of length 3 (num samples) + or `None`. If `None`, means will be zero. """ num_channels = 3 - if mean is None: - mean = np.zeros(num_channels) + if means is None: + means = np.zeros(num_channels) - if covar is None: - covar = np.array([[1, 0.5, 0], [0.5, 1, -0.25], [0, -0.25, 1]]) + cov_mat = np.array([[1, 0.5, 0], [0.5, 1, -0.25], [0, -0.25, 1]]) + # Generate recording and multivariate Gaussian data to set recording = self.get_empty_custom_recording(num_segments, num_channels, dtype) - seg_1_data = np.random.multivariate_normal( - mean, covar, recording.get_num_samples(segment_index=0) # TODO: RENAME! - ) + seg_1_data = np.random.multivariate_normal(means, cov_mat, recording.get_num_samples(segment_index=0)) + + # Set the dtype, if `int16`, first scale to +/- 1 then cast to int16 range. if dtype == np.float32: seg_1_data = seg_1_data.astype(dtype) + elif dtype == np.int16: - seg_1_data = np.round(seg_1_data * 32767).astype(np.int16) + seg_1_data /= seg_1_data.max() + seg_1_data = np.round((seg_1_data) * 32767).astype(np.int16) else: raise ValueError("dtype must be float32 or int16") + # Set the data on the recording and return recording.set_data(seg_1_data) assert np.array_equal(recording.get_traces(segment_index=0), seg_1_data), "segment 1 test setup did not work." - return mean, covar, recording + return means, cov_mat, recording def get_empty_custom_recording(self, num_segments, num_channels, dtype): - """ """ return CustomRecording( - durations=[10, 10] if num_segments == 2 else [10], # will auto-fill zeros + durations=[10 for _ in range(num_segments)], num_channels=num_channels, channel_ids=np.arange(num_channels), sampling_frequency=30000, dtype=dtype, ) - def covar_from_whitening_mat(self, whitened_recording, eps): + def cov_mat_from_whitening_mat(self, whitened_recording, eps): """ The whitening matrix is computed as the inverse square root of the covariance matrix @@ -142,6 +200,26 @@ def covar_from_whitening_mat(self, whitened_recording, eps): return S + def assert_recording_is_white(self, recording): + """ + Compute the covariance matrix of the recording, + and assert that it is close to identity. + """ + X = recording.get_traces() + S = self.compute_cov_mat(X) + + assert np.allclose(S, np.eye(recording.get_num_channels()), rtol=0, atol=1e-4) + + def compute_cov_mat(self, X): + """ + Estimate the covariance matrix from data + using the standard linear algebra approach. + """ + X = X - np.mean(X, axis=0) + S = X.T @ X / X.shape[0] + + return S + ################################################################################### # Tests ################################################################################### @@ -149,12 +227,16 @@ def covar_from_whitening_mat(self, whitened_recording, eps): @pytest.mark.parametrize("dtype", [np.float32, np.int16]) def test_compute_covariance_matrix(self, dtype): """ - + Test that the covariance matrix is computed as expected and + data is white after whitening step. Test against float32 and + int16, testing int16 is important to ensure data + is cast to float before computing the covariance matrix, + otherwise it can overflow. """ eps = 1e-8 - mean, covar, recording = self.get_float_test_data(num_segments=1, dtype=dtype) + _, cov_mat, recording = self.get_float_test_data(num_segments=1, dtype=dtype) - whitened_recording = si.whiten( + whitened_recording = whiten( recording, apply_mean=False, regularize=False, @@ -164,25 +246,27 @@ def test_compute_covariance_matrix(self, dtype): dtype=np.float32, ) - if dtype == np.float32: - test_covar = self.covar_from_whitening_mat(whitened_recording, eps) - assert np.allclose(test_covar, covar, rtol=0, atol=0.01) + test_cov_mat = self.cov_mat_from_whitening_mat(whitened_recording, eps) - # TODO: OWN FUNCTION - X = whitened_recording.get_traces() - X = X - np.mean(X, axis=0) - S = X.T @ X / X.shape[0] + # If the data is in `int16` the covariance matrix will be scaled up + # as data is set to +/32767 range before cast. + if dtype == np.int16: + test_cov_mat /= test_cov_mat[0][0] + assert np.allclose(test_cov_mat, cov_mat, rtol=0, atol=0.01) - assert np.allclose(S, np.eye(recording.get_num_channels()), rtol=0, atol=1e-4) + self.assert_recording_is_white(whitened_recording) def test_non_default_eps(self): """ - + Try a new non-default eps and check that it is correctly + propagated to the matrix computation. The test is that + the `cov_mat_from_whitening_mat` should recovery exctly + the cov mat if the correct eps is used. """ eps = 1 - mean, covar, recording = self.get_float_test_data(num_segments=1, dtype=np.float32) + _, cov_mat, recording = self.get_float_test_data(num_segments=1, dtype=np.float32) - whitened_recording = si.whiten( + whitened_recording = whiten( recording, apply_mean=False, regularize=False, @@ -190,15 +274,20 @@ def test_non_default_eps(self): chunk_size=recording.get_num_samples(segment_index=0) - 1, eps=eps, ) - test_covar = self.covar_from_whitening_mat(whitened_recording, eps) - assert np.allclose(test_covar, covar, rtol=0, atol=0.01) + test_cov_mat = self.cov_mat_from_whitening_mat(whitened_recording, eps) + assert np.allclose(test_cov_mat, cov_mat, rtol=0, atol=0.01) - def test_compute_covariance_matrix_float_2_segments(self): + def test_compute_covariance_matrix_2_segments(self): """ - + Check that the covariance marix is estimated from across + segments in a multi-segment recording. This is done simply + by setting the second segment as all zeros and checking the + estimated covariances are all halved. This makes sense as + the zeros do not affect the covariance estimation + but the covariance matrix is scaled by 1 / N. """ eps = 1e-8 - mean, covar, recording = self.get_float_test_data(num_segments=2, dtype=np.float32) + _, cov_mat, recording = self.get_float_test_data(num_segments=2, dtype=np.float32) all_zero_data = np.zeros( (recording.get_num_samples(segment_index=0), recording.get_num_channels()), @@ -210,7 +299,7 @@ def test_compute_covariance_matrix_float_2_segments(self): segment_index=1, ) - whitened_recording = si.whiten( + whitened_recording = whiten( recording, apply_mean=True, regularize=False, @@ -220,21 +309,25 @@ def test_compute_covariance_matrix_float_2_segments(self): eps=eps, ) - test_covar = self.covar_from_whitening_mat(whitened_recording, eps) + test_cov_mat = self.cov_mat_from_whitening_mat(whitened_recording, eps) - assert np.allclose(test_covar, covar / 2, rtol=0, atol=0.01) + assert np.allclose(test_cov_mat, cov_mat / 2, rtol=0, atol=0.01) @pytest.mark.parametrize("apply_mean", [True, False]) def test_apply_mean(self, apply_mean): """ - + Test that the `apply_mean` argument is propagated correctly. + Note that in the case `apply_mean=False`, the covariance matrix + is in unusual scaling and so the varainces alone are checked. + Also, the data is not as well whitened and so this is not + tested against. """ means = np.array([10, 20, 30]) eps = 1e-8 - mean, covar, recording = self.get_float_test_data(num_segments=1, dtype=np.float32, mean=means) + _, cov_mat, recording = self.get_float_test_data(num_segments=1, dtype=np.float32, means=means) - whitened_recording = si.whiten( + whitened_recording = whiten( recording, apply_mean=apply_mean, regularize=False, @@ -244,49 +337,53 @@ def test_apply_mean(self, apply_mean): eps=eps, ) - test_covar = self.covar_from_whitening_mat(whitened_recording, eps) + test_cov_mat = self.cov_mat_from_whitening_mat(whitened_recording, eps) if apply_mean: - assert np.allclose(test_covar, covar, rtol=0, atol=0.01) + assert np.allclose(test_cov_mat, cov_mat, rtol=0, atol=0.01) else: - assert np.allclose(np.diag(test_covar), means**2, rtol=0, atol=5) + assert np.allclose(np.diag(test_cov_mat), means**2, rtol=0, atol=5) - # TODO: insert test cov is white function - # breakpoint() # TODO: check whitened data is cov identity even when apply_mean=False... + # Note the recording is typically not white if the mean is + # not removed prior to covariance matrix estimation. + if apply_mean: + self.assert_recording_is_white(whitened_recording) + @pytest.mark.skipif(not HAS_SKLEARN, reason="sklearn must be installed.") def test_compute_sklearn_covariance_matrix(self): """ - TODO: assume centered is fixed to True - Test some random stuff + A basic check that the `compute_sklearn_covariance_matrix` + function from `whiten.py` computes the same matrix + as using the sklearn function directly for some + arbitraily chosen methods / parameters. - # TODO: this is not appropraite for all covariance functions. - only one with the fit method! e.g. does not work with leodit_wolf + Note that the function-style sklearn covariance + methods are not supported. """ - from sklearn import covariance - X = np.random.randn(100, 100) test_cov = compute_sklearn_covariance_matrix( X, {"method": "GraphicalLasso", "alpha": 1, "enet_tol": 0.04} ) # RENAME test_cov - cov = covariance.GraphicalLasso(alpha=1, enet_tol=0.04, assume_centered=True).fit(X).covariance_ + cov = sklearn_covariance.GraphicalLasso(alpha=1, enet_tol=0.04, assume_centered=True).fit(X).covariance_ assert np.allclose(test_cov, cov) test_cov = compute_sklearn_covariance_matrix( X, {"method": "ShrunkCovariance", "shrinkage": 0.3} ) # RENAME test_cov - cov = covariance.ShrunkCovariance(shrinkage=0.3, assume_centered=True).fit(X).covariance_ + cov = sklearn_covariance.ShrunkCovariance(shrinkage=0.3, assume_centered=True).fit(X).covariance_ assert np.allclose(test_cov, cov) + @pytest.mark.skipif(not HAS_SKLEARN, reason="sklearn must be installed.") def test_whiten_regularisation_norm(self): """ - + Check that the covariance matrix estimated by the + whitening preprocessing is the same as the one + computed from sklearn when regularise kwargs are given. """ - from sklearn import covariance - _, _, recording = self.get_float_test_data(num_segments=1, dtype=np.float32) - whitened_recording = si.whiten( + whitened_recording = whiten( recording, regularize=True, regularize_kwargs={"method": "ShrunkCovariance", "shrinkage": 0.3}, @@ -296,35 +393,37 @@ def test_whiten_regularisation_norm(self): eps=1e-8, ) - test_covar = self.covar_from_whitening_mat(whitened_recording, eps=1e-8) + test_cov_mat = self.cov_mat_from_whitening_mat(whitened_recording, eps=1e-8) + # Compute covariance matrix using sklearn directly and compare. X = recording.get_traces()[:-1, :] X = X - np.mean(X, axis=0) + cov_mat = sklearn_covariance.ShrunkCovariance(shrinkage=0.3, assume_centered=True).fit(X).covariance_ - covar = covariance.ShrunkCovariance(shrinkage=0.3, assume_centered=True).fit(X).covariance_ - - assert np.allclose(test_covar, covar, rtol=0, atol=1e-4) - - # TODO: insert test whitened recording is white + assert np.allclose(test_cov_mat, cov_mat, rtol=0, atol=1e-4) def test_local_vs_global_whiten(self): - # Make test data with 4 channels, known covar between all - # do with radius = 2. compute manually. Will need to change channel locations - # check matches well. - + """ + Generate a set of channels each separated by y_dist. Set the + radius_um to just above y_dist such that only neighbouring + channels are considered for whitening. Test that whitening + is correct for the first pair and last pair. + """ _, _, recording = self.get_float_test_data(num_segments=1, dtype=np.float32) y_dist = 2 - recording.set_channel_locations([ - [0.0, 0], - [0.0, y_dist * 1], - [0.0, y_dist * 2], - ]) + recording.set_channel_locations( + [ + [0.0, 0], + [0.0, y_dist * 1], + [0.0, y_dist * 2], + ] + ) results = {"global": None, "local": None} for mode in ["global", "local"]: - whitened_recording = si.whiten( + whitened_recording = whiten( recording, apply_mean=True, num_chunks_per_segment=1, @@ -338,28 +437,27 @@ def test_local_vs_global_whiten(self): assert results["local"]._kwargs["W"][0][2] == 0.0 assert results["global"]._kwargs["W"][0][2] != 0.0 - # TEST - whitened_data = results["local"].get_traces() + # Parse out the data into two pairs of channels + # from which the local variance was computed. + whitened_data = results["local"].get_traces() set_1 = whitened_data[:, :2] - np.mean(whitened_data[:, :2], axis=0) set_2 = whitened_data[:, 1:] - np.mean(whitened_data[:, 1:], axis=0) - assert np.allclose( - np.eye(2), set_1.T@set_1 / set_1.shape[0], - rtol=0, atol=1e-2 - ) - assert np.allclose( - np.eye(2), set_2.T@set_2 / set_2.shape[0], - rtol=0, atol=1e-2 - ) - # TODO: own function - X = whitened_data - np.mean(whitened_data, axis=0) - covar_ = X.T@X - X.shape[0] - assert not np.allclose(np.eye(3), covar_, rtol=0, atol=1e-2) + # Check that the pairs of close channels + # whitened together are white + assert np.allclose(np.eye(2), self.compute_cov_mat(set_1), rtol=0, atol=1e-2) + assert np.allclose(np.eye(2), self.compute_cov_mat(set_2), rtol=0, atol=1e-2) + + # Check that the data overall is not white + assert not np.allclose(np.eye(3), self.compute_cov_mat(whitened_data), rtol=0, atol=1e-2) def test_passed_W_and_M(self): """ - TODO: Need options make clear same whitening matrix for all segments. Is this realistic? + Check that passing W (whitening matrix) and M (means) is + sucessfully propagated to the relevant segments and stored + on the kwargs. It is assumed if this is true, they will + be used for the actual whitening computation. """ num_chan = 4 recording = self.get_empty_custom_recording(2, num_chan, dtype=np.float32) @@ -367,28 +465,23 @@ def test_passed_W_and_M(self): test_W = np.random.normal(size=(num_chan, num_chan)) test_M = np.random.normal(size=num_chan) - whitened_recording = si.whiten( - recording, - W=test_W, - M=test_M - ) + whitened_recording = whiten(recording, W=test_W, M=test_M) for seg_idx in [0, 1]: - assert np.array_equal( - whitened_recording._recording_segments[seg_idx].W, - test_W - ) - assert np.array_equal( - whitened_recording._recording_segments[seg_idx].M, - test_M - ) + assert np.array_equal(whitened_recording._recording_segments[seg_idx].W, test_W) + assert np.array_equal(whitened_recording._recording_segments[seg_idx].M, test_M) assert whitened_recording._kwargs["W"] == test_W.tolist() assert whitened_recording._kwargs["M"] == test_M.tolist() def test_whiten_general(self, create_cache_folder): """ + Perform some general tests on the whitening functionality. + First, perform smoke test that `compute_whitening_matrix` is running, + check recording output datatypes are as expected. Check that + saving preseves datatype, `int_scale` is propagated, and + regularisation reduces the norm. """ cache_folder = create_cache_folder rec = generate_recording(num_channels=4, seed=2205) @@ -423,5 +516,6 @@ def test_whiten_general(self, create_cache_folder): assert rec4._kwargs["M"] is None # test regularization : norm should be smaller - W2, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, regularize=True) - assert np.linalg.norm(W1) > np.linalg.norm(W2) + if HAS_SKLEARN: + W2, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, regularize=True) + assert np.linalg.norm(W1) > np.linalg.norm(W2) diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index f94386e3bf..153407422a 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -47,7 +47,8 @@ class WhitenRecording(BasePreprocessor): of sklearn, specified in regularize_kwargs. Default is GraphicalLassoCV regularize_kwargs : {'method' : 'GraphicalLassoCV'} Dictionary of the parameters that could be provided to the method of sklearn, if - the covariance matrix needs to be regularized. + the covariance matrix needs to be regularized. Note that sklearn covariance methods + that are implemented as functions, not classes, are not supported. **random_chunk_kwargs : Keyword arguments for `spikeinterface.core.get_random_data_chunk()` function Returns @@ -202,23 +203,21 @@ def compute_whitening_matrix( if data.dtype.kind == "f": median_data_sqr = np.median(data**2) # use the square because cov (and hence S) scales as the square if median_data_sqr < 1 and median_data_sqr > 0: - if eps is None: + if eps is None: # TODO: I dont think this will ever be triggered because if eps is None is set above. eps = max(1e-16, median_data_sqr * 1e-3) # use a small fraction of the median of the squared data if mode == "global": - U, S, Ut = np.linalg.svd(cov, full_matrices=True) - W = (U @ np.diag(1 / np.sqrt(S + eps))) @ Ut + W = compute_whitening_from_covariance(cov, eps) # TODO: carefully check elif mode == "local": assert radius_um is not None n = cov.shape[0] distances = get_channel_distances(recording) - W = np.zeros((n, n), dtype="float64") # TODO: should fix to float32 for consistency + W = np.zeros((n, n), dtype="float32") for c in range(n): (inds,) = np.nonzero(distances[c, :] <= radius_um) cov_local = cov[inds, :][:, inds] - U, S, Ut = np.linalg.svd(cov_local, full_matrices=True) - W_local = (U @ np.diag(1 / np.sqrt(S + eps))) @ Ut + W_local = compute_whitening_from_covariance(cov_local, eps) # TODO: carefully check W[inds, c] = W_local[c == inds] else: raise ValueError(f"compute_whitening_matrix : wrong mode {mode}") @@ -226,10 +225,22 @@ def compute_whitening_matrix( return W, M -def compute_covariance_matrix( - recording, apply_mean, regularize, regularize_kwargs, random_chunk_kwargs -): # TODO: check order - """ """ +def compute_whitening_from_covariance(cov, eps): + """ + Compute the whitening matrix from the covariance + matrix using ZCA whitening approach. Note the `eps` + ensures division by zero is not possible and regularises. + """ + U, S, Ut = np.linalg.svd(cov, full_matrices=True) + W = (U @ np.diag(1 / np.sqrt(S + eps))) @ Ut + return W + + +def compute_covariance_matrix(recording, apply_mean, regularize, regularize_kwargs, random_chunk_kwargs): + """ + Compute the covariance matrix from randomly sampled data chunsk. + See `compute_whitening_matrix()` for parameters. + """ random_data = get_random_data_chunks(recording, concatenated=True, return_scaled=False, **random_chunk_kwargs) random_data = random_data.astype(np.float32) @@ -248,22 +259,17 @@ def compute_covariance_matrix( cov = cov / data.shape[0] else: cov = compute_sklearn_covariance_matrix(data, regularize_kwargs) - # import sklearn.covariance - # method = regularize_kwargs.pop("method") - # regularize_kwargs["assume_centered"] = True - # estimator_class = getattr(sklearn.covariance, method) - # estimator = estimator_class(**regularize_kwargs) - # estimator.fit(data) - # cov = estimator.covariance_ - - return data, cov, M # TODO: rename data - -# TODO: do we want to fix assume centered here or directly use `apply_mean`? + return data, cov, M def compute_sklearn_covariance_matrix(data, regularize_kwargs): + """ + Estimate the covariance matrix using scikit-learn functions. + Note that sklearn covariance methods that are implemented + as functions, not classes, are not supported. + """ import sklearn.covariance if "assume_centered" in regularize_kwargs and not regularize_kwargs["assume_centered"]: @@ -279,10 +285,5 @@ def compute_sklearn_covariance_matrix(data, regularize_kwargs): return cov -# 1) factor out to own function -# 2) test function is simple way -# 3) monkeypatch compute covariance to check function is called and returned -# 4) check norm is smaller - # function for API whiten = define_function_from_class(source_class=WhitenRecording, name="whiten")