Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Acquisition function optimization #65

Merged
merged 7 commits into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion gpax/acquisition/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .acquisition import UCB, EI, POI, UE, Thompson, KG
from .batch_acquisition import qEI, qPOI, qUCB, qKG
from .optimize import optimize_acq

__all__ = ["UCB", "EI", "POI", "UE", "KG", "Thompson", "qEI", "qPOI", "qUCB", "qKG"]
__all__ = ["UCB", "EI", "POI", "UE", "KG", "Thompson", "qEI", "qPOI", "qUCB", "qKG", "optimize_acq"]
97 changes: 97 additions & 0 deletions gpax/acquisition/optimize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""
optimize.py
==============

Optimize continuous acquisition functions

Created by Maxim Ziatdinov (email: [email protected])
"""

from typing import Type, Callable, Union, List, Tuple

import jax.numpy as jnp
import jax.random as jra
import numpy as onp

from ..models.gp import ExactGP


def optimize_acq(rng_key: jnp.ndarray,
model: Type[ExactGP],
acq_fn: Callable,
num_initial_guesses: int,
lower_bound: Union[List, Tuple, float, onp.ndarray, jnp.ndarray],
upper_bound: Union[List, Tuple, float, onp.ndarray, jnp.ndarray],
**kwargs) -> jnp.ndarray:
"""
Optimizes an acquisition function for a given Gaussian Process model using the JAXopt library.

This function finds the point that maximizes the acquisition function within the specified bounds.
It uses L-BFGS-B algorithm through ScipyBoundedMinimize from JAXopt.

Args:
rng_key: A JAX random key for stochastic processes.
model: The Gaussian Process model to be used.
acq_fn: The acquisition function to be maximized.
num_initial_guesses: Number of random initial guesses for the optimization.
lower_bound: Lower bounds for the optimization.
upper_bound: Upper bounds for the optimization.
**kwargs: Additional keyword arguments to be passed to the acquisition function.

Returns:
Parameter(s) that maximize the acquisition function within the specified bounds.

Note:
Ensure JAXopt is installed to use this function (`pip install jaxopt`).
The acquisition function is minimized using its negative value to find the maximum.

Examples:

Optimize EI given a trained GP model for 1D problem

>>> acq_fn = gpax.acquisition.EI
>>> num_initial_guesses = 10
>>> lower_bound = -2.0
>>> upper_bound = 2.0
>>> x_next = gpax.acquisition.optimize_acq(
>>> rng_key, gp_model, acq_fn,
>>> num_initial_guesses, lower_bound, upper_bound,
>>> maximize=False, noiseless=True)
"""

try:
import jaxopt # noqa: F401
except ImportError as e:
raise ImportError(
"You need to install `jaxopt` to be able to use this feature. "
"It can be installed with `pip install jaxopt`."
) from e

def acq(x):
x = jnp.array([x])
x = x[None] if x.ndim == 0 else x
obj = -acq_fn(rng_key, model, x, **kwargs)
return jnp.reshape(obj, ())

lower_bound = ensure_array(lower_bound)
upper_bound = ensure_array(upper_bound)

initial_guesses = jra.uniform(
rng_key, shape=(num_initial_guesses, lower_bound.shape[0]),
minval=lower_bound, maxval=upper_bound)
initial_acq_vals = acq_fn(rng_key, model, initial_guesses, **kwargs)
best_initial_guess = initial_guesses[initial_acq_vals.argmax()].squeeze()

minimizer = jaxopt.ScipyBoundedMinimize(fun=acq, method='l-bfgs-b')
result = minimizer.run(best_initial_guess, bounds=(lower_bound, upper_bound))

return result.params


def ensure_array(x):
if not isinstance(x, jnp.ndarray):
if isinstance(x, (list, tuple, float, onp.ndarray)):
x = jnp.array([x]) if isinstance(x, float) else jnp.array(x)
else:
raise TypeError(f"Expected input to be a list, tuple, float, or jnp.ndarray, got {type(x)} instead.")
return x
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ matplotlib>=3.1.1
jax>=0.2.21
numpyro>=0.8.0
dm-haiku>=0.0.5
jaxopt>0.8.0
typing-extensions>=4.4.0
2 changes: 0 additions & 2 deletions tests/test_acq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import numpy as onp
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpy.testing import assert_equal, assert_

sys.path.insert(0, "../gpax/")
Expand Down
36 changes: 36 additions & 0 deletions tests/test_optimize_acq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import sys
import pytest
import numpy as onp
import jax.numpy as jnp
from numpy.testing import assert_

sys.path.insert(0, "../gpax/")

from gpax.models.gp import ExactGP
from gpax.acquisition.optimize import optimize_acq
from gpax.acquisition.acquisition import UCB, EI
from gpax.utils import get_keys


def get_inputs():
X = onp.random.uniform(-2, 2, size=(4,))
y = X**3
return X, y


@pytest.mark.parametrize("acq_fn", [UCB, EI])
def test_optimize_acq(acq_fn):
lower_bound = -2.0
upper_bound = 2.0
num_initial_guesses = 3
key1, key2 = get_keys()
X, y = get_inputs()
model = ExactGP(1, 'RBF')
model.fit(key1, X, y, num_warmup=50, num_samples=50)
x_next = optimize_acq(
key2, model, acq_fn, num_initial_guesses, lower_bound, upper_bound)
assert_(isinstance(x_next, jnp.ndarray))




Loading