Skip to content

Commit

Permalink
Merge pull request #133 from cpnota/release/0.5.0
Browse files Browse the repository at this point in the history
Release/0.5.0
  • Loading branch information
cpnota authored Apr 18, 2020
2 parents 76bc1b3 + 4919e87 commit 57536b2
Show file tree
Hide file tree
Showing 69 changed files with 1,139 additions and 603 deletions.
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ ignored-classes=optparse.Values,thread._local,_thread._local
# (useful for modules/projects where namespaces are manipulated during runtime
# and thus existing member attributes cannot be deduced by static analysis. It
# supports qualified module names, as well as Unix pattern matching.
ignored-modules=
ignored-modules=numpy

# Show a hint with possible names when a member name was not found. The aspect
# of finding the hint is based on edit distance.
Expand Down
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ branches:
install:
- pip install https://download.pytorch.org/whl/cpu/torch-1.0.1.post2-cp36-cp36m-linux_x86_64.whl
- pip install torchvision
- pip install -q -e .
- pip install -q -e .["dev"]
script:
- make lint
- make test
4 changes: 1 addition & 3 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
install:
conda install pytorch torchvision cudatoolkit=10.1 -c pytorch
pip install tensorboard
pip install -e .
pip install -e .[dev]

lint:
pylint all --rcfile=.pylintrc
Expand Down
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,10 @@ pip install autonomous-learning-library[pytorch]

## Running the Presets

If you just want to test out some cool agents, the `scripts` directory contains the basic code for doing so.
If you just want to test out some cool agents, the library includes several scripts for doing so:

```
python scripts/atari.py Breakout a2c
all-atari Breakout a2c
```

You can watch the training progress using:
Expand All @@ -84,12 +84,16 @@ and opening your browser to http://localhost:6006.
Once the model is trained to your satisfaction, you can watch the trained model play using:

```
python scripts/watch_atari.py Breakout "runs/_a2c [id]"
all-watch-atari Breakout "runs/_a2c [id]"
```

where `id` is the ID of your particular run. You should should be able to find it using tab completion or by looking in the `runs` directory.
The `autonomous-learning-library` also contains presets and scripts for classic control and PyBullet environments.

