-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpets.py
127 lines (108 loc) · 4.81 KB
/
pets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import hydra
import logging
import gym
import torch
import numpy as np
import barl.envs # NOQA
from tqdm import trange
from barl.util.misc_util import Dumper
from pets_reward_functions import reward_functions
import mbrl.models as models
import mbrl.planning as planning
import mbrl.util.common as common_util
from mbrl.env.termination_fns import no_termination
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")
@hydra.main(config_path='cfg', config_name='pets')
def main(config):
# set seeds
np.random.seed(config.seed)
torch.manual_seed(config.seed)
dumper = Dumper(config.name)
config.dynamics_model.model.device = device
config.agent.optimizer_cfg.device = device
env = gym.make(config.env.name)
env.seed(config.seed)
obs_shape = env.observation_space.shape
act_shape = env.action_space.shape
# Create a 1-D dynamics model for this environment
dynamics_model = common_util.create_one_dim_tr_model(config, obs_shape, act_shape)
# TODO
reward_fn = reward_functions[config.env.name]
# Create a gym-like environment to encapsulate the model
model_env = models.ModelEnv(env, dynamics_model, no_termination, reward_fn)
replay_buffer = common_util.create_replay_buffer(config, obs_shape, act_shape)
common_util.rollout_agent_trajectories(
env,
config.env.max_path_length, # initial exploration steps
planning.RandomAgent(env),
{}, # keyword arguments to pass to agent.act()
replay_buffer=replay_buffer,
trial_length=config.env.max_path_length,
)
print("# samples stored", replay_buffer.num_stored)
agent_cfg = config.agent
action_lb = env.action_space.low
action_ub = env.action_space.high
agent_cfg.optimizer_cfg.lower_bound = np.tile(action_lb, (agent_cfg.planning_horizon, 1)).tolist()
agent_cfg.optimizer_cfg.upper_bound = np.tile(action_ub, (agent_cfg.planning_horizon, 1)).tolist()
agent = planning.create_trajectory_optim_agent_for_model(
model_env,
agent_cfg,
num_particles=20
)
def train_callback(_model, _total_calls, _epoch, tr_loss, val_score, _best_val):
pass
# dumper.add('Train Loss', tr_loss)
# dumper.add('Val Score', val_score.mean().item())
# Create a trainer for the model
model_trainer = models.ModelTrainer(dynamics_model, optim_lr=1e-3, weight_decay=5e-5)
# Main PETS loop
for trial in range(config.num_trials):
eval_returns = []
pbar = trange(config.num_eval_trials)
for etrial in pbar:
obs = env.reset()
agent.reset()
done = False
total_reward = 0.0
steps_trial = 0
while not done:
# --------------- Model Training -----------------
if etrial == 0:
if steps_trial == 0:
dynamics_model.update_normalizer(replay_buffer.get_all()) # update normalizer stats
dataset_train, dataset_val = common_util.get_basic_buffer_iterators(
replay_buffer,
batch_size=config.overrides.model_batch_size,
val_ratio=config.overrides.validation_ratio,
ensemble_size=config.dynamics_model.model.ensemble_size,
shuffle_each_epoch=True,
bootstrap_permutes=False, # build bootstrap dataset using sampling with replacement
)
model_trainer.train(
dataset_train,
dataset_val=dataset_val,
num_epochs=50,
patience=50,
callback=train_callback)
# --- Doing env step using the agent and adding to model dataset ---
next_obs, reward, done, _ = common_util.step_env_and_add_to_buffer(
env, obs, agent, {}, replay_buffer)
else:
action = agent.act(obs)
next_obs, reward, done, _ = env.step(action)
obs = next_obs
total_reward += reward
steps_trial += 1
if steps_trial == config.env.max_path_length:
break
eval_returns.append(total_reward)
stats = {"Mean Return": np.mean(eval_returns), "Std Return:": np.std(eval_returns)}
pbar.set_postfix(stats)
logging.info(f"Trial {trial}, returns_mean={np.mean(eval_returns):.2f}, returns_std={np.std(eval_returns):.2f}, ndata={len(replay_buffer)}") # NOQA
dumper.add('Eval Returns', eval_returns)
dumper.add('Eval ndata', len(replay_buffer))
dumper.save()
if __name__ == '__main__':
main()