-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a warning if data is not normalized
- Loading branch information
1 parent
f79ad81
commit 59f28d3
Showing
1 changed file
with
13 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,7 @@ | |
Created by Maxim Ziatdinov (email: [email protected]) | ||
""" | ||
|
||
import warnings | ||
from typing import Callable, Dict, Optional, Tuple, Union | ||
|
||
import jax.numpy as jnp | ||
|
@@ -19,7 +20,9 @@ | |
|
||
|
||
class siGP(ExactGP): | ||
|
||
""" | ||
Gaussian process with uncertain inputs | ||
""" | ||
def __init__(self, | ||
input_dim: int, | ||
kernel: Union[str, kernel_fn_type], | ||
|
@@ -39,6 +42,15 @@ def model(self, X: jnp.ndarray, y: jnp.ndarray = None, **kwargs: float) -> None: | |
""" | ||
Gaussian process model for uncertain (stochastic) inputs | ||
""" | ||
if not (X.max() == 1 and X.min() == 0): | ||
warnings.warn( | ||
"The default `sigma_x` prior for uncertain (stochastic) inputs assumes data is " | ||
"normalized to (0, 1), which is not be the case for your data. Therefore, the default prior " | ||
"may not be optimal for your case. Consider passing custom prior for sigma_x. For example, " | ||
"`sigma_x_prior_dist=numpyro.distributions.HalfNormal(scale)` if using NumPyro directly " | ||
"or `sigma_x_prior_dist=gpax.utils.halfnormal_dist(scale)` if using a GPax wrapper", | ||
UserWarning, | ||
) | ||
# Initialize mean function at zeros | ||
f_loc = jnp.zeros(X.shape[0]) | ||
|
||
|