forked from aliu22/plasma-profile-predictor
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun.py
111 lines (103 loc) · 4.48 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import argparse
from tqdm import trange, tqdm
import pickle
from profile_env import ProfileEnv, TearingProfileEnv, SCENARIO_PATH,\
TEARING_PATH, NN_TEARING_PATH
from policy import PIDPolicy, PINJRLPolicy
from mpc import CEM, RS
from utils import make_output_dir
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("name", help="The name of the experiment and output directory")
parser.add_argument("--policy", default="RS", choices=["RS", "CEM", "PID", "RL"])
parser.add_argument("--num_trials", type=int, default=100, help="The number of rollouts to conduct")
parser.add_argument("--num_samples", type=int, default=1000, help="The number of samples in an RS run")
parser.add_argument("--popsize", type=int, default=100, help="The population size for CEM")
parser.add_argument("--num_elites", type=int, default=10, help="The number of elites in a CEM run")
parser.add_argument("--num_iters", type=int, default=10, help="The number of iterations of CEM to run")
parser.add_argument("--discount_rate", type=float, default=1., help="The discount rate for optimization purposes")
parser.add_argument("--horizon", type=int, default=5, help="The horizon for optimization")
parser.add_argument("--alpha_cem", type=float, default=0.25, help="The alpha for CEM")
parser.add_argument("--epsilon_cem", type=float, default=0.01, help="The epsilon for CEM")
parser.add_argument("--env", default="full", choices=["full", "betan"])
parser.add_argument("-ow", dest="overwrite", action="store_true")
parser.add_argument("-P", type=float, default=0.2, help="Proportional gain")
parser.add_argument("-I", type=float, default=0.5, help="Integral gain")
parser.add_argument("-D", type=float, default=0.0, help="Derivative gain")
parser.add_argument('--rl_model_path', help='Path to policy.')
parser.add_argument('--use_nn_tearing', action='store_true')
parser.add_argument('--cuda_device', default='')
parser.add_argument('--pudb', action='store_true')
return parser.parse_args()
def run_trial(policy, env):
state = env.reset()
policy.reset()
states = [state]
actions = []
rewards = []
infos = []
done = False
while not done:
action = policy(state)
if actions is None:
break
state, reward, done, info = env.step(action)
states.append(state)
actions.append(action)
tqdm.write(f"Action Reward: {reward}")
rewards.append(reward)
infos.append(info)
tqdm.write(f"Total Reward: {sum(rewards)}")
return states, actions, rewards, infos
def create_env(args):
if args.env == "betan":
env = ProfileEnv(scenario_path=SCENARIO_PATH)
elif args.env == "full":
rew_coefs = (9, 10)
tpath = NN_TEARING_PATH if args.use_nn_tearing else TEARING_PATH
env = TearingProfileEnv(scenario_path=SCENARIO_PATH,
tearing_path=tpath,
rew_coefs=rew_coefs,
nn_tearing=args.use_nn_tearing)
return env
def main(args):
if args.pudb:
import pudb; pudb.set_trace()
output_dir = make_output_dir(args.name, args.overwrite, args)
env = create_env(args)
if args.policy == "RS":
policy = RS(env=env,
horizon=args.horizon,
shots=args.num_samples)
elif args.policy == "CEM":
policy = CEM(env,
horizon=args.horizon,
popsize=args.popsize,
n_elites=args.num_elites,
n_iters=args.num_iters,
alpha=args.alpha_cem,
epsilon=args.epsilon_cem)
elif args.policy == "PID":
policy = PIDPolicy(env=env,
P=args.P,
I=args.I,
D=args.D,
tau=env.tau)
elif args.policy == 'RL':
policy = PINJRLPolicy(
model_path=args.rl_model_path,
env=env,
cuda_device=args.cuda_device,
)
else:
raise ValueError('Unknown policy: %s' % args.policy)
episodes = []
episode_path = output_dir / 'episodes.pk'
for i in trange(args.num_trials):
states, actions, rewards, infos = run_trial(policy, env)
episodes.append((states, actions, rewards, infos))
with episode_path.open('wb') as f:
pickle.dump(episodes, f)
if __name__ == '__main__':
args = parse_arguments()
main(args)