diff --git a/gpax/models/gp.py b/gpax/models/gp.py index f249377..98f943e 100644 --- a/gpax/models/gp.py +++ b/gpax/models/gp.py @@ -251,7 +251,7 @@ def get_samples(self, chain_dim: bool = False) -> Dict[str, jnp.ndarray]: return self.mcmc.get_samples(group_by_chain=chain_dim) def get_mvn_posterior( - self, X_new: jnp.ndarray, params: Dict[str, jnp.ndarray], noiseless: bool = False, **kwargs: float + self, X_new: jnp.ndarray, params: Dict[str, jnp.ndarray], noiseless: bool = False, use_cholesky: bool = False, **kwargs: float ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ Returns parameters (mean and cov) of multivariate normal posterior @@ -267,13 +267,24 @@ def get_mvn_posterior( k_pp = self.kernel(X_new, X_new, params, noise_p, **kwargs) k_pX = self.kernel(X_new, self.X_train, params, jitter=0.0) k_XX = self.kernel(self.X_train, self.X_train, params, noise, **kwargs) - # compute the predictive covariance and mean - K_xx_inv = jnp.linalg.inv(k_XX) - cov = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX))) - mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, y_residual)) + + # Compute the predictive covariance and mean + # since K_xx is symmetric positive-definite, we can use the more efficient and + # stable Cholesky decomposition instead of matrix inversion + + if use_cholesky: + K_xx_cho = jax.scipy.linalg.cho_factor(k_XX) + cov = k_pp - jnp.matmul(k_pX, jax.scipy.linalg.cho_solve(K_xx_cho, k_pX.T)) + mean = jnp.matmul(k_pX, jax.scipy.linalg.cho_solve(K_xx_cho, y_residual)) + else: + K_xx_inv = jnp.linalg.inv(k_XX) + cov = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX))) + mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, y_residual)) + if self.mean_fn is not None: args = [X_new, params] if self.mean_fn_prior else [X_new] mean += self.mean_fn(*args).squeeze() + return mean, cov def _predict( @@ -283,11 +294,12 @@ def _predict( params: Dict[str, jnp.ndarray], n: int, noiseless: bool = False, + use_cholesky: bool = False, **kwargs: float ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Prediction with a single sample of GP parameters""" # Get the predictive mean and covariance - y_mean, K = self.get_mvn_posterior(X_new, params, noiseless, **kwargs) + y_mean, K = self.get_mvn_posterior(X_new, params, noiseless, use_cholesky, **kwargs) # draw samples from the posterior predictive for a given set of parameters y_sampled = dist.MultivariateNormal(y_mean, K).sample(rng_key, sample_shape=(n,)) return y_mean, y_sampled @@ -304,10 +316,11 @@ def _predict_in_batches( predict_fn: Callable[[jnp.ndarray, int], Tuple[jnp.ndarray]] = None, noiseless: bool = False, device: Type[jaxlib.xla_extension.Device] = None, + use_cholesky: bool = False, **kwargs: float ) -> Tuple[jnp.ndarray, jnp.ndarray]: if predict_fn is None: - predict_fn = lambda xi: self.predict(rng_key, xi, samples, n, filter_nans, noiseless, device, **kwargs) + predict_fn = lambda xi: self.predict(rng_key, xi, samples, n, filter_nans, noiseless, device, use_cholesky, **kwargs) def predict_batch(Xi): out1, out2 = predict_fn(Xi) @@ -333,6 +346,7 @@ def predict_in_batches( predict_fn: Callable[[jnp.ndarray, int], Tuple[jnp.ndarray]] = None, noiseless: bool = False, device: Type[jaxlib.xla_extension.Device] = None, + use_cholesky: bool = False, **kwargs: float ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ @@ -342,7 +356,7 @@ def predict_in_batches( to avoid a memory overflow """ y_pred, y_sampled = self._predict_in_batches( - rng_key, X_new, batch_size, 0, samples, n, filter_nans, predict_fn, noiseless, device, **kwargs + rng_key, X_new, batch_size, 0, samples, n, filter_nans, predict_fn, noiseless, device, use_cholesky, **kwargs ) y_pred = jnp.concatenate(y_pred, 0) y_sampled = jnp.concatenate(y_sampled, -1) @@ -357,6 +371,7 @@ def predict( filter_nans: bool = False, noiseless: bool = False, device: Type[jaxlib.xla_extension.Device] = None, + use_cholesky: bool = False, **kwargs: float ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ @@ -391,7 +406,7 @@ def predict( samples = jax.device_put(samples, device) num_samples = len(next(iter(samples.values()))) vmap_args = (jra.split(rng_key, num_samples), samples) - predictive = jax.vmap(lambda prms: self._predict(prms[0], X_new, prms[1], n, noiseless, **kwargs)) + predictive = jax.vmap(lambda prms: self._predict(prms[0], X_new, prms[1], n, noiseless, use_cholesky, **kwargs)) y_means, y_sampled = predictive(vmap_args) if filter_nans: y_sampled_ = [y_i for y_i in y_sampled if not jnp.isnan(y_i).any()] diff --git a/tests/test_gp.py b/tests/test_gp.py index 25edadf..7d3d223 100644 --- a/tests/test_gp.py +++ b/tests/test_gp.py @@ -169,6 +169,21 @@ def test_get_mvn_posterior_noiseless(): assert_array_equal(mean1, mean2) assert onp.count_nonzero(cov1 - cov2) > 0 +def test_get_mvn_posterior_cholesky(): + X, y = get_dummy_data(unsqueeze=True) + X_test, _ = get_dummy_data(unsqueeze=True) + params = {"k_length": jnp.array([1.0]), + "k_scale": jnp.array(1.0), + "noise": jnp.array(0.1)} + m = ExactGP(1, 'RBF') + m.X_train = X + m.y_train = y + mean, cov = m.get_mvn_posterior(X_test, params, use_cholesky=True) + assert isinstance(mean, jnp.ndarray) + assert isinstance(cov, jnp.ndarray) + assert_equal(mean.shape, (X_test.shape[0],)) + assert_equal(cov.shape, (X_test.shape[0], X_test.shape[0])) + def test_single_sample_prediction(): rng_key = get_keys()[0]