-
Notifications
You must be signed in to change notification settings - Fork 72
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #248 from cpnota/release/0.7.1
Release/0.7.1
- Loading branch information
Showing
46 changed files
with
707 additions
and
159 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
|
||
class VectorEnvironment(ABC): | ||
""" | ||
A reinforcement learning vector Environment. | ||
Similar to a regular RL environment except many environments are stacked together | ||
in the observations, rewards, and dones, and the vector environment expects | ||
an action to be given for each environment in step. | ||
Also, since sub-environments are done at different times, you do not need to | ||
manually reset the environments when they are done, rather the vector environment | ||
automatically resets environments when they are complete. | ||
""" | ||
|
||
@property | ||
@abstractmethod | ||
def name(self): | ||
""" | ||
The name of the environment. | ||
""" | ||
|
||
@abstractmethod | ||
def reset(self): | ||
""" | ||
Reset the environment and return a new initial state. | ||
Returns | ||
------- | ||
State | ||
The initial state for the next episode. | ||
""" | ||
|
||
@abstractmethod | ||
def step(self, action): | ||
""" | ||
Apply an action and get the next state. | ||
Parameters | ||
---------- | ||
action : Action | ||
The action to apply at the current time step. | ||
Returns | ||
------- | ||
all.environments.State | ||
The State of the environment after the action is applied. | ||
This State object includes both the done flag and any additional "info" | ||
float | ||
The reward achieved by the previous action | ||
""" | ||
|
||
@abstractmethod | ||
def close(self): | ||
""" | ||
Clean up any extraneous environment objects. | ||
""" | ||
|
||
@property | ||
@abstractmethod | ||
def state_array(self): | ||
""" | ||
A StateArray of the Environments at the current timestep. | ||
""" | ||
|
||
@property | ||
@abstractmethod | ||
def state_space(self): | ||
""" | ||
The Space representing the range of observable states for each environment. | ||
Returns | ||
------- | ||
Space | ||
An object of type Space that represents possible states the agent may observe | ||
""" | ||
|
||
@property | ||
def observation_space(self): | ||
""" | ||
Alias for Environment.state_space. | ||
Returns | ||
------- | ||
Space | ||
An object of type Space that represents possible states the agent may observe | ||
""" | ||
return self.state_space | ||
|
||
@property | ||
@abstractmethod | ||
def action_space(self): | ||
""" | ||
The Space representing the range of possible actions for each environment. | ||
Returns | ||
------- | ||
Space | ||
An object of type Space that represents possible actions the agent may take | ||
""" | ||
|
||
@property | ||
@abstractmethod | ||
def device(self): | ||
""" | ||
The torch device the environment lives on. | ||
""" | ||
|
||
@property | ||
@abstractmethod | ||
def num_envs(self): | ||
""" | ||
Number of environments in vector. This is the number of actions step() expects as input | ||
and the number of observations, dones, etc returned by the environment. | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import gym | ||
import torch | ||
from all.core import State | ||
from ._vector_environment import VectorEnvironment | ||
import numpy as np | ||
|
||
|
||
class DuplicateEnvironment(VectorEnvironment): | ||
''' | ||
Turns a list of ALL Environment objects into a VectorEnvironment object | ||
This wrapper just takes the list of States the environments generate and outputs | ||
a StateArray object containing all of the environment states. Like all vector | ||
environments, the sub environments are automatically reset when done. | ||
Args: | ||
envs: A list of ALL environments | ||
device (optional): the device on which tensors will be stored | ||
''' | ||
|
||
def __init__(self, envs, device=torch.device('cpu')): | ||
self._name = envs[0].name | ||
self._envs = envs | ||
self._state = None | ||
self._action = None | ||
self._reward = None | ||
self._done = True | ||
self._info = None | ||
self._device = device | ||
|
||
@property | ||
def name(self): | ||
return self._name | ||
|
||
def reset(self): | ||
self._state = State.array([sub_env.reset() for sub_env in self._envs]) | ||
return self._state | ||
|
||
def step(self, actions): | ||
states = [] | ||
actions = actions.cpu().detach().numpy() | ||
for sub_env, action in zip(self._envs, actions): | ||
state = sub_env.reset() if sub_env.state.done else sub_env.step(action) | ||
states.append(state) | ||
self._state = State.array(states) | ||
return self._state | ||
|
||
def close(self): | ||
return self._env.close() | ||
|
||
def seed(self, seed): | ||
for i, env in enumerate(self._envs): | ||
env.seed(seed + i) | ||
|
||
@property | ||
def state_space(self): | ||
return self._envs[0].observation_space | ||
|
||
@property | ||
def action_space(self): | ||
return self._envs[0].action_space | ||
|
||
@property | ||
def state_array(self): | ||
return self._state | ||
|
||
@property | ||
def device(self): | ||
return self._device | ||
|
||
@property | ||
def num_envs(self): | ||
return len(self._envs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import unittest | ||
import gym | ||
import torch | ||
from all.environments import DuplicateEnvironment, GymEnvironment | ||
|
||
|
||
def make_vec_env(num_envs=3): | ||
env = [GymEnvironment('CartPole-v0') for i in range(num_envs)] | ||
return env | ||
|
||
|
||
class DuplicateEnvironmentTest(unittest.TestCase): | ||
def test_env_name(self): | ||
env = DuplicateEnvironment(make_vec_env()) | ||
self.assertEqual(env.name, 'CartPole-v0') | ||
|
||
def test_num_envs(self): | ||
num_envs = 5 | ||
env = DuplicateEnvironment(make_vec_env(num_envs)) | ||
self.assertEqual(env.num_envs, num_envs) | ||
self.assertEqual((num_envs,), env.reset().shape) | ||
|
||
def test_reset(self): | ||
num_envs = 5 | ||
env = DuplicateEnvironment(make_vec_env(num_envs)) | ||
state = env.reset() | ||
self.assertEqual(state.observation.shape, (num_envs, 4)) | ||
self.assertTrue((state.reward == torch.zeros(num_envs, )).all()) | ||
self.assertTrue((state.done == torch.zeros(num_envs, )).all()) | ||
self.assertTrue((state.mask == torch.ones(num_envs, )).all()) | ||
|
||
def test_step(self): | ||
num_envs = 5 | ||
env = DuplicateEnvironment(make_vec_env(num_envs)) | ||
env.reset() | ||
state = env.step(torch.ones(num_envs, dtype=torch.int32)) | ||
self.assertEqual(state.observation.shape, (num_envs, 4)) | ||
self.assertTrue((state.reward == torch.ones(num_envs, )).all()) | ||
self.assertTrue((state.done == torch.zeros(num_envs, )).all()) | ||
self.assertTrue((state.mask == torch.ones(num_envs, )).all()) | ||
|
||
def test_step_until_done(self): | ||
num_envs = 3 | ||
env = DuplicateEnvironment(make_vec_env(num_envs)) | ||
env.seed(5) | ||
env.reset() | ||
for _ in range(100): | ||
state = env.step(torch.ones(num_envs, dtype=torch.int32)) | ||
if state.done[0]: | ||
break | ||
self.assertEqual(state[0].observation.shape, (4,)) | ||
self.assertEqual(state[0].reward, 1.) | ||
self.assertTrue(state[0].done) | ||
self.assertEqual(state[0].mask, 0) |
Oops, something went wrong.