Skip to content

Commit

Permalink
Update tests for BBI(2)
Browse files Browse the repository at this point in the history
  • Loading branch information
achamma committed Jul 8, 2024
1 parent 5e6df55 commit 1e8f393
Showing 1 changed file with 25 additions and 84 deletions.
109 changes: 25 additions & 84 deletions hidimstat/test/test_BBI.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ def _generate_data(n_samples=100, n_features=10, prob_type="regression",
return X, y, grps, list_nominal


def test_BBI_DNN():
def test_BBI_condDNN():

X, y, grps, list_nominal = _generate_data()
X, y, _, list_nominal = _generate_data()
# Compute importance with residuals
bbi_res = BlockBasedImportance(
estimator=None,
Expand Down Expand Up @@ -128,85 +128,26 @@ def test_BBI_DNN():
results_samp = bbi_samp.compute_importance()
assert len(results_samp["pval"]) == X.shape[1]

# def test_BBI_samplingRF():

# X, y, grps, list_nominal = _generate_data()
# # Permutation Method
# bbi_perm = BlockBasedImportance(
# estimator=None,
# importance_estimator="Mod_RF",
# do_hyper=True,
# dict_hyper=None,
# conditional=False,
# group_stacking=False,
# prob_type="regression",
# k_fold=2,
# list_nominal=list_nominal,
# n_jobs=10,
# verbose=0,
# n_perm=100,
# )
# bbi_perm.fit(X, y)
# results_perm = bbi_perm.compute_importance()
# assert len(results_perm["pval"]) == X.shape[1]
# # Conditional Method
# bbi_cond = BlockBasedImportance(
# estimator='RF',
# importance_estimator="Mod_RF",
# do_hyper=True,
# dict_hyper=None,
# conditional=True,
# group_stacking=False,
# prob_type="regression",
# k_fold=k_fold,
# list_nominal=list_nominal,
# n_jobs=10,
# verbose=0,
# n_perm=100,
# )
# bbi_cond.fit(X, y)
# results_cond = bbi_cond.compute_importance()
# pvals_cond = -np.log10(results_cond["pval"] + 1e-5)
# assert len(pvals_cond) == X.shape[1]


# def test_BBI_residuals():
# # Permutation Method
# bbi_perm = BlockBasedImportance(
# estimator='RF',
# importance_estimator=None,
# do_hyper=True,
# dict_hyper=None,
# conditional=False,
# group_stacking=False,
# prob_type="regression",
# k_fold=k_fold,
# list_nominal=list_nominal,
# n_jobs=10,
# verbose=0,
# n_perm=100,
# )
# bbi_perm.fit(X, y)
# results_perm = bbi_perm.compute_importance()
# pvals_perm = -np.log10(results_perm["pval"] + 1e-10)
# assert len(pvals_perm) == X.shape[1]

# # Conditional Method
# bbi_cond = BlockBasedImportance(
# estimator='RF',
# importance_estimator=None,
# do_hyper=True,
# dict_hyper=None,
# conditional=True,
# group_stacking=False,
# prob_type="regression",
# k_fold=k_fold,
# list_nominal=list_nominal,
# n_jobs=10,
# verbose=0,
# n_perm=100,
# )
# bbi_cond.fit(X, y)
# results_cond = bbi_cond.compute_importance()
# pvals_cond = -np.log10(results_cond["pval"] + 1e-5)
# assert len(pvals_cond) == X.shape[1]

def test_BBI_permDNN():

X, y, _, list_nominal = _generate_data()
# Compute importance with sampling RF
bbi_perm = BlockBasedImportance(
estimator=None,
importance_estimator="Mod_RF",
do_hyper=True,
dict_hyper=None,
conditional=False,
group_stacking=False,
prob_type="regression",
k_fold=2,
list_nominal=list_nominal,
n_jobs=10,
verbose=0,
n_perm=100,
)
bbi_perm.fit(X, y)
results_perm = bbi_perm.compute_importance()
assert len(results_perm["pval"]) == X.shape[1]

0 comments on commit 1e8f393

Please sign in to comment.