-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathppo.py
208 lines (185 loc) · 7.67 KB
/
ppo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
# This script will run the PPO experiments listed in the paper
import popgym # noqa: F401
import popgym.envs
from popgym import wrappers
# Feel free to set the environment variables to whatever envs/models
# you would like to test
import os
from typing import Any, List
import ray
import torch
from ray.tune.registry import register_env # noqa: F401
from ray.rllib.algorithms.callbacks import DefaultCallbacks
import popgym # noqa: F401
from popgym import wrappers
from popgym.baselines.ray_models.ray_diffnc import DiffNC # noqa: F401
from popgym.baselines.ray_models.ray_elman import Elman
from popgym.baselines.ray_models.ray_frameconv import Frameconv
from popgym.baselines.ray_models.ray_framestack import Framestack
from popgym.baselines.ray_models.ray_fwp import FastWeightProgrammer
from popgym.baselines.ray_models.ray_gru import GRU
from popgym.baselines.ray_models.ray_indrnn import IndRNN
from popgym.baselines.ray_models.ray_linear_attention import LinearAttention
from popgym.baselines.ray_models.ray_lmu import LMU
from popgym.baselines.ray_models.ray_lstm import LSTM
from popgym.baselines.ray_models.ray_mlp import MLP, BasicMLP
from popgym.baselines.ray_models.ray_s4d import S4D
from popgym.core.env import POPGymEnv
from models.ray_ffm import RayFFM, RayFFMNoOscillate, RayFFMNoLearnOscillate, RayFFMNoDecay, RayFFMNoLearnDecay, RayFFMNoInGate, RayFFMNoOutGate, LoggingCallback
def main():
env_names: List[Any] = []
env_types = os.environ.get("POPGYM_EXPERIMENT", "ALL")
desired_models = os.environ.get("POPGYM_MODELS", "ALL")
num_splits = int(os.environ.get("POPGYM_NUM_SPLITS", 1))
split_id = int(os.environ.get("POPGYM_SPLIT_ID", 0))
project_id = os.environ.get("POPGYM_PROJECT", "FFM_debug")
gpu_per_worker = float(os.environ.get("POPGYM_GPU", 0.25))
max_steps = int(os.environ.get("POPGYM_STEPS", 15e6))
storage_path = os.environ.get("POPGYM_STORAGE", "/tmp/ray_results")
num_samples = int(os.environ.get("POPGYM_SAMPLES", 1))
# Used for testing
# Maximum episode length and backprop thru time truncation length
bptt_cutoff = int(os.environ.get("POPGYM_BPTT_CUTOFF", 1024))
num_workers = int(os.environ.get("POPGYM_WORKERS", 4))
num_minibatch = int(os.environ.get("POPGYM_MINIBATCH", 8))
num_envs_per_worker = int(os.environ.get("POPGYM_ENVS_PER_WORKER", 16))
# Hidden size of linear layers
h = 128
# Hidden size of memory
h_memory = 256
train_batch_size = bptt_cutoff * max(num_workers, 1) * num_envs_per_worker
def wrap(env: POPGymEnv) -> POPGymEnv:
return wrappers.Antialias(wrappers.PreviousAction(env))
# Register all envs with ray
envs = popgym.envs.ALL
for cls, info in envs.items():
env_name = info["id"]
register_env(env_name, lambda x: wrap(cls()))
# Of the registered envs, pick out the ones we actually want to run
env_names = []
for e in env_types.split(","):
# getattr will either return a dict of {class: info[name}}
# or just a class
res = getattr(popgym.envs, e)
if isinstance(res, dict):
desired_envs = list(res.keys())
else:
desired_envs = [res]
for d in desired_envs:
env_names.append(envs[d]["id"])
env_names = env_names[split_id::num_splits]
# Setup the models we want to train
attn_models = [
LinearAttention,
FastWeightProgrammer,
]
rnn_models = [LSTM, GRU, Elman, LMU, IndRNN, DiffNC]
conv_models = [S4D]
basic_models = [
BasicMLP,
MLP,
Framestack,
Frameconv,
]
ffm_models = [RayFFM, RayFFMNoOscillate, RayFFMNoLearnOscillate, RayFFMNoDecay, RayFFMNoLearnDecay, RayFFMNoInGate, RayFFMNoOutGate]
models = ffm_models #+ rnn_models + attn_models + conv_models + basic_models
# Filter models by env variable
if desired_models != "ALL":
models = [m for m in models if m.__name__ in desired_models.split(",")]
def trial_name(trial):
env = trial.config["env"].replace("popgym-", "").replace("-v0", "")
model = trial.config["model"]["custom_model"].__name__
emb = trial.config["model"]["custom_model_config"].get("embedding", None)
emb_size = trial.config["model"]["custom_model_config"].get(
"embedding_size", None
)
return "-".join([str(s) for s in [env, model, emb, emb_size] if s is not None])
config = {
# Environments or env names
"env": ray.tune.grid_search(env_names),
# Should always be torch
"framework": "torch",
# Number of rollout workers
"num_workers": num_workers,
#"num_gpus_per_worker": gpu_per_worker,
# Number of envs per rollout worker
"num_envs_per_worker": num_envs_per_worker,
# Num gpus used for the train worker
"num_gpus": gpu_per_worker,
# Loss coeff for the ppo value function
"vf_loss_coeff": 1.0,
# Num transitions in each training epoch
"train_batch_size": train_batch_size,
# Chunk size of transitions sent from rollout workers to trainer
"rollout_fragment_length": bptt_cutoff,
# Size of minibatches within epoch
"sgd_minibatch_size": num_minibatch * bptt_cutoff,
# decay gamma
"gamma": 0.99,
# Required due to RLlib PPO bugs
"horizon": bptt_cutoff,
# RLlib bug with truncate_episodes:
# each batch the temporal dim shrinks by one
# for now, just use complete_episodes
"batch_mode": "complete_episodes",
# "min_sample_timesteps_per_reporting": train_batch_size,
"min_sample_timesteps_per_iteration": train_batch_size,
"callbacks": LoggingCallback,
# Describe your RL model here
"model": {
# Truncate sequences into no more than this many timesteps
"max_seq_len": bptt_cutoff,
# Custom model class
"custom_model": ray.tune.grid_search(models),
# Config passed to custom model constructor
# see base_model.py to see how these are used
"custom_model_config": {
"preprocessor_input_size": h,
"preprocessor": torch.nn.Sequential(
torch.nn.Linear(h, h),
torch.nn.LeakyReLU(inplace=True),
),
"preprocessor_output_size": h,
"hidden_size": h_memory,
"postprocessor": torch.nn.Identity(),
"actor": torch.nn.Sequential(
torch.nn.Linear(h_memory, h),
torch.nn.LeakyReLU(inplace=True),
torch.nn.Linear(h, h),
torch.nn.LeakyReLU(inplace=True),
),
"critic": torch.nn.Sequential(
torch.nn.Linear(h_memory, h),
torch.nn.LeakyReLU(inplace=True),
torch.nn.Linear(h, h),
torch.nn.LeakyReLU(inplace=True),
),
"postprocessor_output_size": h,
},
},
}
# When to stop training
stop = {"timesteps_total": max_steps}
# Write your own wandb entity here
if project_id:
from ray.air.callbacks.wandb import WandbLoggerCallback
logging_callbacks = [
WandbLoggerCallback(
project=project_id, entity="prorok-lab", log_config=True
)
]
else:
logging_callbacks = []
ray.init()
ray.tune.run(
"PPO",
config=config,
stop=stop,
callbacks=logging_callbacks,
trial_name_creator=trial_name,
verbose=1,
num_samples=num_samples,
local_dir=storage_path,
)
if __name__ == "__main__":
main()