diff --git a/gpax/models/hskgp.py b/gpax/models/hskgp.py index cb797e2..f25f170 100644 --- a/gpax/models/hskgp.py +++ b/gpax/models/hskgp.py @@ -48,8 +48,7 @@ class VarNoiseGP(ExactGP): Optional priors over noise mean function noise_lengthscale_prior_dist: Optional custom prior distribution over noise kernel lengthscale. Defaults to LogNormal(0, 1). - - Examples: + Examples: Use two different kernels with default priors for main and noise processes @@ -162,7 +161,7 @@ def _sample_noise_kernel_params(self) -> Dict[str, jnp.ndarray]: return {"k_noise_length": noise_length, "k_noise_scale": noise_scale} def get_mvn_posterior( - self, X_new: jnp.ndarray, params: Dict[str, jnp.ndarray], *arg, **kwargs + self, X_new: jnp.ndarray, params: Dict[str, jnp.ndarray], *args, **kwargs ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ Returns parameters (mean and cov) of multivariate normal posterior @@ -186,17 +185,6 @@ def get_mvn_posterior( mean += self.mean_fn(*args).squeeze() # Noise GP part - predicted_log_var = self.get_noise_mvn_posterior(X_new, params, **kwargs) - predicted_var = jnp.exp(predicted_log_var) - - # Return the main GP's predictive mean and combined (main + noise) covariance matrix - return mean, cov + jnp.diag(predicted_var) - - def get_noise_mvn_posterior(self, - X_new: jnp.ndarray, - params: Dict[str, jnp.ndarray], - **kwargs - ) -> Tuple[jnp.ndarray, jnp.ndarray]: # Compute noise kernel matrices k_pX_noise = self.noise_kernel(X_new, self.X_train, params, jitter=0.0) k_XX_noise = self.noise_kernel(self.X_train, self.X_train, params, 0, **kwargs) @@ -210,30 +198,22 @@ def get_noise_mvn_posterior(self, if self.noise_mean_fn is not None: args = [X_new, params] if self.noise_mean_fn_prior else [X_new] predicted_log_var += jnp.log(self.noise_mean_fn(*args)).squeeze() - - #k_pp_noise = self.noise_kernel(X_new, X_new, params, 0, **kwargs) - #cov_noise = k_pp_noise - jnp.matmul(k_pX_noise, jnp.matmul(K_xx_noise_inv, jnp.transpose(k_pX_noise))) - - return predicted_log_var - - # def get_data_var_samples(self): - # """Returns samples with inferred (training) data variance - aka noise""" - # samples = self.mcmc.get_samples() - # log_var = samples["log_var"] - # if self.noise_mean_fn is not None: - # if self.noise_mean_fn_prior is not None: - # mean_ = jax.vmap(self.noise_mean_fn, in_axes=(None, 0))(self.X_train.squeeze(), samples) - # else: - # mean_ = self.noise_mean_fn(self.X_train.squeeze()) - # log_var += jnp.log(mean_) - # return jnp.exp(samples["log_var"]) - - def get_data_var_samples(self, **kwargs): + predicted_noise_variance = jnp.exp(predicted_log_var) + + # Return the main GP's predictive mean and combined (main + noise) covariance matrix + return mean, cov + jnp.diag(predicted_noise_variance) + + def get_data_var_samples(self): """Returns samples with inferred (training) data variance - aka noise""" - predict_ = lambda p: self.get_noise_mvn_posterior(self.X_train, p, **kwargs) samples = self.mcmc.get_samples() - predicted_log_var = jax.vmap(predict_)(samples) - return jnp.exp(predicted_log_var) + log_var = samples["log_var"] + if self.noise_mean_fn is not None: + if self.noise_mean_fn_prior is not None: + mean_ = jax.vmap(self.noise_mean_fn, in_axes=(None, 0))(self.X_train.squeeze(), samples) + else: + mean_ = self.noise_mean_fn(self.X_train.squeeze()) + log_var += jnp.log(mean_) + return jnp.exp(samples["log_var"]) def _print_summary(self): samples = self.get_samples(1)