Skip to content

Commit

Permalink
Add test for _sample_x with differnet number of features
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax authored Jan 21, 2024
1 parent 525c386 commit c11a880
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions tests/test_uigp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@ def get_dummy_data():
return jnp.array(X_prime), jnp.array(y)


@pytest.mark.parametrize("n_features", [1, 5])
def test_sample_x(n_features):
X = onp.random.randn(32, n_features)
m = UIGP(n_features, 'RBF')
with numpyro.handlers.seed(rng_seed=0):
X_prime = m._sample_x(X)
assert_(isinstance(X_prime, jnp.ndarray))
assert_(X_prime.shape[-1], n_features)


def test_fit():
rng_key = get_keys()[0]
X, y = get_dummy_data()
Expand Down

0 comments on commit c11a880

Please sign in to comment.