diff --git a/TruncatedNormal.py b/TruncatedNormal.py index edf186e..10c5f67 100644 --- a/TruncatedNormal.py +++ b/TruncatedNormal.py @@ -73,11 +73,19 @@ def auc(self): @staticmethod def _little_phi(x): - return (-(x ** 2) * 0.5).exp() * CONST_INV_SQRT_2PI + if x.isinf(): + return torch.zeros(x.size()).to(x) + else: + return (-(x ** 2) * 0.5).exp() * CONST_INV_SQRT_2PI @staticmethod def _big_phi(x): - return 0.5 * (1 + (x * CONST_INV_SQRT_2).erf()) + if x.isposinf(): + return torch.ones(x.size()).to(x) + elif x.isneginf(): + return torch.zeros(x.size()).to(x) + else: + return 0.5 * (1 + (x * CONST_INV_SQRT_2).erf()) @staticmethod def _inv_big_phi(x): diff --git a/tests/test.py b/tests/test.py index e634e42..8239eca 100644 --- a/tests/test.py +++ b/tests/test.py @@ -69,7 +69,12 @@ def _test_numerical(self, loc, scale, a, b, do_icdf=True): N = 10 for i in range(N): p = i / (N - 1) - x = a + (b - a) * p + if torch.isinf(torch.tensor(a)): + x = b - scale * i / N + elif torch.isinf(torch.tensor(b)): + x = a + scale * i / N + else: + x = a + (b - a) * p cdf_sc = sc.cdf(x) cdf_pt = float(pt.cdf(torch.tensor(x))) @@ -84,6 +89,18 @@ def _test_numerical(self, loc, scale, a, b, do_icdf=True): icdf_pt = float(pt.icdf(torch.tensor(p))) self.assertRelativelyEqual(icdf_sc, icdf_pt, tol=1e-4, error=1e-3) + def _test_grad(self, loc, scale, a, b, grad_point): + loc = torch.nn.parameter.Parameter(torch.tensor(loc)) + scale = torch.nn.parameter.Parameter(torch.tensor(scale)) + pt = TruncatedNormalPT(loc, scale, a, b) + grads = torch.autograd.grad(pt.log_prob(grad_point), [loc, scale]) + self.assertFalse(any([grad.isnan() for grad in grads])) + + def test_grad(self): + self._test_grad(0., 1., -2., 0., -1.) + self._test_grad(0., 1., -2., torch.inf, -1.) + self._test_grad(0., 1., -torch.inf, 0., -1.) + def test_simple(self): self._test_numerical(0., 1., -2., 0.) self._test_numerical(0., 1., -2., 1.) @@ -95,6 +112,10 @@ def test_simple(self): self._test_numerical(0., 1., 0., 2.) self._test_numerical(1., 2., 1., 2.) self._test_numerical(1., 2., 2., 4.) + self._test_numerical(0., 1., -2., torch.inf) + self._test_numerical(0., 1., -torch.inf, 0.) + self._test_numerical(1., 2., 2., torch.inf) + self._test_numerical(1., 2., -torch.inf, 4.) def test_precision(self): self._test_numerical(0., 1., 2., 3.) @@ -112,7 +133,9 @@ def test_support(self): pt = TruncatedNormalPT(0., 1., -1., 2., validate_args=None) with self.assertRaises(ValueError) as e: pt.log_prob(torch.tensor(-10)) - self.assertEqual(str(e.exception), 'The value argument must be within the support') + + self.assertFalse(str(e.exception) != 'The value argument must be within the support' and + str(e.exception) != 'Expected value argument (Tensor of shape ()) to be within the support (Interval(lower_bound=-1.0, upper_bound=2.0)) of the distribution TruncatedNormal(a: -1.0, b: 2.0), but found invalid values:\n-10.0') def test_cuda(self): if not torch.cuda.is_available():