Skip to content

Commit

Permalink
Merge pull request #199 from danielward27/uniform_non_trainable
Browse files Browse the repository at this point in the history
Make uniform non-trainable
  • Loading branch information
danielward27 authored Dec 18, 2024
2 parents be408e4 + 27d1d2c commit b90e6e9
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions flowjax/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from jax.scipy.special import logsumexp
from jax.tree_util import tree_map
from jaxtyping import Array, ArrayLike, PRNGKeyArray, Shaped
from paramax import AbstractUnwrappable, Parameterize, unwrap
from paramax import AbstractUnwrappable, Parameterize, non_trainable, unwrap
from paramax.utils import inv_softplus

from flowjax.bijections import (
Expand Down Expand Up @@ -478,17 +478,18 @@ def __init__(self, minval: ArrayLike, maxval: ArrayLike):
(minval, maxval), maxval <= minval, "minval must be less than the maxval."
)
self.base_dist = _StandardUniform(shape)
self.bijection = Affine(loc=minval, scale=maxval - minval)
self.bijection = non_trainable(Affine(loc=minval, scale=maxval - minval))

@property
def minval(self):
"""Minimum value of the uniform distribution."""
return self.bijection.loc
return unwrap(self.bijection.loc)

@property
def maxval(self):
"""Maximum value of the uniform distribution."""
return self.bijection.loc + unwrap(self.bijection.scale)
unwrapped = unwrap(self)
return unwrapped.loc + unwrapped.scale


class _StandardGumbel(AbstractDistribution):
Expand Down

0 comments on commit b90e6e9

Please sign in to comment.