-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
231 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,8 @@ __pycache__/ | |
*.py[cod] | ||
*$py.class | ||
codecov | ||
demo/reinforce_my.py | ||
demo/test.py | ||
|
||
# C extensions | ||
*.so | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import argparse, os, sys | ||
import gym | ||
import numpy as np | ||
from itertools import count | ||
from collections import deque | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
from torch.distributions import Categorical | ||
|
||
|
||
parser = argparse.ArgumentParser(description='PyTorch REINFORCE example') | ||
parser.add_argument('--gamma', type=float, default=0.99, metavar='G', | ||
help='discount factor (default: 0.99)') | ||
parser.add_argument('--seed', type=int, default=543, metavar='N', | ||
help='random seed (default: 543)') | ||
parser.add_argument('--render', action='store_true', | ||
help='render the environment') | ||
parser.add_argument('--log-interval', type=int, default=10, metavar='N', | ||
help='interval between training status logs (default: 10)') | ||
args = parser.parse_args() | ||
|
||
|
||
env = gym.make('Acrobot-v1') | ||
env.reset(seed=args.seed) | ||
torch.manual_seed(args.seed) | ||
|
||
|
||
class Policy(nn.Module): | ||
def __init__(self): | ||
super(Policy, self).__init__() | ||
self.affine1 = nn.Linear(6, 128) | ||
self.dropout = nn.Dropout(p=0.6) | ||
self.affine2 = nn.Linear(128, 3) | ||
|
||
self.saved_log_probs = [] | ||
self.rewards = [] | ||
|
||
def forward(self, x): | ||
x = self.affine1(x) | ||
x = self.dropout(x) | ||
x = F.relu(x) | ||
action_scores = self.affine2(x) | ||
return F.softmax(action_scores, dim=1) | ||
|
||
|
||
policy = Policy() | ||
optimizer = optim.Adam(policy.parameters(), lr=1e-2) | ||
eps = np.finfo(np.float32).eps.item() | ||
|
||
|
||
def select_action(state): | ||
state = torch.from_numpy(state).float().unsqueeze(0) | ||
probs = policy(state) | ||
m = Categorical(probs) | ||
action = m.sample() | ||
policy.saved_log_probs.append(m.log_prob(action)) | ||
return action.item() | ||
|
||
|
||
def finish_episode(): | ||
R = 0 | ||
policy_loss = [] | ||
returns = deque() | ||
for r in policy.rewards[::-1]: | ||
R = r + args.gamma * R | ||
returns.appendleft(R) | ||
returns = torch.tensor(returns) | ||
returns = (returns - returns.mean()) / (returns.std() + eps) | ||
for log_prob, R in zip(policy.saved_log_probs, returns): | ||
policy_loss.append(-log_prob * R) | ||
optimizer.zero_grad() | ||
policy_loss = torch.cat(policy_loss).sum() | ||
policy_loss.backward() | ||
optimizer.step() | ||
del policy.rewards[:] | ||
del policy.saved_log_probs[:] | ||
|
||
|
||
def main(): | ||
running_reward = 10 | ||
for i_episode in count(1): | ||
state, _ = env.reset() | ||
ep_reward = 0 | ||
for t in range(1, 10000): # Don't infinite loop while learning | ||
action = select_action(state) | ||
state, reward, done, _, _ = env.step(action) | ||
if args.render: | ||
env.render() | ||
policy.rewards.append(reward) | ||
ep_reward += reward | ||
if done: | ||
break | ||
|
||
running_reward = 0.05 * ep_reward + (1 - 0.05) * running_reward | ||
finish_episode() | ||
if i_episode % args.log_interval == 0: | ||
print('Episode {}\tLast reward: {:.2f}\tAverage reward: {:.2f}'.format( | ||
i_episode, ep_reward, running_reward)) | ||
if running_reward > env.spec.reward_threshold: | ||
print("Solved! Running reward is now {} and " | ||
"the last episode runs to {} time steps!".format(running_reward, t)) | ||
break | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
import torch | ||
from torch.distributions import constraints | ||
from pyro.distributions.torch_distribution import TorchDistribution | ||
from pyro.distributions import Gumbel | ||
from torch.distributions import constraints | ||
|
||
class GumbelSoftmaxTopK(TorchDistribution): | ||
""" | ||
Implimentation of the Gaussian-soft max topK trick from https://arxiv.org/pdf/1903.06059 | ||
Parameters: | ||
- a (Tensor): logits, if not from Simples, we project a into it | ||
- K (int): how many samples without replacement to pick | ||
- support (Tensor): support of the discrete distribution. If None, it will be `torch.range(len(a))`. It must be the same `len` as `a`. | ||
""" | ||
|
||
arg_constraints = {'a': constraints.real} | ||
has_rsample = True | ||
|
||
def __init__(self, a: torch.Tensor, K: int, | ||
support: torch.Tensor = None, validate_args: bool = None): | ||
""" | ||
Initializes the GumbelSoftmaxTopK distribution. | ||
Args: | ||
- a (Tensor): logits, if not from Simples, we project a into it | ||
- K (int): how many samples without replacement to pick | ||
- support (Tensor): support of the discrete distribution. If None, it will be `torch.range(len(a))`. It must be the same `len` as `a`. | ||
- validate_args (bool): Whether to validate arguments. | ||
""" | ||
self.a = a.float() / a.sum() # Ensure loc is a float tensor from simplex | ||
self.gumbel = Gumbel(loc=0, scale=1, validate_args=validate_args) | ||
if support is None: | ||
self.supp = torch.arange(a.numel()).reshape(a.shape) | ||
else: | ||
if support.shape != a.shape: | ||
raise ValueError("support and a must have the same shape") | ||
self.supp = support | ||
self.K = int(K) # Ensure K is a int number | ||
super().__init__(validate_args=validate_args) | ||
|
||
@property | ||
def batch_shape(self) -> torch.Size: | ||
""" | ||
Returns the batch shape of the distribution. | ||
The batch shape represents the shape of independent distributions. | ||
For example, if `loc` is vector of length 3, | ||
the batch shape will be `[3]`, indicating 3 independent distributions. | ||
""" | ||
return self.a.shape | ||
|
||
@property | ||
def event_shape(self) -> torch.Size: | ||
""" | ||
Returns the event shape of the distribution. | ||
The event shape represents the shape of each individual event. | ||
""" | ||
return torch.Size() | ||
|
||
def rsample(self, sample_shape: torch.Size = None) -> torch.Tensor: | ||
""" | ||
Generates a sample from the distribution using the Gaussian-soft max topK trick. | ||
Args: | ||
- sample_shape (torch.Size): The shape of the sample. | ||
Returns: | ||
- torch.Tensor: A sample from the distribution. | ||
""" | ||
if sample_shape is None: | ||
sample_shape = torch.Size([self.K]) | ||
|
||
G = self.gumbel.rsample(sample_shape=self.a.shape) | ||
_, idxs = torch.topk(G + torch.log(self.a), k = self.K) | ||
return self.supp.reshape(-1)[idxs].reshape(shape=sample_shape) | ||
|
||
def sample(self, sample_shape: torch.Size = None) -> torch.Tensor: | ||
""" | ||
Generates a sample from the distribution. | ||
Args: | ||
- sample_shape (torch.Size): The shape of the sample. | ||
Returns: | ||
- torch.Tensor: A sample from the distribution. | ||
""" | ||
with torch.no_grad(): | ||
return self.rsample(sample_shape) | ||
|
||
def log_prob(self, value: torch.Tensor, shape: torch.Size = torch.Size([1])) -> torch.Tensor: | ||
""" | ||
Computes the log probability of the given value. | ||
Args: | ||
- value (Tensor): The value for which to compute the log probability. | ||
- shape(torch.Size): The shape of the output | ||
Returns: | ||
- torch.Tensor: The log probability of the given value. | ||
""" | ||
if self._validate_args: | ||
self._validate_sample(value) | ||
|
||
idx = (self.supp.reshape(-1) == value).nonzero().squeeze() | ||
|
||
return torch.log(self.a.reshape(-1)[idx]).reshape(shape=shape) | ||
|
||
def _validate_sample(self, value: torch.Tensor): | ||
""" | ||
Validates the given sample value. | ||
Args: | ||
- value (Tensor): The sample value to validate. | ||
""" | ||
if self._validate_args: | ||
if value not in self.supp: | ||
raise ValueError("Sample value must be in the support") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters