Skip to content

Commit

Permalink
more dev and exp with gpr
Browse files Browse the repository at this point in the history
  • Loading branch information
jtwhite79 committed Nov 12, 2023
1 parent ea7edbd commit b5bb086
Showing 1 changed file with 19 additions and 10 deletions.
29 changes: 19 additions & 10 deletions pyemu/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3805,7 +3805,7 @@ def get_current_prop(_cur_thresh):


def prep_for_gpr(pst_fname,input_fnames,output_fnames,gpr_t_d="gpr_template",gp_kernel=None,nverf=0,
plot_fits=False):
plot_fits=False,apply_standard_scalar=False):
"""helper function to setup a gaussian-process-regression emulator for outputs of interest. This
is primarily targeted at low-dimensional settings like those encountered in PESTPP-MOU
Expand All @@ -3829,6 +3829,9 @@ def prep_for_gpr(pst_fname,input_fnames,output_fnames,gpr_t_d="gpr_template",gp_

from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, Matern, ConstantKernel
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline

import pickle
import inspect
pst = pyemu.Pst(pst_fname)
Expand Down Expand Up @@ -3916,12 +3919,13 @@ def prep_for_gpr(pst_fname,input_fnames,output_fnames,gpr_t_d="gpr_template",gp_


if gp_kernel is None:
gp_kernel = ConstantKernel(constant_value=1.0,constant_value_bounds=(1e-8,1e8)) *\
RBF(length_scale=1000.0, length_scale_bounds=(1e-8, 1e8))
#gp_kernel = Matern(length_scale=100.0, length_scale_bounds=(1e-4, 1e4), nu=4)
#gp_kernel = ConstantKernel(constant_value=1.0,constant_value_bounds=(1e-8,1e8)) *\
# RBF(length_scale=1000.0, length_scale_bounds=(1e-8, 1e8))
gp_kernel = Matern(length_scale=100.0, length_scale_bounds=(1e-4, 1e4), nu=0.5)

for hp in gp_kernel.hyperparameters:
print(hp)

cut = df.shape[0] - nverf
X_train = df.loc[:, input_names].values.copy()[:cut, :]
X_verf = df.loc[:, input_names].values.copy()[cut:, :]
Expand All @@ -3937,11 +3941,16 @@ def prep_for_gpr(pst_fname,input_fnames,output_fnames,gpr_t_d="gpr_template",gp_
y_train = df.loc[:, output_name].values.copy()[:cut]
print("training GPR for {0} with {1} data points".format(output_name,y_train.shape[0]))
gaussian_process = GaussianProcessRegressor(kernel=gp_kernel, n_restarts_optimizer=20)
gaussian_process.fit(X_train, y_train)
print(output_name,"optimized kernel:",gaussian_process.kernel_)
if not apply_standard_scalar:
print("WARNING: not applying StandardScalar transformation - user beware!")
pipeline = Pipeline([("gpr",gaussian_process)])
else:
pipeline = Pipeline([("std_scalar", StandardScaler()), ("gpr", gaussian_process)])
pipeline.fit(X_train, y_train)
print(output_name,"optimized kernel:",pipeline["gpr"].kernel_)
if plot_fits:
print("...plotting fits for",output_name)
predmean,predstd = gaussian_process.predict(df.loc[:, input_names].values.copy(), return_std=True)
predmean,predstd = pipeline.predict(df.loc[:, input_names].values.copy(), return_std=True)
df.loc[:,"predmean"] = predmean
df.loc[:,"predstd"] = predstd
isverf = np.zeros_like(predmean)
Expand Down Expand Up @@ -3971,13 +3980,13 @@ def prep_for_gpr(pst_fname,input_fnames,output_fnames,gpr_t_d="gpr_template",gp_
if os.path.exists(os.path.join(gpr_t_d,model_fname)):
print("WARNING: model_fname '{0}' exists, overwriting...".format(model_fname))
with open(os.path.join(gpr_t_d,model_fname),'wb') as f:
pickle.dump(gaussian_process,f)
pickle.dump(pipeline,f)

model_fnames.append(model_fname)
if nverf > 0:
pred_mean,pred_std = gaussian_process.predict(X_verf,return_std=True)
pred_mean,pred_std = pipeline.predict(X_verf,return_std=True)
vdf = pd.DataFrame({"y_verf":y_verf,"y_pred":pred_mean,"y_pred_std":pred_std})
verf_fname = os.path.join(gpr_t_d,os.path.split(pst_fname)[1]+"."+output_name+".verf.csv")
verf_fname = os.path.join(gpr_t_d,"{0}_gpr_verf.csv".format(output_name))
vdf.to_csv(verf_fname)
print("saved ",output_fname,"verfication csv to",verf_fname)
mabs = np.abs(vdf.y_verf - vdf.y_pred).mean()
Expand Down

0 comments on commit b5bb086

Please sign in to comment.