-
Notifications
You must be signed in to change notification settings - Fork 203
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
backtracking linesearch huge slow down -- recompiling functions every run of update_fn
?
#1171
Comments
while_loop should definitely not recompile the functions every time Can you try the context manager: https://jax.readthedocs.io/en/latest/_autosummary/jax.log_compiles.html to see if you're hitting recompilation many times? |
Okay, I added I see a lot of logs during startup, and then the following when I compile my function.
The fist time the optimizer.update is called, I see
The second time optimizer.update is called, I see it compile some stuff again (the only long one though is the
I let it run some more steps and it continues to recompile and take ~ 90 sec each step. |
Ok, it looks like you correctly identified that it's really recompiling a function the second time (jax.log_compiles is a good way to check)! Are you sure the function being recompiled isn't the value function? If you add a regular Python print statement to your value function it should only print if it's compiling it. There's something causing the compilation cache to be invalidated (e.g., a static argument), but from the log you attached the difference doesn't seem to be in the argument shapes and types. Can you produce a minimal repro? |
I added the print like you suggested, and it shows up twice (apparently the function is reran when analyzing for the gradient -- that makes sense), but the print doesn't show up when calling the update. Not really minimal, here is a stripped down version (optimizing a small value function, which still allows seeing the while being recompiled). Changing the "if 1" to "if 0" in the setup will cause my huge value function to be used instead (isolated, but unminimized, in another file). Side note: Github would not allow me to attach a .tgz, with the error "We don’t support that file type. Try again with GIF, ... , TGZ, ...". Weird. |
The 'while' being recompiled makes sense to me, since the However I have no idea why the details of the value function matters in this. |
Hmm, maybe this in
where Or maybe it is just |
EDIT(3): ---- deleted --- a bad test lead to nonsense. Correct results in next post. |
The |
Can you try using a jitted version of opt.update:
|
It worked, but compiled twice?
|
That's awesome! I suppose JAX couldn't hash the local function
Perhaps the state after init is not the same as the state after step 0 (but the same as after steps 1, 2, 3...) |
The only difference I could see printing out state from init and state from update were some
As for the 'while', I was able to verify the issue is using function local functions. |
When trying to optimize a large function,
scale_by_backtracking_linesearch
will slow to a crawl as if it is trying to recompile my function every single time.I eventually tracked it down to this:
_src/linesearch.py : scale_by_backtracking_linesearch : update_fn
If I change it to:
I go from minutes per step to many steps per second. (I am just running on a cpu, nothing fancy here.)
I'm not well versed in the jax stuff, so I'm not actually sure what is going on. However, my
value_fn
takes a long time to jit (minutes), and the docs sayjax.lax.while_loop
will compile the functions. So as a guess, it appears that compilingbody_fn
(which doesn't takevalue_fn
as an argument, but gets it from the surroundingupdate_fn
scope) somehow leads tovalue_fn
being re-analyzed / re-compiled eacn time update is called?To be honest, this structure looks a bit strange in the first place, and
cond_fn
andbody_fn
are defined locally inupdate_fn
, so even if this wasn't hittingvalue_fn
, it still seems weird to redefine and recompilecond_fn
andbody_fn
each time. It looks like it was done this way mostly to get values from the surrounding scope instead of by function params. Again, I'm not vell versed in the jax stuff, so maybe I'm misunderstanding how stuff is scheduled and passed around, but rearranging the code may alleviate that as well.The text was updated successfully, but these errors were encountered: