-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimization across multidimensional array #70
Comments
I came across a vectorized Levenberg-Marquardt implementation, which sounds like exactly what I need! Sadly I have only a little bit of experience in programming with JAX. I will try to naively implement this paper in Python, and then try to port it to Optimistix. tldr; the paper introduces an operator
|
JAX has a great tool for this called |
Thanks for the tip! I was reading about
single_scattering_albedo: ArrayLike, # [a, b]
incidence_direction: ArrayLike, # [a, b, 3]
emission_direction: ArrayLike, # [a, b, 3]
surface_orientation: ArrayLike, # [a, b, 3]
phase_function: dict,
b_n: ArrayLike,
a_n: ArrayLike = np.nan,
roughness: float = 0,
hs: float = 0,
bs0: float = 0,
hc: float = 0,
bc0: float = 0,
refl_optimization: ArrayLike = 0.0, # [a, b] I used amsa1 = jax.vmap(
amsa,
(0, 0, 0, 0, None, None, None, None, None, None, None, None, 0),
0,
)
amsa2 = jax.vmap(
amsa1,
(1, 1, 1, 1, None, None, None, None, None, None, None, None, 1),
1,
)
def amsa_optx_wrapper(x, args):
return amsa2(x, *args) Is this correct? Or is there a better way? This question is kinda related with the one above.
And btw, have I missed something in the documentation regarding |
Yes, exactly! Usually better to let vmap handle the rank-polymorphism for you.
This looks reasonable. FWIW you can simplify your handling of the vmaps a bit: pack your arguments into batched and nonbatched groups and you can just do You might also like
If I had to guess, SciPy is probably doing its optimizer-internal operations on the CPU -- only evaluating your user-provided vector field on the GPU. So probably it's using less memory anyway.
Yup, common with JAX! |
I just recently came across JAX and I am now trying to use it for my implementation of the Hapke Anisotropic Multiple Scattering Approximation model. I made a similar issue on jaxopt, but since the repository isn't going to be maintained much in the future, I gave optimistix a try! It seems to be faster than jaxopt and a bit than scipy:
I am using optx.least_squares with LM in a for-loop which iterates over all pixels and tries to find the best fit to the target function. The structure is similar to the code provided here.
I was wondering if the current implementation would somehow allow me to pass a multidimensional array, or even a matrix, and optimize along an axis. Is there a trick maybe to achieve what I want?
Also, would it be possible to maybe provide a derivative function of the target function? I am still impressed I got such runtimes without providing it, but why derive it if I can provide it :)
The text was updated successfully, but these errors were encountered: