Skip to content

Commit

Permalink
Add a warning if data is not normalized
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Jan 4, 2024
1 parent f79ad81 commit 59f28d3
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion gpax/models/sigp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Created by Maxim Ziatdinov (email: [email protected])
"""

import warnings
from typing import Callable, Dict, Optional, Tuple, Union

import jax.numpy as jnp
Expand All @@ -19,7 +20,9 @@


class siGP(ExactGP):

"""
Gaussian process with uncertain inputs
"""
def __init__(self,
input_dim: int,
kernel: Union[str, kernel_fn_type],
Expand All @@ -39,6 +42,15 @@ def model(self, X: jnp.ndarray, y: jnp.ndarray = None, **kwargs: float) -> None:
"""
Gaussian process model for uncertain (stochastic) inputs
"""
if not (X.max() == 1 and X.min() == 0):
warnings.warn(
"The default `sigma_x` prior for uncertain (stochastic) inputs assumes data is "
"normalized to (0, 1), which is not be the case for your data. Therefore, the default prior "
"may not be optimal for your case. Consider passing custom prior for sigma_x. For example, "
"`sigma_x_prior_dist=numpyro.distributions.HalfNormal(scale)` if using NumPyro directly "
"or `sigma_x_prior_dist=gpax.utils.halfnormal_dist(scale)` if using a GPax wrapper",
UserWarning,
)
# Initialize mean function at zeros
f_loc = jnp.zeros(X.shape[0])

Expand Down

0 comments on commit 59f28d3

Please sign in to comment.