Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

minimize and forward AD #112

Open
vadmbertr opened this issue Jan 24, 2025 · 19 comments
Open

minimize and forward AD #112

vadmbertr opened this issue Jan 24, 2025 · 19 comments
Labels
question User queries

Comments

@vadmbertr
Copy link

Hi!

I'm facing a similar use-case as the one described here #50 but I would like to optimize using minimize rather than least_squares.
Is there any plan/option for supporting a solution similar to options={"jac": "fwd"}?

Thanks a lot!
Vadim

@johannahaffner
Copy link
Contributor

johannahaffner commented Jan 24, 2025

Hi Vadim,

selection of forward- vs. reverse-mode autodiff is currently implemented at the solver level. Which solvers are you interested in using?

If you're using diffrax underneath: diffrax now has efficient forward-mode autodiff as well, with diffrax.ForwardMode.

@vadmbertr
Copy link
Author

Hi @johannahaffner,

Thanks for you reply!

I was interested in BFGS, but in forward mode. I was able to successfully solve a "least_squares" problem using forward mode and GaussNewton.
Indeed I'm using diffrax with diffrax.ForwardMode!

@johannahaffner
Copy link
Contributor

You're welcome!

Could you share the error you are getting? BFGS already uses jax.linearize, which implements forward-mode automatic differentiation. So this should actually work out of the box.

@patrick-kidger patrick-kidger added the question User queries label Jan 25, 2025
@vadmbertr
Copy link
Author

vadmbertr commented Jan 27, 2025

Hi!

Here is a MWE to reproduce:

from diffrax import diffeqsolve, Euler, ForwardMode, ODETerm, SaveAt
import jax.numpy as jnp
import optimistix as optx


def fn(y0, _=None):
    vector_field = lambda t, y, args: -y
    term = ODETerm(vector_field)
    solver = Euler()
    saveat = SaveAt(ts=[0., 1., 2., 3.])

    sol = diffeqsolve(term, solver, t0=0, t1=3, dt0=0.1, y0=y0, saveat=saveat, adjoint=ForwardMode())
    
    return sol.ys


def least_square(y0, _=None):
    ys = fn(y0)
    return jnp.sum(ys**2)


ls_sol = optx.least_squares(fn, optx.GaussNewton(rtol=1e-8, atol=1e-8), jnp.asarray(1.))
print(ls_sol.value)  # 0.0
min_sol = optx.minimise(least_square, optx.BFGS(rtol=1e-8, atol=1e-8), jnp.asarray(1.))  # error raised here
print(min_sol.value)

and I get the following:
ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop with dynamic start/stop values. Try using lax.scan, or using fori_loop with static start/stop.

I'm using optimistix 0.0.10, diffrax 0.6.2 and equinox 0.11.11 if that matters.

EDIT: note that if I comment out adjoint=ForwardMode() and the two ls_sol lines, the min_sol ones evaluate as expected (but as I said I would like to use foward mode).

@johannahaffner
Copy link
Contributor

Ok I think I see where this is coming from - it looks like this is raised by the optimistix implicit adjoint. We've likely not run into this before because ForwardMode is pretty new.

The quickest fix I could think of didn't work, so I'll have to look into something else outside of working hours. Thanks for the MWE, I will keep you posted!

(In the meantime, you could try diffrax.DirectAdjoint - not a recommended long-term solution because its not super memory efficient, but it will allow for reverse-mode differentiation in the one place it is required, and forward-mode everywhere else.)

@johannahaffner
Copy link
Contributor

johannahaffner commented Jan 27, 2025

Ok, this would require a bit of a change. I had not appreciated that jax.linear_transpose will actually use reverse-mode AD machinery under the hood, but it does. We're using that in the minimisers to get a gradient from the linearised function.

return jax.linear_transpose(lin_fn, *primals)(1.0)

Here is what I think we can do:

  1. Create options fwd, bwd and branches for each.
  • In the forward branch, compute the gradient with jacfwd, which is equivalent to the gradient for a scalar function
  • In the reverse mode branch, keep doing what we are doing for the performance benefits
  1. Also create options as above, but get the gradient out of lin_fn a different way - essentially a custom jacfwd that uses lin_fn, constructs unit pytrees of shape y, then stacks the output of vmap(lin_fn)(unit_pytrees). The motivation here would be that we if already have lin_fn around, we might as well use it.
    Not sure what the performance would be. As mentioned in the documentation of jax.linearize, storing the linearised function has a memory overhead that might not outweigh the compilation time benefits in this case.

