Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] Fix use_reentrant issues #1036

Merged
merged 8 commits into from
Oct 31, 2023

Conversation

younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Oct 18, 2023

Partially fixes: huggingface/trl#835

This PR depends on huggingface/transformers#27020 - with that PR, we introduce a new argument in the gradient_checkpointing_enable() API in order for users to pass gradient_checkpointing_kwargs. To fix some issues with DDP and gradient_checkpointing, it is recommended to use use_reentrant=True which is the fix for huggingface/trl#835

Therefore I propose to expose an optional argument gradient_checkpointing_kwargs in prepare_model_for_kbit_training

cc @BenjaminBossan @pacman100

@BenjaminBossan
Copy link
Member

Is this ready for review?

@younesbelkada
Copy link
Contributor Author

Almost, I'll update the PR in a bit with more description

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Oct 24, 2023

The documentation is not available anymore as the PR was closed or merged.

@younesbelkada
Copy link
Contributor Author

huggingface/transformers#27020 being merged, this PR is ready for review!

@younesbelkada younesbelkada marked this pull request as ready for review October 25, 2023 10:20
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making the big fix in transformers and updating PEFT to enable it. I have only a few comments, please check them out.

If True, use gradient checkpointing to save memory at the expense of slower backward pass.
gradient_checkpointing_kwargs (`dict`, *optional*, defaults to `None`):
Keyword arguments to pass to the gradient checkpointing function, e.g. `use_reentrant=True`. Note this is
only available in the latest transformers versions.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to specify the exact transformers version, because "latest" is relative.

use_gradient_checkpointing (`bool`, *optional*, defaults to `True`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
gradient_checkpointing_kwargs (`dict`, *optional*, defaults to `None`):
Keyword arguments to pass to the gradient checkpointing function, e.g. `use_reentrant=True`. Note this is
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain what use_reentrant does?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I referred users to check the pytorch documentation, let me know if you want me to detail more

src/peft/utils/other.py Show resolved Hide resolved
src/peft/utils/other.py Show resolved Hide resolved
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adjusting the code. It now looks good to me.

@younesbelkada
Copy link
Contributor Author

Thanks, I'll merge after finishing some experiments with huggingface/transformers#27068 and huggingface/trl#912

@younesbelkada younesbelkada merged commit bdeb06b into huggingface:main Oct 31, 2023
@younesbelkada younesbelkada deleted the add-usereentrant branch October 31, 2023 15:51
@pacman100
Copy link
Contributor

To fix some issues with DDP and gradient_checkpointing, it is recommended to use use_reentrant=True which is the fix for huggingface/trl#835

Hello Younes, thank you for fixing the gradient checkpointing related issues as per our discussions.

Just a nit, here you meant use_reentrant=False, right?

@younesbelkada
Copy link
Contributor Author

AH yes correct @pacman100 !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[RewardTrainer] Enable gradient checkpointing for all multi-GPU training modes
4 participants