Skip to content

Commit

Permalink
Continue exampanding tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZiminski committed Nov 7, 2024
1 parent 915d744 commit a625cbf
Show file tree
Hide file tree
Showing 2 changed files with 249 additions and 115 deletions.
297 changes: 203 additions & 94 deletions src/spikeinterface/preprocessing/tests/test_whiten.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@
from spikeinterface.core import generate_recording
from spikeinterface.core import BaseRecording, BaseRecordingSegment
from spikeinterface.preprocessing import whiten, scale, compute_whitening_matrix
from spikeinterface.preprocessing.whiten import compute_covariance_matrix
from spikeinterface.preprocessing.whiten import compute_sklearn_covariance_matrix

import spikeinterface.full as si # TOOD: is this a bad idea? remove!


class CustomRecording(BaseRecording):
"""
""" """

"""
def __init__(self, durations, num_channels, channel_ids, sampling_frequency, dtype):
BaseRecording.__init__(self, sampling_frequency=sampling_frequency, channel_ids=channel_ids, dtype=dtype)

Expand All @@ -32,14 +31,10 @@ def __init__(self, durations, num_channels, channel_ids, sampling_frequency, dty
"sampling_frequency": sampling_frequency,
}

# TODO
# 1) save covariance matrix
# 2)


class CustomRecordingSegment(BaseRecordingSegment):
"""
"""
""" """

def __init__(self, num_samples, num_channels, sampling_frequency):
self.num_samples = num_samples
self.num_channels = num_channels
Expand Down Expand Up @@ -67,94 +62,208 @@ def get_num_samples(self):

# TODO: return random cghuns scaled vs unscled

@pytest.mark.parametrize("eps", [1e-8, 1])
@pytest.mark.parametrize("num_segments", [1, 2])
@pytest.mark.parametrize("dtype", [np.float32]) # np.int16
def test_compute_whitening_matrix(eps, num_segments, dtype):
"""
"""
num_channels = 3

recording = CustomRecording(
durations=[10, 10] if num_segments == 2 else [10], # will auto-fill zeros
num_channels=num_channels,
channel_ids=np.arange(num_channels),
sampling_frequency=30000,
dtype=dtype
)
num_samples = recording.get_num_samples(segment_index=0)

# 1) setup the data with known mean and covariance.
mean_1 = mean_2 = np.zeros(num_channels) # TODO: diferent tests
# mean_2 = np.arange(num_channels)

# Covariances for simulated data. Limit off-diagonals larger than variances
# for realism + stability / PSD. Actually, a better way is to just get
# some random data and compute x.T@x#
cov_1 = np.array(
[[1, 0.5, 0],
[0.5, 1, -0.25],
[0, -0.25, 1]]
)
seg_1_data = np.random.multivariate_normal(mean_1, cov_1, recording.get_num_samples(segment_index=0))
seg_1_data = seg_1_data.astype(dtype)

recording._recording_segments[0].set_data(seg_1_data)
assert np.array_equal(recording.get_traces(segment_index=0), seg_1_data), "segment 1 test setup did not work."

if num_segments == 2:

class TestWhiten:

def get_float_test_data(self, num_segments, dtype, mean=None, covar=None):
"""
mention the segment thing
"""
num_channels = 3
dtype = np.float32

if mean is None:
mean = np.zeros(num_channels)

if covar is None:
covar = np.array([[1, 0.5, 0], [0.5, 1, -0.25], [0, -0.25, 1]])

recording = self.get_empty_custom_recording(num_segments, num_channels, dtype)

seg_1_data = np.random.multivariate_normal(
mean, covar, recording.get_num_samples(segment_index=0) # TODO: RENAME!
)
if dtype == np.float32:
seg_1_data = seg_1_data.astype(dtype)
elif dtype == np.int16:
seg_1_data = np.round(seg_1_data * 32767).astype(np.int16)
else:
raise ValueError("dtype must be float32 or int16")

