-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cannot turn off rematerialization for an individual lang op with hand specified grad transform #1658
Comments
I'm a bit confused here, the randomness is not actually recomputed? That said, let's disable the auto-recomputation of intermediates for now if it causes more problems than it solves. |
Changing grad definitions has little effect on the generated computation function because what to compute and how to compute are separate in Thunder. Grad definitions specify only the "what" part. Many trace execution transformations define the "how" part. The recomputation/rematerialization is part of the "how" transformation. Setting The second part to not recomputing dropout is disabling (or making it an option) the "replace_uniform" transformation as you provide in the patch above. With these two things, the dropout mask in Returning to the request in the issue title, it's complicated and more involved to do what you're asking for. Dropout is a composite operation in Thunder represented and by the time fusion rematerialization pass gets the traces dropout is gone and only prims decomposition is available with no reference to non-decomposed ops. |
I think there are a few things we should follow-up on from @kevinstephano's example above:
We should probably discuss more over a VC |
I am trying to not recompute
torch.nn.functional.dropout
because recomputing the random numbers can be expensive relative to the fusion it appears in. Therefore, I am trying to selectively not rematerialize the computation for dropout as well as make a custom grad function such that only a byte mask is saved from forward to backward. Without the custom grad function, the autograd transform will save the randomfloat
numbers instead of thebyte
mask.🐛 Bug
To Reproduce
Patch to apply:
Patch to apply:
If I look at the result, I see that the rematerialization happens even in the presence of a
register_grad
function:If I add
enable_saved_for_backward_recomputation=False
to mythunder.jit
call. Then rematerialization does not take over:I would like to be able to selectively tag
dropout
to say just don't rematerialize this op. However, adding the tag does not work.cc @t-vi @riccardofelluga
The text was updated successfully, but these errors were encountered: