Skip to content

Commit

Permalink
Working version but grad calculation extremely slow
Browse files Browse the repository at this point in the history
  • Loading branch information
pabloprf committed Nov 22, 2024
1 parent 53e010c commit 5342c6d
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 28 deletions.
2 changes: 0 additions & 2 deletions src/mitim_tools/opt_tools/BOTORCHtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,6 @@ def __init__(
)




# TODO: Allow subsetting of other covar modules
if outcome_transform is not None:
self.outcome_transform = outcome_transform
Expand Down
15 changes: 1 addition & 14 deletions src/mitim_tools/opt_tools/STEPtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,11 +163,6 @@ def fit_step(self, avoidPoints=None, fitWithTrainingDataIfContains=None):
with open(self.fileOutputs, "a") as f:
f.write(f" (took total of {txt_time})")

embed()
x = torch.rand(10_000, self.train_X.shape[-1]).to(self.dfT)
with IOtools.speeder("/Users/pablorf/PROJECTS/project_2024_PORTALSdevelopment/speed/profiler_gp64.prof") as s:
self.GP["combined_model"].gpmodel.posterior(x)

def _fit_multioutput_model(self):

surrogateOptions = self.surrogateOptions["selectSurrogate"]('AllMITIM', self.surrogateOptions)
Expand Down Expand Up @@ -359,7 +354,7 @@ def defineFunctions(self, scalarized_objective):
I create this so that, upon reading a pickle, I re-call it. Otherwise, it is very heavy to store lambdas
"""

self.evaluators = {"GP": self.GP["combined_model"]}
self.evaluators = {"GP": self.GP["mo_model"]}

# **************************************************************************************************
# Objective (multi-objective model -> single objective residual)
Expand Down Expand Up @@ -442,14 +437,6 @@ def residual(Y, X = None):
)
)


embed()
x = torch.rand(64, self.train_X.shape[-1]).to(self.dfT)
with IOtools.speeder("/Users/pablorf/PROJECTS/project_2024_PORTALSdevelopment/speed/profiler_acq64.prof") as s:
self.evaluators["acq_function"](x)



# **************************************************************************************************
# Quick function to return components (I need this for ROOT too, since I need the components)
# **************************************************************************************************
Expand Down
1 change: 0 additions & 1 deletion src/mitim_tools/opt_tools/SURROGATEtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,6 @@ def perform_model_fit(self, mll):
-mll.forward(mll.model(*mll.model.train_inputs), mll.model.train_targets)
.detach()
]
embed()

def callback(x, y, mll=mll):
track_fval.append(y.fval)
Expand Down
25 changes: 14 additions & 11 deletions src/mitim_tools/opt_tools/optimizers/BOTORCHoptim.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
import types
import botorch
import random
from mitim_tools.opt_tools import OPTtools
Expand Down Expand Up @@ -37,6 +36,7 @@ def findOptima(fun, optimization_params = {}, writeTrajectory=False):
"sample_around_best": True,
"disp": 50 if read_verbose_level() == 5 else False,
"seed": fun.seed,
"maxiter": 100,
}

"""
Expand Down Expand Up @@ -64,16 +64,19 @@ def __call__(self, x, *args, **kwargs):
seq_message = f'({"sequential" if sequential_q else "joint"}) ' if q>1 else ''
print(f"\t\t- Optimizing using optimize_acqf: {q = } {seq_message}, {num_restarts = }, {raw_samples = }")

with IOtools.timer(name = "\n\t- Optimization", name_timer = '\t\t- Time: '):
x_opt, _ = botorch.optim.optimize_acqf(
acq_function=fun_opt,
bounds=fun.bounds_mod,
raw_samples=raw_samples,
q=q,
sequential=sequential_q,
num_restarts=num_restarts,
options=options,
)

#with IOtools.timer(name = "\n\t- Optimization", name_timer = '\t\t- Time: '):
#with IOtools.speeder("/Users/pablorf/PROJECTS/project_2024_PORTALSdevelopment/speed/profiler_opt.prof") as s:
x_opt, _ = botorch.optim.optimize_acqf(
acq_function=fun_opt,
bounds=fun.bounds_mod,
raw_samples=raw_samples,
q=q,
sequential=sequential_q,
num_restarts=num_restarts,
options=options,
)
embed()

acq_evaluated = torch.Tensor(acq_evaluated)

Expand Down

0 comments on commit 5342c6d

Please sign in to comment.