Skip to content

Commit

Permalink
Merge pull request #186 from NeuroDiffGym/sb/fix_save_load
Browse files Browse the repository at this point in the history
  • Loading branch information
sathvikbhagavan authored Dec 23, 2022
2 parents 0f85aeb + f195b69 commit a3df9fd
Showing 1 changed file with 53 additions and 4 deletions.
57 changes: 53 additions & 4 deletions neurodiffeq/solvers_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import inspect
import ast
import types
import random

# Is Dev mode
try:
Expand Down Expand Up @@ -81,6 +82,7 @@ def get_file(url, name):

def get_source(lambda_function):
lambda_text = ""

try:
source_lines, _ = inspect.getsourcelines(lambda_function)
lambda_text = "".join([line.strip() for line in source_lines])
Expand Down Expand Up @@ -201,6 +203,25 @@ def get_sample_solution2D(solver):
pass
return sample_solution_curve

def get_sample_solutionBundle1D(solver):
sample_solution_curve = []
try:
t = np.linspace(solver.r_min[0], solver.r_max[0], 10 *
(int(solver.r_max[0]-solver.r_min[0])))

values = [(random.random()*(solver.r_max[i]-solver.r_min[i]) + solver.r_min[i])*np.ones(len(t)) for i in range(1, len(solver.r_min))]
sample_solution = solver.get_solution()(t, *values)

if not isinstance(sample_solution, list):
sample_solution = [sample_solution]

for i in range(len(sample_solution)):
sample_solution[i] = sample_solution[i].cpu().detach().numpy().tolist()

sample_solution_curve = [t.tolist(), sample_solution]
except:
pass
return sample_solution_curve

def get_networks(solver):
networks = []
Expand Down Expand Up @@ -239,6 +260,7 @@ class SolverConfig():
ode_system = None
pde_system = None
nets = None
best_nets = None
optimizer = None
optimizer_params = None
train_generator = None
Expand Down Expand Up @@ -278,6 +300,10 @@ def save(self,
sample_solution = get_sample_solution1D(self)
elif self.__class__.__name__ == "Solver2D":
sample_solution = get_sample_solution2D(self)
elif self.__class__.__name__ == "BundleSolver1D":
sample_solution = get_sample_solutionBundle1D(self)
else:
sample_solution = None #Temp Fix to avoid error, until conditions made for all solvers

diff_equation_details = {
"equation": get_source(self.diff_eqs),
Expand All @@ -286,7 +312,7 @@ def save(self,
"generator": get_generator(self.generator),
"sample_solution": sample_solution,
"sample_loss": self.metrics_history['valid_loss'],
"criterion": get_source(self.criterion),
"loss_fn": get_source(self.loss_fn),
"networks": get_networks(self),
"optimizer": {
"name": self.optimizer.__class__.__name__,
Expand All @@ -296,10 +322,11 @@ def save(self,

save_dict = {
"metrics": self.metrics_fn,
"criterion": self.criterion,
"loss_fn": self.loss_fn,
"conditions": self.conditions,
"global_epoch": self.global_epoch, # loss_history
"nets": self.nets,
"best_nets": self.best_nets,
"optimizer": self.optimizer,
"optimizer_state": self.optimizer.state_dict(),
"optimizer_class": optimizer_class,
Expand Down Expand Up @@ -418,6 +445,10 @@ def load(cls,
nets = load_dict['nets']
else:
nets = config.nets
if config.best_nets == None:
best_nets = load_dict['best_nets']
else:
best_nets = config.best_nets

# Loading user defined optimizer or optimizer from load file

Expand Down Expand Up @@ -460,7 +491,7 @@ def load(cls,

solver = cls(ode_system=de_system,
conditions=cond,
criterion=load_dict['criterion'],
loss_fn=load_dict['loss_fn'],
metrics=load_dict['metrics'],
nets=nets,
optimizer=optimizer,
Expand All @@ -482,8 +513,26 @@ def load(cls,
train_generator=train_generator,
valid_generator=valid_generator,
optimizer=optimizer,
criterion=load_dict['criterion'],
loss_fn=load_dict['loss_fn'],
metrics=load_dict['metrics'])
elif load_dict["type_name"] == "BundleSolver1D":
t_min = load_dict['solver'].r_min[0]
t_max = load_dict['solver'].r_max[0]

solver = cls(ode_system=de_system,
conditions=cond,
metrics=load_dict['metrics'],
nets=nets,
optimizer=optimizer,
train_generator=train_generator,
valid_generator=valid_generator,
t_min=t_min,
t_max=t_max,
theta_min=tuple(load_dict['solver'].r_min[1:]),
theta_max=tuple(load_dict['solver'].r_max[1:]))

if best_nets != None:
solver.best_nets = best_nets

solver.metrics_history['train_loss'] = train_loss
solver.metrics_history['valid_loss'] = valid_loss
Expand Down

0 comments on commit a3df9fd

Please sign in to comment.