Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

backtracking linesearch huge slow down -- recompiling functions every run of update_fn ? #1171

Open
bafflingbits opened this issue Jan 10, 2025 · 12 comments
Assignees

Comments

@bafflingbits
Copy link
Contributor

bafflingbits commented Jan 10, 2025

When trying to optimize a large function, scale_by_backtracking_linesearch will slow to a crawl as if it is trying to recompile my function every single time.

I eventually tracked it down to this:

_src/linesearch.py : scale_by_backtracking_linesearch : update_fn

    search_state = jax.lax.while_loop(cond_fn, body_fn, search_state)

If I change it to:

    while cond_fn(search_state):
        search_state = body_fn(search_state)

I go from minutes per step to many steps per second. (I am just running on a cpu, nothing fancy here.)

I'm not well versed in the jax stuff, so I'm not actually sure what is going on. However, my value_fn takes a long time to jit (minutes), and the docs say jax.lax.while_loop will compile the functions. So as a guess, it appears that compiling body_fn (which doesn't take value_fn as an argument, but gets it from the surrounding update_fn scope) somehow leads to value_fn being re-analyzed / re-compiled eacn time update is called?

To be honest, this structure looks a bit strange in the first place, and cond_fn and body_fn are defined locally in update_fn, so even if this wasn't hitting value_fn, it still seems weird to redefine and recompile cond_fn and body_fn each time. It looks like it was done this way mostly to get values from the surrounding scope instead of by function params. Again, I'm not vell versed in the jax stuff, so maybe I'm misunderstanding how stuff is scheduled and passed around, but rearranging the code may alleviate that as well.

@rdyro
Copy link
Collaborator

rdyro commented Jan 10, 2025

while_loop should definitely not recompile the functions every time

Can you try the context manager: https://jax.readthedocs.io/en/latest/_autosummary/jax.log_compiles.html to see if you're hitting recompilation many times?

@rdyro rdyro self-assigned this Jan 10, 2025
@bafflingbits
Copy link
Contributor Author

Okay, I added jax.config.update('jax_log_compiles', True) after importing jax.

I see a lot of logs during startup, and then the following when I compile my function.

WARNING:2025-01-10 02:53:58,889:jax._src.interpreters.pxla:1906: Compiling _f with global shapes and types [ShapedArray(float64[120])]. Argument mapping: (UnspecifiedValue,).
WARNING:2025-01-10 02:55:28,462:jax._src.dispatch:182: Finished jaxpr to MLIR module conversion jit(_f) in 89.554991961 sec

The fist time the optimizer.update is called, I see

...
WARNING:2025-01-10 03:02:03,774:jax._src.dispatch:182: Finished tracing + transforming while for pji
t in 0.002302170 sec
WARNING:2025-01-10 03:02:03,952:jax._src.interpreters.pxla:1906: Compiling while with global shapes
and types [ShapedArray(float64[120]), ShapedArray(float64[120]), ShapedArray(float64[]), ShapedArray
(float64[]), ShapedArray(float64[], weak_type=True), ShapedArray(float64[]), ShapedArray(float64[120
]), ShapedArray(float64[], weak_type=True), ShapedArray(int64[], weak_type=True)]. Argument mapping:
 (UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, Unspecif
iedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue).
WARNING:2025-01-10 03:03:31,794:jax._src.dispatch:182: Finished jaxpr to MLIR module conversion jit(
while) in 87.841812372 sec
WARNING:2025-01-10 03:04:11,102:jax._src.dispatch:182: Finished XLA compilation of jit(while) in 39.
301886797 sec

The second time optimizer.update is called, I see it compile some stuff again (the only long one though is the while one)

...
WARNING:2025-01-10 03:04:11,243:jax._src.dispatch:182: Finished tracing + transforming while for pjit in 0.000306845 sec
WARNING:2025-01-10 03:04:11,244:jax._src.interpreters.pxla:1906: Compiling while with global shapes and types [ShapedArray(float64[120]), ShapedArray(float64[120]), ShapedArray(float64[]), ShapedArray(float64[]), ShapedArray(float64[], weak_type=True), ShapedArray(float64[]), ShapedArray(float64[120]), ShapedArray(float64[], weak_type=True), ShapedArray(int64[], weak_type=True)]. Argument mapping: (UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue).
WARNING:2025-01-10 03:05:39,298:jax._src.dispatch:182: Finished jaxpr to MLIR module conversion jit(while) in 88.052956343 sec
WARNING:2025-01-10 03:06:17,114:jax._src.dispatch:182: Finished XLA compilation of jit(while) in 37.810113668 sec

I let it run some more steps and it continues to recompile and take ~ 90 sec each step.
I don't see anything about it recompiling my value function, but it must be related somehow because the time goes to almost zero when using a simple quadratic function for my value function.

@rdyro
Copy link
Collaborator

rdyro commented Jan 10, 2025

Ok, it looks like you correctly identified that it's really recompiling a function the second time (jax.log_compiles is a good way to check)! Are you sure the function being recompiled isn't the value function? If you add a regular Python print statement to your value function it should only print if it's compiling it.

There's something causing the compilation cache to be invalidated (e.g., a static argument), but from the log you attached the difference doesn't seem to be in the argument shapes and types.

Can you produce a minimal repro?

@bafflingbits
Copy link
Contributor Author

I added the print like you suggested, and it shows up twice (apparently the function is reran when analyzing for the gradient -- that makes sense), but the print doesn't show up when calling the update.

Not really minimal, here is a stripped down version (optimizing a small value function, which still allows seeing the while being recompiled). Changing the "if 1" to "if 0" in the setup will cause my huge value function to be used instead (isolated, but unminimized, in another file).
example_issue.zip


Side note: Github would not allow me to attach a .tgz, with the error "We don’t support that file type. Try again with GIF, ... , TGZ, ...". Weird.

@bafflingbits
Copy link
Contributor Author

The 'while' being recompiled makes sense to me, since the cond_fn and body_fn are defined locally in update_fn. I guess I don't know python internals well, but my mental model is that when doing this, python recreates those functions each time 'update' is called, and so when running the while command it would see these are new functions and compile them.

However I have no idea why the details of the value function matters in this.

@bafflingbits
Copy link
Contributor Author

bafflingbits commented Jan 10, 2025

Hmm, maybe this in body_fn somehow causes reanalysis of the value function?

value_fn_ = functools.partial(value_fn, **fn_kwargs)

where fn_kwargs is not obtained via any function params, but via body_fn's outer-scope (update_fn's local scope).

