-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathmain.py
57 lines (44 loc) · 1.76 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from absl import app
from absl import flags
import sys
import torch
from utils import arglist
from runs.minigame import MiniGame
from utils.preprocess import Preprocess
torch.set_default_tensor_type('torch.FloatTensor')
torch.manual_seed(arglist.SEED)
FLAGS = flags.FLAGS
FLAGS(sys.argv)
flags.DEFINE_bool("render", False, "Whether to render with pygame.")
env_names = ["DefeatZerglingsAndBanelings", "DefeatRoaches",
"CollectMineralShards", "MoveToBeacon", "FindAndDefeatZerglings",
"BuildMarines", "CollectMineralsAndGas"]
rl_algo = 'ddpg'
def main(_):
for map_name in env_names:
if rl_algo == 'ddpg':
from agent.ddpg import DDPGAgent
from networks.acnetwork_q_seperated import ActorNet, CriticNet
from utils.memory import SequentialMemory
actor = ActorNet()
critic = CriticNet()
memory = SequentialMemory(limit=arglist.DDPG.memory_limit)
learner = DDPGAgent(actor, critic, memory)
elif rl_algo == 'ppo':
from agent.ppo import PPOAgent
from networks.acnetwork_v_seperated import ActorNet, CriticNet
from utils.memory import EpisodeMemory
actor = ActorNet()
critic = CriticNet()
memory = EpisodeMemory(limit=arglist.PPO.memory_limit,
action_shape=arglist.action_shape,
observation_shape=arglist.observation_shape)
learner = PPOAgent(actor, critic, memory)
else:
raise NotImplementedError()
preprocess = Preprocess()
game = MiniGame(map_name, learner, preprocess, nb_episodes=10000)
game.run_ddpg()
return 0
if __name__ == '__main__':
app.run(main)