Skip to content

Commit

Permalink
feat: Levy distribution (#1943)
Browse files Browse the repository at this point in the history
* feat: Levy distribution

* fix: correct log scale calculation and update entropy method in Levy distribution

* doc: add documentation and methods for Lévy distribution
  • Loading branch information
Qazalbash authored Jan 20, 2025
1 parent 4704656 commit 5aca6cb
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 0 deletions.
8 changes: 8 additions & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,14 @@ Laplace
:show-inheritance:
:member-order: bysource

Levy
^^^^
.. autoclass:: numpyro.distributions.continuous.Levy
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

LKJ
^^^
.. autoclass:: numpyro.distributions.continuous.LKJ
Expand Down
2 changes: 2 additions & 0 deletions numpyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
InverseGamma,
Kumaraswamy,
Laplace,
Levy,
LKJCholesky,
Logistic,
LogNormal,
Expand Down Expand Up @@ -160,6 +161,7 @@
"Kumaraswamy",
"Laplace",
"LeftTruncatedDistribution",
"Levy",
"LKJ",
"LKJCholesky",
"Logistic",
Expand Down
99 changes: 99 additions & 0 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
xlogy,
)
from jax.scipy.stats import norm as jax_norm
from jax.typing import ArrayLike

from numpyro.distributions import constraints
from numpyro.distributions.discrete import _to_logits_bernoulli
Expand Down Expand Up @@ -2966,3 +2967,101 @@ def infer_shapes(
batch_shape = lax.broadcast_shapes(concentration, matrix[:-2])
event_shape = matrix[-2:]
return batch_shape, event_shape


class Levy(Distribution):
r"""Lévy distribution is a special case of Lévy alpha-stable distribution.
Its probability density function is given by,
.. math::
f(x\mid \mu, c) = \sqrt{\frac{c}{2\pi(x-\mu)^{3}}} \exp\left(-\frac{c}{2(x-\mu)}\right), \qquad x > \mu
where :math:`\mu` is the location parameter and :math:`c` is the scale parameter.
:param loc: Location parameter.
:param scale: Scale parameter.
"""

arg_constraints = {
"loc": constraints.positive,
"scale": constraints.positive,
}

def __init__(self, loc, scale, *, validate_args=None):
self.loc, self.scale = promote_shapes(loc, scale)
batch_shape = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
self._support = constraints.greater_than(loc)
super(Levy, self).__init__(batch_shape, validate_args=validate_args)

@constraints.dependent_property(is_discrete=False)
def support(self):
return self._support

@validate_sample
def log_prob(self, value):
r"""Compute the log probability density function of the Lévy distribution.
.. math::
\log f(x\mid \mu, c) = \frac{1}{2}\log\left(\frac{c}{2\pi}\right) - \frac{c}{2(x-\mu)}
- \frac{3}{2}\log(x-\mu), \qquad x > \mu
:param value: A batch of samples from the distribution.
:return: an array with shape `value.shape[:-self.event_shape]`
:rtype: numpy.ndarray
"""
shifted_value = value - self.loc
return -0.5 * (
jnp.log(2.0 * jnp.pi) - jnp.log(self.scale) + self.scale / shifted_value
) - 1.5 * jnp.log(shifted_value)

def sample(self, key: ArrayLike, sample_shape: tuple[int, ...] = ()) -> ArrayLike:
assert is_prng_key(key)
u = random.uniform(key, shape=sample_shape + self.batch_shape)
return self.icdf(u)

def icdf(self, q: ArrayLike) -> ArrayLike:
r"""
The inverse cumulative distribution function of Lévy distribution is given by,
.. math::
F^{-1}(q\mid \mu, c) = \mu + c\left(\Phi^{-1}(1-q/2)\right)^{-2}
where :math:`\Phi^{-1}` is the inverse of the standard normal cumulative distribution function.
:param q: quantile values, should belong to [0, 1].
:return: the samples whose cdf values equals to `q`.
"""
return self.loc + self.scale * jnp.power(ndtri(1 - 0.5 * q), -2)

def cdf(self, value: ArrayLike) -> ArrayLike:
r"""The cumulative distribution function of Lévy distribution is given by,
.. math::
F(x\mid \mu, c) = 2 - 2\Phi\left(\sqrt{\frac{c}{x-\mu}}\right)
where :math:`\Phi` is the standard normal cumulative distribution function.
:param value: samples from Lévy distribution.
:return: output of the cumulative distribution function evaluated at `value`.
"""
inv_standardized = self.scale / (value - self.loc)
return 2.0 - 2.0 * ndtr(jnp.sqrt(inv_standardized))

@property
def mean(self) -> ArrayLike:
return jnp.broadcast_to(jnp.inf, self.batch_shape)

@property
def variance(self) -> ArrayLike:
return jnp.broadcast_to(jnp.inf, self.batch_shape)

def entropy(self) -> ArrayLike:
r"""If :math:`X \sim \text{Levy}(\mu, c)`, then the entropy of :math:`X` is given by,
.. math::
H(X) = \frac{1}{2}+\frac{3}{2}\gamma+\frac{1}{2}\ln{\left(16\pi c^2\right)}
"""
return jnp.broadcast_to(
0.5 + 1.5 * jnp.euler_gamma + 0.5 * jnp.log(16 * jnp.pi), self.batch_shape
) + jnp.log(self.scale)
4 changes: 4 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,7 @@ def __init__(
),
dist.Wishart: _wishart_to_scipy,
_TruncatedNormal: _truncnorm_to_scipy,
dist.Levy: lambda loc, scale: osp.levy(loc=loc, scale=scale),
}


Expand Down Expand Up @@ -933,6 +934,9 @@ def get_sp_dist(jax_dist):
T(dist.DoublyTruncatedPowerLaw, np.pi, 5.0, 50.0),
T(dist.DoublyTruncatedPowerLaw, -1.0, 5.0, 50.0),
T(dist.DoublyTruncatedPowerLaw, np.pi, 1.0, 2.0),
T(dist.Levy, 0.0, 1.0),
T(dist.Levy, 0.0, np.array([1.0, 2.0, 10.0])),
T(dist.Levy, np.array([1.0, 2.0, 10.0]), np.pi),
]

DIRECTIONAL = [
Expand Down

0 comments on commit 5aca6cb

Please sign in to comment.