Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] Fix use_reentrant issues #1036

Merged
merged 8 commits into from
Oct 31, 2023
Merged
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 39 additions & 11 deletions src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,29 @@ def starcoder_model_postprocess_past_key_value(past_key_values):
return tuple(result)


def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True):
def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True, gradient_checkpointing_kwargs=None):
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved
r"""
Note this method only works for `transformers` models.

This method wraps the entire protocol for preparing a model before running a training. This includes:
1- Cast the layernorm in fp32 2- making output embedding layer require grads 3- Add the upcasting of the lm
head to fp32

Args:
model, (`transformers.PreTrainedModel`):
model (`transformers.PreTrainedModel`):
The loaded model from `transformers`
use_gradient_checkpointing (`bool`, *optional*, defaults to `True`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
gradient_checkpointing_kwargs (`dict`, *optional*, defaults to `None`):
Keyword arguments to pass to the gradient checkpointing function, please refer to the documentation of
`torch.utils.checkpoint.checkpoint` for more details about the arguments that you can pass to that method.
Note this is only available in the latest transformers versions (> 4.34.1).
"""
loaded_in_kbit = getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)
is_gptq_quantized = getattr(model, "quantization_method", None) == "gptq"
if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {}

for name, param in model.named_parameters():
# freeze base model's layers
param.requires_grad = False
Expand All @@ -86,19 +97,36 @@ def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True):
param.data = param.data.to(torch.float32)

if (loaded_in_kbit or is_gptq_quantized) and use_gradient_checkpointing:
# For backward compatibility
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
# When having `use_reentrant` + gradient_checkpointing, there is no need for this hack
if "use_reentrant" not in gradient_checkpointing_kwargs or not gradient_checkpointing_kwargs["use_reentrant"]:
# For backward compatibility
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:

def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)

model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

# enable gradient checkpointing for memory efficiency
model.gradient_checkpointing_enable()
# To support older transformers versions, check if the model supports gradient_checkpointing_kwargs
_supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
inspect.signature(model.gradient_checkpointing_enable).parameters
)

if not _supports_gc_kwargs and len(gradient_checkpointing_kwargs) > 0:
warnings.warn(
"gradient_checkpointing_kwargs is not supported in this version of transformers. The passed kwargs will be ignored."
" if you want to use that feature, please upgrade to the latest version of transformers.",
FutureWarning,
)

gc_enable_kwargs = (
{} if not _supports_gc_kwargs else {"gradient_checkpointing_kwargs": gradient_checkpointing_kwargs}
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved
)

# enable gradient checkpointing for memory efficiency
model.gradient_checkpointing_enable(**gc_enable_kwargs)
return model


Expand Down