Skip to content

Commit

Permalink
Added ES optimization initializer
Browse files Browse the repository at this point in the history
  • Loading branch information
hvarfner committed Feb 21, 2025
1 parent ae56adf commit 520aad7
Show file tree
Hide file tree
Showing 3 changed files with 286 additions and 106 deletions.
300 changes: 194 additions & 106 deletions botorch/optim/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from botorch.acquisition import analytic, monte_carlo, multi_objective
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.fixed_feature import FixedFeatureAcquisitionFunction
from botorch.acquisition.joint_entropy_search import qJointEntropySearch
from botorch.acquisition.knowledge_gradient import (
_get_value_function,
qKnowledgeGradient,
Expand Down Expand Up @@ -471,6 +472,90 @@ def gen_batch_initial_conditions(
return batch_initial_conditions


def gen_optimal_input_initial_conditions(
acq_function: AcquisitionFunction,
bounds: Tensor,
q: int,
num_restarts: int,
raw_samples: int,
fixed_features: dict[int, float] | None = None,
options: dict[str, bool | float | int] | None = None,
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
):
device = bounds.device
if not hasattr(acq_function, "optimal_inputs"):
raise AttributeError(
"gen_optimal_input_initial_conditions can only be used with "
"an AcquisitionFunction that has an optimal_inputs attribute."
)
frac_random: float = options.get("frac_random", 0.0)
if not 0 <= frac_random <= 1:
raise ValueError(
f"frac_random must take on values in (0,1). Value: {frac_random}"
)

batch_limit = options.get("batch_limit")
num_optima = acq_function.optimal_inputs.shape[:-1].numel()
suggestions = acq_function.optimal_inputs.reshape(num_optima, -1)
X = torch.empty(0, q, bounds.shape[1], dtype=bounds.dtype)
num_random = round(raw_samples * frac_random)
if num_random > 0:
X_rnd = sample_q_batches_from_polytope(
n=num_random,
q=q,
bounds=bounds,
n_burnin=options.get("n_burnin", 10000),
n_thinning=options.get("n_thinning", 32),
equality_constraints=equality_constraints,
inequality_constraints=inequality_constraints,
)
X = torch.cat((X, X_rnd))

if num_random < raw_samples:
X_perturbed = sample_points_around_best(
acq_function=acq_function,
n_discrete_points=q * (raw_samples - num_random),
sigma=options.get("sample_around_best_sigma", 1e-2),
bounds=bounds,
best_X=suggestions,
)
X_perturbed = X_perturbed.view(
raw_samples - num_random, q, bounds.shape[-1]
).cpu()
X = torch.cat((X, X_perturbed))

if options.get("sample_around_best", False):
X_best = sample_points_around_best(
acq_function=acq_function,
n_discrete_points=q * raw_samples,
sigma=options.get("sample_around_best_sigma", 1e-3),
bounds=bounds,
)
X_best = X_best.view(raw_samples, q, bounds.shape[-1]).cpu()
X = torch.cat((X, X_best))

with torch.no_grad():
if batch_limit is None:
batch_limit = X.shape[0]
# Evaluate the acquisition function on `X_rnd` using `batch_limit`
# sized chunks.
acq_vals = torch.cat(
[
acq_function(x_.to(device=device)).cpu()
for x_ in X.split(split_size=batch_limit, dim=0)
],
dim=0,
)
idx = boltzmann_sample(
function_values=acq_vals,
num_samples=num_restarts,
eta=options.get("eta", 2.0),
)
# set the respective initial conditions to the sampled optimizers
return X[idx]


def gen_one_shot_kg_initial_conditions(
acq_function: qKnowledgeGradient,
bounds: Tensor,
Expand Down Expand Up @@ -605,59 +690,59 @@ def gen_one_shot_hvkg_initial_conditions(
) -> Tensor | None:
r"""Generate a batch of smart initializations for qHypervolumeKnowledgeGradient.
This function generates initial conditions for optimizing one-shot HVKG using
the hypervolume maximizing set (of fixed size) under the posterior mean.
Intutively, the hypervolume maximizing set of the fantasized posterior mean
will often be close to a hypervolume maximizing set under the current posterior
mean. This function uses that fact to generate the initial conditions
for the fantasy points. Specifically, a fraction of `1 - frac_random` (see
options) of the restarts are generated by learning the hypervolume maximizing sets
under the current posterior mean, where each hypervolume maximizing set is
obtained from maximizing the hypervolume from a different starting point. Given
a hypervolume maximizing set, the `q` candidate points are selected using to the
standard initialization strategy in `gen_batch_initial_conditions`, with the fixed
hypervolume maximizing set. The remaining `frac_random` restarts fantasy points
as well as all `q` candidate points are chosen according to the standard
initialization strategy in `gen_batch_initial_conditions`.
Args:
acq_function: The qKnowledgeGradient instance to be optimized.
bounds: A `2 x d` tensor of lower and upper bounds for each column of
task features.
q: The number of candidates to consider.
num_restarts: The number of starting points for multistart acquisition
function optimization.
raw_samples: The number of raw samples to consider in the initialization
heuristic.
fixed_features: A map `{feature_index: value}` for features that
should be fixed to a particular value during generation.
options: Options for initial condition generation. These contain all
settings for the standard heuristic initialization from
`gen_batch_initial_conditions`. In addition, they contain
`frac_random` (the fraction of fully random fantasy points),
`num_inner_restarts` and `raw_inner_samples` (the number of random
restarts and raw samples for solving the posterior objective
maximization problem, respectively) and `eta` (temperature parameter
for sampling heuristic from posterior objective maximizers).
inequality constraints: A list of tuples (indices, coefficients, rhs),
with each tuple encoding an inequality constraint of the form
`\sum_i (X[indices[i]] * coefficients[i]) >= rhs`.
equality constraints: A list of tuples (indices, coefficients, rhs),
with each tuple encoding an inequality constraint of the form
`\sum_i (X[indices[i]] * coefficients[i]) = rhs`.
Returns:
A `num_restarts x q' x d` tensor that can be used as initial conditions
for `optimize_acqf()`. Here `q' = q + num_fantasies` is the total number
of points (candidate points plus fantasy points).
Example:
>>> qHVKG = qHypervolumeKnowledgeGradient(model, ref_point)
>>> bounds = torch.tensor([[0., 0.], [1., 1.]])
>>> Xinit = gen_one_shot_hvkg_initial_conditions(
>>> qHVKG, bounds, q=3, num_restarts=10, raw_samples=512,
>>> options={"frac_random": 0.25},
>>> )
This function generates initial conditions for optimizing one-shot HVKG using
the hypervolume maximizing set (of fixed size) under the posterior mean.
Intutively, the hypervolume maximizing set of the fantasized posterior mean
will often be close to a hypervolume maximizing set under the current posterior
mean. This function uses that fact to generate the initial conditions
for the fantasy points. Specifically, a fraction of `1 - frac_random` (see
options) of the restarts are generated by learning the hypervolume maximizing sets
under the current posterior mean, where each hypervolume maximizing set is
obtained from maximizing the hypervolume from a different starting point. Given
a hypervolume maximizing set, the `q` candidate points are selected using to the
standard initialization strategy in `gen_batch_initial_conditions`, with the fixed
hypervolume maximizing set. The remaining `frac_random` restarts fantasy points
as well as all `q` candidate points are chosen according to the standard
initialization strategy in `gen_batch_initial_conditions`.
Args:
acq_function: The qKnowledgeGradient instance to be optimized.
bounds: A `2 x d` tensor of lower and upper bounds for each column of
task features.
q: The number of candidates to consider.
num_restarts: The number of starting points for multistart acquisition
function optimization.
raw_samples: The number of raw samples to consider in the initialization
heuristic.
fixed_features: A map `{feature_index: value}` for features that
should be fixed to a particular value during generation.
options: Options for initial condition generation. These contain all
settings for the standard heuristic initialization from
`gen_batch_initial_conditions`. In addition, they contain
`frac_random` (the fraction of fully random fantasy points),
`num_inner_restarts` and `raw_inner_samples` (the number of random
restarts and raw samples for solving the posterior objective
maximization problem, respectively) and `eta` (temperature parameter
for sampling heuristic from posterior objective maximizers).
inequality constraints: A list of tuples (indices, coefficients, rhs),
with each tuple encoding an inequality constraint of the form
`\sum_i (X[indices[i]] * coefficients[i]) >= rhs`.
equality constraints: A list of tuples (indices, coefficients, rhs),
with each tuple encoding an inequality constraint of the form
`\sum_i (X[indices[i]] * coefficients[i]) = rhs`.
Returns:
A `num_restarts x q' x d` tensor that can be used as initial conditions
for `optimize_acqf()`. Here `q' = q + num_fantasies` is the total number
of points (candidate points plus fantasy points).
gen_batch_initial_conditions Example:
>>> qHVKG = qHypervolumeKnowledgeGradient(model, ref_point)
>>> bounds = torch.tensor([[0., 0.], [1., 1.]])
>>> Xinit = gen_one_shot_hvkg_initial_conditions(
>>> qHVKG, bounds, q=3, num_restarts=10, raw_samples=512,
>>> options={"frac_random": 0.25},
>>> )
"""
from botorch.optim.optimize import optimize_acqf

Expand Down Expand Up @@ -1139,6 +1224,7 @@ def sample_points_around_best(
best_pct: float = 5.0,
subset_sigma: float = 1e-1,
prob_perturb: float | None = None,
best_X: Tensor | None = None,
) -> Tensor | None:
r"""Find best points and sample nearby points.
Expand All @@ -1157,60 +1243,62 @@ def sample_points_around_best(
An optional `n_discrete_points x d`-dim tensor containing the
sampled points. This is None if no baseline points are found.
"""
X = get_X_baseline(acq_function=acq_function)
if X is None:
return
with torch.no_grad():
try:
posterior = acq_function.model.posterior(X)
except AttributeError:
warnings.warn(
"Failed to sample around previous best points.",
BotorchWarning,
stacklevel=3,
)
if best_X is None:
X = get_X_baseline(acq_function=acq_function)
if X is None:
return
mean = posterior.mean
while mean.ndim > 2:
# take average over batch dims
mean = mean.mean(dim=0)
try:
f_pred = acq_function.objective(mean)
# Some acquisition functions do not have an objective
# and for some acquisition functions the objective is None
except (AttributeError, TypeError):
f_pred = mean
if hasattr(acq_function, "maximize"):
# make sure that the optimiztaion direction is set properly
if not acq_function.maximize:
f_pred = -f_pred
try:
# handle constraints for EHVI-based acquisition functions
constraints = acq_function.constraints
if constraints is not None:
neg_violation = -torch.stack(
[c(mean).clamp_min(0.0) for c in constraints], dim=-1
).sum(dim=-1)
feas = neg_violation == 0
if feas.any():
f_pred[~feas] = float("-inf")
else:
# set objective equal to negative violation
f_pred = neg_violation
except AttributeError:
pass
if f_pred.ndim == mean.ndim and f_pred.shape[-1] > 1:
# multi-objective
# find pareto set
is_pareto = is_non_dominated(f_pred)
best_X = X[is_pareto]
else:
if f_pred.shape[-1] == 1:
f_pred = f_pred.squeeze(-1)
n_best = max(1, round(X.shape[0] * best_pct / 100))
# the view() is to ensure that best_idcs is not a scalar tensor
best_idcs = torch.topk(f_pred, n_best).indices.view(-1)
best_X = X[best_idcs]
with torch.no_grad():
try:
posterior = acq_function.model.posterior(X)
except AttributeError:
warnings.warn(
"Failed to sample around previous best points.",
BotorchWarning,
stacklevel=3,
)
return
mean = posterior.mean
while mean.ndim > 2:
# take average over batch dims
mean = mean.mean(dim=0)
try:
f_pred = acq_function.objective(mean)
# Some acquisition functions do not have an objective
# and for some acquisition functions the objective is None
except (AttributeError, TypeError):
f_pred = mean
if hasattr(acq_function, "maximize"):
# make sure that the optimiztaion direction is set properly
if not acq_function.maximize:
f_pred = -f_pred
try:
# handle constraints for EHVI-based acquisition functions
constraints = acq_function.constraints
if constraints is not None:
neg_violation = -torch.stack(
[c(mean).clamp_min(0.0) for c in constraints], dim=-1
).sum(dim=-1)
feas = neg_violation == 0
if feas.any():
f_pred[~feas] = float("-inf")
else:
# set objective equal to negative violation
f_pred = neg_violation
except AttributeError:
pass
if f_pred.ndim == mean.ndim and f_pred.shape[-1] > 1:
# multi-objective
# find pareto set
is_pareto = is_non_dominated(f_pred)
best_X = X[is_pareto]
else:
if f_pred.shape[-1] == 1:
f_pred = f_pred.squeeze(-1)
n_best = max(1, round(X.shape[0] * best_pct / 100))
# the view() is to ensure that best_idcs is not a scalar tensor
best_idcs = torch.topk(f_pred, n_best).indices.view(-1)
best_X = X[best_idcs]

use_perturbed_sampling = best_X.shape[-1] >= 20 or prob_perturb is not None
n_trunc_normal_points = (
n_discrete_points // 2 if use_perturbed_sampling else n_discrete_points
Expand Down
7 changes: 7 additions & 0 deletions botorch/optim/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
AcquisitionFunction,
OneShotAcquisitionFunction,
)
from botorch.acquisition.joint_entropy_search import qJointEntropySearch
from botorch.acquisition.knowledge_gradient import qKnowledgeGradient
from botorch.acquisition.multi_objective.hypervolume_knowledge_gradient import (
qHypervolumeKnowledgeGradient,
)
from botorch.acquisition.predictive_entropy_search import qPredictiveEntropySearch
from botorch.exceptions import InputDataError, UnsupportedError
from botorch.exceptions.errors import CandidateGenerationError
from botorch.exceptions.warnings import OptimizationWarning
Expand All @@ -33,6 +35,7 @@
gen_batch_initial_conditions,
gen_one_shot_hvkg_initial_conditions,
gen_one_shot_kg_initial_conditions,
gen_optimal_input_initial_conditions,
TGenInitialConditions,
)
from botorch.optim.stopping import ExpMAStoppingCriterion
Expand Down Expand Up @@ -174,6 +177,10 @@ def get_ic_generator(self) -> TGenInitialConditions:
return gen_one_shot_kg_initial_conditions
elif isinstance(self.acq_function, qHypervolumeKnowledgeGradient):
return gen_one_shot_hvkg_initial_conditions
elif isinstance(
self.acq_function, (qJointEntropySearch, qPredictiveEntropySearch)
):
return gen_optimal_input_initial_conditions
return gen_batch_initial_conditions


Expand Down
Loading

0 comments on commit 520aad7

Please sign in to comment.