Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cholesky decomposition for ExactGP class #115

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Added the ability to use cholesky decomposition instead of naive inve…
…rse for the ExactGP class. Added a test for cholesky decomposition
mjbajwa committed Sep 13, 2024
commit 924ee820eaa4cbc00504aa9000c491358d367e9b
33 changes: 24 additions & 9 deletions gpax/models/gp.py
Original file line number Diff line number Diff line change
@@ -251,7 +251,7 @@ def get_samples(self, chain_dim: bool = False) -> Dict[str, jnp.ndarray]:
return self.mcmc.get_samples(group_by_chain=chain_dim)

def get_mvn_posterior(
self, X_new: jnp.ndarray, params: Dict[str, jnp.ndarray], noiseless: bool = False, **kwargs: float
self, X_new: jnp.ndarray, params: Dict[str, jnp.ndarray], noiseless: bool = False, use_cholesky: bool = False, **kwargs: float
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Returns parameters (mean and cov) of multivariate normal posterior
@@ -267,13 +267,24 @@ def get_mvn_posterior(
k_pp = self.kernel(X_new, X_new, params, noise_p, **kwargs)
k_pX = self.kernel(X_new, self.X_train, params, jitter=0.0)
k_XX = self.kernel(self.X_train, self.X_train, params, noise, **kwargs)
# compute the predictive covariance and mean
K_xx_inv = jnp.linalg.inv(k_XX)
cov = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX)))
mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, y_residual))

# Compute the predictive covariance and mean
# since K_xx is symmetric positive-definite, we can use the more efficient and
# stable Cholesky decomposition instead of matrix inversion

if use_cholesky:
K_xx_cho = jax.scipy.linalg.cho_factor(k_XX)
cov = k_pp - jnp.matmul(k_pX, jax.scipy.linalg.cho_solve(K_xx_cho, k_pX.T))
mean = jnp.matmul(k_pX, jax.scipy.linalg.cho_solve(K_xx_cho, y_residual))
else:
K_xx_inv = jnp.linalg.inv(k_XX)
cov = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX)))
mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, y_residual))

if self.mean_fn is not None:
args = [X_new, params] if self.mean_fn_prior else [X_new]
mean += self.mean_fn(*args).squeeze()

return mean, cov

def _predict(
@@ -283,11 +294,12 @@ def _predict(
params: Dict[str, jnp.ndarray],
n: int,
noiseless: bool = False,
use_cholesky: bool = False,
**kwargs: float
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Prediction with a single sample of GP parameters"""
# Get the predictive mean and covariance
y_mean, K = self.get_mvn_posterior(X_new, params, noiseless, **kwargs)
y_mean, K = self.get_mvn_posterior(X_new, params, noiseless, use_cholesky, **kwargs)
# draw samples from the posterior predictive for a given set of parameters
y_sampled = dist.MultivariateNormal(y_mean, K).sample(rng_key, sample_shape=(n,))
return y_mean, y_sampled
@@ -304,10 +316,11 @@ def _predict_in_batches(
predict_fn: Callable[[jnp.ndarray, int], Tuple[jnp.ndarray]] = None,
noiseless: bool = False,
device: Type[jaxlib.xla_extension.Device] = None,
use_cholesky: bool = False,
**kwargs: float
) -> Tuple[jnp.ndarray, jnp.ndarray]:
if predict_fn is None:
predict_fn = lambda xi: self.predict(rng_key, xi, samples, n, filter_nans, noiseless, device, **kwargs)
predict_fn = lambda xi: self.predict(rng_key, xi, samples, n, filter_nans, noiseless, device, use_cholesky, **kwargs)

def predict_batch(Xi):
out1, out2 = predict_fn(Xi)
@@ -333,6 +346,7 @@ def predict_in_batches(
predict_fn: Callable[[jnp.ndarray, int], Tuple[jnp.ndarray]] = None,
noiseless: bool = False,
device: Type[jaxlib.xla_extension.Device] = None,
use_cholesky: bool = False,
**kwargs: float
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
@@ -342,7 +356,7 @@ def predict_in_batches(
to avoid a memory overflow
"""
y_pred, y_sampled = self._predict_in_batches(
rng_key, X_new, batch_size, 0, samples, n, filter_nans, predict_fn, noiseless, device, **kwargs
rng_key, X_new, batch_size, 0, samples, n, filter_nans, predict_fn, noiseless, device, use_cholesky, **kwargs
)
y_pred = jnp.concatenate(y_pred, 0)
y_sampled = jnp.concatenate(y_sampled, -1)
@@ -357,6 +371,7 @@ def predict(
filter_nans: bool = False,
noiseless: bool = False,
device: Type[jaxlib.xla_extension.Device] = None,
use_cholesky: bool = False,
**kwargs: float
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
@@ -391,7 +406,7 @@ def predict(
samples = jax.device_put(samples, device)
num_samples = len(next(iter(samples.values())))
vmap_args = (jra.split(rng_key, num_samples), samples)
predictive = jax.vmap(lambda prms: self._predict(prms[0], X_new, prms[1], n, noiseless, **kwargs))
predictive = jax.vmap(lambda prms: self._predict(prms[0], X_new, prms[1], n, noiseless, use_cholesky, **kwargs))
y_means, y_sampled = predictive(vmap_args)
if filter_nans:
y_sampled_ = [y_i for y_i in y_sampled if not jnp.isnan(y_i).any()]
15 changes: 15 additions & 0 deletions tests/test_gp.py
Original file line number Diff line number Diff line change
@@ -169,6 +169,21 @@ def test_get_mvn_posterior_noiseless():
assert_array_equal(mean1, mean2)
assert onp.count_nonzero(cov1 - cov2) > 0

def test_get_mvn_posterior_cholesky():
X, y = get_dummy_data(unsqueeze=True)
X_test, _ = get_dummy_data(unsqueeze=True)
params = {"k_length": jnp.array([1.0]),
"k_scale": jnp.array(1.0),
"noise": jnp.array(0.1)}
m = ExactGP(1, 'RBF')
m.X_train = X
m.y_train = y
mean, cov = m.get_mvn_posterior(X_test, params, use_cholesky=True)
assert isinstance(mean, jnp.ndarray)
assert isinstance(cov, jnp.ndarray)
assert_equal(mean.shape, (X_test.shape[0],))
assert_equal(cov.shape, (X_test.shape[0], X_test.shape[0]))


def test_single_sample_prediction():
rng_key = get_keys()[0]