-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_pool.py
executable file
·35 lines (28 loc) · 1.21 KB
/
train_pool.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
#!/usr/bin/env python
from os import path
from multiprocessing import Pool
from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import PPO2
from plugins.snake_env import SnakeEnv
# Optional: PPO2 requires a vectorized environment to run
# the env is now wrapped automatically when passing it to the constructor
def train(model_payload):
env = SnakeEnv(grid_size=4)
model_name = model_payload['name']
model_path = path.join('models', model_name)
learning_rate = model_payload['learning_rate']
if path.isfile(model_path):
env = DummyVecEnv([lambda: env])
model = PPO2.load('model_path', env=env, learning_rate=learning_rate, tensorboard_log=f'./logs/{model_name}')
else:
model = PPO2(MlpPolicy, env=env, learning_rate=learning_rate, tensorboard_log=f'./logs/{model_name}')
model.learn(total_timesteps=1000000)
model.save(model_path)
env.close()
learning_rates = list(map(lambda x: 10**x * 0.0005, range(-2, 1)))
model_data = []
for i, params in enumerate(learning_rates):
model_data.append({'name': f'ppo2_model2_{i}', 'learning_rate': params})
with Pool(3) as p:
p.map(train, model_data)