From a2d32c7d3f9ad20e1054d2568e096edd8bf2ba9a Mon Sep 17 00:00:00 2001 From: vepricov Date: Sat, 23 Nov 2024 13:21:17 +0300 Subject: [PATCH] add more tests --- .../distributions/GumbelSoftmaxTopK.py | 5 ++++ .../test_CorrelatedRelaxedBernoulli.py | 28 +++++++++++++++++++ .../test_GaussianRelaxedBernoulli.py | 9 ++---- tests/distributions/test_HardConcrete.py | 6 +--- 4 files changed, 36 insertions(+), 12 deletions(-) create mode 100644 tests/distributions/test_CorrelatedRelaxedBernoulli.py diff --git a/src/relaxit/distributions/GumbelSoftmaxTopK.py b/src/relaxit/distributions/GumbelSoftmaxTopK.py index 7d5a73c..8b3d2ef 100644 --- a/src/relaxit/distributions/GumbelSoftmaxTopK.py +++ b/src/relaxit/distributions/GumbelSoftmaxTopK.py @@ -8,6 +8,11 @@ class GumbelSoftmaxTopK(TorchDistribution): """ Implimentation of the Gaussian-soft max topK trick from https://arxiv.org/pdf/1903.06059 + :param a: logits, if not from Simples, we project a into it. + :type a: torch.Tensor + :param K: how many samples without replacement to pick. + :type K: int + :param support: support of the discrete distribution. If None, it will be `torch.arange(a.numel()).reshape(a.shape)`. It must be the same `len` as `a`. Parameters: - a (Tensor): logits, if not from Simples, we project a into it - K (int): how many samples without replacement to pick diff --git a/tests/distributions/test_CorrelatedRelaxedBernoulli.py b/tests/distributions/test_CorrelatedRelaxedBernoulli.py new file mode 100644 index 0000000..0e99449 --- /dev/null +++ b/tests/distributions/test_CorrelatedRelaxedBernoulli.py @@ -0,0 +1,28 @@ +import torch +import sys, os +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..', 'src'))) +from relaxit.distributions.CorrelatedRelaxedBernoulli import CorrelatedRelaxedBernoulli + +# Testing reparameterized sampling from the GaussianRelaxedBernoulli distribution + +def test_sample_shape(): + pi = torch.tensor([0.1, 0.2, 0.3]) + R = torch.tensor([[1.]]) + tau = torch.tensor([2.]) + + distr = CorrelatedRelaxedBernoulli(pi=pi, R=R, tau=tau) + samples = distr.rsample() + assert samples.shape == torch.Size([3]) + +def test_sample_grad(): + pi = torch.tensor([0.1, 0.2, 0.3], requires_grad=True) + R = torch.tensor([[1.]]) + tau = torch.tensor([2.]) + + distr = CorrelatedRelaxedBernoulli(pi=pi, R=R, tau=tau) + samples = distr.rsample() + assert samples.requires_grad == True + +if __name__ == "__main__": + test_sample_shape() + test_sample_grad() \ No newline at end of file diff --git a/tests/distributions/test_GaussianRelaxedBernoulli.py b/tests/distributions/test_GaussianRelaxedBernoulli.py index 2dbb6d9..2be9eea 100644 --- a/tests/distributions/test_GaussianRelaxedBernoulli.py +++ b/tests/distributions/test_GaussianRelaxedBernoulli.py @@ -17,10 +17,5 @@ def test_sample_grad(): loc = torch.tensor([0.], requires_grad=True) scale = torch.tensor([1.], requires_grad=True) distr = GaussianRelaxedBernoulli(loc = loc, scale=scale) - samples = distr.rsample(sample_shape = torch.Size([3])) - - assert samples.requires_grad == True - -if __name__ == "__main__": - test_sample_shape() - test_sample_grad() \ No newline at end of file + samples = distr.rsample() + assert samples.requires_grad == True \ No newline at end of file diff --git a/tests/distributions/test_HardConcrete.py b/tests/distributions/test_HardConcrete.py index 381d221..6aaa1ee 100644 --- a/tests/distributions/test_HardConcrete.py +++ b/tests/distributions/test_HardConcrete.py @@ -23,8 +23,4 @@ def test_sample_grad(): distr = HardConcrete(alpha=alpha, beta=beta, gamma=gamma, xi=xi) samples = distr.rsample(sample_shape = torch.Size([3])) - assert samples.requires_grad == True - -if __name__ == "__main__": - test_sample_shape() - test_sample_grad() \ No newline at end of file + assert samples.requires_grad == True \ No newline at end of file