Skip to content

Commit

Permalink
Fix imports
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Jan 4, 2024
1 parent fb3c72f commit f79ad81
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions gpax/models/sigp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import numpyro
import numpyro.distributions as dist

from . import ExactGP

kernel_fn_type = Callable[[jnp.ndarray, jnp.ndarray, Dict[str, jnp.ndarray], jnp.ndarray], jnp.ndarray]


Expand All @@ -27,9 +29,8 @@ def __init__(self,
noise_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
noise_prior_dist: Optional[dist.Distribution] = None,
lengthscale_prior_dist: Optional[dist.Distribution] = None,
sigma_x_prior_dist: Optional[dist.Distribution] = None,

) -> None:
sigma_x_prior_dist: Optional[dist.Distribution] = None
) -> None:
args = (input_dim, kernel, mean_fn, kernel_prior, mean_fn_prior, noise_prior, noise_prior_dist, lengthscale_prior_dist)
super(siGP, self).__init__(*args)
self.sigma_x_prior_dist = sigma_x_prior_dist
Expand Down Expand Up @@ -71,7 +72,7 @@ def model(self, X: jnp.ndarray, y: jnp.ndarray = None, **kwargs: float) -> None:

def _sample_x(self, X):
if self.sigma_x_prior_dist is not None:
sigma_x_dist = self.sigma_x_prior_dist
sigma_x_dist = self.sigma_x_prior_dist
else:
sigma_x_dist = dist.HalfNormal(1)
sigma_x = numpyro.sample("sigma_x", sigma_x_dist)
Expand Down Expand Up @@ -125,4 +126,4 @@ def _predict(

def _print_summary(self):
samples = self.get_samples(1)
numpyro.diagnostics.print_summary({k: v for (k, v) in samples.items() if 'X_prime' not in k})
numpyro.diagnostics.print_summary({k: v for (k, v) in samples.items() if 'X_prime' not in k})

0 comments on commit f79ad81

Please sign in to comment.