Skip to content

Commit

Permalink
Improvements of get_optimal_samples and optimize_posterior_samples to
Browse files Browse the repository at this point in the history
    improve performance and runtime of PES/JES
  • Loading branch information
hvarfner committed Feb 21, 2025
1 parent 4dbe092 commit 3ae01cb
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 22 deletions.
12 changes: 11 additions & 1 deletion botorch/acquisition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,7 @@ def get_optimal_samples(
posterior_transform: ScalarizedPosteriorTransform | None = None,
objective: MCAcquisitionObjective | None = None,
return_transformed: bool = False,
options: dict | None = None,
) -> tuple[Tensor, Tensor]:
"""Draws sample paths from the posterior and maximizes the samples using GD.
Expand All @@ -551,7 +552,8 @@ def get_optimal_samples(
objective: An MCAcquisitionObjective, used to negate the objective or otherwise
transform sample outputs. Cannot be combined with `posterior_transform`.
return_transformed: If True, return the transformed samples.
options: Options for generation of initial candidates, passed to
gen_batch_initial_conditions.
Returns:
The optimal input locations and corresponding outputs, x* and f*.
Expand All @@ -576,12 +578,20 @@ def get_optimal_samples(
sample_transform = None

paths = get_matheron_path_model(model=model, sample_shape=torch.Size([num_optima]))
suggested_points = prune_inferior_points(
model=model,
X=model.train_inputs[0],
posterior_transform=posterior_transform,
objective=objective,
)
optimal_inputs, optimal_outputs = optimize_posterior_samples(
paths=paths,
bounds=bounds,
raw_samples=raw_samples,
num_restarts=num_restarts,
sample_transform=sample_transform,
return_transformed=return_transformed,
suggested_points=suggested_points,
options=options,
)
return optimal_inputs, optimal_outputs
48 changes: 36 additions & 12 deletions botorch/utils/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,10 @@
BotorchTensorDimensionError,
InfeasibilityError,
)
from botorch.utils.transforms import standardize
from botorch.exceptions.warnings import UserInputWarning
from botorch.sampling.qmc import NormalQMCEngine

from botorch.utils.transforms import normalize, unnormalize
from botorch.utils.transforms import normalize, standardize, unnormalize
from scipy.spatial import Delaunay, HalfspaceIntersection
from torch import LongTensor, Tensor
from torch.distributions import Normal
Expand Down Expand Up @@ -1000,10 +999,12 @@ def sparse_to_dense_constraints(
def optimize_posterior_samples(
paths: GenericDeterministicModel,
bounds: Tensor,
raw_samples: int = 1024,
num_restarts: int = 20,
raw_samples: int = 2048,
num_restarts: int = 4,
sample_transform: Callable[[Tensor], Tensor] | None = None,
return_transformed: bool = False,
suggested_points: Tensor | None = None,
options: dict | None = None,
) -> tuple[Tensor, Tensor]:
r"""Cheaply maximizes posterior samples by random querying followed by
gradient-based optimization using SciPy's L-BFGS-B routine.
Expand All @@ -1012,19 +1013,27 @@ def optimize_posterior_samples(
paths: Random Fourier Feature-based sample paths from the GP
bounds: The bounds on the search space.
raw_samples: The number of samples with which to query the samples initially.
Raw samples are cheap to evaluate, so this should ideally be set much higher
than num_restarts.
num_restarts: The number of points selected for gradient-based optimization.
Should be set low relative to the number of raw
sample_transform: A callable transform of the sample outputs (e.g.
MCAcquisitionObjective or ScalarizedPosteriorTransform.evaluate) used to
negate the objective or otherwise transform the output.
return_transformed: A boolean indicating whether to return the transformed
or non-transformed samples.
suggested_points: Tensor of suggested input locations that are high-valued.
These are more densely evaluated during the sampling phase of optimization.
options: Options for generation of initial candidates, passed to
gen_batch_initial_conditions.
Returns:
A two-element tuple containing:
- X_opt: A `num_optima x [batch_size] x d`-dim tensor of optimal inputs x*.
- f_opt: A `num_optima x [batch_size] x m`-dim, optionally
`num_optima x [batch_size] x 1`-dim, tensor of optimal outputs f*.
"""
options = {} if options is None else options

def path_func(x) -> Tensor:
res = paths(x)
Expand All @@ -1033,21 +1042,35 @@ def path_func(x) -> Tensor:

return res.squeeze(-1)

candidate_set = unnormalize(
SobolEngine(dimension=bounds.shape[1], scramble=True).draw(n=raw_samples),
bounds=bounds,
)
# queries all samples on all candidates - output shape
# raw_samples * num_optima * num_models
frac_random = 1 if suggested_points is None else options.get("frac_random", 0.9)
candidate_set = draw_sobol_samples(
bounds=bounds, n=round(raw_samples * frac_random), q=1
).squeeze(-2)
if frac_random < 1:
perturbed_suggestions = sample_truncated_normal_perturbations(
X=suggested_points,
n_discrete_points=round(raw_samples * (1 - frac_random)),
sigma=options.get("sample_around_best_sigma", 1e-2),
bounds=bounds,
)
candidate_set = torch.cat((candidate_set, perturbed_suggestions))

candidate_queries = path_func(candidate_set)
argtop_k = torch.topk(candidate_queries, num_restarts, dim=-1).indices
X_top_k = candidate_set[argtop_k, :]
idx = boltzmann_sample(
function_values=candidate_queries.unsqueeze(-1),
num_samples=num_restarts,
eta=options.get("eta", 5.0),
replacement=False,
)
ics = candidate_set[idx, :]

# to avoid circular import, the import occurs here
from botorch.generation.gen import gen_candidates_scipy

X_top_k, f_top_k = gen_candidates_scipy(
X_top_k,
ics,
path_func,
lower_bounds=bounds[0],
upper_bounds=bounds[1],
Expand Down Expand Up @@ -1101,8 +1124,9 @@ def boltzmann_sample(
eta *= temp_decrease
weights = torch.exp(eta * norm_weights)

# squeeze in case of m = 1 (mono-output provided as batch_size x N x 1)
return batched_multinomial(
weights=weights, num_samples=num_samples, replacement=replacement
weights=weights.squeeze(-1), num_samples=num_samples, replacement=replacement
)


Expand Down
29 changes: 22 additions & 7 deletions test/acquisition/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
UnsupportedError,
)
from botorch.models import SingleTaskGP
from botorch.utils.test_helpers import get_fully_bayesian_model, get_model
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior
from gpytorch.distributions import MultivariateNormal

Expand Down Expand Up @@ -413,17 +414,14 @@ def test_project_to_sample_points(self):


class TestGetOptimalSamples(BotorchTestCase):
def test_get_optimal_samples(self):
dims = 3
dtype = torch.float64
def _test_get_optimal_samples_base(self, model):
dims = model.train_inputs[0].shape[1]
dtype = model.train_targets.dtype
batch_shape = model.batch_shape
for_testing_speed_kwargs = {"raw_samples": 20, "num_restarts": 2}
num_optima = 7
batch_shape = (3,)

bounds = torch.tensor([[0, 1]] * dims, dtype=dtype).T
X = torch.rand(*batch_shape, 4, dims, dtype=dtype)
Y = torch.sin(2 * 3.1415 * X).sum(dim=-1, keepdim=True).to(dtype)
model = SingleTaskGP(train_X=X, train_Y=Y)
posterior_transform = ScalarizedPosteriorTransform(
weights=torch.ones(1, dtype=dtype)
)
Expand All @@ -438,6 +436,7 @@ def test_get_optimal_samples(self):
num_optima=num_optima,
**for_testing_speed_kwargs,
)

correct_X_shape = (num_optima,) + batch_shape + (dims,)
correct_f_shape = (num_optima,) + batch_shape + (1,)
self.assertEqual(X_opt_def.shape, correct_X_shape)
Expand Down Expand Up @@ -519,6 +518,22 @@ def test_get_optimal_samples(self):
**for_testing_speed_kwargs,
)

def test_optimal_samples(self):
dims = 3
dtype = torch.float64
X = torch.rand(4, dims, dtype=dtype)
Y = torch.sin(2 * 3.1415 * X).sum(dim=-1, keepdim=True).to(dtype)
model = get_model(train_X=X, train_Y=Y)
self._test_get_optimal_samples_base(model)
fully_bayesian_model = get_fully_bayesian_model(
train_X=X,
train_Y=Y,
num_models=4,
standardize_model=True,
infer_noise=True,
)
self._test_get_optimal_samples_base(fully_bayesian_model)


class TestPreferenceUtils(BotorchTestCase):
def test_repeat_to_match_aug_dim(self):
Expand Down
3 changes: 1 addition & 2 deletions test/optim/test_initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,12 @@
sample_perturbed_subset_dims,
sample_points_around_best,
sample_q_batches_from_polytope,
sample_truncated_normal_perturbations,
transform_constraints,
transform_inter_point_constraint,
transform_intra_point_constraint,
)
from botorch.sampling.normal import IIDNormalSampler
from botorch.utils.sampling import draw_sobol_samples, manual_seed, unnormalize
from botorch.utils.sampling import manual_seed, unnormalize
from botorch.utils.testing import (
_get_max_violation_of_bounds,
_get_max_violation_of_constraints,
Expand Down

0 comments on commit 3ae01cb

Please sign in to comment.