diff --git a/gpax/models/sigp.py b/gpax/models/sigp.py index e76382c..a029577 100644 --- a/gpax/models/sigp.py +++ b/gpax/models/sigp.py @@ -13,6 +13,8 @@ import numpyro import numpyro.distributions as dist +from . import ExactGP + kernel_fn_type = Callable[[jnp.ndarray, jnp.ndarray, Dict[str, jnp.ndarray], jnp.ndarray], jnp.ndarray] @@ -27,9 +29,8 @@ def __init__(self, noise_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, noise_prior_dist: Optional[dist.Distribution] = None, lengthscale_prior_dist: Optional[dist.Distribution] = None, - sigma_x_prior_dist: Optional[dist.Distribution] = None, - - ) -> None: + sigma_x_prior_dist: Optional[dist.Distribution] = None + ) -> None: args = (input_dim, kernel, mean_fn, kernel_prior, mean_fn_prior, noise_prior, noise_prior_dist, lengthscale_prior_dist) super(siGP, self).__init__(*args) self.sigma_x_prior_dist = sigma_x_prior_dist @@ -71,7 +72,7 @@ def model(self, X: jnp.ndarray, y: jnp.ndarray = None, **kwargs: float) -> None: def _sample_x(self, X): if self.sigma_x_prior_dist is not None: - sigma_x_dist = self.sigma_x_prior_dist + sigma_x_dist = self.sigma_x_prior_dist else: sigma_x_dist = dist.HalfNormal(1) sigma_x = numpyro.sample("sigma_x", sigma_x_dist) @@ -125,4 +126,4 @@ def _predict( def _print_summary(self): samples = self.get_samples(1) - numpyro.diagnostics.print_summary({k: v for (k, v) in samples.items() if 'X_prime' not in k}) \ No newline at end of file + numpyro.diagnostics.print_summary({k: v for (k, v) in samples.items() if 'X_prime' not in k})