From 7465939d4393034a4d68d9b7a813573eac5a2573 Mon Sep 17 00:00:00 2001 From: "Julio A. Peraza" <52050407+JulioAPeraza@users.noreply.github.com> Date: Tue, 9 Apr 2024 11:38:14 -0400 Subject: [PATCH] Add an inflation factor to correct for multiple contrasts in Stouffer's combination test (#117) * Add correction term for multiple contrasts in Stouffer's combination test * Update RTD yml * Update .readthedocs.yml * Update combination.py * Update .readthedocs.yml * Update setup.cfg * Update testing.yml * Update testing.yml * Run black * Make sure solutions and symbols match * Update combination.py --- pymare/estimators/combination.py | 76 +++++++++++++++++++++++--- pymare/tests/test_combination_tests.py | 37 +++++++++++++ 2 files changed, 105 insertions(+), 8 deletions(-) diff --git a/pymare/estimators/combination.py b/pymare/estimators/combination.py index 9bad45b..6f9f41a 100644 --- a/pymare/estimators/combination.py +++ b/pymare/estimators/combination.py @@ -111,17 +111,77 @@ class StoufferCombinationTest(CombinationTest): """ # Maps Dataset attributes onto fit() args; see BaseEstimator for details. - _dataset_attr_map = {"z": "y", "w": "v"} - - def fit(self, z, w=None): - """Fit the estimator to z-values, optionally with weights.""" - return super().fit(z, w=w) - - def p_value(self, z, w=None): + _dataset_attr_map = {"z": "y", "w": "n", "g": "v"} + + def _inflation_term(self, z, w, g): + """Calculate the variance inflation term for each group. + + This term is used to adjust the variance of the combined z-score when + multiple sample come from the same study. + + Parameters + ---------- + z : :obj:`numpy.ndarray` of shape (n, d) + Array of z-values. + w : :obj:`numpy.ndarray` of shape (n, d) + Array of weights. + g : :obj:`numpy.ndarray` of shape (n, d) + Array of group labels. + + Returns + ------- + sigma : float + The variance inflation term. + """ + # Only center if the samples are not all the same, to prevent division by zero + # when calculating the correlation matrix. + # This centering is problematic for N=2 + all_samples_same = np.all(np.equal(z, z[0]), axis=0).all() + z = z if all_samples_same else z - z.mean(0) + + # Use the value from one feature, as all features have the same groups and weights + groups = g[:, 0] + weights = w[:, 0] + + # Loop over groups + unique_groups = np.unique(groups) + + sigma = 0 + for group in unique_groups: + group_indices = np.where(groups == group)[0] + group_z = z[group_indices] + + # For groups with only one sample the contribution to the summand is 0 + n_samples = len(group_indices) + if n_samples < 2: + continue + + # Calculate the within group correlation matrix and sum the non-diagonal elements + corr = np.corrcoef(group_z, rowvar=True) + upper_indices = np.triu_indices(n_samples, k=1) + non_diag_corr = corr[upper_indices] + w_i, w_j = weights[upper_indices[0]], weights[upper_indices[1]] + + sigma += (2 * w_i * w_j * non_diag_corr).sum() + + return sigma + + def fit(self, z, w=None, g=None): + """Fit the estimator to z-values, optionally with weights and groups.""" + return super().fit(z, w=w, g=g) + + def p_value(self, z, w=None, g=None): """Calculate p-values.""" if w is None: w = np.ones_like(z) - cz = (z * w).sum(0) / np.sqrt((w**2).sum(0)) + + # Calculate the variance inflation term, sum of non-diagonal elements of sigma. + sigma = self._inflation_term(z, w, g) if g is not None else 0 + + # The sum of diagonal elements of sigma is given by (w**2).sum(0). + variance = (w**2).sum(0) + sigma + + cz = (z * w).sum(0) / np.sqrt(variance) return ss.norm.sf(cz) diff --git a/pymare/tests/test_combination_tests.py b/pymare/tests/test_combination_tests.py index 36e2918..9f5f7e1 100644 --- a/pymare/tests/test_combination_tests.py +++ b/pymare/tests/test_combination_tests.py @@ -42,3 +42,40 @@ def test_combination_test_from_dataset(Cls, data, mode, expected): results = est.summary() z = ss.norm.isf(results.p) assert np.allclose(z, expected, atol=1e-5) + + +def test_stouffer_adjusted(): + """Test StoufferCombinationTest with weights and groups.""" + # Test with weights and groups + data = np.array( + [ + [2.1, 0.7, -0.2, 4.1, 3.8], + [1.1, 0.2, 0.4, 1.3, 1.5], + [-0.6, -1.6, -2.3, -0.8, -4.0], + [2.5, 1.7, 2.1, 2.3, 2.5], + [3.1, 2.7, 3.1, 3.3, 3.5], + [3.6, 3.2, 3.6, 3.8, 4.0], + ] + ) + weights = np.tile(np.array([4, 3, 4, 10, 15, 10]), (data.shape[1], 1)).T + groups = np.tile(np.array([0, 0, 1, 2, 2, 2]), (data.shape[1], 1)).T + + results = StoufferCombinationTest("directed").fit(z=data, w=weights, g=groups).params_ + z = ss.norm.isf(results["p"]) + + z_expected = np.array([5.00088912, 3.70356943, 4.05465924, 5.4633001, 5.18927878]) + assert np.allclose(z, z_expected, atol=1e-5) + + # Test with weights and no groups. Limiting cases. + # Limiting case 1: all correlations are one. + n_maps_l1 = 5 + common_sample = np.array([2.1, 0.7, -0.2]) + data_l1 = np.tile(common_sample, (n_maps_l1, 1)) + groups_l1 = np.tile(np.array([0, 0, 0, 0, 0]), (data_l1.shape[1], 1)).T + + results_l1 = StoufferCombinationTest("directed").fit(z=data_l1, g=groups_l1).params_ + z_l1 = ss.norm.isf(results_l1["p"]) + + sigma_l1 = n_maps_l1 * (n_maps_l1 - 1) # Expected inflation term + z_expected_l1 = n_maps_l1 * common_sample / np.sqrt(n_maps_l1 + sigma_l1) + assert np.allclose(z_l1, z_expected_l1, atol=1e-5)