Skip to content

Commit

Permalink
[core] Fix use_reentrant issues (#1036)
Browse files Browse the repository at this point in the history
* fix use_reentrant issues

* fix

* fixup

* address comments.

* add warnings

* oops

* fix

* quality
  • Loading branch information
younesbelkada authored Oct 31, 2023
1 parent 884b1ac commit bdeb06b
Showing 1 changed file with 39 additions and 11 deletions.
50 changes: 39 additions & 11 deletions src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,29 @@ def starcoder_model_postprocess_past_key_value(past_key_values):
return tuple(result)


def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True):
def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True, gradient_checkpointing_kwargs=None):
r"""
Note this method only works for `transformers` models.
This method wraps the entire protocol for preparing a model before running a training. This includes:
1- Cast the layernorm in fp32 2- making output embedding layer require grads 3- Add the upcasting of the lm
head to fp32
Args:
model, (`transformers.PreTrainedModel`):
model (`transformers.PreTrainedModel`):
The loaded model from `transformers`
use_gradient_checkpointing (`bool`, *optional*, defaults to `True`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
gradient_checkpointing_kwargs (`dict`, *optional*, defaults to `None`):
Keyword arguments to pass to the gradient checkpointing function, please refer to the documentation of
`torch.utils.checkpoint.checkpoint` for more details about the arguments that you can pass to that method.
Note this is only available in the latest transformers versions (> 4.34.1).
"""
loaded_in_kbit = getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)
is_gptq_quantized = getattr(model, "quantization_method", None) == "gptq"
if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {}

for name, param in model.named_parameters():
# freeze base model's layers
param.requires_grad = False
Expand All @@ -86,19 +97,36 @@ def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True):
param.data = param.data.to(torch.float32)

if (loaded_in_kbit or is_gptq_quantized) and use_gradient_checkpointing:
# For backward compatibility
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
# When having `use_reentrant=False` + gradient_checkpointing, there is no need for this hack
if "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]:
# For backward compatibility
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:

def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)

model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

# enable gradient checkpointing for memory efficiency
model.gradient_checkpointing_enable()
# To support older transformers versions, check if the model supports gradient_checkpointing_kwargs
_supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
inspect.signature(model.gradient_checkpointing_enable).parameters
)

if not _supports_gc_kwargs and len(gradient_checkpointing_kwargs) > 0:
warnings.warn(
"gradient_checkpointing_kwargs is not supported in this version of transformers. The passed kwargs will be ignored."
" if you want to use that feature, please upgrade to the latest version of transformers.",
FutureWarning,
)

gc_enable_kwargs = (
{} if not _supports_gc_kwargs else {"gradient_checkpointing_kwargs": gradient_checkpointing_kwargs}
)

# enable gradient checkpointing for memory efficiency
model.gradient_checkpointing_enable(**gc_enable_kwargs)
return model


Expand Down

0 comments on commit bdeb06b

Please sign in to comment.