Skip to content

Commit

Permalink
working on ppw gpr worker
Browse files Browse the repository at this point in the history
  • Loading branch information
jtwhite79 committed Nov 28, 2024
1 parent aeff898 commit 4a440e5
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 22 deletions.
14 changes: 11 additions & 3 deletions autotest/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2746,7 +2746,6 @@ def pypestworker_test():
p.wait()
finish = datetime.now()
print("all done, took",(finish-start).total_seconds())


m_d2 = m_d+"_base"
start2 = datetime.now()
Expand Down Expand Up @@ -3178,6 +3177,14 @@ def gpr_zdt1_test():
assert psum.obj_1.min() < 0.05


def gpr_zdt1_ppw():
t_d = "zdt1_gpr_template"
os.chdir(t_d)
pst_name = "zdt1.pst"
ppw = pyemu.helpers.gpr_pyworker(pst_name,"localhost",4004)
os.chdir("..")


if __name__ == "__main__":
#ppu_geostats_test(".")
#gpr_compare_invest()
Expand All @@ -3188,8 +3195,9 @@ def gpr_zdt1_test():
# sys.path.insert(0,t_d)
# from forward_run import helper as frun
# ppw_worker(0,case,t_d,"localhost",4004,frun)
pypestworker_test()
#gpr_zdt1_test()
#pypestworker_test()
gpr_zdt1_test()

#while True:
# thresh_pars_test()
#obs_ensemble_quantile_test()
Expand Down
106 changes: 87 additions & 19 deletions pyemu/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
pass

import pyemu
from pyemu.utils.os_utils import run, start_workers
from pyemu.utils.os_utils import run, start_workers,PyPestWorker


class Trie:
Expand Down Expand Up @@ -4287,12 +4287,25 @@ def prep_for_gpr(pst_fname,input_fnames,output_fnames,gpr_t_d="gpr_template",gp_
gpst.observation_data.loc[stdobs,"weight"] = 0.0
gpst.pestpp_options = pst.pestpp_options
gpst.prior_information = pst.prior_information.copy()
lines = [line[4:] for line in inspect.getsource(gpr_forward_run).split("\n")][1:]

#lines = [line[4:] for line in inspect.getsource(gpr_forward_run).split("\n")][1:]
frun_lines = inspect.getsource(gpr_forward_run)
getfxn_lines = inspect.getsource(get_gpr_model_dict)
emulfxn_lines = inspect.getsource(emulate_with_gpr)
with open(os.path.join(gpr_t_d, "forward_run.py"), 'w') as f:

for line in lines:
f.write(line + "\n")
f.write("\n")
for import_name in ["pandas as pd","os","pickle","numpy as np"]:
f.write("import {0}\n".format(import_name))
for line in getfxn_lines:
f.write(line)
f.write("\n")
for line in emulfxn_lines:
f.write(line)
f.write("\n")
for line in frun_lines:
f.write(line)
f.write("if __name__ == '__main__':\n")
f.write(" gpr_forward_run()\n")


gpst.control_data.noptmax = 0
gpst.model_command = "python forward_run.py"
Expand All @@ -4307,26 +4320,81 @@ def prep_for_gpr(pst_fname,input_fnames,output_fnames,gpr_t_d="gpr_template",gp_
gpst.write(os.path.join(gpr_t_d, gpst_fname), version=2)


def gpr_forward_run():
"""the function to evaluate a set of inputs thru the GPR emulators.\
This function gets added programmatically to the forward run process"""
import os
import pandas as pd
import numpy as np
def get_gpr_model_dict(mdf):
import pickle
from sklearn.gaussian_process import GaussianProcessRegressor
input_df = pd.read_csv("gpr_input.csv",index_col=0)
mdf = pd.read_csv(os.path.join("gprmodel_info.csv"),index_col=0)
gpr_model_dict = {}
for output_name,model_fname in zip(mdf.output_name,mdf.model_fname):
gaussian_process = pickle.load(open(model_fname,'rb'))
gpr_model_dict[output_name] = gaussian_process
return gpr_model_dict


def emulate_with_gpr(input_df,mdf,gpr_model_dict):
mdf.loc[:,"sim"] = np.nan
mdf.loc[:,"sim_std"] = np.nan
for output_name,model_fname in zip(mdf.output_name,mdf.model_fname):
guassian_process = pickle.load(open(model_fname,'rb'))
sim = guassian_process.predict(np.atleast_2d(input_df.parval1.values),return_std=True)
for output_name,gaussian_process in gpr_model_dict.items():
sim = gaussian_process.predict(np.atleast_2d(input_df.parval1.values),return_std=True)
mdf.loc[output_name,"sim"] = sim[0]
mdf.loc[output_name,"sim_std"] = sim[1]
return mdf


def gpr_pyworker(pst,host,port,input_df=None,mdf=None):
import os
import pandas as pd
import numpy as np
import pickle
if input_df is None:
input_df = pd.read_csv("gpr_input.csv",index_col=0)
if mdf is None:
mdf = pd.read_csv("gprmodel_info.csv",index_col=0)
gpr_model_dict = get_gpr_model_dict(mdf)
ppw = PyPestWorker(pst,host,port,verbose=False)
parameters = ppw.get_parameters()
if parameters is None:
return
obs = ppw._pst.observation_data.copy()
obs = obs.loc[ppw.obs_names,:]
par = ppw._pst.parameter_data.copy()
usepar = par.loc[par.parnme.isin(input_df.index.values),"parnme"].values
parameters = parameters.loc[usepar]
while True:
indf = input_df.copy()
indf.loc[parameters.index,"parval1"] = parameters.values
simdf = emulate_with_gpr(indf,mdf,gpr_model_dict)
obs.loc[simdf.index,"obsval"] = simdf.sim.values
obs.loc[simdf.index.map(lambda x: x+"_gprstd"),"obsval"] = simdf.sim_std.values
ppw.send_observations(obs.obsval.values)
parameters = ppw.get_parameters()
if parameters is None:
break
parameters = parameters.loc[usepar]



def gpr_forward_run():
"""the function to evaluate a set of inputs thru the GPR emulators.\
This function gets added programmatically to the forward run process"""
#import os
import pandas as pd
#import numpy as np
#import pickle
#from sklearn.gaussian_process import GaussianProcessRegressor
input_df = pd.read_csv("gpr_input.csv",index_col=0)
mdf = pd.read_csv("gprmodel_info.csv",index_col=0)
gpr_model_dict = get_gpr_model_dict(mdf)
mdf = emulate_with_gpr(input_df,mdf,gpr_model_dict)

# mdf.loc[:,"sim"] = np.nan
# mdf.loc[:,"sim_std"] = np.nan
# for output_name,model_fname in zip(mdf.output_name,mdf.model_fname):
# guassian_process = pickle.load(open(model_fname,'rb'))
# sim = guassian_process.predict(np.atleast_2d(input_df.parval1.values),return_std=True)
# mdf.loc[output_name,"sim"] = sim[0]
# mdf.loc[output_name,"sim_std"] = sim[1]

mdf.loc[:,["output_name","sim","sim_std"]].to_csv("gpr_output.csv",index=False)

return mdf

def randrealgen_optimized(nreal, tol=1e-7, max_samples=1000000):
"""
Expand Down

0 comments on commit 4a440e5

Please sign in to comment.