-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathrun_cartpole.py
89 lines (78 loc) · 2.81 KB
/
run_cartpole.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#!/usr/bin/env python
# coding: utf-8
"""Run PETS on classical continuous cartpole."""
import time
import acme
import jax
import numpy as np
from absl import app
from absl import flags
from acme import specs
from acme import wrappers
from environments.cartpole_continuous import CartPoleEnv
from gym import wrappers as gym_wrappers
from ml_collections import config_flags
from magi.agents.pets import builder
from magi.utils import loggers
FLAGS = flags.FLAGS
config_flags.DEFINE_config_file("config")
flags.mark_flag_as_required("config")
flags.DEFINE_integer("num_episodes", int(100), "Number of episodes.")
flags.DEFINE_integer("seed", 0, "Random seed.")
flags.DEFINE_bool("wandb", False, "whether to log result to wandb")
flags.DEFINE_string("wandb_project", "", "wandb project name")
flags.DEFINE_string("wandb_entity", "", "wandb project entity")
def make_environment(seed, task_horizon):
"""Creates an OpenAI Gym environment."""
# Load the gym environment.
environment = CartPoleEnv()
environment = gym_wrappers.TimeLimit(environment, task_horizon)
environment.seed(seed)
environment = wrappers.GymWrapper(environment)
environment = wrappers.SinglePrecisionWrapper(environment)
return environment
def main(unused_argv):
del unused_argv
config = FLAGS.config
np.random.seed(FLAGS.seed)
rng = np.random.default_rng(FLAGS.seed + 1)
environment = make_environment(int(rng.integers(0, 2**32)), config.task_horizon)
environment_spec = specs.make_environment_spec(environment)
agent = builder.make_agent(
environment_spec,
config.reward_fn,
config.termination_fn,
config.obs_preproc,
config.obs_postproc,
config.targ_proc,
hidden_sizes=config.hidden_sizes,
population_size=config.population_size,
activation=jax.nn.silu,
planning_horizon=config.planning_horizon,
cem_alpha=config.cem_alpha,
cem_elite_frac=config.cem_elite_frac,
cem_return_mean_elites=config.cem_return_mean_elites,
weight_decay=config.weight_decay,
lr=config.lr,
min_delta=config.min_delta,
num_ensembles=config.num_ensembles,
num_particles=config.num_particles,
num_epochs=config.num_epochs,
seed=rng.integers(0, 2**32),
patience=config.patience,
)
logger = loggers.make_logger(
"environment_loop",
use_wandb=FLAGS.wandb,
wandb_kwargs={
"project": FLAGS.wandb_project,
"entity": FLAGS.wandb_entity,
"name": f"pets_cartpole_{FLAGS.seed}_{int(time.time())}",
"config": FLAGS,
},
)
env_loop = acme.EnvironmentLoop(environment, agent, logger=logger)
env_loop.run(num_episodes=FLAGS.num_episodes)
if __name__ == "__main__":
jax.config.config_with_absl()
app.run(main)