Skip to content

Commit

Permalink
add new distr
Browse files Browse the repository at this point in the history
  • Loading branch information
Vepricov committed Nov 23, 2024
1 parent dad8d6f commit dd93fb0
Show file tree
Hide file tree
Showing 4 changed files with 231 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ __pycache__/
*.py[cod]
*$py.class
codecov
demo/reinforce_my.py
demo/test.py

# C extensions
*.so
Expand Down
108 changes: 108 additions & 0 deletions demo/reinforce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import argparse, os, sys
import gym
import numpy as np
from itertools import count
from collections import deque
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical


parser = argparse.ArgumentParser(description='PyTorch REINFORCE example')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
help='discount factor (default: 0.99)')
parser.add_argument('--seed', type=int, default=543, metavar='N',
help='random seed (default: 543)')
parser.add_argument('--render', action='store_true',
help='render the environment')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='interval between training status logs (default: 10)')
args = parser.parse_args()


env = gym.make('Acrobot-v1')
env.reset(seed=args.seed)
torch.manual_seed(args.seed)


class Policy(nn.Module):
def __init__(self):
super(Policy, self).__init__()
self.affine1 = nn.Linear(6, 128)
self.dropout = nn.Dropout(p=0.6)
self.affine2 = nn.Linear(128, 3)

self.saved_log_probs = []
self.rewards = []

def forward(self, x):
x = self.affine1(x)
x = self.dropout(x)
x = F.relu(x)
action_scores = self.affine2(x)
return F.softmax(action_scores, dim=1)


policy = Policy()
optimizer = optim.Adam(policy.parameters(), lr=1e-2)
eps = np.finfo(np.float32).eps.item()


def select_action(state):
state = torch.from_numpy(state).float().unsqueeze(0)
probs = policy(state)
m = Categorical(probs)
action = m.sample()
policy.saved_log_probs.append(m.log_prob(action))
return action.item()


def finish_episode():
R = 0
policy_loss = []
returns = deque()
for r in policy.rewards[::-1]:
R = r + args.gamma * R
returns.appendleft(R)
returns = torch.tensor(returns)
returns = (returns - returns.mean()) / (returns.std() + eps)
for log_prob, R in zip(policy.saved_log_probs, returns):
policy_loss.append(-log_prob * R)
optimizer.zero_grad()
policy_loss = torch.cat(policy_loss).sum()
policy_loss.backward()
optimizer.step()
del policy.rewards[:]
del policy.saved_log_probs[:]


def main():
running_reward = 10
for i_episode in count(1):
state, _ = env.reset()
ep_reward = 0
for t in range(1, 10000): # Don't infinite loop while learning
action = select_action(state)
state, reward, done, _, _ = env.step(action)
if args.render:
env.render()
policy.rewards.append(reward)
ep_reward += reward
if done:
break

running_reward = 0.05 * ep_reward + (1 - 0.05) * running_reward
finish_episode()
if i_episode % args.log_interval == 0:
print('Episode {}\tLast reward: {:.2f}\tAverage reward: {:.2f}'.format(
i_episode, ep_reward, running_reward))
if running_reward > env.spec.reward_threshold:
print("Solved! Running reward is now {} and "
"the last episode runs to {} time steps!".format(running_reward, t))
break


if __name__ == '__main__':
main()
119 changes: 119 additions & 0 deletions src/relaxit/distributions/GumbelSoftmaxTopK.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import torch
from torch.distributions import constraints
from pyro.distributions.torch_distribution import TorchDistribution
from pyro.distributions import Gumbel
from torch.distributions import constraints

class GumbelSoftmaxTopK(TorchDistribution):
"""
Implimentation of the Gaussian-soft max topK trick from https://arxiv.org/pdf/1903.06059
Parameters:
- a (Tensor): logits, if not from Simples, we project a into it
- K (int): how many samples without replacement to pick
- support (Tensor): support of the discrete distribution. If None, it will be `torch.range(len(a))`. It must be the same `len` as `a`.
"""

arg_constraints = {'a': constraints.real}
has_rsample = True

def __init__(self, a: torch.Tensor, K: int,
support: torch.Tensor = None, validate_args: bool = None):
"""
Initializes the GumbelSoftmaxTopK distribution.
Args:
- a (Tensor): logits, if not from Simples, we project a into it
- K (int): how many samples without replacement to pick
- support (Tensor): support of the discrete distribution. If None, it will be `torch.range(len(a))`. It must be the same `len` as `a`.
- validate_args (bool): Whether to validate arguments.
"""
self.a = a.float() / a.sum() # Ensure loc is a float tensor from simplex
self.gumbel = Gumbel(loc=0, scale=1, validate_args=validate_args)
if support is None:
self.supp = torch.arange(a.numel()).reshape(a.shape)
else:
if support.shape != a.shape:
raise ValueError("support and a must have the same shape")
self.supp = support
self.K = int(K) # Ensure K is a int number
super().__init__(validate_args=validate_args)

@property
def batch_shape(self) -> torch.Size:
"""
Returns the batch shape of the distribution.
The batch shape represents the shape of independent distributions.
For example, if `loc` is vector of length 3,
the batch shape will be `[3]`, indicating 3 independent distributions.
"""
return self.a.shape

@property
def event_shape(self) -> torch.Size:
"""
Returns the event shape of the distribution.
The event shape represents the shape of each individual event.
"""
return torch.Size()

def rsample(self, sample_shape: torch.Size = None) -> torch.Tensor:
"""
Generates a sample from the distribution using the Gaussian-soft max topK trick.
Args:
- sample_shape (torch.Size): The shape of the sample.
Returns:
- torch.Tensor: A sample from the distribution.
"""
if sample_shape is None:
sample_shape = torch.Size([self.K])

G = self.gumbel.rsample(sample_shape=self.a.shape)
_, idxs = torch.topk(G + torch.log(self.a), k = self.K)
return self.supp.reshape(-1)[idxs].reshape(shape=sample_shape)

def sample(self, sample_shape: torch.Size = None) -> torch.Tensor:
"""
Generates a sample from the distribution.
Args:
- sample_shape (torch.Size): The shape of the sample.
Returns:
- torch.Tensor: A sample from the distribution.
"""
with torch.no_grad():
return self.rsample(sample_shape)

def log_prob(self, value: torch.Tensor, shape: torch.Size = torch.Size([1])) -> torch.Tensor:
"""
Computes the log probability of the given value.
Args:
- value (Tensor): The value for which to compute the log probability.
- shape(torch.Size): The shape of the output
Returns:
- torch.Tensor: The log probability of the given value.
"""
if self._validate_args:
self._validate_sample(value)

idx = (self.supp.reshape(-1) == value).nonzero().squeeze()

return torch.log(self.a.reshape(-1)[idx]).reshape(shape=shape)

def _validate_sample(self, value: torch.Tensor):
"""
Validates the given sample value.
Args:
- value (Tensor): The sample value to validate.
"""
if self._validate_args:
if value not in self.supp:
raise ValueError("Sample value must be in the support")
2 changes: 2 additions & 0 deletions src/relaxit/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .HardConcrete import HardConcrete
from .InvertibleGaussian import InvertibleGaussian
from .LogisticNormalSoftmax import LogisticNormalSoftmax
from .GumbelSoftmaxTopK import GumbelSoftmaxTopK

__all__ = [
"GaussianRelaxedBernoulli",
Expand All @@ -12,4 +13,5 @@
"HardConcrete",
"InvertibleGaussian",
"LogisticNormalSoftmax",
"GumbelSoftmaxTopK",
]

0 comments on commit dd93fb0

Please sign in to comment.