-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
helper scripts for loading and testing stuff
- Loading branch information
Showing
3 changed files
with
84 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import os | ||
import sys | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
from stable_baselines.results_plotter import load_results, ts2xy | ||
|
||
|
||
def moving_average(values, window): | ||
""" | ||
Smooth values by doing a moving average | ||
:param values: (numpy array) | ||
:param window: (int) | ||
:return: (numpy array) | ||
""" | ||
weights = np.repeat(1.0, window) / window | ||
return np.convolve(values, weights, 'valid') | ||
|
||
|
||
def plot_results(log_folder, model_name, plt_dir, title='Learning Curve'): | ||
""" | ||
plot the results | ||
:param log_folder: (str) the save location of the results to plot | ||
:param title: (str) the title of the task to plot | ||
""" | ||
# m_name_csv = model_name + ".csv" | ||
# old_file_name = os.path.join(log_folder, "monitor.csv") | ||
# new_file_name = os.path.join(log_folder, m_name_csv) | ||
save_name = os.path.join(plt_dir, model_name) | ||
|
||
x, y = ts2xy(load_results(log_folder), 'timesteps') | ||
# shutil.copy(old_file_name, new_file_name) | ||
y = moving_average(y, window=10) | ||
# Truncate x | ||
x = x[len(x) - len(y):] | ||
|
||
fig = plt.figure(title) | ||
plt.plot(x, y) | ||
plt.xlabel('Number of Timesteps') | ||
plt.ylabel('Rewards') | ||
plt.title(title + " Smoothed") | ||
print('Saving plot at:', save_name) | ||
plt.savefig(save_name + ".png") | ||
plt.savefig(save_name + ".eps") | ||
print("plots saved...") | ||
# plt.show() | ||
|
||
if __name__ == '__main__': | ||
plot_results(str(sys.argv[1]), str(sys.argv[2]), str(sys.argv[3])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
import subprocess | ||
import sys | ||
|
||
if __name__ == "__main__": | ||
env_name = str(sys.argv[1]) | ||
model_name = str(sys.argv[2]) | ||
|
||
subprocess.Popen( | ||
'''export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libGLEW.so:/usr/lib/nvidia-410/libGL.so; python load_agent.py '%s' '%s' ''' % (env_name, model_name), shell=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import gym | ||
|
||
from stable_baselines.common.policies import MlpPolicy | ||
from stable_baselines.common.vec_env import SubprocVecEnv | ||
from stable_baselines import PPO2 | ||
|
||
# multiprocess environment | ||
n_cpu = 20 | ||
env = SubprocVecEnv([lambda: gym.make('CartPole-v1') for i in range(n_cpu)]) | ||
|
||
model = PPO2(MlpPolicy, env, verbose=1) | ||
model.learn(total_timesteps=25000) | ||
model.save("ppo2_cartpole") | ||
|
||
del model # remove to demonstrate saving and loading | ||
|
||
model = PPO2.load("ppo2_cartpole") | ||
|
||
# Enjoy trained agent | ||
obs = env.reset() | ||
for _ in range(1000): | ||
action, _states = model.predict(obs) | ||
obs, rewards, dones, info = env.step(action) | ||
# env.render() | ||
|
||
env.close() |