Skip to content

Commit

Permalink
add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Vepricov committed Nov 23, 2024
1 parent f910e9a commit a2d32c7
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 12 deletions.
5 changes: 5 additions & 0 deletions src/relaxit/distributions/GumbelSoftmaxTopK.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ class GumbelSoftmaxTopK(TorchDistribution):
"""
Implimentation of the Gaussian-soft max topK trick from https://arxiv.org/pdf/1903.06059
:param a: logits, if not from Simples, we project a into it.
:type a: torch.Tensor
:param K: how many samples without replacement to pick.
:type K: int
:param support: support of the discrete distribution. If None, it will be `torch.arange(a.numel()).reshape(a.shape)`. It must be the same `len` as `a`.
Parameters:
- a (Tensor): logits, if not from Simples, we project a into it
- K (int): how many samples without replacement to pick
Expand Down
28 changes: 28 additions & 0 deletions tests/distributions/test_CorrelatedRelaxedBernoulli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import torch
import sys, os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..', 'src')))
from relaxit.distributions.CorrelatedRelaxedBernoulli import CorrelatedRelaxedBernoulli

# Testing reparameterized sampling from the GaussianRelaxedBernoulli distribution

def test_sample_shape():
pi = torch.tensor([0.1, 0.2, 0.3])
R = torch.tensor([[1.]])
tau = torch.tensor([2.])

distr = CorrelatedRelaxedBernoulli(pi=pi, R=R, tau=tau)
samples = distr.rsample()
assert samples.shape == torch.Size([3])

def test_sample_grad():
pi = torch.tensor([0.1, 0.2, 0.3], requires_grad=True)
R = torch.tensor([[1.]])
tau = torch.tensor([2.])

distr = CorrelatedRelaxedBernoulli(pi=pi, R=R, tau=tau)
samples = distr.rsample()
assert samples.requires_grad == True

if __name__ == "__main__":
test_sample_shape()
test_sample_grad()
9 changes: 2 additions & 7 deletions tests/distributions/test_GaussianRelaxedBernoulli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,5 @@ def test_sample_grad():
loc = torch.tensor([0.], requires_grad=True)
scale = torch.tensor([1.], requires_grad=True)
distr = GaussianRelaxedBernoulli(loc = loc, scale=scale)
samples = distr.rsample(sample_shape = torch.Size([3]))

assert samples.requires_grad == True

if __name__ == "__main__":
test_sample_shape()
test_sample_grad()
samples = distr.rsample()
assert samples.requires_grad == True
6 changes: 1 addition & 5 deletions tests/distributions/test_HardConcrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,4 @@ def test_sample_grad():
distr = HardConcrete(alpha=alpha, beta=beta, gamma=gamma, xi=xi)
samples = distr.rsample(sample_shape = torch.Size([3]))

assert samples.requires_grad == True

if __name__ == "__main__":
test_sample_shape()
test_sample_grad()
assert samples.requires_grad == True

0 comments on commit a2d32c7

Please sign in to comment.