-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtest.py
38 lines (31 loc) · 1.29 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import gym
import json
from src.utils import *
from src.agents import *
class Tester:
def __init__(self, config_file, model_file='./pretrained_models/bipedal_walker_v3/ddpg_bipedalwalker_v3_0'):
self.config = Tester.parse_config(config_file)
self.env = gym.make(self.config['env_name'])
self.state_dimension = self.env.observation_space.shape[0]
self.action_dimension = self.env.action_space.shape[0]
self.max_action = float(self.env.action_space.high[0])
self.device = torch.device('cpu')
self.agent = DDPGAgent(
state_dim=self.state_dimension, action_dim=self.action_dimension,
max_action=self.max_action, device=self.device,
discount=self.config['discount'], tau=self.config['tau']
)
self.agent.load_checkpoint(model_file)
@staticmethod
def parse_config(json_file):
with open(json_file, 'r') as f:
configs = json.load(f)
return configs
def test(self, eval_episodes, render=True):
self.mean_rewards = evaluate_policy(
self.agent, self.config['env_name'], self.config['seed'],
eval_episodes=eval_episodes, render=render
)
print(self.mean_rewards)
# tester = Tester('./configs/BipedalWalker-v3.json')
# tester.test()