Skip to content

Commit

Permalink
Add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Jan 5, 2024
1 parent c2a56dc commit 06627e3
Showing 1 changed file with 51 additions and 1 deletion.
52 changes: 51 additions & 1 deletion tests/test_uigp.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,54 @@ def test_fit():
X, y = get_dummy_data()
m = UIGP(1, 'RBF')
m.fit(rng_key, X, y, num_warmup=10, num_samples=10)
assert m.mcmc is not None
assert m.mcmc is not None


def test_fit_with_custom_sigma_x_prior():
rng_key = get_keys()[0]
X, y = get_dummy_data()
m = UIGP(1, 'RBF', sigma_x_prior_dist=dist.HalfNormal(0.55))
m.fit(rng_key, X, y, num_warmup=10, num_samples=10)
assert m.mcmc is not None


def test_get_mvn_posterior():
X, y = get_dummy_data()
X_test, _ = get_dummy_data()
X = X[:, None]
X_test = X_test[:, None]
params = {"k_length": jnp.array([1.0]),
"k_scale": jnp.array(1.0),
"noise": jnp.array(0.1),
"k_noise_length": jnp.array(0.5),
"sigma_x": jnp.array(0.3),
"X_prime": jnp.array(X + 0.1)
}
m = UIGP(1, 'RBF')
m.X_train = X
m.y_train = y
mean, cov = m.get_mvn_posterior(X_test, params)
assert isinstance(mean, jnp.ndarray)
assert isinstance(cov, jnp.ndarray)
assert_equal(mean.shape, (X_test.shape[0],))
assert_equal(cov.shape, (X_test.shape[0], X_test.shape[0]))


@pytest.mark.parametrize("noiseless", [True, False])
def test_predict(noiseless):
key = get_keys()[0]
X, y = get_dummy_data()
X_test, _ = get_dummy_data()
X = X[:, None]
X_test = X_test[:, None]
params = {"k_length": jnp.array([1.0]),
"k_scale": jnp.array(1.0),
"noise": jnp.array(0.1),
"k_noise_length": jnp.array(0.5),
"sigma_x": jnp.array(0.3),
"X_prime": jnp.array(X + 0.1)
}
m = UIGP(1, 'RBF')
m.X_train = X
m.y_train = y
m._predict(key, X_test, params, 5, noiseless)

0 comments on commit 06627e3

Please sign in to comment.