-
Notifications
You must be signed in to change notification settings - Fork 416
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Performance & runtime improvements to info-theoretic acquisition functions (1/N) #2748
base: main
Are you sure you want to change the base?
Conversation
@sdaulton has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Thanks! It seems like |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #2748 +/- ##
=======================================
Coverage 99.99% 99.99%
=======================================
Files 203 203
Lines 18690 18689 -1
=======================================
- Hits 18689 18688 -1
Misses 1 1 ☔ View full report in Codecov by Sentry. |
@sdaulton for sure! I currently observe similar things for JES, but I'm not sure whether the found points are actually higher in acquisition function value or not (for either LogEI or JES) |
That would be interesting to see |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Carl! This seems like a decent improvement. Just a few comments in-line
botorch/utils/sampling.py
Outdated
@@ -1008,13 +1012,17 @@ def optimize_posterior_samples( | |||
negate the objective or otherwise transform the output. | |||
return_transformed: A boolean indicating whether to return the transformed | |||
or non-transformed samples. | |||
suggested_points |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's complete the docstring here
botorch/utils/sampling.py
Outdated
bounds=bounds, n=round(raw_samples * frac_random), q=1 | ||
).squeeze(-2) | ||
if suggested_points is not None: | ||
from botorch.optim.initializers import sample_truncated_normal_perturbations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a local import, because it leads to cyclical dependencies? If so, we could move sample_truncated_normal_perturbations
under utils (assuming it doesn't depend other code in optim).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so! let me check
botorch/utils/sampling.py
Outdated
candidate_set = draw_sobol_samples( | ||
bounds=bounds, n=round(raw_samples * frac_random), q=1 | ||
).squeeze(-2) | ||
if suggested_points is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If suggested_points is None
, we end up with a candidate_set
of size smaller than raw_samples
. Should we make sure we always use raw_samples
points?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, good catch. Absolutely
botorch/utils/sampling.py
Outdated
weights = ( | ||
candidate_queries - candidate_queries.mean(dim=-1, keepdim=True) | ||
) / candidate_queries.std(dim=-1, keepdim=True) | ||
eta = options.get("eta", 2.0) | ||
weights = torch.exp(eta * weights) | ||
|
||
# weights can be more than 2D, which is not supported by torch.multinomial | ||
# the argsort picks out the indices that are nonzero, i.e. those that are drawn | ||
# (without replacement, so we will always have num_restarts nonzero ones) | ||
idx = ( | ||
Multinomial(num_restarts, probs=weights) | ||
.sample() | ||
.argsort(descending=True)[..., :num_restarts] | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like a very similar logic to initialize_q_batch
. Would it make sense to re-use that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is extremely similar, but torch.multinomial does not support 3D multinomial sampling which is why I re-used the logic (see the comment), but I initially tried initialize_q_batch
.
Literally all the duplicated logic is to accomodate fully Bayesian models (otherwise I could have just used ThompsonSampling
and optimize_acqf
, but the extra dim that fully Bayesian models generate really throws it off.
520aad7
to
43744fa
Compare
d5d9f9d
to
3ae01cb
Compare
code, reshuffling of other sampling methods (that don't take an acqf)
improve performance and runtime of PES/JES
3ae01cb
to
f4f01bf
Compare
idcs = batched_multinomial( | ||
weights=weights.permute(*range(1, len(batch_shape) + 1), 0), num_samples=n | ||
).permute(-1, *range(len(batch_shape))) | ||
idcs = boltzmann_sample( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool, this deduplicated a bunch of repeated logic here.
@sdaulton has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
A series of improvements directed towards improving the performance of PES & JES, as well as their MultiObj counterparts.
Motivation
As pointed out by @SebastianAment in this paper, the BoTorch variant of JES, and to an extent PES, is brutally slow an suspiciously ill-performing. To bring them up to their potential, I've added a series of performance improvements:
1. Improvement to get_optimal_samples and optimal_posterior_samples: As this is an integral part of their efficiency, I've added

.
suggestions
(similar tosample_around_best
) tooptimize_posterior_samples
.Marginal runtime improvement in acquisition optimization (sampling time practically unchanged):
Substantial performance improvement:
2. Added initializer to acquisition funcction optimization: Similar to KG, ES methods have sensible suggestions for acquisition function optimization in the form of the sampled optima. This drastically reduces the time of acquisition function optimization, which could on occasion take 30+ seconds when
num_restarts
was large>4
.Benchmarking INC
2b. Multi-objective support for initializer: By re-naming arguments of the multi-objective variants, we get consistency and support for MO variants.
3. Enabled gradient-based optimization for PES: The current implementation contains a while-loop which forces the gradients to be recursively computed. This commonly causes NaN gradients, which is why the recommended option is
"with_grad": False
in the tutorial. Onedetach()
alleviates this issue, enabling gradient-based optimization.NOTE: this has NOT been ablated, since the non-grad optimization is extremely computationally demanding.
Test Plan
Unit tests and benchmarking.
Related PRs
First of a couple!
Bonus: while benchmarking, I had issues repro'ing the LogEI performance initially. I found that

sample_around_best
made LogEI worse on Mich5. All experiments are otherwise a repro of the settings used in the LogEI paper.