diff --git a/.gitignore b/.gitignore index 60b7f2f..f72a873 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,8 @@ __pycache__/ *.py[cod] *$py.class codecov +demo/reinforce_my.py +demo/test.py # C extensions *.so diff --git a/demo/reinforce.py b/demo/reinforce.py new file mode 100644 index 0000000..2729da3 --- /dev/null +++ b/demo/reinforce.py @@ -0,0 +1,108 @@ +import argparse, os, sys +import gym +import numpy as np +from itertools import count +from collections import deque +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.distributions import Categorical + + +parser = argparse.ArgumentParser(description='PyTorch REINFORCE example') +parser.add_argument('--gamma', type=float, default=0.99, metavar='G', + help='discount factor (default: 0.99)') +parser.add_argument('--seed', type=int, default=543, metavar='N', + help='random seed (default: 543)') +parser.add_argument('--render', action='store_true', + help='render the environment') +parser.add_argument('--log-interval', type=int, default=10, metavar='N', + help='interval between training status logs (default: 10)') +args = parser.parse_args() + + +env = gym.make('Acrobot-v1') +env.reset(seed=args.seed) +torch.manual_seed(args.seed) + + +class Policy(nn.Module): + def __init__(self): + super(Policy, self).__init__() + self.affine1 = nn.Linear(6, 128) + self.dropout = nn.Dropout(p=0.6) + self.affine2 = nn.Linear(128, 3) + + self.saved_log_probs = [] + self.rewards = [] + + def forward(self, x): + x = self.affine1(x) + x = self.dropout(x) + x = F.relu(x) + action_scores = self.affine2(x) + return F.softmax(action_scores, dim=1) + + +policy = Policy() +optimizer = optim.Adam(policy.parameters(), lr=1e-2) +eps = np.finfo(np.float32).eps.item() + + +def select_action(state): + state = torch.from_numpy(state).float().unsqueeze(0) + probs = policy(state) + m = Categorical(probs) + action = m.sample() + policy.saved_log_probs.append(m.log_prob(action)) + return action.item() + + +def finish_episode(): + R = 0 + policy_loss = [] + returns = deque() + for r in policy.rewards[::-1]: + R = r + args.gamma * R + returns.appendleft(R) + returns = torch.tensor(returns) + returns = (returns - returns.mean()) / (returns.std() + eps) + for log_prob, R in zip(policy.saved_log_probs, returns): + policy_loss.append(-log_prob * R) + optimizer.zero_grad() + policy_loss = torch.cat(policy_loss).sum() + policy_loss.backward() + optimizer.step() + del policy.rewards[:] + del policy.saved_log_probs[:] + + +def main(): + running_reward = 10 + for i_episode in count(1): + state, _ = env.reset() + ep_reward = 0 + for t in range(1, 10000): # Don't infinite loop while learning + action = select_action(state) + state, reward, done, _, _ = env.step(action) + if args.render: + env.render() + policy.rewards.append(reward) + ep_reward += reward + if done: + break + + running_reward = 0.05 * ep_reward + (1 - 0.05) * running_reward + finish_episode() + if i_episode % args.log_interval == 0: + print('Episode {}\tLast reward: {:.2f}\tAverage reward: {:.2f}'.format( + i_episode, ep_reward, running_reward)) + if running_reward > env.spec.reward_threshold: + print("Solved! Running reward is now {} and " + "the last episode runs to {} time steps!".format(running_reward, t)) + break + + +if __name__ == '__main__': + main() diff --git a/src/relaxit/distributions/GumbelSoftmaxTopK.py b/src/relaxit/distributions/GumbelSoftmaxTopK.py new file mode 100644 index 0000000..7d5a73c --- /dev/null +++ b/src/relaxit/distributions/GumbelSoftmaxTopK.py @@ -0,0 +1,119 @@ +import torch +from torch.distributions import constraints +from pyro.distributions.torch_distribution import TorchDistribution +from pyro.distributions import Gumbel +from torch.distributions import constraints + +class GumbelSoftmaxTopK(TorchDistribution): + """ + Implimentation of the Gaussian-soft max topK trick from https://arxiv.org/pdf/1903.06059 + + Parameters: + - a (Tensor): logits, if not from Simples, we project a into it + - K (int): how many samples without replacement to pick + - support (Tensor): support of the discrete distribution. If None, it will be `torch.range(len(a))`. It must be the same `len` as `a`. + """ + + arg_constraints = {'a': constraints.real} + has_rsample = True + + def __init__(self, a: torch.Tensor, K: int, + support: torch.Tensor = None, validate_args: bool = None): + """ + Initializes the GumbelSoftmaxTopK distribution. + + Args: + - a (Tensor): logits, if not from Simples, we project a into it + - K (int): how many samples without replacement to pick + - support (Tensor): support of the discrete distribution. If None, it will be `torch.range(len(a))`. It must be the same `len` as `a`. + - validate_args (bool): Whether to validate arguments. + """ + self.a = a.float() / a.sum() # Ensure loc is a float tensor from simplex + self.gumbel = Gumbel(loc=0, scale=1, validate_args=validate_args) + if support is None: + self.supp = torch.arange(a.numel()).reshape(a.shape) + else: + if support.shape != a.shape: + raise ValueError("support and a must have the same shape") + self.supp = support + self.K = int(K) # Ensure K is a int number + super().__init__(validate_args=validate_args) + + @property + def batch_shape(self) -> torch.Size: + """ + Returns the batch shape of the distribution. + + The batch shape represents the shape of independent distributions. + For example, if `loc` is vector of length 3, + the batch shape will be `[3]`, indicating 3 independent distributions. + """ + return self.a.shape + + @property + def event_shape(self) -> torch.Size: + """ + Returns the event shape of the distribution. + + The event shape represents the shape of each individual event. + """ + return torch.Size() + + def rsample(self, sample_shape: torch.Size = None) -> torch.Tensor: + """ + Generates a sample from the distribution using the Gaussian-soft max topK trick. + + Args: + - sample_shape (torch.Size): The shape of the sample. + + Returns: + - torch.Tensor: A sample from the distribution. + """ + if sample_shape is None: + sample_shape = torch.Size([self.K]) + + G = self.gumbel.rsample(sample_shape=self.a.shape) + _, idxs = torch.topk(G + torch.log(self.a), k = self.K) + return self.supp.reshape(-1)[idxs].reshape(shape=sample_shape) + + def sample(self, sample_shape: torch.Size = None) -> torch.Tensor: + """ + Generates a sample from the distribution. + + Args: + - sample_shape (torch.Size): The shape of the sample. + + Returns: + - torch.Tensor: A sample from the distribution. + """ + with torch.no_grad(): + return self.rsample(sample_shape) + + def log_prob(self, value: torch.Tensor, shape: torch.Size = torch.Size([1])) -> torch.Tensor: + """ + Computes the log probability of the given value. + + Args: + - value (Tensor): The value for which to compute the log probability. + - shape(torch.Size): The shape of the output + + Returns: + - torch.Tensor: The log probability of the given value. + """ + if self._validate_args: + self._validate_sample(value) + + idx = (self.supp.reshape(-1) == value).nonzero().squeeze() + + return torch.log(self.a.reshape(-1)[idx]).reshape(shape=shape) + + def _validate_sample(self, value: torch.Tensor): + """ + Validates the given sample value. + + Args: + - value (Tensor): The sample value to validate. + """ + if self._validate_args: + if value not in self.supp: + raise ValueError("Sample value must be in the support") \ No newline at end of file diff --git a/src/relaxit/distributions/__init__.py b/src/relaxit/distributions/__init__.py index 863a5e1..132246a 100644 --- a/src/relaxit/distributions/__init__.py +++ b/src/relaxit/distributions/__init__.py @@ -4,6 +4,7 @@ from .HardConcrete import HardConcrete from .InvertibleGaussian import InvertibleGaussian from .LogisticNormalSoftmax import LogisticNormalSoftmax +from .GumbelSoftmaxTopK import GumbelSoftmaxTopK __all__ = [ "GaussianRelaxedBernoulli", @@ -12,4 +13,5 @@ "HardConcrete", "InvertibleGaussian", "LogisticNormalSoftmax", + "GumbelSoftmaxTopK", ] \ No newline at end of file