Skip to content

Commit

Permalink
Merge pull request #121 from danielward27/laplace
Browse files Browse the repository at this point in the history
add Laplace
  • Loading branch information
danielward27 authored Dec 6, 2023
2 parents a09c36d + ef60f11 commit e99c928
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 7 deletions.
54 changes: 48 additions & 6 deletions flowjax/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def _sample(self, key, condition=None):
def _sample_and_log_prob(
self,
key: Array,
condition=None,
condition: Array | None = None,
): # TODO add overide decorator when python>=3.12 is common
# We override to avoid computing the inverse transformation.
base_sample, log_prob_base = self.base_dist._sample_and_log_prob(key, condition)
Expand Down Expand Up @@ -408,7 +408,7 @@ class LogNormal(AbstractTransformed):
Args:
loc: Location paramter. Defaults to 0.
scale: Scale parameter. Defaults to 1.0.
scale: Scale parameter. Defaults to 1.
"""

base_dist: StandardNormal
Expand Down Expand Up @@ -527,7 +527,7 @@ class Gumbel(AbstractTransformed):
Args:
loc: Location paramter.
scale: Scale parameter. Defaults to 1.0.
scale: Scale parameter. Defaults to 1.
"""

base_dist: _StandardGumbel
Expand Down Expand Up @@ -573,7 +573,7 @@ class Cauchy(AbstractTransformed):
Args:
loc: Location paramter.
scale: Scale parameter. Defaults to 1.0.
scale: Scale parameter. Defaults to 1.
"""

base_dist: _StandardCauchy
Expand Down Expand Up @@ -628,8 +628,8 @@ class StudentT(AbstractTransformed):
Args:
df: The degrees of freedom.
loc: Location parameter. Defaults to 0.0.
scale: Scale parameter. Defaults to 1.0.
loc: Location parameter. Defaults to 0.
scale: Scale parameter. Defaults to 1.
"""

base_dist: _StandardStudentT
Expand All @@ -656,6 +656,48 @@ def df(self):
return self.base_dist.df


class _StandardLaplace(AbstractDistribution):
"""Implements standard laplace distribution (loc=0, scale=1)."""

shape: tuple[int, ...] = ()
cond_shape: ClassVar[None] = None

def _log_prob(self, x, condition=None):
return jstats.laplace.logpdf(x).sum()

def _sample(self, key, condition=None):
return jr.laplace(key, shape=self.shape)


class Laplace(AbstractTransformed):
"""Laplace distribution.
``loc`` and ``scale`` should broadcast to the dimension of the distribution..
Args:
loc: Location paramter. Defaults to 0.
scale: Scale parameter. Defaults to 1.
"""

base_dist: _StandardLaplace
bijection: Affine

def __init__(self, loc: ArrayLike = 0, scale: ArrayLike = 1):
shape = jnp.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
self.base_dist = _StandardLaplace(shape)
self.bijection = Affine(loc, scale)

@property
def loc(self):
"""Location of the distribution."""
return self.bijection.loc

@property
def scale(self):
"""Scale of the distribution."""
return self.bijection.scale


class SpecializeCondition(AbstractDistribution): # TODO check tested
"""Specialise a distribution to a particular conditioning variable instance.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ license = { file = "LICENSE" }
name = "flowjax"
readme = "README.md"
requires-python = ">=3.10"
version = "11.0.0"
version = "11.1.0"

[project.urls]
repository = "https://github.com/danielward27/flowjax"
Expand Down
4 changes: 4 additions & 0 deletions tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
AbstractTransformed,
Cauchy,
Gumbel,
Laplace,
LogNormal,
MultivariateNormal,
Normal,
Expand All @@ -22,6 +23,7 @@
Uniform,
_StandardCauchy,
_StandardGumbel,
_StandardLaplace,
_StandardStudentT,
_StandardUniform,
)
Expand All @@ -44,6 +46,8 @@
"_StandardStudentT": lambda shape: _StandardStudentT(jnp.ones(shape)),
"StudentT": lambda shape: StudentT(jnp.ones(shape)),
"LogNormal": lambda shape: LogNormal(jnp.ones(shape), 2),
"_StandardLaplace": _StandardLaplace,
"Laplace": lambda shape: Laplace(jnp.ones(shape)),
}


Expand Down

0 comments on commit e99c928

Please sign in to comment.