-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathagent.py
62 lines (53 loc) · 1.87 KB
/
agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import argparse
from tdw_transport_challenge.challenge import Challenge
from tdw_transport_challenge.simple_agent import TestAgent
from tdw_transport_challenge.h_agent import H_agent
import logging
import os
def init_logs():
logger = logging.getLogger('simple_example')
logger.setLevel(logging.DEBUG)
os.makedirs("results", exist_ok=True)
fh = logging.FileHandler(os.path.join("results", "output.log"))
fh.setLevel(logging.DEBUG)
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
fh.setFormatter(formatter)
ch.setFormatter(formatter)
logger.addHandler(fh)
logger.addHandler(ch)
return logger
def get_agent(agent_class, logger, ckpt_path=""):
if agent_class == "Test":
return TestAgent()
elif agent_class == "h_agent":
return H_agent(logger=logger)
'''elif agent_class == "Random":
return RandomAgent()
elif agent_class == "ForwardOnly":
return ForwardOnlyAgent()
elif agent_class == "SAC":
return SACAgent(root_dir=ckpt_path)'''
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--agent-class", type=str, default="Test", choices=["Test", "h_agent"])
parser.add_argument("--ckpt-path", default="", type=str)
parser.add_argument("--port", default=1071, type=int)
args = parser.parse_args()
if not os.path.exists('/results'):
os.mkdir('/results')
logger = init_logs()
# Instantiate your agent here
agent = get_agent(
agent_class=args.agent_class,
logger=logger,
ckpt_path=args.ckpt_path
)
challenge = Challenge(logger, args.port)
try:
challenge.submit(agent)
finally:
challenge.close()
if __name__ == "__main__":
main()