-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathagent.py
42 lines (34 loc) · 1.58 KB
/
agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
from utils import preprocess_obs
class Agent:
"""
Agent class to get action with action model
and maintain rnn_hidden for input of action model
"""
def __init__(self, encoder, rssm, action_model):
self.encoder = encoder
self.rssm = rssm
self.action_model = action_model
self.device = next(self.action_model.parameters()).device
self.rnn_hidden = torch.zeros(1, rssm.rnn_hidden_dim, device=self.device)
def __call__(self, obs, training=True):
"""
if training=False, returned action is mean
instead of sample from action_model's distribution
"""
# preprocess observation and transpose for torch style (channel-first)
obs = preprocess_obs(obs)
obs = torch.as_tensor(obs, device=self.device)
obs = obs.transpose(1, 2).transpose(0, 1).unsqueeze(0)
with torch.no_grad():
# embed observation, compute state posterior, sample from state posterior
# and get action using sampled state and rnn_hidden as input
embedded_obs = self.encoder(obs)
state_posterior = self.rssm.posterior(self.rnn_hidden, embedded_obs)
state = state_posterior.sample()
action = self.action_model(state, self.rnn_hidden, training=training)
# update rnn_hidden for next step
_, self.rnn_hidden = self.rssm.prior(state, action, self.rnn_hidden)
return action.squeeze().cpu().numpy()
def reset(self):
self.rnn_hidden = torch.zeros(1, self.rssm.rnn_hidden_dim, device=self.device)