-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core
] Fix use_reentrant
issues
#1036
[core
] Fix use_reentrant
issues
#1036
Conversation
Is this ready for review? |
Almost, I'll update the PR in a bit with more description |
The documentation is not available anymore as the PR was closed or merged. |
huggingface/transformers#27020 being merged, this PR is ready for review! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for making the big fix in transformers and updating PEFT to enable it. I have only a few comments, please check them out.
src/peft/utils/other.py
Outdated
If True, use gradient checkpointing to save memory at the expense of slower backward pass. | ||
gradient_checkpointing_kwargs (`dict`, *optional*, defaults to `None`): | ||
Keyword arguments to pass to the gradient checkpointing function, e.g. `use_reentrant=True`. Note this is | ||
only available in the latest transformers versions. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be better to specify the exact transformers version, because "latest" is relative.
src/peft/utils/other.py
Outdated
use_gradient_checkpointing (`bool`, *optional*, defaults to `True`): | ||
If True, use gradient checkpointing to save memory at the expense of slower backward pass. | ||
gradient_checkpointing_kwargs (`dict`, *optional*, defaults to `None`): | ||
Keyword arguments to pass to the gradient checkpointing function, e.g. `use_reentrant=True`. Note this is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain what use_reentrant
does?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I referred users to check the pytorch documentation, let me know if you want me to detail more
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adjusting the code. It now looks good to me.
Thanks, I'll merge after finishing some experiments with huggingface/transformers#27068 and huggingface/trl#912 |
Hello Younes, thank you for fixing the gradient checkpointing related issues as per our discussions. Just a nit, here you meant |
AH yes correct @pacman100 ! |
Partially fixes: huggingface/trl#835
This PR depends on huggingface/transformers#27020 - with that PR, we introduce a new argument in the
gradient_checkpointing_enable()
API in order for users to passgradient_checkpointing_kwargs
. To fix some issues with DDP and gradient_checkpointing, it is recommended to useuse_reentrant=True
which is the fix for huggingface/trl#835Therefore I propose to expose an optional argument
gradient_checkpointing_kwargs
inprepare_model_for_kbit_training
cc @BenjaminBossan @pacman100