Skip to content

Commit

Permalink
minor edits
Browse files Browse the repository at this point in the history
  • Loading branch information
AGKhalil committed Aug 9, 2019
1 parent c9422d9 commit 86acc50
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 18 deletions.
21 changes: 3 additions & 18 deletions manual_cl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@


def alter_env(exp_type, variant):
print('LOGGER: alter_env_start')
xml_path = os.path.join(gym_real.__path__[0], "envs/assets/real.xml")

tree = ET.parse(xml_path)
Expand All @@ -38,21 +37,17 @@ def alter_env(exp_type, variant):
pos.set("pos", str(variant) + " 0 " + str(abs(-0.1) + 0.7))

tree.write(xml_path)
print('LOGGER: alter_env_end')

def load_checkpoint(checkpoint, run_path):
print('LOGGER: load_checkpoint_start')
checkpoint_log = run_path + '/checkpoint'
checkpoint_keeper = []
with open(checkpoint_log) as csv_file:
csv_reader = csv.reader(csv_file)
for row in csv_reader:
checkpoint_keeper.append(row[0])
print('LOGGER: load_checkpoint_end')
return run_path + '/' + checkpoint_keeper[checkpoint] + '.pkl'

def log_experiments(exp_num, exp_type, variants, model_names, exp_log, log_dict):
print('LOGGER: log_experiments_start')
file = shelve.open(exp_log)
file['exp' + exp_num] = [exp_type, variants, model_names]
file.close()
Expand All @@ -62,38 +57,31 @@ def log_experiments(exp_num, exp_type, variants, model_names, exp_log, log_dict)
# if os.stat(exp_log).st_size == 0:
# csv_writer.writerow(['exp', 'type', 'variants', 'models'])
# csv_writer.writerow(['exp' + str(exp_num), exp_type, variants, model_names])
print('LOGGER: log_experiments_end')

def run_experiment(exp_num, exp_type, variants, n_cpu, step_total, exp_log, log_dict):
print('LOGGER: run_experiments_start')
model_names = []
run_path = ''
for order, variant in enumerate(variants):
print('LOGGER: run_experiments_mid')
alter_env(exp_type, variant)
env = gym.make("Real-v0")
env = Monitor(env, 'tf_save', allow_early_resets=True)
env = SubprocVecEnv([lambda: env for i in range(n_cpu)])
print('LOGGER: run_experiments_mid2')
if order == 0:
model = PPO2(MlpPolicy, env, verbose=1, tensorboard_log="./tensorboard_log/")
model = PPO2(MlpPolicy, env, verbose=0, tensorboard_log="./tensorboard_log/")
else:
load_name = load_checkpoint(-1, run_path)
model = PPO2.load(load_name, env=env)
print('LOGGER: run_experiments_mid3')
model_names.append(model.model_name)
run_path = model.graph_dir
model.learn(total_timesteps=step_total)
env.close()
del model, env
print('LOGGER: run_experiments_mid4')
log_experiments(exp_num, exp_type, variants, model_names, exp_log, log_dict)
print('LOGGER: run_experiments_end')

if __name__ == "__main__":
n_cpu = 8
n_cpu = 20
n_step = 128
desired_log_pts = 1500
desired_log_pts = 1000
step_total = desired_log_pts * n_cpu * n_step
leg_lengths = [i * -0.1 for i in range(1, 5)]
goal_diss = [i * -2 for i in range(2, 6)]
Expand All @@ -105,7 +93,6 @@ def run_experiment(exp_num, exp_type, variants, n_cpu, step_total, exp_log, log_
dis_type = 'GOAL_DIS'
exp_types = [leg_type, dis_type]
for i in range(5):
print('LOGGER: main_loop_start')
for j, exp_type in enumerate(exp_types):
if exp_type == leg_type:
variant = leg_lengths
Expand All @@ -114,8 +101,6 @@ def run_experiment(exp_num, exp_type, variants, n_cpu, step_total, exp_log, log_
perm_2 = list(permutations(variant, 2))
perm_4 = list(permutations(variant))
perms = perm_2 + perm_4
print('LOGGER: main_loop_mid')
for k, perm in enumerate(perms):
print('LOGGER: main_loop_midstart')
run_experiment(str(i) + '_' + str(j) + '_' + str(k), exp_type, perm, n_cpu, step_total, exp_log, log_dict)

1 change: 1 addition & 0 deletions plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def plot_graph(plot_name, save_name, length, reward):
plt.savefig(save_name + ".png")
plt.savefig(save_name + ".eps")
# plt.show()
plt.close()

plot_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "plot_saves/")
os.makedirs(plot_dir, exist_ok=True)
Expand Down

0 comments on commit 86acc50

Please sign in to comment.