Skip to content

Commit

Permalink
misc testing
Browse files Browse the repository at this point in the history
  • Loading branch information
pabloprf committed Nov 22, 2024
1 parent 876775b commit 53e010c
Show file tree
Hide file tree
Showing 4 changed files with 217 additions and 225 deletions.
4 changes: 1 addition & 3 deletions src/mitim_modules/portals/PORTALStools.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@

def selectSurrogate(output, surrogateOptions, CGYROrun=False):

print(
f'\t- Selecting surrogate options for "{output}" to be run'
)
print(f'\t- Selecting surrogate options for "{output}" to be run')

if output is not None:
# If it's a target, just linear
Expand Down
28 changes: 6 additions & 22 deletions src/mitim_tools/opt_tools/BOTORCHtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,8 @@
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import OutcomeTransform, Standardize
from botorch.models.utils import validate_input_scaling
from botorch.models.utils.gpytorch_modules import (
get_covar_module_with_dim_scaled_prior,
get_gaussian_likelihood_with_lognormal_prior,
)
from botorch.utils.types import DEFAULT
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
from gpytorch.means.constant_mean import ConstantMean
from gpytorch.models.exact_gp import ExactGP
from gpytorch.module import Module
from torch import Tensor

from linear_operator.operators import CholLinearOperator, DiagLinearOperator
Expand Down Expand Up @@ -136,7 +129,6 @@ def __init__(
batch_shape=self._aug_batch_shape, variables=variables
)


"""
-----------------------------------------------------------------------
GP Kernel - Covariance
Expand Down Expand Up @@ -562,23 +554,15 @@ def __init__(self, *gp_models):
def prepareToGenerateCommons(self):
self.models[0].input_transform.tf1.flag_to_store = True
# Make sure that this ModelListGP evaluation is fresh
if (
"parameters_combined"
in self.models[0].input_transform.tf1.surrogate_parameters
):
del self.models[0].input_transform.tf1.surrogate_parameters[
"parameters_combined"
]
if ("surrogate_parameters" in self.models[0].input_transform.tf1.__dict__) and \
("parameters_combined" in self.models[0].input_transform.tf1.surrogate_parameters):
del self.models[0].input_transform.tf1.surrogate_parameters["parameters_combined"]

def cold_startCommons(self):
self.models[0].input_transform.tf1.flag_to_store = False
if (
"parameters_combined"
in self.models[0].input_transform.tf1.surrogate_parameters
):
del self.models[0].input_transform.tf1.surrogate_parameters[
"parameters_combined"
]
if ("surrogate_parameters" in self.models[0].input_transform.tf1.__dict__) and \
("parameters_combined" in self.models[0].input_transform.tf1.surrogate_parameters):
del self.models[0].input_transform.tf1.surrogate_parameters["parameters_combined"]

def transform_inputs(self, X):
self.prepareToGenerateCommons()
Expand Down
Loading

0 comments on commit 53e010c

Please sign in to comment.