Skip to content

Commit

Permalink
helper scripts for loading and testing stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
AGKhalil committed Jul 30, 2019
1 parent f6fc033 commit 17e8771
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 0 deletions.
49 changes: 49 additions & 0 deletions plot_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
from stable_baselines.results_plotter import load_results, ts2xy


def moving_average(values, window):
"""
Smooth values by doing a moving average
:param values: (numpy array)
:param window: (int)
:return: (numpy array)
"""
weights = np.repeat(1.0, window) / window
return np.convolve(values, weights, 'valid')


def plot_results(log_folder, model_name, plt_dir, title='Learning Curve'):
"""
plot the results
:param log_folder: (str) the save location of the results to plot
:param title: (str) the title of the task to plot
"""
# m_name_csv = model_name + ".csv"
# old_file_name = os.path.join(log_folder, "monitor.csv")
# new_file_name = os.path.join(log_folder, m_name_csv)
save_name = os.path.join(plt_dir, model_name)

x, y = ts2xy(load_results(log_folder), 'timesteps')
# shutil.copy(old_file_name, new_file_name)
y = moving_average(y, window=10)
# Truncate x
x = x[len(x) - len(y):]

fig = plt.figure(title)
plt.plot(x, y)
plt.xlabel('Number of Timesteps')
plt.ylabel('Rewards')
plt.title(title + " Smoothed")
print('Saving plot at:', save_name)
plt.savefig(save_name + ".png")
plt.savefig(save_name + ".eps")
print("plots saved...")
# plt.show()

if __name__ == '__main__':
plot_results(str(sys.argv[1]), str(sys.argv[2]), str(sys.argv[3]))
9 changes: 9 additions & 0 deletions stand_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import subprocess
import sys

if __name__ == "__main__":
env_name = str(sys.argv[1])
model_name = str(sys.argv[2])

subprocess.Popen(
'''export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libGLEW.so:/usr/lib/nvidia-410/libGL.so; python load_agent.py '%s' '%s' ''' % (env_name, model_name), shell=True)
26 changes: 26 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import gym

from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common.vec_env import SubprocVecEnv
from stable_baselines import PPO2

# multiprocess environment
n_cpu = 20
env = SubprocVecEnv([lambda: gym.make('CartPole-v1') for i in range(n_cpu)])

model = PPO2(MlpPolicy, env, verbose=1)
model.learn(total_timesteps=25000)
model.save("ppo2_cartpole")

del model # remove to demonstrate saving and loading

model = PPO2.load("ppo2_cartpole")

# Enjoy trained agent
obs = env.reset()
for _ in range(1000):
action, _states = model.predict(obs)
obs, rewards, dones, info = env.step(action)
# env.render()

env.close()

0 comments on commit 17e8771

Please sign in to comment.