Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimization across multidimensional array #70

Open
arunoruto opened this issue Jul 30, 2024 · 4 comments
Open

Optimization across multidimensional array #70

arunoruto opened this issue Jul 30, 2024 · 4 comments
Labels
question User queries

Comments

@arunoruto
Copy link

arunoruto commented Jul 30, 2024

I just recently came across JAX and I am now trying to use it for my implementation of the Hapke Anisotropic Multiple Scattering Approximation model. I made a similar issue on jaxopt, but since the repository isn't going to be maintained much in the future, I gave optimistix a try! It seems to be faster than jaxopt and a bit than scipy:

# optimistix:
Inverse AMSA: Mean +- std dev: 984 ms +- 30 ms
# scipy:
Inverse AMSA: Mean +- std dev: 1.17 sec +- 0.03 sec
# LM:
Inverse AMSA: Mean +- std dev: 42.7 sec +- 0.5 sec

I am using optx.least_squares with LM in a for-loop which iterates over all pixels and tries to find the best fit to the target function. The structure is similar to the code provided here.

I was wondering if the current implementation would somehow allow me to pass a multidimensional array, or even a matrix, and optimize along an axis. Is there a trick maybe to achieve what I want?

Also, would it be possible to maybe provide a derivative function of the target function? I am still impressed I got such runtimes without providing it, but why derive it if I can provide it :)

@arunoruto
Copy link
Author

I came across a vectorized Levenberg-Marquardt implementation, which sounds like exactly what I need!

Sadly I have only a little bit of experience in programming with JAX. I will try to naively implement this paper in Python, and then try to port it to Optimistix.

tldr; the paper introduces an operator G and factor mu with the final LM equation becoming:

arg min ||f(p)-y||^2 + mu ||Gp||^2
     p

Gp returns a concatenation of the spatial gradients of each parameter map.

@patrick-kidger
Copy link
Owner

JAX has a great tool for this called jax.vmap. First describe a single optimization problem you want to solve, and then vmap it over as many extra dimensions as you like. JAX and Optimistix will autovectorize the whole operation.

@patrick-kidger patrick-kidger added the question User queries label Jul 31, 2024
@arunoruto
Copy link
Author

Thanks for the tip! I was reading about vmaps this morning, so I was able to sketch a small PoC and it seems to work! I am not sure if it is the right way to do it, so I wanted to ask if I am on the right track:

  • My current amsa function works for scalars, vectors and images/matrices. Would you say it would be better to cater it to scalars and then apply vmap to obtain the vector and matrix versions? What is the JAX way of doing it?
  • My function accepts a lot of parameters:
  single_scattering_albedo: ArrayLike, # [a, b]
  incidence_direction: ArrayLike, # [a, b, 3]
  emission_direction: ArrayLike, # [a, b, 3]
  surface_orientation: ArrayLike, # [a, b, 3]
  phase_function: dict,
  b_n: ArrayLike,
  a_n: ArrayLike = np.nan,
  roughness: float = 0,
  hs: float = 0,
  bs0: float = 0,
  hc: float = 0,
  bc0: float = 0,
  refl_optimization: ArrayLike = 0.0, # [a, b]

I used vmap twice:

    amsa1 = jax.vmap(
        amsa,
        (0, 0, 0, 0, None, None, None, None, None, None, None, None, 0),
        0,
    )
    amsa2 = jax.vmap(
        amsa1,
        (1, 1, 1, 1, None, None, None, None, None, None, None, None, 1),
        1,
    )

    def amsa_optx_wrapper(x, args):
        return amsa2(x, *args)

Is this correct? Or is there a better way? This question is kinda related with the one above.

  • I am getting an out of memory error when using the code on GPU, which doesn't happen with scipy. Is vmap not batching correctly?

And btw, have I missed something in the documentation regarding vmap and optimistix? I couldn't find anything regarding automatic vectorization. Or is it something fairly common with JAX and is implicitly known?

@patrick-kidger
Copy link
Owner

My current amsa function works for scalars, vectors and images/matrices. Would you say it would be better to cater it to scalars and then apply vmap to obtain the vector and matrix versions? What is the JAX way of doing it?

Yes, exactly! Usually better to let vmap handle the rank-polymorphism for you.

I used vmap twice:

This looks reasonable. FWIW you can simplify your handling of the vmaps a bit: pack your arguments into batched and nonbatched groups and you can just do jax.vmap(fn, in_axes=(0, None))((to_batch1, to_batch2, to_batch3, ...), (nobatch1, nobatch2, ...)). All that's going on there is that in_axes should be a pytree-prefix of the arguments.

You might also like equinox.filter_vmap if that's ever helpful.

I am getting an out of memory error when using the code on GPU, which doesn't happen with scipy.

If I had to guess, SciPy is probably doing its optimizer-internal operations on the CPU -- only evaluating your user-provided vector field on the GPU. So probably it's using less memory anyway.

Or is it something fairly common with JAX and is implicitly known?

Yup, common with JAX!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants