Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add some more bijections #206

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

aseyboldt
Copy link

This adds

  • DCT: discrete cosine transform
  • AsymmetricAffine: a smooth nonlinear transform that applies different scaling factors for positive and negative values
  • Householder: a householder transformation
  • Neg: multiply by -1
  • Sandwich: apply a transform, then a second and then the inverse of the first

Copy link
Owner

@danielward27 danielward27 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cheers! I have been meaning to add the Householder transform for a while! I had a little play with the discrete cosine transform too - it definitely gives some fun possibilities for modelling sequence/time series data.

I do wan't to be a bit careful to avoid overly expanding the number of bijections included in FlowJAX, so I think omitting Neg and AsymmetricAffine is probably sensible for now. I've added some comments, and if you are happy to incorporate the changes I'll merge it in. Thanks for the pull request!

from jax.scipy import fft


class Neg(AbstractBijection):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is probably not worth including as a separate bijection. One workaround would be to use:

from flowjax.bijections import Scale
from paramax import non_trainable
shape = ()
neg = Scale(jnp.ones(shape))
neg = eqx.tree_at(lambda neg: neg.scale, neg, replace=non_trainable(-jnp.ones(shape)))

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

return (x, x_grad, y_logp - logdet)


class DCT(AbstractBijection):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we rename this to DiscreteCosine?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

self.axis = axis
self.norm = "ortho"

def _dct(self, x: Array, inverse: bool = False) -> Array:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems more readable to avoid the if else here, and just call idct or dct in transform_and_log_det and inverse_and_log_det

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also maybe avoid norm: str attribute, since it is fixed anyway?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -20,3 +24,120 @@ def transform_and_log_det(self, x, condition=None):
def inverse_and_log_det(self, y, condition=None):
x = jnp.log(-jnp.expm1(-y)) + y
return x, softplus(-x).sum()


class AsymmetricAffine(AbstractBijection):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems relatively niche, is there a reference for this somewhere? I want to be cautious to avoid an endless expansion of bijections that will clutter the package.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -300,3 +304,54 @@ def inverse_and_log_det(self, y, condition=None):
x = self.inverter(self.bijection, y, condition)
_, log_det = self.bijection.transform_and_log_det(x, condition)
return x, log_det


class Sandwich(AbstractBijection):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I quite like this idea :)


def __init__(self, outer: AbstractBijection, inner: AbstractBijection):
shape = inner.shape
if outer.shape != shape:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think, we can remove this shape checking, if we make use of chain below. In this case the error would only be raised at tracing of transform/inverse methods, but it keeps things more concise (and also supports chaining of unconditional and conditional transforms)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chain for instance also checks in the init, and we have to check the cond_shapes anyway, right?
I changed it to use the same helpers as the Chain method.
But if you really don't want it there, I can of course also remove the shape check.

self.inner = inner

def transform_and_log_det(self, x: Array, condition=None) -> tuple[Array, Array]:
z1, logdet1 = self.outer.transform_and_log_det(x, condition)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chain([self.outer, self.inner, Invert(self.outer)]).transform_and_log_det(x, condition)

return y, logdet1 + logdet2 + logdet3

def inverse_and_log_det(self, y: Array, condition=None) -> tuple[Array, Array]:
z1, logdet1 = self.outer.transform_and_log_det(y, condition)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chain([self.outer, self.inner, Invert(self.outer)]).inverse_and_log_det(x, condition)

def inverse_and_log_det(self, y: Array, condition: Array | None = None):
return self._householder(y, self.params), jnp.zeros(())

def inverse_gradient_and_val(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove extra method for the time being

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wasn't meant to be in there, sorry.

self.shape = (params.shape[-1],)
self.params = params

def _householder(self, x: Array, params: Array) -> Array:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like we can simplify to

vec = params / jnp.linalg.norm(params)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a more readable option, is to instead store:

self.unit_vec = paramax.Parameterize(lambda x: x/jnp.linalg.norm(x), params)

and use:

x - 2 * self.unit_vec * (x @ self.unit_vec)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm actually not entirely happy with the parametrization of the vector.
It will for instance blow up if x is close to zero.
But parameterizing unit vectors isn't trivial, as there is no bijection between any R^n and the space of unit vectors. Maybe we can find something that at least doesn't have any poles?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe something like this would be a better default parametrization for the unit vector?:

def householder_transform(p, v):
    n = p.shape[0]

    # Compute Householder vector u
    e1 = jnp.zeros_like(p).at[0].set(1.0)
    u = p - e1
    norm_u_raw = jnp.linalg.norm(u)

    # Check if p is already aligned with e1
    is_aligned = norm_u_raw < 1e-6

    # Case 1: If p is approximately e1, embed v as [0, v]
    # This avoids dividing by small values if norm(u) is small
    v_proj_full = jnp.zeros(n).at[1:].set(v)

    # Case 2: Apply the Householder transformation normally
    def apply_householder():
        u_normalized = u / norm_u_raw
        dot_uv = jnp.dot(u_normalized[1:], v)
        v_proj = v - 2.0 * dot_uv * u_normalized[1:]
        return jnp.zeros(n).at[1:].set(v_proj)

    v_proj_full = jnp.where(is_aligned, v_proj_full, apply_householder())

    return v_proj_full


def exp_map_sphere(p, v):
    """Riemannian exponential map on the n-sphere S^n

    Compute the point on the sphere reached by the exponential
    map from point p with tangent vector v (shape: (n+1,))

    Parameters:
    p: Point on the sphere (shape: (n+1,))
    v: Tangent vector in R^n (shape: (n,))
    """
    # Project v into the correct tangent space using Householder transformation
    v_proj = householder_transform(p, v)

    norm_v_raw = jnp.linalg.norm(v_proj)
    # Avoid NaNs in the gradient
    norm_v = jnp.where(norm_v_raw > 1e-6, norm_v_raw, 1.0)

    direction = v_proj / norm_v

    # General case: Compute the exponential map
    exp_general = jnp.cos(norm_v) * p + jnp.sin(norm_v) * direction

    # Small v case: Use a Taylor expansion and re-normalize to stay on the sphere
    exp_taylor = p + v_proj
    exp_taylor = exp_taylor / jnp.linalg.norm(exp_taylor)
    return jnp.where(norm_v_raw > 1e-6, exp_general, exp_taylor)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be hesitant to come up with a custom unit vector parameterization, although I understand the simple approach likely leads to a (probably unavoidable) bad loss landscape. I haven't checked carefully, but most the implementations of related stuff seems to use the simple idea (dividing by the norm for the unit vector or squared norm when computing the householder transform: e.g. stan, original householder flow, nflows). Division by the norm is pretty common in other applications (e.g. Weight normalisation). I think it makes sense to have the unit vector as a Parameterized attribute self.unit_vec, so it can be modified as needed. I'll admit I'm no expert here, so I am biased towards the standard approach.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed it to the Parameterize call with params / norm(params).

@aseyboldt
Copy link
Author

Thanks for the review, I'll fix those early next week.

About AsymmetricAffine: That one doesn't have a reference, I made it up during experimentation. I wanted to have something between an affine and the rational quadratic spline as transformation in the coupling layer. For some posteriors that seems to work very well. Wasn't easy to come up with something where neither the forward nor inverse blows up, and that still has a closed form inverse. I don't know if it will be useful to other people, but I think it might be...

@danielward27
Copy link
Owner

Thanks! AsymmetricAffine is interesting, but I generally want to try to only include more established methods (to avoid the package becoming cluttered, ensure that people can find full descriptions of the methods, ensure that empirically it provides advantages against other methods etc).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants