Skip to content

Commit

Permalink
Change test=False to explore=True in acting
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaixhin committed Oct 1, 2019
1 parent 613f24b commit b1b5409
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,12 @@
free_nats = torch.full((1, ), args.free_nats, device=args.device) # Allowed deviation in KL divergence


def update_belief_and_act(args, env, planner, transition_model, encoder, belief, posterior_state, action, observation, test):
def update_belief_and_act(args, env, planner, transition_model, encoder, belief, posterior_state, action, observation, explore=False):
# Infer belief over current state q(s_t|o≤t,a<t) from the history
belief, _, _, _, posterior_state, _, _ = transition_model(posterior_state, action.unsqueeze(dim=0), belief, encoder(observation).unsqueeze(dim=0)) # Action and observation need extra time dimension
belief, posterior_state = belief.squeeze(dim=0), posterior_state.squeeze(dim=0) # Remove time dimension from belief/state
action = planner(belief, posterior_state) # Get action from planner(q(s_t|o≤t,a<t), p)
if not test:
if explore:
action = action + args.action_noise * torch.randn_like(action) # Add exploration noise ε ~ p(ε) to the action
next_observation, reward, done = env.step(action.cpu() if isinstance(env, EnvBatcher) else action[0].cpu()) # Perform environment step (action repeats handled internally)
return belief, posterior_state, action, next_observation, reward, done
Expand All @@ -142,7 +142,7 @@ def update_belief_and_act(args, env, planner, transition_model, encoder, belief,
belief, posterior_state, action = torch.zeros(1, args.belief_size, device=args.device), torch.zeros(1, args.state_size, device=args.device), torch.zeros(1, env.action_size, device=args.device)
pbar = tqdm(range(args.max_episode_length // args.action_repeat))
for t in pbar:
belief, posterior_state, action, observation, reward, done = update_belief_and_act(args, env, planner, transition_model, encoder, belief, posterior_state, action, observation.to(device=args.device), test=True)
belief, posterior_state, action, observation, reward, done = update_belief_and_act(args, env, planner, transition_model, encoder, belief, posterior_state, action, observation.to(device=args.device))
total_reward += reward
if args.render:
env.render()
Expand Down Expand Up @@ -218,7 +218,7 @@ def update_belief_and_act(args, env, planner, transition_model, encoder, belief,
belief, posterior_state, action = torch.zeros(1, args.belief_size, device=args.device), torch.zeros(1, args.state_size, device=args.device), torch.zeros(1, env.action_size, device=args.device)
pbar = tqdm(range(args.max_episode_length // args.action_repeat))
for t in pbar:
belief, posterior_state, action, next_observation, reward, done = update_belief_and_act(args, env, planner, transition_model, encoder, belief, posterior_state, action, observation.to(device=args.device), test=False)
belief, posterior_state, action, next_observation, reward, done = update_belief_and_act(args, env, planner, transition_model, encoder, belief, posterior_state, action, observation.to(device=args.device), explore=True)
D.append(observation, action.cpu(), reward, done)
total_reward += reward
observation = next_observation
Expand Down Expand Up @@ -250,7 +250,7 @@ def update_belief_and_act(args, env, planner, transition_model, encoder, belief,
belief, posterior_state, action = torch.zeros(args.test_episodes, args.belief_size, device=args.device), torch.zeros(args.test_episodes, args.state_size, device=args.device), torch.zeros(args.test_episodes, env.action_size, device=args.device)
pbar = tqdm(range(args.max_episode_length // args.action_repeat))
for t in pbar:
belief, posterior_state, action, next_observation, reward, done = update_belief_and_act(args, test_envs, planner, transition_model, encoder, belief, posterior_state, action, observation.to(device=args.device), test=True)
belief, posterior_state, action, next_observation, reward, done = update_belief_and_act(args, test_envs, planner, transition_model, encoder, belief, posterior_state, action, observation.to(device=args.device))
total_rewards += reward.numpy()
if not args.symbolic_env: # Collect real vs. predicted frames for video
video_frames.append(make_grid(torch.cat([observation, observation_model(belief, posterior_state).cpu()], dim=3) + 0.5, nrow=5).numpy()) # Decentre
Expand Down

0 comments on commit b1b5409

Please sign in to comment.