Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using optimistix to solve optimization problems in parallel? #106

Open
mjo22 opened this issue Jan 3, 2025 · 4 comments
Open

Using optimistix to solve optimization problems in parallel? #106

mjo22 opened this issue Jan 3, 2025 · 4 comments
Labels
question User queries

Comments

@mjo22
Copy link

mjo22 commented Jan 3, 2025

Hello! I would like to integrate optimistix into my work, but I would appreciate some insight from others more experienced in the package because I believe my problem will require usage of a more advanced API. The defining feature of my problem is that I have many independent similar optimization problems (~ 100) to solve in parallel. I am currently addressing this by using optax and vmapping over both my loss function evaluations and my parameter updates. Pseudocode for this is the following:

import optax
import equinox as eqx
import jax
from functools import partial

@partial(eqx.filter_vmap, in_axes=(0, None))
@eqx.filter_value_and_grad
def compute_loss(pytree, args):
    ...

@eqx.filter_jit
def make_step(pytree, opt_state, args):
    loss_value, grads = compute_loss(pytree, args)
    updates, opt_state = optim.update(grads, opt_state, pytree)
    pytree = jax.vmap(lambda d, u: eqx.apply_updates(d, u))(pytree, updates)
    return pytree, opt_state, loss_value

initial_pytree = ...
optim = optax.adam(learning_rate=...)
opt_state = optim.init(initial_pytree)
n_steps = 100
for _ in range(n_steps):
    pytree, opt_state, loss_value = make_step(pytree, opt_state, args)
    ...

...

I am wondering if a similar approach is possible in optimistix. From reading through the documentation, it seems to me that I could directly work with AbstractDescent and AbstractSearch objects to step through a similar loop, but perhaps it won't be possible to use the AbstractMinimiser API. I am also interested to hear if there are other recommendations you may have.

@patrick-kidger
Copy link
Owner

I think you should be able to do this by using the existing Optimistix optimizers (no needed to define a new Abstract{Search,Descent}!) and wrapping their calls in a jax.vmap.

If you wanted to do this in a step-by-step fashion like you Optax example then see interactively step through a solve. The more common Optimistix APIs are to use optx.minimise etc directly. (And these can of course be vmap'd.)

In your example above, note that it might be simpler to apply a single vmap around the whole make_step function (rather than two around the loss and the apply_updates). As optax.adam operates per-parameter then this will be equivalent + probably easier to reason about.

@patrick-kidger patrick-kidger added the question User queries label Jan 3, 2025
@mjo22
Copy link
Author

mjo22 commented Jan 4, 2025

Ah okay thank you for the insight—I had been reading through this example of the advanced API, but I assumed there would be an issue with the fact that each optimization is evaluating its own termination condition. Using the optimistix.minimise API would be ideal, but I did not realize it could be vmapped for this reason.

lmagining modifying the “interactively step through a solve”, I would think this requires making a call to something like a vmapped “step” and “terminate” with a while loop termination when all problems have converged. Along this reasoning, some problems will have already converged in a given call to the vmapped “step” and “terminate”. Out of curiosity, how will this behave under the hood?

I am also wondering what should I consider when evaluating GPU performance of a vmap like this.

@patrick-kidger
Copy link
Owner

The overall computational cost will be batch size × greatest number of steps for any batch element. For example if there are a batch of 2 elements, the first batch element needs 15 steps and the second needs 18 steps, then the cost will be 2 × 18 = 36 steps.

The reason for this is that under the hood, batch-of-loop is transformed into loop-of-batch, with any already-terminated batch elements running an iteration whose output is then ignored. The loop terminates when all batch elements have terminated.

The reason for this is in part driven by GPU concerns: the above is a SIMD-friendly way of expressing batches of loops.

FWIW all of the above is just the normal behaviour of jax.vmap(jax.lax.while_loop), we're not really doing anything special here! :)

@mjo22
Copy link
Author

mjo22 commented Jan 6, 2025

Okay this really clears things up! Thanks so much, this was puzzling me. I had not thought through this in context of the general behavior of jax.vmap(jax.lax.while_loop).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants