Skip to content

Commit

Permalink
Add option to choose b/w BNN and regular NN
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax authored Feb 2, 2024
1 parent 8e97fb8 commit 672eae8
Showing 1 changed file with 26 additions and 12 deletions.
38 changes: 26 additions & 12 deletions gpax/models/vidkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import numpyro.distributions as dist
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoDelta, AutoNormal
from numpyro.contrib.module import random_haiku_module
from numpyro.contrib.module import random_haiku_module, haiku_module
from jax import jit
import haiku as hk

Expand Down Expand Up @@ -68,7 +68,7 @@ class viDKL(ExactGP):

def __init__(self, input_dim: Union[int, Tuple[int]], z_dim: int = 2, kernel: str = 'RBF',
kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
nn: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None,
nn: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None, nn_prior: bool = False,
latent_prior: Optional[Callable[[jnp.ndarray], Dict[str, jnp.ndarray]]] = None,
guide: str = 'delta', **kwargs
) -> None:
Expand All @@ -77,6 +77,7 @@ def __init__(self, input_dim: Union[int, Tuple[int]], z_dim: int = 2, kernel: st
raise NotImplementedError("Select guide between 'delta' and 'normal'")
nn_module = nn if nn else MLP
self.nn_module = hk.transform(lambda x: nn_module(z_dim)(x))
self.nn_prior = nn_prior
self.kernel_dim = z_dim
self.data_dim = (input_dim,) if isinstance(input_dim, int) else input_dim
self.latent_prior = latent_prior
Expand All @@ -87,9 +88,13 @@ def __init__(self, input_dim: Union[int, Tuple[int]], z_dim: int = 2, kernel: st
def model(self, X: jnp.ndarray, y: jnp.ndarray = None, **kwargs) -> None:
"""DKL probabilistic model"""
# NN part
feature_extractor = random_haiku_module(
"feature_extractor", self.nn_module, input_shape=(1, *self.data_dim),
prior=(lambda name, shape: dist.Cauchy() if name.startswith("b") else dist.Normal()))
if self.nn_prior: # MAP
feature_extractor = random_haiku_module(
"feature_extractor", self.nn_module, input_shape=(1, *self.data_dim),
prior=(lambda name, shape: dist.Cauchy() if name.startswith("b") else dist.Normal()))
else: # MLE
feature_extractor = haiku_module(
"feature_extractor", self.nn_module, input_shape=(1, *self.data_dim))
z = feature_extractor(X)
if self.latent_prior: # Sample latent variable
z = self.latent_prior(z)
Expand Down Expand Up @@ -134,14 +139,23 @@ def single_fit(self, rng_key: jnp.array, X: jnp.ndarray, y: jnp.ndarray,
y=y,
**kwargs
)

params, _, losses = svi.run(rng_key, num_steps, progress_bar=progress_bar)
# Get DKL parameters from the guide
params_map = svi.guide.median(params)
# Get NN weights
nn_params = get_haiku_dict(params_map)
# Get GP kernel hyperparmeters
kernel_params = {k: v for (k, v) in params_map.items()
if not k.startswith("feature_extractor")}

# Get DKL trained parameters from the guide
if self.nn_prior: # MAP
params_map = svi.guide.median(params)
# Get NN weights
nn_params = get_haiku_dict(params_map)
# Get GP kernel hyperparmeters
kernel_params = {k: v for (k, v) in params_map.items()
if not k.startswith("feature_extractor")}
else: # MLE
# Get NN weights
nn_params = params["feature_extractor$params"]
# Get kernel parameters from the guide
kernel_params = svi.guide.median(params)

return nn_params, kernel_params, losses

def fit(self, rng_key: jnp.array, X: jnp.ndarray, y: jnp.ndarray,
Expand Down

0 comments on commit 672eae8

Please sign in to comment.