If you want to test out your own agents, you will need to define your own scripts.
Some examples can be found in the `examples` folder).
See the [docs](https://autonomous-learning-library.readthedocs.io) for information on building your own agents!

## Note

This library was built in the [Autonomous Learning Laboratory](http://all.cs.umass.edu) (ALL) at the [University of Massachusetts, Amherst](https://www.umass.edu).
Expand Down
19 changes: 18 additions & 1 deletion all/agents/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,24 @@ def act(self, state, reward):
Args:
state (all.environment.State): The environment state at the current timestep.
reward (torch.Tensor): The reward from the previous timestep.
info (:obj:, optional): The info object from the environment.
Returns:
torch.Tensor: The action to take at the current timestep.
"""

@abstractmethod
def eval(self, state, reward):
"""
Select an action for the current timestep in evaluation mode.
Unlike act, this method should NOT update the internal parameters of the agent.
Most of the time, this method should return the greedy action according to the current policy.
This method is useful when using evaluation methodologies that distinguish between the performance
of the agent during training and the performance of the resulting policy.
Args:
state (all.environment.State): The environment state at the current timestep.
reward (torch.Tensor): The reward from the previous timestep.
Returns:
torch.Tensor: The action to take at the current timestep.
Expand Down
5 changes: 4 additions & 1 deletion all/agents/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,12 @@ def act(self, states, rewards):
self._buffer.store(self._states, self._actions, rewards)
self._train(states)
self._states = states
self._actions = self.policy.eval(self.features.eval(states)).sample()
self._actions = self.policy.no_grad(self.features.no_grad(states)).sample()
return self._actions

def eval(self, states, _):
return self.policy.eval(self.features.eval(states))

def _train(self, next_states):
if len(self._buffer) >= self._batch_size:
# load trajectories from buffer
Expand Down
10 changes: 6 additions & 4 deletions all/agents/c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,21 +60,23 @@ def act(self, state, reward):
self._action = self._choose_action(state)
return self._action

def eval(self, state, _):
return self._best_actions(self.q_dist.eval(state))

def _choose_action(self, state):
if self._should_explore():
return torch.randint(
self.q_dist.n_actions, (len(state),), device=self.q_dist.device
)
return self._best_actions(state)
return self._best_actions(self.q_dist.no_grad(state))

def _should_explore(self):
return (
len(self.replay_buffer) < self.replay_start_size
or np.random.rand() < self.exploration
)

def _best_actions(self, states):
probs = self.q_dist.eval(states)
def _best_actions(self, probs):
q_values = (probs * self.q_dist.atoms).sum(dim=2)
return torch.argmax(q_values, dim=1)

Expand Down Expand Up @@ -103,7 +105,7 @@ def _should_train(self):
return self._frames_seen > self.replay_start_size and self._frames_seen % self.update_frequency == 0

def _compute_target_dist(self, states, rewards):
actions = self._best_actions(states)
actions = self._best_actions(self.q_dist.no_grad(states))
dist = self.q_dist.target(states, actions)
shifted_atoms = (
rewards.view((-1, 1)) + self.discount_factor * self.q_dist.atoms
Expand Down
5 changes: 4 additions & 1 deletion all/agents/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,11 @@ def act(self, state, reward):
self._action = self._choose_action(state)
return self._action

def eval(self, state, _):
return self.policy.eval(state)

def _choose_action(self, state):
action = self.policy.eval(state)
action = self.policy.no_grad(state)
action = action + self._noise.sample()
action = torch.min(action, self._high)
action = torch.max(action, self._low)
Expand Down
7 changes: 5 additions & 2 deletions all/agents/ddqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,20 @@ def act(self, state, reward):
self.replay_buffer.store(self._state, self._action, reward, state)
self._train()
self._state = state
self._action = self.policy(state)
self._action = self.policy.no_grad(state)
return self._action

def eval(self, state, _):
return self.policy.eval(state)

def _train(self):
if self._should_train():
# sample transitions from buffer
(states, actions, rewards, next_states, weights) = self.replay_buffer.sample(self.minibatch_size)
# forward pass
values = self.q(states, actions)
# compute targets
next_actions = torch.argmax(self.q.eval(next_states), dim=1)
next_actions = torch.argmax(self.q.no_grad(next_states), dim=1)
targets = rewards + self.discount_factor * self.q.target(next_states, next_actions)
# compute loss
loss = self.loss(values, targets, weights)
Expand Down
13 changes: 9 additions & 4 deletions all/agents/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ class DQN(Agent):
policy (GreedyPolicy): A policy derived from the Q-function.
replay_buffer (ReplayBuffer): The experience replay buffer.
discount_factor (float): Discount factor for future rewards.
exploration (float): The probability of choosing a random action.
loss (function): The weighted loss function to use.
minibatch_size (int): The number of experiences to sample in each training update.
n_actions (int): The number of available actions.
replay_start_size (int): Number of experiences in replay buffer when training begins.
update_frequency (int): Number of timesteps per training update.
'''
Expand All @@ -31,18 +33,18 @@ def __init__(self,
loss=mse_loss,
minibatch_size=32,
replay_start_size=5000,
update_frequency=1
update_frequency=1,
):
# objects
self.q = q
self.policy = policy
self.replay_buffer = replay_buffer
self.loss = staticmethod(loss)
# hyperparameters
self.discount_factor = discount_factor
self.minibatch_size = minibatch_size
self.replay_start_size = replay_start_size
self.update_frequency = update_frequency
self.minibatch_size = minibatch_size
self.discount_factor = discount_factor
# private
self._state = None
self._action = None
Expand All @@ -52,9 +54,12 @@ def act(self, state, reward):
self.replay_buffer.store(self._state, self._action, reward, state)
self._train()
self._state = state
self._action = self.policy(state)
self._action = self.policy.no_grad(state)
return self._action

def eval(self, state, _):
return self.policy.eval(state)

def _train(self):
if self._should_train():
# sample transitions from buffer
Expand Down
11 changes: 7 additions & 4 deletions all/agents/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,21 @@ def act(self, states, rewards):
self._buffer.store(self._states, self._actions, rewards)
self._train(states)
self._states = states
self._actions = self.policy.eval(self.features.eval(states)).sample()
self._actions = self.policy.no_grad(self.features.no_grad(states)).sample()
return self._actions

def eval(self, states, _):
return self.policy.eval(self.features.eval(states))

def _train(self, next_states):
if len(self._buffer) >= self._batch_size:
# load trajectories from buffer
states, actions, advantages = self._buffer.advantages(next_states)

# compute target values
features = self.features.eval(states)
pi_0 = self.policy.eval(features).log_prob(actions)
targets = self.v.eval(features) + advantages
features = self.features.no_grad(states)
pi_0 = self.policy.no_grad(features).log_prob(actions)
targets = self.v.no_grad(features) + advantages

# train for several epochs
for _ in range(self.epochs):
Expand Down
7 changes: 5 additions & 2 deletions all/agents/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,19 @@ def act(self, state, reward):
self.replay_buffer.store(self._state, self._action, reward, state)
self._train()
self._state = state
self._action = self.policy.eval(state)[0]
self._action = self.policy.no_grad(state)[0]
return self._action

def eval(self, state, _):
return self.policy.eval(state)[0]

def _train(self):
if self._should_train():
# sample from replay buffer
(states, actions, rewards, next_states, _) = self.replay_buffer.sample(self.minibatch_size)

# compute targets for Q and V
_actions, _log_probs = self.policy.eval(states)
_actions, _log_probs = self.policy.no_grad(states)
q_targets = rewards + self.discount_factor * self.v.target(next_states)
v_targets = torch.min(
self.q_1.target(states, _actions),
Expand Down
3 changes: 3 additions & 0 deletions all/agents/vac.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ def act(self, state, reward):
self._action = self._distribution.sample()
return self._action

def eval(self, state, _):
return self.policy.eval(self.features.eval(state))

def _train(self, state, reward):
if self._features:
# forward pass
Expand Down
5 changes: 4 additions & 1 deletion all/agents/vpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def act(self, state, reward):
return self._act(state, reward)
return self._terminal(state, reward)

def eval(self, state, _):
return self.policy.eval(self.features.eval(state))

def _initial(self, state):
features = self.features(state)
distribution = self.policy(features)
Expand Down Expand Up @@ -82,7 +85,7 @@ def _terminal(self, state, reward):
self._train()

# have to return something
return self.policy.eval(self.features.eval(state)).sample()
return self.policy.no_grad(self.features.no_grad(state)).sample()

def _train(self):
# forward pass
Expand Down
5 changes: 4 additions & 1 deletion all/agents/vqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,14 @@ def __init__(self, q, policy, discount_factor=0.99):

def act(self, state, reward):
self._train(reward, state)
action = self.policy(state)
action = self.policy.no_grad(state)
self._state = state
self._action = action
return action

def eval(self, state, _):
return self.policy.eval(state)

def _train(self, reward, next_state):
if self._state:
# forward pass
Expand Down
5 changes: 4 additions & 1 deletion all/agents/vsarsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,15 @@ def __init__(self, q, policy, discount_factor=0.99):
self._action = None

def act(self, state, reward):
action = self.policy(state)
action = self.policy.no_grad(state)
self._train(reward, state, action)
self._state = state
self._action = action
return action

def eval(self, state, _):
return self.policy.eval(state)

def _train(self, reward, next_state, next_action):
if self._state:
# forward pass
Expand Down
Loading

0 comments on commit 57536b2

Please sign in to comment.