-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add some more bijections #206
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cheers! I have been meaning to add the Householder
transform for a while! I had a little play with the discrete cosine transform too - it definitely gives some fun possibilities for modelling sequence/time series data.
I do wan't to be a bit careful to avoid overly expanding the number of bijections included in FlowJAX, so I think omitting Neg
and AsymmetricAffine
is probably sensible for now. I've added some comments, and if you are happy to incorporate the changes I'll merge it in. Thanks for the pull request!
flowjax/bijections/orthogonal.py
Outdated
from jax.scipy import fft | ||
|
||
|
||
class Neg(AbstractBijection): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is probably not worth including as a separate bijection. One workaround would be to use:
from flowjax.bijections import Scale
from paramax import non_trainable
shape = ()
neg = Scale(jnp.ones(shape))
neg = eqx.tree_at(lambda neg: neg.scale, neg, replace=non_trainable(-jnp.ones(shape)))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
flowjax/bijections/orthogonal.py
Outdated
return (x, x_grad, y_logp - logdet) | ||
|
||
|
||
class DCT(AbstractBijection): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we rename this to DiscreteCosine
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
flowjax/bijections/orthogonal.py
Outdated
self.axis = axis | ||
self.norm = "ortho" | ||
|
||
def _dct(self, x: Array, inverse: bool = False) -> Array: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems more readable to avoid the if else here, and just call idct or dct in transform_and_log_det
and inverse_and_log_det
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also maybe avoid norm: str
attribute, since it is fixed anyway?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
flowjax/bijections/softplus.py
Outdated
@@ -20,3 +24,120 @@ def transform_and_log_det(self, x, condition=None): | |||
def inverse_and_log_det(self, y, condition=None): | |||
x = jnp.log(-jnp.expm1(-y)) + y | |||
return x, softplus(-x).sum() | |||
|
|||
|
|||
class AsymmetricAffine(AbstractBijection): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems relatively niche, is there a reference for this somewhere? I want to be cautious to avoid an endless expansion of bijections that will clutter the package.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -300,3 +304,54 @@ def inverse_and_log_det(self, y, condition=None): | |||
x = self.inverter(self.bijection, y, condition) | |||
_, log_det = self.bijection.transform_and_log_det(x, condition) | |||
return x, log_det | |||
|
|||
|
|||
class Sandwich(AbstractBijection): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I quite like this idea :)
flowjax/bijections/utils.py
Outdated
|
||
def __init__(self, outer: AbstractBijection, inner: AbstractBijection): | ||
shape = inner.shape | ||
if outer.shape != shape: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think, we can remove this shape checking, if we make use of chain below. In this case the error would only be raised at tracing of transform/inverse methods, but it keeps things more concise (and also supports chaining of unconditional and conditional transforms)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Chain for instance also checks in the init, and we have to check the cond_shapes anyway, right?
I changed it to use the same helpers as the Chain method.
But if you really don't want it there, I can of course also remove the shape check.
flowjax/bijections/utils.py
Outdated
self.inner = inner | ||
|
||
def transform_and_log_det(self, x: Array, condition=None) -> tuple[Array, Array]: | ||
z1, logdet1 = self.outer.transform_and_log_det(x, condition) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Chain([self.outer, self.inner, Invert(self.outer)]).transform_and_log_det(x, condition)
flowjax/bijections/utils.py
Outdated
return y, logdet1 + logdet2 + logdet3 | ||
|
||
def inverse_and_log_det(self, y: Array, condition=None) -> tuple[Array, Array]: | ||
z1, logdet1 = self.outer.transform_and_log_det(y, condition) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Chain([self.outer, self.inner, Invert(self.outer)]).inverse_and_log_det(x, condition)
flowjax/bijections/orthogonal.py
Outdated
def inverse_and_log_det(self, y: Array, condition: Array | None = None): | ||
return self._householder(y, self.params), jnp.zeros(()) | ||
|
||
def inverse_gradient_and_val( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove extra method for the time being
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wasn't meant to be in there, sorry.
flowjax/bijections/orthogonal.py
Outdated
self.shape = (params.shape[-1],) | ||
self.params = params | ||
|
||
def _householder(self, x: Array, params: Array) -> Array: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like we can simplify to
vec = params / jnp.linalg.norm(params)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe a more readable option, is to instead store:
self.unit_vec = paramax.Parameterize(lambda x: x/jnp.linalg.norm(x), params)
and use:
x - 2 * self.unit_vec * (x @ self.unit_vec)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm actually not entirely happy with the parametrization of the vector.
It will for instance blow up if x
is close to zero.
But parameterizing unit vectors isn't trivial, as there is no bijection between any R^n and the space of unit vectors. Maybe we can find something that at least doesn't have any poles?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe something like this would be a better default parametrization for the unit vector?:
def householder_transform(p, v):
n = p.shape[0]
# Compute Householder vector u
e1 = jnp.zeros_like(p).at[0].set(1.0)
u = p - e1
norm_u_raw = jnp.linalg.norm(u)
# Check if p is already aligned with e1
is_aligned = norm_u_raw < 1e-6
# Case 1: If p is approximately e1, embed v as [0, v]
# This avoids dividing by small values if norm(u) is small
v_proj_full = jnp.zeros(n).at[1:].set(v)
# Case 2: Apply the Householder transformation normally
def apply_householder():
u_normalized = u / norm_u_raw
dot_uv = jnp.dot(u_normalized[1:], v)
v_proj = v - 2.0 * dot_uv * u_normalized[1:]
return jnp.zeros(n).at[1:].set(v_proj)
v_proj_full = jnp.where(is_aligned, v_proj_full, apply_householder())
return v_proj_full
def exp_map_sphere(p, v):
"""Riemannian exponential map on the n-sphere S^n
Compute the point on the sphere reached by the exponential
map from point p with tangent vector v (shape: (n+1,))
Parameters:
p: Point on the sphere (shape: (n+1,))
v: Tangent vector in R^n (shape: (n,))
"""
# Project v into the correct tangent space using Householder transformation
v_proj = householder_transform(p, v)
norm_v_raw = jnp.linalg.norm(v_proj)
# Avoid NaNs in the gradient
norm_v = jnp.where(norm_v_raw > 1e-6, norm_v_raw, 1.0)
direction = v_proj / norm_v
# General case: Compute the exponential map
exp_general = jnp.cos(norm_v) * p + jnp.sin(norm_v) * direction
# Small v case: Use a Taylor expansion and re-normalize to stay on the sphere
exp_taylor = p + v_proj
exp_taylor = exp_taylor / jnp.linalg.norm(exp_taylor)
return jnp.where(norm_v_raw > 1e-6, exp_general, exp_taylor)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would be hesitant to come up with a custom unit vector parameterization, although I understand the simple approach likely leads to a (probably unavoidable) bad loss landscape. I haven't checked carefully, but most the implementations of related stuff seems to use the simple idea (dividing by the norm for the unit vector or squared norm when computing the householder transform: e.g. stan, original householder flow, nflows). Division by the norm is pretty common in other applications (e.g. Weight normalisation). I think it makes sense to have the unit vector as a Parameterized
attribute self.unit_vec
, so it can be modified as needed. I'll admit I'm no expert here, so I am biased towards the standard approach.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed it to the Parameterize call with params / norm(params).
Thanks for the review, I'll fix those early next week. About AsymmetricAffine: That one doesn't have a reference, I made it up during experimentation. I wanted to have something between an affine and the rational quadratic spline as transformation in the coupling layer. For some posteriors that seems to work very well. Wasn't easy to come up with something where neither the forward nor inverse blows up, and that still has a closed form inverse. I don't know if it will be useful to other people, but I think it might be... |
Thanks! |
38fd4c8
to
1002c9a
Compare
This adds