@patrick-kidger WDYT? Try both and see what sticks?

@patrick-kidger
Copy link
Owner

Ok, this would require a bit of a change. I had not appreciated that jax.linear_transpose will actually use reverse-mode AD machinery under the hood, but it does.

FWIW it's actually the other way around! Reverse-mode AD uses linear transpose. And indeed when it comes to stuff like jax.lax.while_loop then to be precise this isn't transposable, rather than not being reverse-mode autodifferentiable. (Since transposition is an infrequently used feature relative to reverse-mode AD, though, the JAX error message refers to the latter.)

As for what do here: if we want to support an alternative forward mode here then I think evaluating vmap(lin_fn)(unit_pytrees) would be the computationally optimal way to do this. (This is actually want jacfwd already does under the hood.)

But FWIW @vadmbertr the reason this doesn't come up super frequently is because for BFGS it is usually more efficient to use reverse-mode here to compute a gradient. Typically one would pair the Diffrax adjoint=... with the Optimistix solver=... as appropriate. (I will grant you that it's kind of annoying that this is needed, alas JAX does not currently support jvp-of-custom_vjp and this abstraction leak is what results.) Would doing this work for your real-world use-case?

@vadmbertr
Copy link
Author

Hi!
Thank you for digging into this.

@johannahaffner, indeed DirectAdjoint works as a workaround but might not be ideal for the reason you pointed out.

@patrick-kidger, what would be the reason for BFGS being more efficient in reverse-mode? (in the context where evaluating the gradient is significantly faster in forward mode) A colleague of mine wanted to compare a numpy implementation of a differential equation calibration problem with a JAX one, and he initially used BFGS for the solve. What would be the solver you recommend here (using optx.minimise)?

Again, thanks for the replies!
Vadim

@johannahaffner
Copy link
Contributor

@patrick-kidger, what would be the reason for BFGS being more efficient in reverse-mode? (in the context where evaluating the gradient is significantly faster in forward mode)

In general, reverse-mode automatic differentiation is more efficient for functions that map high-dimensional inputs to low-dimensional outputs (e.g. a neural network with many parameters and a scalar loss function). BFGS operates on such a scalar loss.
Forward-mode automatic differentiation is more efficient in the opposite setting, where few inputs get mapped to many outputs, such as a mechanistic model (like an ODE) with few parameters fitted to a long time series. The residual Jacobian is then going to be tall and narrow - with many rows for the residual, each of which is considered a model output.

This means that forward- vs. reverse-mode being more efficient is not really a model property! It really depends on the optimiser you use and whether it operates on the residuals or on their squared sum.

@vadmbertr
Copy link
Author

@patrick-kidger, what would be the reason for BFGS being more efficient in reverse-mode? (in the context where evaluating the gradient is significantly faster in forward mode)

In general, reverse-mode automatic differentiation is more efficient for functions that map high-dimensional inputs to low-dimensional outputs (e.g. a neural network with many parameters and a scalar loss function). BFGS operates on such a scalar loss. Forward-mode automatic differentiation is more efficient in the opposite setting, where few inputs get mapped to many outputs, such as a mechanistic model (like an ODE) with few parameters fitted to a long time series. The residual Jacobian is then going to be tall and narrow - with many rows for the residual, each of which is considered a model output.

This means that forward- vs. reverse-mode being more efficient is not really a model property! It really depends on the optimiser you use and whether it operates on the residuals or on their squared sum.

The use case is a bit in between what you described as the (scalar) loss "aggregates" (not the squared sum of the individual residuals) the outputs of a model with (very) few parameters. So I believe the Jacobian is wider but low-dimensional and (experimentally) computing the adjoint of the model is (much) faster in forward-mode than in reverse-mode.

@johannahaffner
Copy link
Contributor

The use case is a bit in between what you described as the (scalar) loss "aggregates" (not the squared sum of the individual residuals) the outputs of a model with (very) few parameters. So I believe the Jacobian is wider but low-dimensional and (experimentally) computing the adjoint of the model is (much) faster in forward-mode than in reverse-mode.

