diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py index 62aedce273..e33c52e21e 100644 --- a/src/peft/utils/other.py +++ b/src/peft/utils/other.py @@ -63,18 +63,29 @@ def starcoder_model_postprocess_past_key_value(past_key_values): return tuple(result) -def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True): +def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True, gradient_checkpointing_kwargs=None): r""" + Note this method only works for `transformers` models. + This method wraps the entire protocol for preparing a model before running a training. This includes: 1- Cast the layernorm in fp32 2- making output embedding layer require grads 3- Add the upcasting of the lm head to fp32 Args: - model, (`transformers.PreTrainedModel`): + model (`transformers.PreTrainedModel`): The loaded model from `transformers` + use_gradient_checkpointing (`bool`, *optional*, defaults to `True`): + If True, use gradient checkpointing to save memory at the expense of slower backward pass. + gradient_checkpointing_kwargs (`dict`, *optional*, defaults to `None`): + Keyword arguments to pass to the gradient checkpointing function, please refer to the documentation of + `torch.utils.checkpoint.checkpoint` for more details about the arguments that you can pass to that method. + Note this is only available in the latest transformers versions (> 4.34.1). """ loaded_in_kbit = getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False) is_gptq_quantized = getattr(model, "quantization_method", None) == "gptq" + if gradient_checkpointing_kwargs is None: + gradient_checkpointing_kwargs = {} + for name, param in model.named_parameters(): # freeze base model's layers param.requires_grad = False @@ -86,19 +97,36 @@ def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True): param.data = param.data.to(torch.float32) if (loaded_in_kbit or is_gptq_quantized) and use_gradient_checkpointing: - # For backward compatibility - if hasattr(model, "enable_input_require_grads"): - model.enable_input_require_grads() - else: + # When having `use_reentrant=False` + gradient_checkpointing, there is no need for this hack + if "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]: + # For backward compatibility + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: - def make_inputs_require_grad(module, input, output): - output.requires_grad_(True) + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) - model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) - # enable gradient checkpointing for memory efficiency - model.gradient_checkpointing_enable() + # To support older transformers versions, check if the model supports gradient_checkpointing_kwargs + _supports_gc_kwargs = "gradient_checkpointing_kwargs" in list( + inspect.signature(model.gradient_checkpointing_enable).parameters + ) + if not _supports_gc_kwargs and len(gradient_checkpointing_kwargs) > 0: + warnings.warn( + "gradient_checkpointing_kwargs is not supported in this version of transformers. The passed kwargs will be ignored." + " if you want to use that feature, please upgrade to the latest version of transformers.", + FutureWarning, + ) + + gc_enable_kwargs = ( + {} if not _supports_gc_kwargs else {"gradient_checkpointing_kwargs": gradient_checkpointing_kwargs} + ) + + # enable gradient checkpointing for memory efficiency + model.gradient_checkpointing_enable(**gc_enable_kwargs) return model