Skip to content

Commit

Permalink
TST: increase coverage, part I, stats.py (#133)
Browse files Browse the repository at this point in the history
* TST: add test for _find_stat_fun
* ENH, FIX: return ttest_1samp for one group
* TST: move threshold testing, add test for n jobs
* STY: thanks, hound!
* TST: test for paired True and False
* TST: a test for ANOVA analytical threshold
* TST, FIX: forgot to indent the test
* STY: make hound happy
* TST: test tails
* STY: add short docstrings to the tests
* STY: simplify argument passing
* TST: more liberal tests
  • Loading branch information
mmagnuski authored Nov 7, 2023
1 parent 18bc826 commit a9e3952
Show file tree
Hide file tree
Showing 3 changed files with 176 additions and 53 deletions.
46 changes: 0 additions & 46 deletions borsar/cluster/tests/test_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
from borsar.utils import has_numba, _get_test_data_dir
from borsar.cluster.utils import create_fake_data_for_cluster_test
from borsar.cluster.label import find_clusters, _cluster_3d_numpy
from borsar.cluster.stats import (_compute_threshold_via_permutations,
_find_stat_fun, _compute_threshold)
from borsar.cluster import construct_adjacency_matrix, cluster_based_regression


Expand Down Expand Up @@ -719,47 +717,3 @@ def remove_diagnoal_clusters(clusters, adjacency=None):
clusters, _ = find_clusters(
test_data, threshold=0.5, filter_fun_post=remove_diagnoal_clusters)
assert (clusters[0] == data_after_removing_diagonal_clusters).all()


def test_compute_threshold_via_permutations():
"""Make sure that threshold computed through permutations is correct.
Check that the threshold computed through permutations/randomization
on data that fulfills assumptions of analytical tests is sufficiently
close to the analytical threshold.
"""
n_groups = 2

for paired in [False, True]:
if paired:
n_obs = [101, 101]
data = [np.random.randn(n_obs[0])]
data.append(data[0] + np.random.randn(n_obs[0]))
else:
n_obs = [102, 100]
data = [np.random.randn(n) for n in n_obs]

analytical_threshold = _compute_threshold(
data=data, threshold=None, p_threshold=0.05, paired=paired,
one_sample=False)

stat_fun = _find_stat_fun(
n_groups, paired=paired, tail='both')

permutation_threshold = (
_compute_threshold_via_permutations(
data, paired=paired, tail='both', stat_fun=stat_fun,
n_permutations=2_000, progress=False
)
)

avg_perm = np.abs(permutation_threshold).mean()
error = analytical_threshold - avg_perm

print('paired:', paired)
print('analytical_threshold:', analytical_threshold)
print('permutation threshold:', permutation_threshold)
print('average permutation threshold:', avg_perm)
print('difference:', error)

assert np.abs(error) < 0.15
13 changes: 7 additions & 6 deletions borsar/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def stat_fun(*args):
fval, _ = f_oneway(*args)
return fval
return stat_fun
else:
elif n_groups == 2:
if paired:
from scipy.stats import ttest_rel

Expand All @@ -119,12 +119,13 @@ def stat_fun(*args):
return tval
return stat_fun
else:
from scipy.stats import ttest_ind

def stat_fun(*args):
tval, _ = ttest_ind(*args, equal_var=False)
return tval
from mne.stats import ttest_ind_no_p as stat_fun
return stat_fun
else:
# one group
from mne.stats import ttest_1samp_no_p as stat_fun
return stat_fun



# FIXME: streamline/simplify permutation reshaping and transposing
Expand Down
170 changes: 169 additions & 1 deletion borsar/tests/test_stats.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import time
import numpy as np
import pandas as pd

import statsmodels.api as sm
from borsar.stats import compute_regression_t, format_pvalue
from statsmodels.stats.anova import AnovaRM

from borsar.stats import (compute_regression_t, format_pvalue, _find_stat_fun,
_compute_threshold_via_permutations,
_compute_threshold)


def test_compute_regression_t():
Expand Down Expand Up @@ -86,3 +92,165 @@ def test_format_p_value():
assert format_pvalue(0.00025) == 'p < 0.001'
assert format_pvalue(0.000015) == 'p < 0.0001'
assert format_pvalue(0.000000000015) == 'p < 1e-10'


def test_find_stat_fun():
from scipy.stats import (ttest_rel, ttest_ind, ttest_1samp, f_oneway)

data1 = np.random.rand(2, 10, 20)

# independent samples t test
stat_fun = _find_stat_fun(n_groups=2, paired=False, tail='both')
t_val = stat_fun(data1[0], data1[1])
t_val_good, _ = ttest_ind(data1[0], data1[1])
np.testing.assert_almost_equal(t_val, t_val_good)

# paired samples t test
stat_fun = _find_stat_fun(n_groups=2, paired=True, tail='both')
t_val = stat_fun(data1[0], data1[1])
t_val_good, _ = ttest_rel(data1[0], data1[1])
np.testing.assert_almost_equal(t_val, t_val_good)

# one sample t test
stat_fun = _find_stat_fun(n_groups=1, paired=False, tail='both')
t_val = stat_fun(data1[0])
t_val_good, _ = ttest_1samp(data1[0], 0)
np.testing.assert_almost_equal(t_val, t_val_good)

# independent ANOVA
data2 = np.random.rand(3, 10, 20)
stat_fun = _find_stat_fun(n_groups=3, paired=False, tail='pos')
f_val = stat_fun(data2[0], data2[1], data2[2])
f_val_good, _ = f_oneway(data2[0], data2[1], data2[2])
np.testing.assert_almost_equal(f_val, f_val_good)

# paired ANOVA
data3 = np.random.rand(15, 4)
stat_fun = _find_stat_fun(n_groups=3, paired=True, tail='pos')
f_val = stat_fun(*data3.T).item()

subj = np.tile(np.arange(15)[:, None], [1, 4])
group = np.tile(np.arange(4)[None, :], [15, 1])
df = pd.DataFrame(data={'val': data3.ravel(), 'subj': subj.ravel(),
'rep': group.ravel()})
res = AnovaRM(data=df, depvar='val', subject='subj', within=['rep']).fit()
f_val_good = res.anova_table.loc['rep', 'F Value']

np.testing.assert_almost_equal(f_val, f_val_good)


def test_compute_threshold_via_permutations():
"""Make sure that threshold computed through permutations is correct.
Check that the threshold computed through permutations/randomization
on data that fulfills assumptions of analytical tests is sufficiently
close to the analytical threshold.
"""
n_groups = 2

for paired in [False, True]:
if paired:
n_obs = [101, 101]
data = [np.random.randn(n_obs[0])]
data.append(data[0] + np.random.randn(n_obs[0]))
else:
n_obs = [102, 100]
data = [np.random.randn(n) for n in n_obs]

analytical_threshold = _compute_threshold(
data=data, threshold=None, p_threshold=0.05, paired=paired,
one_sample=False)

stat_fun = _find_stat_fun(
n_groups, paired=paired, tail='both')

permutation_threshold = (
_compute_threshold_via_permutations(
data, paired=paired, tail='both', stat_fun=stat_fun,
n_permutations=2_000, progress=False
)
)

avg_perm = np.abs(permutation_threshold).mean()
error = analytical_threshold - avg_perm

print('paired:', paired)
print('analytical_threshold:', analytical_threshold)
print('permutation threshold:', permutation_threshold)
print('average permutation threshold:', avg_perm)
print('difference:', error)

assert np.abs(error) < 0.15


def test_compute_threshold_via_permutations_n_jobs():
'''Thresholds computed with different number of jobs should be similar.'''
data = [np.random.randn(12, 10, 10), np.random.randn(12, 10, 10)]
for paired in [True, False]:
stat_fun = _find_stat_fun(2, paired=paired, tail='both')

pos_thresh, neg_thresh = _compute_threshold_via_permutations(
data, paired=paired, tail='both', stat_fun=stat_fun,
n_permutations=1_000, progress=False
)
pos_thresh_jobs, neg_thresh_jobs = _compute_threshold_via_permutations(
data, paired=paired, tail='both', stat_fun=stat_fun,
n_permutations=1_000, progress=False, n_jobs=2
)

# most differences are < 0.5, but some are larger
# (independent permutation runs)
assert (np.abs(pos_thresh - pos_thresh_jobs) < 0.5).mean() > 0.9
assert (np.abs(neg_thresh - neg_thresh_jobs) < 0.5).mean() > 0.9


def test_compute_anova_analytical_threshold():
'''Make sure that analytical threshold for ANOVA is correct.'''

# the critical F value for alpha = 0.05 was taken from:
# https://www.dummies.com/article/business-careers-money/business/accounting/calculation-analysis/how-to-find-the-critical-values-for-an-anova-hypothesis-using-the-f-table-146050/
# v1 = 6
# v2 = 4
# v1 is the "numerator degrees of freedom"
# (for example n_groups – 1)
# v2 is the "denominator degrees of freedom"
# (for example total n_obs - n_groups )
# in the case of chosen v1 and v2: 7 groups, 11 obs

alpha = 0.05
F_critical = 6.16

val_lists = [[5.12], [3.23, 4.1], [5.5], [4.2, 4.81], [12.2], [6.5, 7.8],
[4.5, 5.2]]
data = [np.array(x) for x in val_lists]

thresh = _compute_threshold(
data, threshold=None, p_threshold=alpha, paired=False,
one_sample=False
)

assert f'{F_critical:.2f}' == f'{thresh:.2f}'


def test_permutation_threshold_tails():
'''Making sure that one-tail thresholds are more liberal than two-tail.'''
n_groups = 2
paired = True

stat_fun = _find_stat_fun(n_groups, paired=paired, tail='both')
data = np.random.rand(n_groups, 12, 10, 10)

args = dict(stat_fun=stat_fun, n_permutations=500, progress=False)
pos_thresh, neg_thresh = _compute_threshold_via_permutations(
data, paired=True, tail='both', **args
)

neg_thresh2 = _compute_threshold_via_permutations(
data, paired=True, tail='neg', **args
)
pos_thresh2 = _compute_threshold_via_permutations(
data, paired=True, tail='pos', **args
)

assert (neg_thresh < neg_thresh2).mean() > 0.92
assert (pos_thresh > pos_thresh2).mean() > 0.92

0 comments on commit a9e3952

Please sign in to comment.