Is jacfwd on the scalar loss faster than grad?

@vadmbertr
Copy link
Author

Is jacfwd on the scalar loss faster than grad?

Yes, consider the following for example:

from diffrax import diffeqsolve, Euler, ForwardMode, ODETerm, RecursiveCheckpointAdjoint, SaveAt
import jax
import jax.numpy as jnp
import optimistix as optx


def fn(w, y0, adjoint):
    ts = jnp.arange(7*24)
    vector_field = lambda t, y, args: y * args * jnp.exp(-t)
    term = ODETerm(vector_field)
    solver = Euler()
    saveat = SaveAt(ts=ts)

    sol = diffeqsolve(term, solver, t0=ts[0], t1=ts[-1], dt0=1, y0=y0, args=w, saveat=saveat, adjoint=adjoint)

    return sol.ys


def loss(w, y0, adjoint):
    ys = fn(w, y0, adjoint)
    return jnp.sum(jnp.max(ys, axis=(1, 2)))


@jax.jit
def fwd(w, y0):
    return jax.jacfwd(loss, argnums=0)(w, y0, ForwardMode())


@jax.jit
def bwd(w, y0):
    return jax.grad(loss, argnums=0)(w, y0, RecursiveCheckpointAdjoint())


y0 = jnp.ones((100, 100))
w = 1.

%time print(fwd(w, y0))
# 536.20465
# CPU times: user 615 ms, sys: 58.3 ms, total: 673 ms
# Wall time: 492 ms
%time print(bwd(w, y0))
# 536.20483
# CPU times: user 1.51 s, sys: 60.6 ms, total: 1.57 s
# Wall time: 1.07 s

%timeit jax.block_until_ready(fwd(w, y0))
# 5.71 ms ± 260 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit jax.block_until_ready(bwd(w, y0))
# 11.9 ms ± 224 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

@johannahaffner
Copy link
Contributor

A little counterintuitive, and therefore very interesting! Thanks for the demo. To support this, I'll include a forward option in BFGS, then. I can probably get to it over the weekend :)

(Your wall & CPU times includes compilation, btw. Call a jitted function once before benchmarking - when you time them below, the compilation has already happened + computation is 100x as fast.)

@vadmbertr
Copy link
Author

A little counterintuitive, and therefore very interesting! Thanks for the demo. To support this, I'll include a forward option in BFGS, then. I can probably get to it over the weekend :)

Wow thanks a lot!

(Your wall & CPU times includes compilation, btw. Call a jitted function once before benchmarking - when you time them below, the compilation has already happened + computation is 100x as fast.)

I believe it is compiled in the %time call and it calls the jitted version in the %timeit one!

@johannahaffner
Copy link
Contributor

Wow thanks a lot!

You're welcome!

I believe it is compiled in the %time call and it calls the jitted version in the %timeit one!

Yes, exactly.

@patrick-kidger
Copy link
Owner

Bear in mind in this example that it's a scalar->scalar function, and for these then it's totally expected for forward-mode to be optimal. This is why I write 'usually more efficient to use reverse-mode' above, rather than always. The fact that this isn't a super frequent use case for BFGS -- there are many specifically scalar->scalar optimizers that are more common then, and arguably our best possible improvement here is actually to add more of them by default! -- is why this hasn't been super important before :)

@johannahaffner
Copy link
Contributor

johannahaffner commented Feb 2, 2025

@vadmbertr can you try https://github.com/johannahaffner/optimistix/tree/forward-fix on your real problem? This should now work, all you need to do is pass options=dict(mode="fwd") to optx.minimise.

If it does not really help on real problems, then no need to add it @patrick-kidger. And agreed for scalar->scalar solvers, in general.

@vadmbertr
Copy link
Author

Hi @johannahaffner,

Thanks for implementing this swiftly!
I can confirm that I get 50% to 1000% speed-up on real-world problems of different complexity (meaning time-step, state domain dimension) so I will be happy if it gets added.

@johannahaffner
Copy link
Contributor

Good morning @vadmbertr, thanks for trying it out so quickly! I opened a PR to add this option for all minimisers.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

3 participants