recording._recording_segments[0].set_data(seg_1_data)
assert np.array_equal(recording.get_traces(segment_index=0), seg_1_data), "segment 1 test setup did not work."

return mean, covar, recording

def get_empty_custom_recording(self, num_segments, num_channels, dtype):
""" """
return CustomRecording(
durations=[10, 10] if num_segments == 2 else [10], # will auto-fill zeros
num_channels=num_channels,
channel_ids=np.arange(num_channels),
sampling_frequency=30000,
dtype=dtype,
)

def covar_from_whitening_mat(self, whitened_recording, eps):
"""
The whitening matrix is computed as the
inverse square root of the covariance matrix
(Sigma, 'S' below + some eps for regularising.
Here the inverse process is performed to compute
the covariance matrix from the whitening matrix
for testing purposes. This allows the entire
workflow to be tested rather than subfunction only.
"""
W = whitened_recording._kwargs["W"]
U, D, Vt = np.linalg.svd(W)
D_new = (1 / D) ** 2 - eps
S = U @ np.diag(D_new) @ Vt

return S

@pytest.mark.parametrize("eps", [1e-8, 1])
@pytest.mark.parametrize("dtype", [np.float32, np.int16])
def test_compute_covariance_matrix(self, dtype, eps):
""" """
eps = 1e-8
mean, covar, recording = self.get_float_test_data(num_segments=1, dtype=dtype)

whitened_recording = si.whiten(
recording,
apply_mean=True,
regularize=False,
num_chunks_per_segment=1,
chunk_size=recording.get_num_samples(segment_index=0) - 1,
eps=eps,
)

test_covar = self.covar_from_whitening_mat(whitened_recording, eps)

assert np.allclose(test_covar, covar, rtol=0, atol=0.01)

if eps != 1:
X = whitened_recording.get_traces()
X = X - np.mean(X, axis=0)
S = X.T @ X / X.shape[0]

assert np.allclose(S, np.eye(recording.get_num_channels()), rtol=0, atol=1e-4)

def test_compute_covariance_matrix_float_2_segments(self):
""" """
eps = 1e-8
mean, covar, recording = self.get_float_test_data(num_segments=2, dtype=np.float32)

recording._recording_segments[1].set_data(
np.zeros((num_samples, num_channels))
np.zeros((recording.get_num_samples(segment_index=0), recording.get_num_channels()))
)

