diff --git a/src/spikeinterface/preprocessing/tests/test_whiten.py b/src/spikeinterface/preprocessing/tests/test_whiten.py index c7dd1b226c..c036e5020c 100644 --- a/src/spikeinterface/preprocessing/tests/test_whiten.py +++ b/src/spikeinterface/preprocessing/tests/test_whiten.py @@ -8,6 +8,9 @@ import spikeinterface.full as si # TOOD: is this a bad idea? remove! +################################################# +# Custom Recording - TODO: get feedback and move +################################################# class CustomRecording(BaseRecording): """ """ @@ -31,9 +34,25 @@ def __init__(self, durations, num_channels, channel_ids, sampling_frequency, dty "sampling_frequency": sampling_frequency, } + def set_data(self, data, segment_index=0): + + if data.shape[0] != self.get_num_samples(segment_index=segment_index): + raise ValueError("The first dimension must be the same size as" + "the number of samples.") + + if data.shape[1] != self.get_num_channels(): + raise ValueError("The second dimension of the data be the same" + "size as the number of channels.") + + if data.dtype != self.dtype: + raise ValueError("The dtype of the data must match the recording dtype.") + + self._recording_segments[segment_index].data = data + class CustomRecordingSegment(BaseRecordingSegment): - """ """ + """ + """ def __init__(self, num_samples, num_channels, sampling_frequency): self.num_samples = num_samples @@ -44,10 +63,6 @@ def __init__(self, num_samples, num_channels, sampling_frequency): self.t_start = None self.time_vector = None - def set_data(self, data): - # TODO: do some checks - self.data = data - def get_traces( self, start_frame: int | None = None, @@ -59,18 +74,21 @@ def get_traces( def get_num_samples(self): return self.num_samples +################################################# +# Test Class +################################################# -# TODO: return random cghuns scaled vs unscled +class TestWhiten: + """ -class TestWhiten: + """ def get_float_test_data(self, num_segments, dtype, mean=None, covar=None): """ mention the segment thing """ num_channels = 3 - dtype = np.float32 if mean is None: mean = np.zeros(num_channels) @@ -90,7 +108,7 @@ def get_float_test_data(self, num_segments, dtype, mean=None, covar=None): else: raise ValueError("dtype must be float32 or int16") - recording._recording_segments[0].set_data(seg_1_data) + recording.set_data(seg_1_data) assert np.array_equal(recording.get_traces(segment_index=0), seg_1_data), "segment 1 test setup did not work." return mean, covar, recording @@ -117,46 +135,79 @@ def covar_from_whitening_mat(self, whitened_recording, eps): workflow to be tested rather than subfunction only. """ W = whitened_recording._kwargs["W"] + U, D, Vt = np.linalg.svd(W) D_new = (1 / D) ** 2 - eps S = U @ np.diag(D_new) @ Vt return S - @pytest.mark.parametrize("eps", [1e-8, 1]) + ################################################################################### + # Tests + ################################################################################### + @pytest.mark.parametrize("dtype", [np.float32, np.int16]) - def test_compute_covariance_matrix(self, dtype, eps): - """ """ + def test_compute_covariance_matrix(self, dtype): + """ + + """ eps = 1e-8 mean, covar, recording = self.get_float_test_data(num_segments=1, dtype=dtype) whitened_recording = si.whiten( recording, - apply_mean=True, + apply_mean=False, regularize=False, num_chunks_per_segment=1, chunk_size=recording.get_num_samples(segment_index=0) - 1, eps=eps, + dtype=np.float32, ) - test_covar = self.covar_from_whitening_mat(whitened_recording, eps) + if dtype == np.float32: + test_covar = self.covar_from_whitening_mat(whitened_recording, eps) + assert np.allclose(test_covar, covar, rtol=0, atol=0.01) - assert np.allclose(test_covar, covar, rtol=0, atol=0.01) + # TODO: OWN FUNCTION + X = whitened_recording.get_traces() + X = X - np.mean(X, axis=0) + S = X.T @ X / X.shape[0] - if eps != 1: - X = whitened_recording.get_traces() - X = X - np.mean(X, axis=0) - S = X.T @ X / X.shape[0] + assert np.allclose(S, np.eye(recording.get_num_channels()), rtol=0, atol=1e-4) - assert np.allclose(S, np.eye(recording.get_num_channels()), rtol=0, atol=1e-4) + def test_non_default_eps(self): + """ + + """ + eps = 1 + mean, covar, recording = self.get_float_test_data(num_segments=1, dtype=np.float32) + + whitened_recording = si.whiten( + recording, + apply_mean=False, + regularize=False, + num_chunks_per_segment=1, + chunk_size=recording.get_num_samples(segment_index=0) - 1, + eps=eps, + ) + test_covar = self.covar_from_whitening_mat(whitened_recording, eps) + assert np.allclose(test_covar, covar, rtol=0, atol=0.01) def test_compute_covariance_matrix_float_2_segments(self): - """ """ + """ + + """ eps = 1e-8 mean, covar, recording = self.get_float_test_data(num_segments=2, dtype=np.float32) - recording._recording_segments[1].set_data( - np.zeros((recording.get_num_samples(segment_index=0), recording.get_num_channels())) + all_zero_data = np.zeros( + (recording.get_num_samples(segment_index=0), recording.get_num_channels()), + dtype=np.float32, + ) + + recording.set_data( + all_zero_data, + segment_index=1, ) whitened_recording = si.whiten( @@ -175,7 +226,9 @@ def test_compute_covariance_matrix_float_2_segments(self): @pytest.mark.parametrize("apply_mean", [True, False]) def test_apply_mean(self, apply_mean): + """ + """ means = np.array([10, 20, 30]) eps = 1e-8 @@ -198,14 +251,16 @@ def test_apply_mean(self, apply_mean): else: assert np.allclose(np.diag(test_covar), means**2, rtol=0, atol=5) - breakpoint() # TODO: check whitened data is cov identity even when apply_mean=False... + # TODO: insert test cov is white function + # breakpoint() # TODO: check whitened data is cov identity even when apply_mean=False... def test_compute_sklearn_covariance_matrix(self): """ TODO: assume centered is fixed to True Test some random stuff - # TODO: this is not appropraite for all covariance functions. only one with the fit method! e.g. does not work with leodit_wolf + # TODO: this is not appropraite for all covariance functions. + only one with the fit method! e.g. does not work with leodit_wolf """ from sklearn import covariance @@ -224,7 +279,9 @@ def test_compute_sklearn_covariance_matrix(self): assert np.allclose(test_cov, cov) def test_whiten_regularisation_norm(self): - """ """ + """ + + """ from sklearn import covariance _, _, recording = self.get_float_test_data(num_segments=1, dtype=np.float32) @@ -248,62 +305,123 @@ def test_whiten_regularisation_norm(self): assert np.allclose(test_covar, covar, rtol=0, atol=1e-4) + # TODO: insert test whitened recording is white + def test_local_vs_global_whiten(self): # Make test data with 4 channels, known covar between all # do with radius = 2. compute manually. Will need to change channel locations # check matches well. - pass + + _, _, recording = self.get_float_test_data(num_segments=1, dtype=np.float32) + + y_dist = 2 + recording.set_channel_locations([ + [0.0, 0], + [0.0, y_dist * 1], + [0.0, y_dist * 2], + ]) + + results = {"global": None, "local": None} + + for mode in ["global", "local"]: + whitened_recording = si.whiten( + recording, + apply_mean=True, + num_chunks_per_segment=1, + chunk_size=recording.get_num_samples(segment_index=0) - 1, + eps=1e-8, + mode=mode, + radius_um=y_dist + 1e-01, + ) + results[mode] = whitened_recording + + assert results["local"]._kwargs["W"][0][2] == 0.0 + assert results["global"]._kwargs["W"][0][2] != 0.0 + + # TEST + whitened_data = results["local"].get_traces() + + set_1 = whitened_data[:, :2] - np.mean(whitened_data[:, :2], axis=0) + set_2 = whitened_data[:, 1:] - np.mean(whitened_data[:, 1:], axis=0) + + assert np.allclose( + np.eye(2), set_1.T@set_1 / set_1.shape[0], + rtol=0, atol=1e-2 + ) + assert np.allclose( + np.eye(2), set_2.T@set_2 / set_2.shape[0], + rtol=0, atol=1e-2 + ) + # TODO: own function + X = whitened_data - np.mean(whitened_data, axis=0) + covar_ = X.T@X - X.shape[0] + assert not np.allclose(np.eye(3), covar_, rtol=0, atol=1e-2) def test_passed_W_and_M(self): - pass + """ + TODO: Need options make clear same whitening matrix for all segments. Is this realistic? + """ + num_chan = 4 + recording = self.get_empty_custom_recording(2, num_chan, dtype=np.float32) - def test_all_kwargs(self): - pass + test_W = np.random.normal(size=(num_chan, num_chan)) + test_M = np.random.normal(size=num_chan) - def test_saved_to_disk(self): - # check attributes are saved properly - pass + whitened_recording = si.whiten( + recording, + W=test_W, + M=test_M + ) + + for seg_idx in [0, 1]: + assert np.array_equal( + whitened_recording._recording_segments[seg_idx].W, + test_W + ) + assert np.array_equal( + whitened_recording._recording_segments[seg_idx].M, + test_M + ) + assert whitened_recording._kwargs["W"] == test_W.tolist() + assert whitened_recording._kwargs["M"] == test_M.tolist() -def test_whiten(create_cache_folder): - cache_folder = create_cache_folder - rec = generate_recording(num_channels=4, seed=2205) + def test_whiten_general(self, create_cache_folder): + """ - print(rec.get_channel_locations()) - random_chunk_kwargs = {} - W1, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, radius_um=None) - # print(W) - # print(M) + """ + cache_folder = create_cache_folder + rec = generate_recording(num_channels=4, seed=2205) - with pytest.raises(AssertionError): - W, M = compute_whitening_matrix(rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=None) - W, M = compute_whitening_matrix(rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=25) - # W must be sparse - np.sum(W == 0) == 6 + random_chunk_kwargs = {} + W1, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, radius_um=None) - rec2 = whiten(rec) - rec2.save(verbose=False) + with pytest.raises(AssertionError): + W, M = compute_whitening_matrix(rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=None) + W, M = compute_whitening_matrix(rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=25) - # test dtype - rec_int = scale(rec2, dtype="int16") - rec3 = whiten(rec_int, dtype="float16") - rec3 = rec3.save(folder=cache_folder / "rec1") - assert rec3.get_dtype() == "float16" + # W must be sparse + np.sum(W == 0) == 6 - # test parallel - rec_par = rec3.save(folder=cache_folder / "rec_par", n_jobs=2) - np.testing.assert_array_equal(rec3.get_traces(segment_index=0), rec_par.get_traces(segment_index=0)) + rec2 = whiten(rec) + rec2.save(verbose=False) - with pytest.raises(AssertionError): - rec4 = whiten(rec_int, dtype=None) - rec4 = whiten(rec_int, dtype=None, int_scale=256) - assert rec4.get_dtype() == "int16" - assert rec4._kwargs["M"] is None + # test dtype + rec_int = scale(rec2, dtype="int16") + rec3 = whiten(rec_int, dtype="float16") + rec3 = rec3.save(folder=cache_folder / "rec1") + assert rec3.get_dtype() == "float16" - # test regularization : norm should be smaller - W2, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, regularize=True) - assert np.linalg.norm(W1) > np.linalg.norm(W2) + # test parallel + rec_par = rec3.save(folder=cache_folder / "rec_par", n_jobs=2) + np.testing.assert_array_equal(rec3.get_traces(segment_index=0), rec_par.get_traces(segment_index=0)) + with pytest.raises(AssertionError): + rec4 = whiten(rec_int, dtype=None) # int_scale should be applied + rec4 = whiten(rec_int, dtype=None, int_scale=256) + assert rec4.get_dtype() == "int16" + assert rec4._kwargs["M"] is None -# if __name__ == "__main__": -# test_compute_whitening_matrix() + # test regularization : norm should be smaller + W2, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, regularize=True) + assert np.linalg.norm(W1) > np.linalg.norm(W2) diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index f430608543..f94386e3bf 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -213,7 +213,7 @@ def compute_whitening_matrix( assert radius_um is not None n = cov.shape[0] distances = get_channel_distances(recording) - W = np.zeros((n, n), dtype="float64") + W = np.zeros((n, n), dtype="float64") # TODO: should fix to float32 for consistency for c in range(n): (inds,) = np.nonzero(distances[c, :] <= radius_um) cov_local = cov[inds, :][:, inds] @@ -231,7 +231,7 @@ def compute_covariance_matrix( ): # TODO: check order """ """ random_data = get_random_data_chunks(recording, concatenated=True, return_scaled=False, **random_chunk_kwargs) - random_data = random_data.astype(np.float64) + random_data = random_data.astype(np.float32) regularize_kwargs = regularize_kwargs if regularize_kwargs is not None else {"method": "GraphicalLassoCV"} @@ -248,7 +248,6 @@ def compute_covariance_matrix( cov = cov / data.shape[0] else: cov = compute_sklearn_covariance_matrix(data, regularize_kwargs) - breakpoint() # import sklearn.covariance # method = regularize_kwargs.pop("method") # regularize_kwargs["assume_centered"] = True