Skip to content

Commit

Permalink
Allow passing jitter as kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax authored Feb 7, 2024
1 parent 01fcbd6 commit e2d8c40
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions gpax/models/sparse_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@ def __init__(self, input_dim: int, kernel: str,
super(viSparseGP, self).__init__(*args)
self.Xu = None

def model(self, X: jnp.ndarray, y: jnp.ndarray = None, Xu: jnp.ndarray = None, **kwargs: float) -> None:
def model(self,
X: jnp.ndarray,
y: jnp.ndarray = None,
Xu: jnp.ndarray = None,
**kwargs: float) -> None:
if Xu is not None:
Xu = numpyro.param("Xu", Xu)
# Initialize mean function at zeros
Expand All @@ -77,8 +81,8 @@ def model(self, X: jnp.ndarray, y: jnp.ndarray = None, Xu: jnp.ndarray = None, *
if self.mean_fn_prior is not None:
args += [self.mean_fn_prior()]
f_loc += self.mean_fn(*args).squeeze()
# Xompute kernel between inducing points
Kuu = self.kernel(Xu, Xu, kernel_params)
# Compute kernel between inducing points
Kuu = self.kernel(Xu, Xu, kernel_params, **kwargs)
# Cholesky decomposition
Luu = cholesky(Kuu).T
# Compute kernel between inducing and training points
Expand Down Expand Up @@ -177,7 +181,7 @@ def get_mvn_posterior(
y_residual -= self.mean_fn(*args).squeeze()

# Compute self- and cross-covariance matrices
Kuu = self.kernel(self.Xu, self.Xu, params)
Kuu = self.kernel(self.Xu, self.Xu, params, **kwargs)
Luu = cholesky(Kuu, lower=True)
Kuf = self.kernel(self.Xu, self.X_train, params, jitter=0)

Expand All @@ -199,7 +203,7 @@ def get_mvn_posterior(
Linv_Ws = Linv_pack[:, W_Dinv_y.shape[1]:]
mean = (Linv_W_Dinv_y.T @ Linv_Ws).squeeze()

Kss = self.kernel(X_new, X_new, params, noise_p)
Kss = self.kernel(X_new, X_new, params, noise_p, **kwargs)
Qss = Ws.T @ Ws
cov = Kss - Qss + Linv_Ws.T @ Linv_Ws

Expand Down

0 comments on commit e2d8c40

Please sign in to comment.