Skip to content

Commit

Permalink
Continue extending whitening tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZiminski committed Nov 11, 2024
1 parent a625cbf commit 7fa0349
Show file tree
Hide file tree
Showing 2 changed files with 186 additions and 69 deletions.
250 changes: 184 additions & 66 deletions src/spikeinterface/preprocessing/tests/test_whiten.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

import spikeinterface.full as si # TOOD: is this a bad idea? remove!

#################################################
# Custom Recording - TODO: get feedback and move
#################################################

class CustomRecording(BaseRecording):
""" """
Expand All @@ -31,9 +34,25 @@ def __init__(self, durations, num_channels, channel_ids, sampling_frequency, dty
"sampling_frequency": sampling_frequency,
}

def set_data(self, data, segment_index=0):

if data.shape[0] != self.get_num_samples(segment_index=segment_index):
raise ValueError("The first dimension must be the same size as"
"the number of samples.")

if data.shape[1] != self.get_num_channels():
raise ValueError("The second dimension of the data be the same"
"size as the number of channels.")

if data.dtype != self.dtype:
raise ValueError("The dtype of the data must match the recording dtype.")

self._recording_segments[segment_index].data = data


class CustomRecordingSegment(BaseRecordingSegment):
""" """
"""
"""

def __init__(self, num_samples, num_channels, sampling_frequency):
self.num_samples = num_samples
Expand All @@ -44,10 +63,6 @@ def __init__(self, num_samples, num_channels, sampling_frequency):
self.t_start = None
self.time_vector = None

def set_data(self, data):
# TODO: do some checks
self.data = data