Or maybe it is just fn_kwargs in the first place?
I'll try running with that commented out, and changing calls to use value_fn directly (should be okay as my value function does not take kwargs).

@bafflingbits
Copy link
Contributor Author

bafflingbits commented Jan 10, 2025

EDIT(3): ---- deleted --- a bad test lead to nonsense. Correct results in next post.

@bafflingbits
Copy link
Contributor Author

The fn_kwargs change unfortunately did nothing. Compiling 'while' still takes a long with with a large value fn, and while gets recompiled with each update.

@rdyro
Copy link
Collaborator

rdyro commented Jan 10, 2025

Can you try using a jitted version of opt.update:

opt_update = jax.jit(opt.update, static_argnames=("value_fn",))

@bafflingbits
Copy link
Contributor Author

bafflingbits commented Jan 10, 2025

It worked, but compiled twice?

...
WARNING:2025-01-10 22:20:00,077:jax._src.interpreters.pxla:1906: Compiling update_fn with global shapes and types [ShapedArray(float32[120]), ShapedArray(int32[]), ShapedArray(float32[120]), ShapedArray(float32[120]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[120]), ShapedArray(float32[120]), ShapedArray(float32[])]. Argument mapping: (UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue).
WARNING:2025-01-10 22:21:32,649:jax._src.dispatch:182: Finished jaxpr to MLIR module conversion jit(update_fn) in 92.570851088 sec
WARNING:2025-01-10 22:22:12,760:jax._src.dispatch:182: Finished XLA compilation of jit(update_fn) in 40.100338459 sec
WARNING:2025-01-10 22:22:12,775:jax._src.interpreters.pxla:1906: Compiling add with global shapes and types [ShapedArray(float32[120]), ShapedArray(float32[120])]. Argument mapping: (UnspecifiedValue, UnspecifiedValue).
WARNING:2025-01-10 22:22:12,800:jax._src.dispatch:182: Finished jaxpr to MLIR module conversion jit(add) in 0.023004532 sec
WARNING:2025-01-10 22:22:12,817:jax._src.dispatch:182: Finished XLA compilation of jit(add) in 0.016420603 sec
step 0: val=5.642762e+03 gnorm=3.618964e+03
WARNING:2025-01-10 22:22:12,837:jax._src.dispatch:182: Finished tracing + transforming update_fn for pjit in 0.016927004 sec
WARNING:2025-01-10 22:22:12,838:jax._src.interpreters.pxla:1906: Compiling update_fn with global shapes and types [ShapedArray(float32[120]), ShapedArray(int32[]), ShapedArray(float32[120]), ShapedArray(float32[120]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[120]), ShapedArray(float32[120]), ShapedArray(float32[])]. Argument mapping: (UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue).
WARNING:2025-01-10 22:23:45,178:jax._src.dispatch:182: Finished jaxpr to MLIR module conversion jit(update_fn) in 92.338912249 sec
WARNING:2025-01-10 22:24:23,111:jax._src.dispatch:182: Finished XLA compilation of jit(update_fn) in 37.927310705 sec
step 1: val=1.699201e+03 gnorm=1.273662e+03
step 2: val=1.594805e+03 gnorm=1.411623e+03
step 3: val=1.393530e+03 gnorm=1.385978e+03
step 4: val=6.974808e+02 gnorm=7.079473e+02
step 5: val=6.971791e+02 gnorm=6.166702e+02
step 6: val=6.945733e+02 gnorm=6.229464e+02
step 7: val=5.359800e+02 gnorm=5.127087e+02
step 8: val=3.473148e+02 gnorm=3.616823e+02
step 9: val=3.242927e+02 gnorm=3.591946e+02

