Skip to content

Commit

Permalink
Update hskgp.py
Browse files Browse the repository at this point in the history
ziatdinovmax authored Nov 16, 2023

Verified

This commit was signed with the committer’s verified signature.
derekpierre Derek Pierre
1 parent 3e3e893 commit dffaaab
Showing 1 changed file with 16 additions and 36 deletions.
52 changes: 16 additions & 36 deletions gpax/models/hskgp.py
Original file line number Diff line number Diff line change
@@ -48,8 +48,7 @@ class VarNoiseGP(ExactGP):
Optional priors over noise mean function
noise_lengthscale_prior_dist:
Optional custom prior distribution over noise kernel lengthscale. Defaults to LogNormal(0, 1).
Examples:
Examples:
Use two different kernels with default priors for main and noise processes
@@ -162,7 +161,7 @@ def _sample_noise_kernel_params(self) -> Dict[str, jnp.ndarray]:
return {"k_noise_length": noise_length, "k_noise_scale": noise_scale}

def get_mvn_posterior(
self, X_new: jnp.ndarray, params: Dict[str, jnp.ndarray], *arg, **kwargs
self, X_new: jnp.ndarray, params: Dict[str, jnp.ndarray], *args, **kwargs
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Returns parameters (mean and cov) of multivariate normal posterior
@@ -186,17 +185,6 @@ def get_mvn_posterior(
mean += self.mean_fn(*args).squeeze()

# Noise GP part
predicted_log_var = self.get_noise_mvn_posterior(X_new, params, **kwargs)
predicted_var = jnp.exp(predicted_log_var)

# Return the main GP's predictive mean and combined (main + noise) covariance matrix
return mean, cov + jnp.diag(predicted_var)

def get_noise_mvn_posterior(self,
X_new: jnp.ndarray,
params: Dict[str, jnp.ndarray],
**kwargs
) -> Tuple[jnp.ndarray, jnp.ndarray]:
# Compute noise kernel matrices
k_pX_noise = self.noise_kernel(X_new, self.X_train, params, jitter=0.0)
k_XX_noise = self.noise_kernel(self.X_train, self.X_train, params, 0, **kwargs)
@@ -210,30 +198,22 @@ def get_noise_mvn_posterior(self,
if self.noise_mean_fn is not None:
args = [X_new, params] if self.noise_mean_fn_prior else [X_new]
predicted_log_var += jnp.log(self.noise_mean_fn(*args)).squeeze()

#k_pp_noise = self.noise_kernel(X_new, X_new, params, 0, **kwargs)
#cov_noise = k_pp_noise - jnp.matmul(k_pX_noise, jnp.matmul(K_xx_noise_inv, jnp.transpose(k_pX_noise)))

return predicted_log_var

# def get_data_var_samples(self):
# """Returns samples with inferred (training) data variance - aka noise"""
# samples = self.mcmc.get_samples()
# log_var = samples["log_var"]
# if self.noise_mean_fn is not None:
# if self.noise_mean_fn_prior is not None:
# mean_ = jax.vmap(self.noise_mean_fn, in_axes=(None, 0))(self.X_train.squeeze(), samples)
# else:
# mean_ = self.noise_mean_fn(self.X_train.squeeze())
# log_var += jnp.log(mean_)
# return jnp.exp(samples["log_var"])

def get_data_var_samples(self, **kwargs):
predicted_noise_variance = jnp.exp(predicted_log_var)

# Return the main GP's predictive mean and combined (main + noise) covariance matrix
return mean, cov + jnp.diag(predicted_noise_variance)

def get_data_var_samples(self):
"""Returns samples with inferred (training) data variance - aka noise"""
predict_ = lambda p: self.get_noise_mvn_posterior(self.X_train, p, **kwargs)
samples = self.mcmc.get_samples()
predicted_log_var = jax.vmap(predict_)(samples)
return jnp.exp(predicted_log_var)
log_var = samples["log_var"]
if self.noise_mean_fn is not None:
if self.noise_mean_fn_prior is not None:
mean_ = jax.vmap(self.noise_mean_fn, in_axes=(None, 0))(self.X_train.squeeze(), samples)
else:
mean_ = self.noise_mean_fn(self.X_train.squeeze())
log_var += jnp.log(mean_)
return jnp.exp(samples["log_var"])

def _print_summary(self):
samples = self.get_samples(1)

0 comments on commit dffaaab

Please sign in to comment.