Skip to content

Commit

Permalink
Fix linting errors in tests/
Browse files Browse the repository at this point in the history
  • Loading branch information
stes committed Oct 27, 2024
1 parent b0151d6 commit e228180
Show file tree
Hide file tree
Showing 7 changed files with 13 additions and 16 deletions.
2 changes: 1 addition & 1 deletion tests/test_criterions.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def _sample_dist_matrices(seed):


@pytest.mark.parametrize("seed", [42, 4242, 424242])
def test_infonce(seed):
def test_infonce_check_output_parts(seed):
pos_dist, neg_dist = _sample_dist_matrices(seed)

ref_loss, ref_align, ref_uniform = _reference_infonce(pos_dist, neg_dist)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def prepare(N=1000, n=128, d=5, probs=[0.3, 0.1, 0.6], device="cpu"):
continuous = torch.randn(N, d).to(device)

rand = torch.from_numpy(np.random.randint(0, N, (n,))).to(device)
qidx = discrete[rand].to(device)
_ = discrete[rand].to(device)
query = continuous[rand] + 0.1 * torch.randn(n, d).to(device)
query = query.to(device)

Expand Down Expand Up @@ -173,7 +173,7 @@ def test_mixed():
discrete, continuous)

reference_idx = distribution.sample_prior(10)
positive_idx = distribution.sample_conditional(reference_idx)
_ = distribution.sample_conditional(reference_idx)

# The conditional distribution p(· | disc, cont) should yield
# samples where the label exactly matches the reference sample.
Expand All @@ -193,7 +193,7 @@ def test_continuous(benchmark):
def _test_distribution(dist):
distribution = dist(continuous)
reference_idx = distribution.sample_prior(10)
positive_idx = distribution.sample_conditional(reference_idx)
_ = distribution.sample_conditional(reference_idx)
return distribution

distribution = _test_distribution(
Expand Down
5 changes: 2 additions & 3 deletions tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def generate_numpy_confounder(filename, dtype):


@register("npz")
def generate_numpy_path(filename, dtype):
def generate_numpy_path_2(filename, dtype):
A = np.arange(1000, dtype=dtype).reshape(10, 100)
np.savez(filename, array=A, other_data="test")
loaded_A = cebra_load.load(pathlib.Path(filename))
Expand Down Expand Up @@ -415,7 +415,7 @@ def generate_csv_path(filename, dtype):

@register_error("csv")
def generate_csv_empty_file(filename, dtype):
with open(filename, "w") as creating_new_csv_file:
with open(filename, "w") as _:
pass
_ = cebra_load.load(filename)

Expand Down Expand Up @@ -616,7 +616,6 @@ def generate_pickle_invalid_key(filename, dtype):

@register_error("pkl", "p")
def generate_pickle_no_array(filename, dtype):
A = np.arange(1000, dtype=dtype).reshape(10, 100)
with open(filename, "wb") as f:
pickle.dump({"A": "test_1", "B": "test_2"}, f)
_ = cebra_load.load(filename)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def test_version_check(version, raises):
cebra.models.model._check_torch_version(raise_error=True)


def test_version_check():
def test_version_check_2():
raises = not cebra.models.model._check_torch_version(raise_error=False)
if raises:
assert len(cebra.models.get_options("*dropout*")) == 0
Expand Down
2 changes: 1 addition & 1 deletion tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def test_compare_models():
_ = cebra_plot.compare_models(models, labels=long_labels, ax=ax)
with pytest.raises(ValueError, match="Invalid.*labels"):
invalid_labels = copy.deepcopy(labels)
ele = invalid_labels.pop()
_ = invalid_labels.pop()
invalid_labels.append(["a"])
_ = cebra_plot.compare_models(models, labels=invalid_labels, ax=ax)

Expand Down
5 changes: 1 addition & 4 deletions tests/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,6 @@ def test_api(estimator, check):
pytest.skip(f"Model architecture {estimator.model_architecture} "
f"requires longer input sizes than 20 samples.")

success = True
exception = None
num_successful = 0
total_runs = 0
Expand Down Expand Up @@ -334,7 +333,6 @@ def test_sklearn(model_architecture, device):
y_c1 = np.random.uniform(0, 1, (1000, 5))
y_c1_s2 = np.random.uniform(0, 1, (800, 5))
y_c2 = np.random.uniform(0, 1, (1000, 2))
y_c2_s2 = np.random.uniform(0, 1, (800, 2))
y_d = np.random.randint(0, 10, (1000,))
y_d_s2 = np.random.randint(0, 10, (800,))

Expand Down Expand Up @@ -817,7 +815,6 @@ def test_sklearn_full(model_architecture, device, pad_before_transform):
X = np.random.uniform(0, 1, (1000, 50))
y_c1 = np.random.uniform(0, 1, (1000, 5))
y_c2 = np.random.uniform(0, 1, (1000, 2))
y_d = np.random.randint(0, 10, (1000,))

# time contrastive
cebra_model.fit(X)
Expand Down Expand Up @@ -883,7 +880,7 @@ def test_sklearn_resampling_model_not_yet_supported(model_architecture, device):

with pytest.raises(ValueError):
cebra_model.fit(X, y_c1)
output = cebra_model.transform(X)
_ = cebra_model.transform(X)


def _iterate_actions():
Expand Down
7 changes: 4 additions & 3 deletions tests/test_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,12 @@ def test_single_session(data_name, loader_initfunc, solver_initfunc):
@pytest.mark.parametrize("data_name, loader_initfunc, solver_initfunc",
single_session_tests)
def test_single_session_auxvar(data_name, loader_initfunc, solver_initfunc):
return # TODO

pytest.skip("Not yet supported")

loader = _get_loader(data_name, loader_initfunc)
model = _make_model(loader.dataset)
behavior_model = _make_behavior_model(loader.dataset)
behavior_model = _make_behavior_model(loader.dataset) # noqa: F841

criterion = cebra.models.InfoNCE()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
Expand Down Expand Up @@ -171,7 +172,7 @@ def test_multi_session(data_name, loader_initfunc, solver_initfunc):

@pytest.mark.parametrize("data_name, loader_initfunc, solver_initfunc",
multi_session_tests)
def test_multi_session(data_name, loader_initfunc, solver_initfunc):
def test_multi_session_2(data_name, loader_initfunc, solver_initfunc):
loader = _get_loader(data_name, loader_initfunc)
criterion = cebra.models.InfoNCE()
model = nn.ModuleList(
Expand Down

0 comments on commit e228180

Please sign in to comment.