@rdyro
Copy link
Collaborator

rdyro commented Jan 10, 2025

That's awesome! I suppose JAX couldn't hash the local function scale_by_backtracking_linesearch was defining - doing so makes for more readable code, but relies on JAX caching which in this cases didn't get a hit

It worked, but compiled twice?

Perhaps the state after init is not the same as the state after step 0 (but the same as after steps 1, 2, 3...)

@bafflingbits
Copy link
Contributor Author

bafflingbits commented Jan 11, 2025

Perhaps the state after init is not the same as the state after step 0 (but the same as after steps 1, 2, 3...)

The only difference I could see printing out state from init and state from update were some weak_type differences. These changes in the init_fn will make the types match update_fn, and avoid triggering a recompilation.

diff --git a/optax/_src/linesearch.py b/optax/_src/linesearch.py
index a28a65a..4d7b012 100644
--- a/optax/_src/linesearch.py
+++ b/optax/_src/linesearch.py
@@ -255,11 +255,11 @@ def scale_by_backtracking_linesearch(
       grad = None
     return ScaleByBacktrackingLinesearchState(
         learning_rate=jnp.array(1.0),
-        value=jnp.array(jnp.inf),
+        value=jnp.array(jnp.inf, dtype=params.dtype),
         grad=grad,
         info=BacktrackingLinesearchInfo(
             num_linesearch_steps=0,
-            decrease_error=jnp.array(jnp.inf),
+            decrease_error=jnp.array(jnp.inf, dtype=params.dtype),
         ),
     )

As for the 'while', I was able to verify the issue is using function local functions.
I gave the 'while' function some global functions that didn't do much, and it no longer caused recompilation. Then I changed the local cond_fn to do the same thing as the external function that worked (so hopefully verifying it is not an issue with outer-scope variables), and once I gave that to 'while' it started recompiling each call again. So some code re-arrangement would be needed to fix the rest.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants