From f79ad812466ac262babb5dae2734380dec7899b8 Mon Sep 17 00:00:00 2001
From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com>
Date: Wed, 3 Jan 2024 16:44:52 -0800
Subject: [PATCH] Fix imports

---
 gpax/models/sigp.py | 11 ++++++-----
 1 file changed, 6 insertions(+), 5 deletions(-)

diff --git a/gpax/models/sigp.py b/gpax/models/sigp.py
index e76382c..a029577 100644
--- a/gpax/models/sigp.py
+++ b/gpax/models/sigp.py
@@ -13,6 +13,8 @@
 import numpyro
 import numpyro.distributions as dist
 
+from . import ExactGP
+
 kernel_fn_type = Callable[[jnp.ndarray, jnp.ndarray, Dict[str, jnp.ndarray], jnp.ndarray], jnp.ndarray]
 
 
@@ -27,9 +29,8 @@ def __init__(self,
                  noise_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
                  noise_prior_dist: Optional[dist.Distribution] = None,
                  lengthscale_prior_dist: Optional[dist.Distribution] = None,
-                 sigma_x_prior_dist: Optional[dist.Distribution] = None,
-
-    ) -> None:
+                 sigma_x_prior_dist: Optional[dist.Distribution] = None
+                 ) -> None:
         args = (input_dim, kernel, mean_fn, kernel_prior, mean_fn_prior, noise_prior, noise_prior_dist, lengthscale_prior_dist)
         super(siGP, self).__init__(*args)
         self.sigma_x_prior_dist = sigma_x_prior_dist
@@ -71,7 +72,7 @@ def model(self, X: jnp.ndarray, y: jnp.ndarray = None, **kwargs: float) -> None:
 
     def _sample_x(self, X):
         if self.sigma_x_prior_dist is not None:
-            sigma_x_dist = self.sigma_x_prior_dist 
+            sigma_x_dist = self.sigma_x_prior_dist
         else:
             sigma_x_dist = dist.HalfNormal(1)
         sigma_x = numpyro.sample("sigma_x", sigma_x_dist)
@@ -125,4 +126,4 @@ def _predict(
 
     def _print_summary(self):
         samples = self.get_samples(1)
-        numpyro.diagnostics.print_summary({k: v for (k, v) in samples.items() if 'X_prime' not in k})
\ No newline at end of file
+        numpyro.diagnostics.print_summary({k: v for (k, v) in samples.items() if 'X_prime' not in k})