_, test_cov, test_mean = compute_covariance_matrix(
recording,
apply_mean=True,
regularize=False,
regularize_kwargs={},
random_chunk_kwargs=dict(
whitened_recording = si.whiten(
recording,
apply_mean=True,
regularize=False,
regularize_kwargs={},
num_chunks_per_segment=1,
chunk_size=recording.get_num_samples(segment_index=0)-1,
chunk_size=recording.get_num_samples(segment_index=0) - 1,
eps=eps,
)
)

if num_segments == 1:
assert np.allclose(test_cov, cov_1, rtol=0, atol=0.01)
else:
assert np.allclose(test_cov, cov_1 / 2, rtol=0, atol=0.01)

# test_cov
# TOOD: own test for mean

whitened_recording = si.whiten(
recording,
apply_mean=True,
regularize=False,
regularize_kwargs={},
num_chunks_per_segment=1,
chunk_size=recording.get_num_samples(segment_index=0) - 1,
eps=eps
)

W = whitened_recording._kwargs["W"]
U, S, Vt = np.linalg.svd(W)
S_ = (1 / S) ** 2 - eps
P = U @ np.diag(S_) @ Vt

if num_segments == 1:
assert np.allclose(P, cov_1, rtol=0, atol=0.01)
else:
assert np.allclose(P, cov_1 / 2, rtol=0, atol=0.01)

# TODO:
# 1) test int16, MVN is not going to work. Completely new test that just tests against X.T@T/n
# 2) test apply mean on / off and means
# 3) make clear eps is tested above
# 4) test regularisation (use existing approach). Maybe test directly against sklearn function
# 5 )test local vs. global
# 6) monkeypatch regularisation and random kwargs to check they are passed correctly.
# 7) test radius and int scale in the simplest way
# 8) test W, M are saved correctly

test_covar = self.covar_from_whitening_mat(whitened_recording, eps)

assert np.allclose(test_covar, covar / 2, rtol=0, atol=0.01)

@pytest.mark.parametrize("apply_mean", [True, False])
def test_apply_mean(self, apply_mean):

means = np.array([10, 20, 30])

eps = 1e-8
mean, covar, recording = self.get_float_test_data(num_segments=1, dtype=np.float32, mean=means)

whitened_recording = si.whiten(
recording,
apply_mean=apply_mean,
regularize=False,
regularize_kwargs={},
num_chunks_per_segment=1,
chunk_size=recording.get_num_samples(segment_index=0) - 1,
eps=eps,
)

test_covar = self.covar_from_whitening_mat(whitened_recording, eps)

if apply_mean:
assert np.allclose(test_covar, covar, rtol=0, atol=0.01)
else:
assert np.allclose(np.diag(test_covar), means**2, rtol=0, atol=5)

breakpoint() # TODO: check whitened data is cov identity even when apply_mean=False...

def test_compute_sklearn_covariance_matrix(self):
"""
TODO: assume centered is fixed to True
Test some random stuff
# TODO: this is not appropraite for all covariance functions. only one with the fit method! e.g. does not work with leodit_wolf
"""
from sklearn import covariance

X = np.random.randn(100, 100)

test_cov = compute_sklearn_covariance_matrix(
X, {"method": "GraphicalLasso", "alpha": 1, "enet_tol": 0.04}
) # RENAME test_cov
cov = covariance.GraphicalLasso(alpha=1, enet_tol=0.04, assume_centered=True).fit(X).covariance_
assert np.allclose(test_cov, cov)

test_cov = compute_sklearn_covariance_matrix(
X, {"method": "ShrunkCovariance", "shrinkage": 0.3}
) # RENAME test_cov
cov = covariance.ShrunkCovariance(shrinkage=0.3, assume_centered=True).fit(X).covariance_
assert np.allclose(test_cov, cov)

def test_whiten_regularisation_norm(self):
""" """
from sklearn import covariance

_, _, recording = self.get_float_test_data(num_segments=1, dtype=np.float32)

whitened_recording = si.whiten(
recording,
regularize=True,
regularize_kwargs={"method": "ShrunkCovariance", "shrinkage": 0.3},
apply_mean=True,
num_chunks_per_segment=1,
chunk_size=recording.get_num_samples(segment_index=0) - 1,
eps=1e-8,
)

test_covar = self.covar_from_whitening_mat(whitened_recording, eps=1e-8)

X = recording.get_traces()[:-1, :]
X = X - np.mean(X, axis=0)

covar = covariance.ShrunkCovariance(shrinkage=0.3, assume_centered=True).fit(X).covariance_

assert np.allclose(test_covar, covar, rtol=0, atol=1e-4)

def test_local_vs_global_whiten(self):
# Make test data with 4 channels, known covar between all
# do with radius = 2. compute manually. Will need to change channel locations
# check matches well.
pass

def test_passed_W_and_M(self):
pass

def test_all_kwargs(self):
pass

def test_saved_to_disk(self):
# check attributes are saved properly
pass


def test_whiten(create_cache_folder):
cache_folder = create_cache_folder
Expand Down Expand Up @@ -196,5 +305,5 @@ def test_whiten(create_cache_folder):
assert np.linalg.norm(W1) > np.linalg.norm(W2)


#if __name__ == "__main__":
# test_compute_whitening_matrix()
# if __name__ == "__main__":
# test_compute_whitening_matrix()
Loading

0 comments on commit a625cbf

Please sign in to comment.