def get_traces(
self,
start_frame: int | None = None,
Expand All @@ -59,18 +74,21 @@ def get_traces(
def get_num_samples(self):
return self.num_samples

#################################################
# Test Class
#################################################

# TODO: return random cghuns scaled vs unscled
class TestWhiten:
"""
class TestWhiten:
"""

def get_float_test_data(self, num_segments, dtype, mean=None, covar=None):
"""
mention the segment thing
"""
num_channels = 3
dtype = np.float32

if mean is None:
mean = np.zeros(num_channels)
Expand All @@ -90,7 +108,7 @@ def get_float_test_data(self, num_segments, dtype, mean=None, covar=None):
else:
raise ValueError("dtype must be float32 or int16")

recording._recording_segments[0].set_data(seg_1_data)
recording.set_data(seg_1_data)
assert np.array_equal(recording.get_traces(segment_index=0), seg_1_data), "segment 1 test setup did not work."

return mean, covar, recording
Expand All @@ -117,46 +135,79 @@ def covar_from_whitening_mat(self, whitened_recording, eps):
workflow to be tested rather than subfunction only.
"""
W = whitened_recording._kwargs["W"]

U, D, Vt = np.linalg.svd(W)
D_new = (1 / D) ** 2 - eps
S = U @ np.diag(D_new) @ Vt

return S

@pytest.mark.parametrize("eps", [1e-8, 1])
###################################################################################
# Tests
###################################################################################

@pytest.mark.parametrize("dtype", [np.float32, np.int16])
def test_compute_covariance_matrix(self, dtype, eps):
""" """
def test_compute_covariance_matrix(self, dtype):
"""
"""
eps = 1e-8
mean, covar, recording = self.get_float_test_data(num_segments=1, dtype=dtype)

whitened_recording = si.whiten(
recording,
apply_mean=True,
apply_mean=False,
regularize=False,
num_chunks_per_segment=1,
chunk_size=recording.get_num_samples(segment_index=0) - 1,
eps=eps,
dtype=np.float32,
)

test_covar = self.covar_from_whitening_mat(whitened_recording, eps)
if dtype == np.float32:
test_covar = self.covar_from_whitening_mat(whitened_recording, eps)
assert np.allclose(test_covar, covar, rtol=0, atol=0.01)

assert np.allclose(test_covar, covar, rtol=0, atol=0.01)
# TODO: OWN FUNCTION
X = whitened_recording.get_traces()
X = X - np.mean(X, axis=0)
S = X.T @ X / X.shape[0]

if eps != 1:
X = whitened_recording.get_traces()
X = X - np.mean(X, axis=0)
S = X.T @ X / X.shape[0]
assert np.allclose(S, np.eye(recording.get_num_channels()), rtol=0, atol=1e-4)

assert np.allclose(S, np.eye(recording.get_num_channels()), rtol=0, atol=1e-4)
def test_non_default_eps(self):
"""
"""
eps = 1
mean, covar, recording = self.get_float_test_data(num_segments=1, dtype=np.float32)

whitened_recording = si.whiten(
recording,
apply_mean=False,
regularize=False,
num_chunks_per_segment=1,
chunk_size=recording.get_num_samples(segment_index=0) - 1,
eps=eps,
)
test_covar = self.covar_from_whitening_mat(whitened_recording, eps)
assert np.allclose(test_covar, covar, rtol=0, atol=0.01)

def test_compute_covariance_matrix_float_2_segments(self):
""" """
"""
"""
eps = 1e-8
mean, covar, recording = self.get_float_test_data(num_segments=2, dtype=np.float32)

recording._recording_segments[1].set_data(
np.zeros((recording.get_num_samples(segment_index=0), recording.get_num_channels()))
all_zero_data = np.zeros(
(recording.get_num_samples(segment_index=0), recording.get_num_channels()),
dtype=np.float32,
)

recording.set_data(
all_zero_data,
segment_index=1,
)

whitened_recording = si.whiten(
Expand All @@ -175,7 +226,9 @@ def test_compute_covariance_matrix_float_2_segments(self):

@pytest.mark.parametrize("apply_mean", [True, False])
def test_apply_mean(self, apply_mean):
"""
"""
means = np.array([10, 20, 30])

eps = 1e-8
Expand All @@ -198,14 +251,16 @@ def test_apply_mean(self, apply_mean):
else:
assert np.allclose(np.diag(test_covar), means**2, rtol=0, atol=5)

breakpoint() # TODO: check whitened data is cov identity even when apply_mean=False...
# TODO: insert test cov is white function
# breakpoint() # TODO: check whitened data is cov identity even when apply_mean=False...

def test_compute_sklearn_covariance_matrix(self):
"""
TODO: assume centered is fixed to True
Test some random stuff
# TODO: this is not appropraite for all covariance functions. only one with the fit method! e.g. does not work with leodit_wolf
# TODO: this is not appropraite for all covariance functions.
only one with the fit method! e.g. does not work with leodit_wolf
"""
from sklearn import covariance

Expand All @@ -224,7 +279,9 @@ def test_compute_sklearn_covariance_matrix(self):
assert np.allclose(test_cov, cov)

def test_whiten_regularisation_norm(self):
""" """
"""
"""
from sklearn import covariance

_, _, recording = self.get_float_test_data(num_segments=1, dtype=np.float32)
Expand All @@ -248,62 +305,123 @@ def test_whiten_regularisation_norm(self):

assert np.allclose(test_covar, covar, rtol=0, atol=1e-4)

# TODO: insert test whitened recording is white

def test_local_vs_global_whiten(self):
# Make test data with 4 channels, known covar between all
# do with radius = 2. compute manually. Will need to change channel locations
# check matches well.
pass

_, _, recording = self.get_float_test_data(num_segments=1, dtype=np.float32)

y_dist = 2
recording.set_channel_locations([
[0.0, 0],
[0.0, y_dist * 1],
[0.0, y_dist * 2],
])

results = {"global": None, "local": None}

for mode in ["global", "local"]:
whitened_recording = si.whiten(
recording,
apply_mean=True,
num_chunks_per_segment=1,
chunk_size=recording.get_num_samples(segment_index=0) - 1,
eps=1e-8,
mode=mode,
radius_um=y_dist + 1e-01,
)
results[mode] = whitened_recording

assert results["local"]._kwargs["W"][0][2] == 0.0
assert results["global"]._kwargs["W"][0][2] != 0.0

# TEST
whitened_data = results["local"].get_traces()

set_1 = whitened_data[:, :2] - np.mean(whitened_data[:, :2], axis=0)
set_2 = whitened_data[:, 1:] - np.mean(whitened_data[:, 1:], axis=0)

assert np.allclose(
np.eye(2), set_1.T@set_1 / set_1.shape[0],
rtol=0, atol=1e-2
)
assert np.allclose(
np.eye(2), set_2.T@set_2 / set_2.shape[0],
rtol=0, atol=1e-2
)
# TODO: own function
X = whitened_data - np.mean(whitened_data, axis=0)
covar_ = X.T@X - X.shape[0]
assert not np.allclose(np.eye(3), covar_, rtol=0, atol=1e-2)

def test_passed_W_and_M(self):
pass
"""
TODO: Need options make clear same whitening matrix for all segments. Is this realistic?
"""
num_chan = 4
recording = self.get_empty_custom_recording(2, num_chan, dtype=np.float32)

def test_all_kwargs(self):
pass
test_W = np.random.normal(size=(num_chan, num_chan))
test_M = np.random.normal(size=num_chan)

def test_saved_to_disk(self):
# check attributes are saved properly
pass
whitened_recording = si.whiten(
recording,
W=test_W,
M=test_M
)

for seg_idx in [0, 1]:
assert np.array_equal(
whitened_recording._recording_segments[seg_idx].W,
test_W
)
assert np.array_equal(
whitened_recording._recording_segments[seg_idx].M,
test_M
)

assert whitened_recording._kwargs["W"] == test_W.tolist()
assert whitened_recording._kwargs["M"] == test_M.tolist()

def test_whiten(create_cache_folder):
cache_folder = create_cache_folder
rec = generate_recording(num_channels=4, seed=2205)
def test_whiten_general(self, create_cache_folder):
"""
print(rec.get_channel_locations())
random_chunk_kwargs = {}
W1, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, radius_um=None)
# print(W)
# print(M)
"""
cache_folder = create_cache_folder
rec = generate_recording(num_channels=4, seed=2205)

with pytest.raises(AssertionError):
W, M = compute_whitening_matrix(rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=None)
W, M = compute_whitening_matrix(rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=25)
# W must be sparse
np.sum(W == 0) == 6
random_chunk_kwargs = {}
W1, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, radius_um=None)

rec2 = whiten(rec)
rec2.save(verbose=False)
with pytest.raises(AssertionError):
W, M = compute_whitening_matrix(rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=None)
W, M = compute_whitening_matrix(rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=25)

# test dtype
rec_int = scale(rec2, dtype="int16")
rec3 = whiten(rec_int, dtype="float16")
rec3 = rec3.save(folder=cache_folder / "rec1")
assert rec3.get_dtype() == "float16"
# W must be sparse
np.sum(W == 0) == 6

# test parallel
rec_par = rec3.save(folder=cache_folder / "rec_par", n_jobs=2)
np.testing.assert_array_equal(rec3.get_traces(segment_index=0), rec_par.get_traces(segment_index=0))
rec2 = whiten(rec)
rec2.save(verbose=False)

with pytest.raises(AssertionError):
rec4 = whiten(rec_int, dtype=None)
rec4 = whiten(rec_int, dtype=None, int_scale=256)
assert rec4.get_dtype() == "int16"
assert rec4._kwargs["M"] is None
# test dtype
rec_int = scale(rec2, dtype="int16")
rec3 = whiten(rec_int, dtype="float16")
rec3 = rec3.save(folder=cache_folder / "rec1")
assert rec3.get_dtype() == "float16"

# test regularization : norm should be smaller
W2, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, regularize=True)
assert np.linalg.norm(W1) > np.linalg.norm(W2)
# test parallel
rec_par = rec3.save(folder=cache_folder / "rec_par", n_jobs=2)
np.testing.assert_array_equal(rec3.get_traces(segment_index=0), rec_par.get_traces(segment_index=0))

with pytest.raises(AssertionError):
rec4 = whiten(rec_int, dtype=None) # int_scale should be applied
rec4 = whiten(rec_int, dtype=None, int_scale=256)
assert rec4.get_dtype() == "int16"
assert rec4._kwargs["M"] is None

# if __name__ == "__main__":
# test_compute_whitening_matrix()
# test regularization : norm should be smaller
W2, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, regularize=True)
assert np.linalg.norm(W1) > np.linalg.norm(W2)
Loading

0 comments on commit 7fa0349

Please sign in to comment.