-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
minimize
and forward AD
#112
Comments
Hi Vadim, selection of forward- vs. reverse-mode autodiff is currently implemented at the solver level. Which solvers are you interested in using? If you're using diffrax underneath: diffrax now has efficient forward-mode autodiff as well, with |
Hi @johannahaffner, Thanks for you reply! I was interested in BFGS, but in forward mode. I was able to successfully solve a "least_squares" problem using forward mode and GaussNewton. |
You're welcome! Could you share the error you are getting? |
Hi! Here is a MWE to reproduce: from diffrax import diffeqsolve, Euler, ForwardMode, ODETerm, SaveAt
import jax.numpy as jnp
import optimistix as optx
def fn(y0, _=None):
vector_field = lambda t, y, args: -y
term = ODETerm(vector_field)
solver = Euler()
saveat = SaveAt(ts=[0., 1., 2., 3.])
sol = diffeqsolve(term, solver, t0=0, t1=3, dt0=0.1, y0=y0, saveat=saveat, adjoint=ForwardMode())
return sol.ys
def least_square(y0, _=None):
ys = fn(y0)
return jnp.sum(ys**2)
ls_sol = optx.least_squares(fn, optx.GaussNewton(rtol=1e-8, atol=1e-8), jnp.asarray(1.))
print(ls_sol.value) # 0.0
min_sol = optx.minimise(least_square, optx.BFGS(rtol=1e-8, atol=1e-8), jnp.asarray(1.)) # error raised here
print(min_sol.value) and I get the following: I'm using optimistix 0.0.10, diffrax 0.6.2 and equinox 0.11.11 if that matters. EDIT: note that if I comment out |
Ok I think I see where this is coming from - it looks like this is raised by the optimistix implicit adjoint. We've likely not run into this before because The quickest fix I could think of didn't work, so I'll have to look into something else outside of working hours. Thanks for the MWE, I will keep you posted! (In the meantime, you could try |
Ok, this would require a bit of a change. I had not appreciated that optimistix/optimistix/_misc.py Line 125 in fb55786
Here is what I think we can do:
@patrick-kidger WDYT? Try both and see what sticks? |
FWIW it's actually the other way around! Reverse-mode AD uses linear transpose. And indeed when it comes to stuff like As for what do here: if we want to support an alternative forward mode here then I think evaluating But FWIW @vadmbertr the reason this doesn't come up super frequently is because for BFGS it is usually more efficient to use reverse-mode here to compute a gradient. Typically one would pair the Diffrax |
Hi! @johannahaffner, indeed @patrick-kidger, what would be the reason for BFGS being more efficient in reverse-mode? (in the context where evaluating the gradient is significantly faster in forward mode) A colleague of mine wanted to compare a Again, thanks for the replies! |
In general, reverse-mode automatic differentiation is more efficient for functions that map high-dimensional inputs to low-dimensional outputs (e.g. a neural network with many parameters and a scalar loss function). This means that forward- vs. reverse-mode being more efficient is not really a model property! It really depends on the optimiser you use and whether it operates on the residuals or on their squared sum. |
The use case is a bit in between what you described as the (scalar) loss "aggregates" (not the squared sum of the individual residuals) the outputs of a model with (very) few parameters. So I believe the Jacobian is wider but low-dimensional and (experimentally) computing the adjoint of the model is (much) faster in forward-mode than in reverse-mode. |
Is |
Yes, consider the following for example: from diffrax import diffeqsolve, Euler, ForwardMode, ODETerm, RecursiveCheckpointAdjoint, SaveAt
import jax
import jax.numpy as jnp
import optimistix as optx
def fn(w, y0, adjoint):
ts = jnp.arange(7*24)
vector_field = lambda t, y, args: y * args * jnp.exp(-t)
term = ODETerm(vector_field)
solver = Euler()
saveat = SaveAt(ts=ts)
sol = diffeqsolve(term, solver, t0=ts[0], t1=ts[-1], dt0=1, y0=y0, args=w, saveat=saveat, adjoint=adjoint)
return sol.ys
def loss(w, y0, adjoint):
ys = fn(w, y0, adjoint)
return jnp.sum(jnp.max(ys, axis=(1, 2)))
@jax.jit
def fwd(w, y0):
return jax.jacfwd(loss, argnums=0)(w, y0, ForwardMode())
@jax.jit
def bwd(w, y0):
return jax.grad(loss, argnums=0)(w, y0, RecursiveCheckpointAdjoint())
y0 = jnp.ones((100, 100))
w = 1.
%time print(fwd(w, y0))
# 536.20465
# CPU times: user 615 ms, sys: 58.3 ms, total: 673 ms
# Wall time: 492 ms
%time print(bwd(w, y0))
# 536.20483
# CPU times: user 1.51 s, sys: 60.6 ms, total: 1.57 s
# Wall time: 1.07 s
%timeit jax.block_until_ready(fwd(w, y0))
# 5.71 ms ± 260 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit jax.block_until_ready(bwd(w, y0))
# 11.9 ms ± 224 μs per loop (mean ± std. dev. of 7 runs, 100 loops each) |
A little counterintuitive, and therefore very interesting! Thanks for the demo. To support this, I'll include a forward option in (Your wall & CPU times includes compilation, btw. Call a jitted function once before benchmarking - when you time them below, the compilation has already happened + computation is 100x as fast.) |
Wow thanks a lot!
I believe it is compiled in the |
You're welcome!
Yes, exactly. |
Bear in mind in this example that it's a scalar->scalar function, and for these then it's totally expected for forward-mode to be optimal. This is why I write 'usually more efficient to use reverse-mode' above, rather than always. The fact that this isn't a super frequent use case for BFGS -- there are many specifically scalar->scalar optimizers that are more common then, and arguably our best possible improvement here is actually to add more of them by default! -- is why this hasn't been super important before :) |
@vadmbertr can you try https://github.com/johannahaffner/optimistix/tree/forward-fix on your real problem? This should now work, all you need to do is pass If it does not really help on real problems, then no need to add it @patrick-kidger. And agreed for scalar->scalar solvers, in general. |
Hi @johannahaffner, Thanks for implementing this swiftly! |
Good morning @vadmbertr, thanks for trying it out so quickly! I opened a PR to add this option for all minimisers. |
Hi!
I'm facing a similar use-case as the one described here #50 but I would like to optimize using
minimize
rather thanleast_squares
.Is there any plan/option for supporting a solution similar to
options={"jac": "fwd"}
?Thanks a lot!
Vadim
The text was updated successfully, but these errors were encountered: