From b5ae0fce886118839c8e99f99785a3c7c4d38549 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leon=20Engl=C3=A4nder?= Date: Thu, 7 Nov 2024 10:02:01 +0100 Subject: [PATCH 01/10] Fix gradient checkpointing and write tests - oerwrite the gradient_checkpointing_enable to provide our ForwardContext during the recomputation of values during backpropagation - 2 bugs remaining: bottleneck adapter for models with the legacy implementation (BERT) & Parallel. Parallel has the problem that we manipulate the batch dimension and this currently leads to an error --- src/adapters/context.py | 5 ++- src/adapters/model_mixin.py | 64 ++++++++++++++++++++++++++ tests/methods/base.py | 69 ++++++++++++++++++++++++++++- tests/methods/test_ia3.py | 6 +++ tests/methods/test_lora.py | 6 +++ tests/methods/test_prefix_tuning.py | 6 +++ tests/methods/test_prompt_tuning.py | 6 +++ tests/methods/test_reft.py | 6 +++ tests/methods/test_unipelt.py | 6 +++ 9 files changed, 170 insertions(+), 4 deletions(-) diff --git a/src/adapters/context.py b/src/adapters/context.py index 70e685d037..db09b8918f 100644 --- a/src/adapters/context.py +++ b/src/adapters/context.py @@ -1,10 +1,11 @@ import functools import threading +from typing import ContextManager from .composition import parse_composition, parse_heads_from_composition -class AdapterSetup: +class AdapterSetup(ContextManager): """ Represents an adapter setup of a model including active adapters and active heads. This class is intended to be used as a context manager using the ``with`` statement. The setup defined by the ``AdapterSetup`` context will @@ -67,7 +68,7 @@ def get_context_head_setup(cls): return None -class ForwardContext: +class ForwardContext(ContextManager): """ Holds context information during a forward pass through a model. This class should be used via the ``ForwardContext.wrap()`` method. diff --git a/src/adapters/model_mixin.py b/src/adapters/model_mixin.py index 659a6cfcff..342913c8d6 100644 --- a/src/adapters/model_mixin.py +++ b/src/adapters/model_mixin.py @@ -1,14 +1,18 @@ +import contextlib +import functools import inspect import logging import os from abc import ABC, abstractmethod from collections import defaultdict from copy import deepcopy +from functools import partial from os.path import join from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch from torch import nn +from torch.utils.checkpoint import checkpoint from adapters.configuration.adapter_config import ConfigUnion, LoRAConfig from transformers import GenerationConfig @@ -1447,6 +1451,66 @@ def save_pretrained( # Remove adapters config del self.config.adapters + def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): + """ + Activates gradient checkpointing for the current model. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + + We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of + the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 + + Args: + gradient_checkpointing_kwargs (dict, *optional*): + Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function. + """ + if not self.supports_gradient_checkpointing: + raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") + + if gradient_checkpointing_kwargs is None: + gradient_checkpointing_kwargs = {"use_reentrant": False} + + # >>> START AH Changes <<< + if "use_reentrant" not in gradient_checkpointing_kwargs: + # use_reentrant must be set. + gradient_checkpointing_kwargs["use_reentrant"] = False + else: + if gradient_checkpointing_kwargs["use_reentrant"]: + raise ValueError( + "Gradient checkpointing with use_reentrant=True is not supported. For gradient checkpointing, we need to set context_fn, which is only supported by PyTorch when use_reentrant is set to False." + ) + + def gradient_checkpointing_function(function, *args, **kwargs): + context = ForwardContext(self, *args, **kwargs) + context_fn = lambda: (contextlib.nullcontext(), context) + return checkpoint(function, *args, context_fn=context_fn, **kwargs) + + gradient_checkpointing_func = functools.partial( + gradient_checkpointing_function, **gradient_checkpointing_kwargs + ) + # >>> END AH Changes <<< + + # For old GC format (transformers < 4.35.0) for models that live on the Hub + # we will fall back to the overwritten `_set_gradient_checkpointing` method + _is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters + + if not _is_using_old_format: + self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func) + else: + self.apply(partial(self._set_gradient_checkpointing, value=True)) + logger.warning( + "You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)." + "Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model." + ) + + if getattr(self, "_hf_peft_config_loaded", False): + # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True + # we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334 + # When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate + # the gradients to make sure the gradient flows. + self.enable_input_require_grads() + @inherit_doc class ModelBaseAdaptersMixin(ModelAdaptersMixin): diff --git a/tests/methods/base.py b/tests/methods/base.py index 0d20f32fef..86eb3e08ca 100644 --- a/tests/methods/base.py +++ b/tests/methods/base.py @@ -1,10 +1,12 @@ import copy import os import tempfile +from typing import Callable import torch import adapters +import adapters.composition as ac from adapters import ADAPTER_MODEL_MAPPING, AdapterSetup, AdapterTrainer, AutoAdapterModel from adapters.heads import CausalLMHead from adapters.utils import WEIGHTS_NAME @@ -247,7 +249,7 @@ def run_full_model_load_test(self, adapter_config): self.assertEqual(len(output1), len(output2)) self.assertTrue(torch.allclose(output1[0], output2[0], atol=1e-4)) - def trainings_run(self, model, lr=1.0, steps=8): + def trainings_run(self, model, lr=1.0, steps=8, batch_size=2, gradient_accumulation_steps=1): # setup dataset train_dataset = self.dataset() @@ -257,7 +259,8 @@ def trainings_run(self, model, lr=1.0, steps=8): learning_rate=lr, max_steps=steps, no_cuda=True, - per_device_train_batch_size=2, + per_device_train_batch_size=batch_size, + gradient_accumulation_steps=gradient_accumulation_steps, remove_unused_columns=False, ) @@ -370,3 +373,65 @@ def run_reset_test(self, adapter_config): # check forward pass self.assertEqual(len(output_1), len(output_2)) self.assertTrue(torch.allclose(output_1[0], output_2[0], atol=1e-3)) + + def _run_gradient_checkpointing_test_helper(self, adapter_setup_fn: Callable[[adapters.ModelAdaptersMixin], None]): + """ + Test that gradient checkpointing produces the same results as normal training + Args: + adapter_setup_fn: Function that takes a model and sets up the adapter training. Must also add a head (usually via self.add_head(...)). We have this in a separate function to allow complex setups (like training a normal adapter or training parallel setups) + """ + + if not self.do_run_train_tests: + self.skipTest("Skipping training tests. Set `do_run_train_tests=True` to run them.") + if self.config_class not in ADAPTER_MODEL_MAPPING: + self.skipTest("Does not support flex heads.") + + config = self.config() + state_dict_after_training = {} + + for train_with_checkpointing in [True, False]: + # Set random seed + torch.manual_seed(42) + + # Initialize model + model = adapters.AutoAdapterModel.from_config(config) + model.to(torch_device) + adapter_setup_fn(model) + + # Enable gradient checkpointing + if train_with_checkpointing: + model.gradient_checkpointing_enable() + model.enable_input_require_grads() + + # Train & store state dict + self.trainings_run(model, batch_size=1, gradient_accumulation_steps=2) + state_dict_after_training[train_with_checkpointing] = copy.deepcopy(model.state_dict()) + + # Check that the state dicts are the same (we know that normal training works as expected, so we only need to check that gradient checkpointing produces the same results.) + for (k1, v1), (k2, v2) in zip( + state_dict_after_training[True].items(), state_dict_after_training[False].items() + ): + v1 = v1.to(v2.device) + self.assertTrue(torch.equal(v1, v2), msg=f"Key {k1} is not equal:\nv1: {v1}\nv2: {v2}") + + def run_gradient_checkpointing_single_adapter_test(self, adapter_config): + def adapter_setup_fn(model): + model.add_adapter("adapter1", config=adapter_config) + self.add_head(model, "adapter1") + model.train_adapter("adapter1") + model.adapter_to("adapter1", torch_device) + + self._run_gradient_checkpointing_test_helper(adapter_setup_fn) + + def run_gradient_checkpointing_test_parallel_adapters(self, adapter_config): + def adapter_setup_fn(model): + model.add_adapter("adapter1", config=adapter_config) + model.add_adapter("adapter2", config=adapter_config) + self.add_head(model, "adapter1") + self.add_head(model, "adapter2") + model.active_adapters = ac.Parallel("adapter1", "adapter2") + model.train_adapter(ac.Parallel("adapter1", "adapter2")) + model.adapter_to("adapter1", torch_device) + model.adapter_to("adapter2", torch_device) + + self._run_gradient_checkpointing_test_helper(adapter_setup_fn) diff --git a/tests/methods/test_ia3.py b/tests/methods/test_ia3.py index 3a30e2448d..ced2dbb003 100644 --- a/tests/methods/test_ia3.py +++ b/tests/methods/test_ia3.py @@ -45,3 +45,9 @@ def test_merge_ia3(self): def test_reset_ia3(self): self.run_reset_test(IA3Config(init_weights="bert")) + + def test_ia3_gradient_checkpointing_single_adapter(self): + self.run_gradient_checkpointing_single_adapter_test(IA3Config()) + + def test_ia3_gradient_checkpointing_parallel_adapters(self): + self.run_gradient_checkpointing_test_parallel_adapters(IA3Config()) diff --git a/tests/methods/test_lora.py b/tests/methods/test_lora.py index 067f78c8b8..0fbd2f6808 100644 --- a/tests/methods/test_lora.py +++ b/tests/methods/test_lora.py @@ -313,3 +313,9 @@ def test_merge_lora(self): def test_reset_lora(self): self.run_reset_test(LoRAConfig(init_weights="bert")) + + def test_lora_gradient_checkpointing_single_adapter(self): + self.run_gradient_checkpointing_single_adapter_test(LoRAConfig()) + + def test_lora_gradient_checkpointing_parallel_adapters(self): + self.run_gradient_checkpointing_test_parallel_adapters(LoRAConfig()) diff --git a/tests/methods/test_prefix_tuning.py b/tests/methods/test_prefix_tuning.py index dd443c0d0b..c6f5ade445 100644 --- a/tests/methods/test_prefix_tuning.py +++ b/tests/methods/test_prefix_tuning.py @@ -101,3 +101,9 @@ def test_prefix_tuning_generate(self): input_ids = input_ids.to(torch_device) generated = model1.generate(input_ids, max_length=seq_output_length) self.assertLessEqual(generated.shape, (1, seq_output_length)) + + def test_prefix_tuning_gradient_checkpointing_single_adapter(self): + self.run_gradient_checkpointing_single_adapter_test(PrefixTuningConfig()) + + def test_prefix_tuning_gradient_checkpointing_parallel_adapters(self): + self.run_gradient_checkpointing_test_parallel_adapters(PrefixTuningConfig()) diff --git a/tests/methods/test_prompt_tuning.py b/tests/methods/test_prompt_tuning.py index 97015d1319..f3c4b5b657 100644 --- a/tests/methods/test_prompt_tuning.py +++ b/tests/methods/test_prompt_tuning.py @@ -36,3 +36,9 @@ def test_load_full_model_prompt_tuning(self): def test_train_prompt_tuning(self): self.run_train_test(PromptTuningConfig(prompt_length=10), ["prompt_tunings.{name}."]) + + def test_prompt_tuning_gradient_checkpointing_single_adapter(self): + self.run_gradient_checkpointing_single_adapter_test(PromptTuningConfig(prompt_length=10)) + + def test_prompt_tuning_gradient_checkpointing_parallel_adapters(self): + self.run_gradient_checkpointing_test_parallel_adapters(PromptTuningConfig(prompt_length=10)) diff --git a/tests/methods/test_reft.py b/tests/methods/test_reft.py index 8849221808..8e5cab0f27 100644 --- a/tests/methods/test_reft.py +++ b/tests/methods/test_reft.py @@ -77,3 +77,9 @@ def test_load_full_model_reft(self): def test_train_loreft(self): self.run_train_test(LoReftConfig(), ["refts.{name}."]) + + def test_reft_gradient_checkpointing_single_adapter(self): + self.run_gradient_checkpointing_single_adapter_test(LoReftConfig()) + + def test_reft_gradient_checkpointing_parallel_adapters(self): + self.run_gradient_checkpointing_test_parallel_adapters(LoReftConfig()) diff --git a/tests/methods/test_unipelt.py b/tests/methods/test_unipelt.py index d29fa5f18d..2191a31161 100644 --- a/tests/methods/test_unipelt.py +++ b/tests/methods/test_unipelt.py @@ -64,3 +64,9 @@ def test_output_adapter_gating_scores_unipelt(self): self.assertGreaterEqual(len(per_layer_scores), 3) for k, v in per_layer_scores.items(): self.assertEqual(self.default_input_samples_shape[0], v.shape[0], k) + + def test_unipelt_gradient_checkpointing_single_adapter(self): + self.run_gradient_checkpointing_single_adapter_test(UniPELTConfig()) + + def test_unipelt_gradient_checkpointing_parallel_adapters(self): + self.run_gradient_checkpointing_test_parallel_adapters(UniPELTConfig()) From cb07dd44c1040eb9a85194f85425ed130eca5304 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leon=20Engl=C3=A4nder?= Date: Mon, 11 Nov 2024 22:51:08 +0100 Subject: [PATCH 02/10] minor fix but doesn't resolve the remaining issues --- src/adapters/model_mixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/adapters/model_mixin.py b/src/adapters/model_mixin.py index 342913c8d6..dfc7022bd1 100644 --- a/src/adapters/model_mixin.py +++ b/src/adapters/model_mixin.py @@ -1482,7 +1482,7 @@ def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): ) def gradient_checkpointing_function(function, *args, **kwargs): - context = ForwardContext(self, *args, **kwargs) + context = ForwardContext.get_context() context_fn = lambda: (contextlib.nullcontext(), context) return checkpoint(function, *args, context_fn=context_fn, **kwargs) From 8421f634d742d88140b1f2afeca9df687039e476 Mon Sep 17 00:00:00 2001 From: calpt Date: Sun, 24 Nov 2024 20:01:56 +0100 Subject: [PATCH 03/10] Only run adjust_tensors_for_parallel_ if bsz is different --- src/adapters/composition.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/adapters/composition.py b/src/adapters/composition.py index 48a6bc8acf..df78966b52 100644 --- a/src/adapters/composition.py +++ b/src/adapters/composition.py @@ -234,7 +234,7 @@ def adjust_tensors_for_parallel(hidden_states, *tensors): """ outputs = [] for tensor in tensors: - if tensor is not None and hidden_states.shape[0] >= tensor.shape[0]: + if tensor is not None and hidden_states.shape[0] > tensor.shape[0]: repeats = [1] * len(tensor.shape) repeats[0] = hidden_states.shape[0] // tensor.shape[0] new_tensor = tensor.repeat(*repeats) @@ -249,7 +249,7 @@ def adjust_tensors_for_parallel_(hidden_states, *tensors): In-place version of adjust_tensors_for_parallel(). """ for tensor in tensors: - if tensor is not None and hidden_states.shape[0] >= tensor.shape[0]: + if tensor is not None and hidden_states.shape[0] > tensor.shape[0]: repeats = [1] * len(tensor.shape) repeats[0] = hidden_states.shape[0] // tensor.shape[0] new_tensor = tensor.repeat(*repeats) From 94df2feb7ea0e9584116f0d44ab3471e16cba53d Mon Sep 17 00:00:00 2001 From: calpt Date: Mon, 25 Nov 2024 21:48:09 +0100 Subject: [PATCH 04/10] remove parallel grad checkpointing test --- tests/methods/base.py | 13 ------------- tests/methods/test_ia3.py | 3 --- tests/methods/test_lora.py | 3 --- tests/methods/test_prefix_tuning.py | 3 --- tests/methods/test_prompt_tuning.py | 3 --- tests/methods/test_reft.py | 3 --- tests/methods/test_unipelt.py | 3 --- 7 files changed, 31 deletions(-) diff --git a/tests/methods/base.py b/tests/methods/base.py index 86eb3e08ca..0c508c3108 100644 --- a/tests/methods/base.py +++ b/tests/methods/base.py @@ -422,16 +422,3 @@ def adapter_setup_fn(model): model.adapter_to("adapter1", torch_device) self._run_gradient_checkpointing_test_helper(adapter_setup_fn) - - def run_gradient_checkpointing_test_parallel_adapters(self, adapter_config): - def adapter_setup_fn(model): - model.add_adapter("adapter1", config=adapter_config) - model.add_adapter("adapter2", config=adapter_config) - self.add_head(model, "adapter1") - self.add_head(model, "adapter2") - model.active_adapters = ac.Parallel("adapter1", "adapter2") - model.train_adapter(ac.Parallel("adapter1", "adapter2")) - model.adapter_to("adapter1", torch_device) - model.adapter_to("adapter2", torch_device) - - self._run_gradient_checkpointing_test_helper(adapter_setup_fn) diff --git a/tests/methods/test_ia3.py b/tests/methods/test_ia3.py index ced2dbb003..b96dbcd02a 100644 --- a/tests/methods/test_ia3.py +++ b/tests/methods/test_ia3.py @@ -48,6 +48,3 @@ def test_reset_ia3(self): def test_ia3_gradient_checkpointing_single_adapter(self): self.run_gradient_checkpointing_single_adapter_test(IA3Config()) - - def test_ia3_gradient_checkpointing_parallel_adapters(self): - self.run_gradient_checkpointing_test_parallel_adapters(IA3Config()) diff --git a/tests/methods/test_lora.py b/tests/methods/test_lora.py index 0fbd2f6808..e1ced5188a 100644 --- a/tests/methods/test_lora.py +++ b/tests/methods/test_lora.py @@ -316,6 +316,3 @@ def test_reset_lora(self): def test_lora_gradient_checkpointing_single_adapter(self): self.run_gradient_checkpointing_single_adapter_test(LoRAConfig()) - - def test_lora_gradient_checkpointing_parallel_adapters(self): - self.run_gradient_checkpointing_test_parallel_adapters(LoRAConfig()) diff --git a/tests/methods/test_prefix_tuning.py b/tests/methods/test_prefix_tuning.py index c6f5ade445..1878c8a18d 100644 --- a/tests/methods/test_prefix_tuning.py +++ b/tests/methods/test_prefix_tuning.py @@ -104,6 +104,3 @@ def test_prefix_tuning_generate(self): def test_prefix_tuning_gradient_checkpointing_single_adapter(self): self.run_gradient_checkpointing_single_adapter_test(PrefixTuningConfig()) - - def test_prefix_tuning_gradient_checkpointing_parallel_adapters(self): - self.run_gradient_checkpointing_test_parallel_adapters(PrefixTuningConfig()) diff --git a/tests/methods/test_prompt_tuning.py b/tests/methods/test_prompt_tuning.py index f3c4b5b657..f2fd1b0345 100644 --- a/tests/methods/test_prompt_tuning.py +++ b/tests/methods/test_prompt_tuning.py @@ -39,6 +39,3 @@ def test_train_prompt_tuning(self): def test_prompt_tuning_gradient_checkpointing_single_adapter(self): self.run_gradient_checkpointing_single_adapter_test(PromptTuningConfig(prompt_length=10)) - - def test_prompt_tuning_gradient_checkpointing_parallel_adapters(self): - self.run_gradient_checkpointing_test_parallel_adapters(PromptTuningConfig(prompt_length=10)) diff --git a/tests/methods/test_reft.py b/tests/methods/test_reft.py index 8e5cab0f27..2a74c2b111 100644 --- a/tests/methods/test_reft.py +++ b/tests/methods/test_reft.py @@ -80,6 +80,3 @@ def test_train_loreft(self): def test_reft_gradient_checkpointing_single_adapter(self): self.run_gradient_checkpointing_single_adapter_test(LoReftConfig()) - - def test_reft_gradient_checkpointing_parallel_adapters(self): - self.run_gradient_checkpointing_test_parallel_adapters(LoReftConfig()) diff --git a/tests/methods/test_unipelt.py b/tests/methods/test_unipelt.py index 2191a31161..b855670ab4 100644 --- a/tests/methods/test_unipelt.py +++ b/tests/methods/test_unipelt.py @@ -67,6 +67,3 @@ def test_output_adapter_gating_scores_unipelt(self): def test_unipelt_gradient_checkpointing_single_adapter(self): self.run_gradient_checkpointing_single_adapter_test(UniPELTConfig()) - - def test_unipelt_gradient_checkpointing_parallel_adapters(self): - self.run_gradient_checkpointing_test_parallel_adapters(UniPELTConfig()) From ea5b68f9ae9e0c12a4d2d81c5859d175898de3d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leon=20Engl=C3=A4nder?= Date: Mon, 13 Jan 2025 23:47:33 +0100 Subject: [PATCH 05/10] fix deberta and albert tests. docs & style & fixes - albert: skip unsupported tests - deberta(V2): fix embedding bug with inplace operations. - deberta: fix LoRAMergedLinear Bug with device mismatch --- src/adapters/methods/lora.py | 4 +- .../models/deberta/modeling_deberta.py | 53 ++++++++++++++++++ .../models/deberta_v2/modeling_deberta_v2.py | 55 +++++++++++++++++++ tests/methods/base.py | 1 - 4 files changed, 111 insertions(+), 2 deletions(-) diff --git a/src/adapters/methods/lora.py b/src/adapters/methods/lora.py index 8f3bc29401..3245afdd99 100644 --- a/src/adapters/methods/lora.py +++ b/src/adapters/methods/lora.py @@ -718,7 +718,9 @@ def pad(self, x, lora, fill_value=None): fill_value = 1 result = x.new_full((*x.shape[:-1], self.out_features), fill_value) result = result.view(-1, self.out_features) - result[:, lora.lora_ind] = x.reshape(-1, self.out_features // 3 * self.get_n_heads(lora)) + # Move lora_ind to the same device as x + lora_ind = lora.lora_ind.to(x.device) + result[:, lora_ind] = x.reshape(-1, self.out_features // 3 * self.get_n_heads(lora)) return result.view((*x.shape[:-1], self.out_features)) def reset_adapter(self): diff --git a/src/adapters/models/deberta/modeling_deberta.py b/src/adapters/models/deberta/modeling_deberta.py index 77c6117b19..8b4c87b2c5 100644 --- a/src/adapters/models/deberta/modeling_deberta.py +++ b/src/adapters/models/deberta/modeling_deberta.py @@ -19,6 +19,7 @@ from torch import nn from transformers.models.deberta.modeling_deberta import ( + DebertaEmbeddings, DebertaOutput, DebertaSelfOutput, DisentangledSelfAttention, @@ -47,6 +48,58 @@ def forward(self, hidden_states, input_tensor): return hidden_states +class DebertaEmbeddingsWithAdapters(DebertaEmbeddings): + def forward(self, input_ids=None, token_type_ids=None, position_ids=None, mask=None, inputs_embeds=None): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + if self.position_embeddings is not None: + position_embeddings = self.position_embeddings(position_ids.long()) + else: + position_embeddings = torch.zeros_like(inputs_embeds) + + embeddings = inputs_embeds + if self.position_biased_input: + # >>> START AH Changes <<< + # HuggingFace uses += instead of + which leads to a bug when using model.enable_input_require_grads. Once this is fixed, we can remove + embeddings = embeddings + position_embeddings + # >>> END AH Changes <<< + if self.token_type_embeddings is not None: + token_type_embeddings = self.token_type_embeddings(token_type_ids) + # >>> START AH Changes <<< + embeddings = embeddings + token_type_embeddings + # >>> END AH Changes <<< + if self.embed_proj is not None: + embeddings = self.embed_proj(embeddings) + + embeddings = self.LayerNorm(embeddings) + + if mask is not None: + if mask.dim() != embeddings.dim(): + if mask.dim() == 4: + mask = mask.squeeze(1).squeeze(1) + mask = mask.unsqueeze(2) + mask = mask.to(embeddings.dtype) + + embeddings = embeddings * mask + + embeddings = self.dropout(embeddings) + return embeddings + + class DisentangledSelfAttentionWithAdapters(DebertaSelfAttentionAdaptersMixin, DisentangledSelfAttention): """ Disentangled self-attention module diff --git a/src/adapters/models/deberta_v2/modeling_deberta_v2.py b/src/adapters/models/deberta_v2/modeling_deberta_v2.py index 2b673c491f..2e7d86ae8a 100644 --- a/src/adapters/models/deberta_v2/modeling_deberta_v2.py +++ b/src/adapters/models/deberta_v2/modeling_deberta_v2.py @@ -19,6 +19,7 @@ from torch import nn from transformers.models.deberta_v2.modeling_deberta_v2 import ( + DebertaV2Embeddings, DebertaV2Output, DebertaV2SelfOutput, DisentangledSelfAttention, @@ -49,6 +50,60 @@ def forward(self, hidden_states, input_tensor): return hidden_states +# Copied from transformers.models.deberta.modeling_deberta.DebertaEmbeddings with DebertaLayerNorm->LayerNorm,Deberta->DebertaV2 +class DebertaV2EmbeddingsWithAdapters(DebertaV2Embeddings): + def forward(self, input_ids=None, token_type_ids=None, position_ids=None, mask=None, inputs_embeds=None): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + if self.position_embeddings is not None: + position_embeddings = self.position_embeddings(position_ids.long()) + else: + position_embeddings = torch.zeros_like(inputs_embeds) + + embeddings = inputs_embeds + if self.position_biased_input: + # >>> START AH Changes <<< + # HuggingFace uses += instead of + which leads to a bug when using model.enable_input_require_grads. Once this is fixed, we can remove DebertaV2EmbeddingsWithAdapters. + embeddings = embeddings + position_embeddings + # >>> END AH Changes <<< + if self.token_type_embeddings is not None: + token_type_embeddings = self.token_type_embeddings(token_type_ids) + # >>> START AH Changes <<< + embeddings = embeddings + token_type_embeddings + # >>> END AH Changes <<< + + if self.embed_proj is not None: + embeddings = self.embed_proj(embeddings) + + embeddings = self.LayerNorm(embeddings) + + if mask is not None: + if mask.dim() != embeddings.dim(): + if mask.dim() == 4: + mask = mask.squeeze(1).squeeze(1) + mask = mask.unsqueeze(2) + mask = mask.to(embeddings.dtype) + + embeddings = embeddings * mask + + embeddings = self.dropout(embeddings) + return embeddings + + class DisentangledSelfAttentionWithAdapters(DebertaV2SelfAttentionAdaptersMixin, DisentangledSelfAttention): def transpose_for_scores_extended(self, x, attention_heads): new_x_shape = x.size()[:-1] + (attention_heads, -1) diff --git a/tests/methods/base.py b/tests/methods/base.py index 0c508c3108..46df3fdc2a 100644 --- a/tests/methods/base.py +++ b/tests/methods/base.py @@ -6,7 +6,6 @@ import torch import adapters -import adapters.composition as ac from adapters import ADAPTER_MODEL_MAPPING, AdapterSetup, AdapterTrainer, AutoAdapterModel from adapters.heads import CausalLMHead from adapters.utils import WEIGHTS_NAME From 5999bb7f5dd1114117906903953c93f88cc075e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leon=20Engl=C3=A4nder?= Date: Mon, 13 Jan 2025 23:53:51 +0100 Subject: [PATCH 06/10] Add docs and notebook --- docs/training.md | 4 + notebooks/Gradient_Checkpointing_Llama.ipynb | 342 +++++++++++++++++++ notebooks/README.md | 7 +- 3 files changed, 352 insertions(+), 1 deletion(-) create mode 100644 notebooks/Gradient_Checkpointing_Llama.ipynb diff --git a/docs/training.md b/docs/training.md index 78fcd9e757..d4de614392 100644 --- a/docs/training.md +++ b/docs/training.md @@ -223,3 +223,7 @@ trainer = AdapterTrainer( _Adapters_ supports fine-tuning of quantized language models similar to [QLoRA (Dettmers et al., 2023)](https://arxiv.org/pdf/2305.14314.pdf) via the `bitsandbytes` library integrated into Transformers. Quantized training is supported for LoRA-based adapters as well as bottleneck adapters and prefix tuning. Please refer to [this notebook](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/QLoRA_Llama_Finetuning.ipynb) for a hands-on guide. + +## Gradient Checkpointing +Gradient checkpointing is supported for all models (e.g. Llama 1/2/3) except for the models that are not supported by Hugging Face Transformers (like ALBERT). Please refer to [this notebook](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/Gradient_Checkpointing_Llama.ipynb) for a hands-on guide. + diff --git a/notebooks/Gradient_Checkpointing_Llama.ipynb b/notebooks/Gradient_Checkpointing_Llama.ipynb new file mode 100644 index 0000000000..dee4d3abd3 --- /dev/null +++ b/notebooks/Gradient_Checkpointing_Llama.ipynb @@ -0,0 +1,342 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "introduction", + "metadata": {}, + "source": [ + "# Efficient Llama Training with Gradient Checkpointing and _Adapters_\n", + "\n", + "In this notebook, we show how to efficiently fine-tune a **Llama 3** model using **gradient checkpointing** and adapter methods.\n", + "\n", + "**Gradient checkpointing** is a technique to reduce peak memory usage significantly and thus enables training larger models with larger batch sizes. Gradient checkpointing achieves this by trading compute for memory: During the forward pass, gradient checkpointing only stores a subset of activations (thus saving memory). During backpropagation, gradient checkpointing recomputes the activations that were not stored. This can significantly reduce memory requirements at the cost of slightly increased computation time.\n", + "\n", + "In this notebook, we finetune Llama-3 8B on supervised instruction tuning data collected by the [Open Assistant project](https://github.com/LAION-AI/Open-Assistant) for training chatbots.\n", + "\n", + "Another way to reduce memore usage is to use quantization. Have a look a the [QLora notebook](QLoRA_Llama_Finetuning.ipynb) for an example. This gradient checkpointing notebook is based on the QLoRA notebook. While we use a normal LoRA setup in this notebook, you can easily replace LoRA with QLoRA to reduce memory usage even further." + ] + }, + { + "cell_type": "markdown", + "id": "installation", + "metadata": {}, + "source": [ + "## Installation\n", + "\n", + "We need `adapters`, `datasets` and `pytorch` for training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "install", + "metadata": {}, + "outputs": [], + "source": [ + "%pip install -qq -U adapters datasets torch" + ] + }, + { + "cell_type": "markdown", + "id": "dataset", + "metadata": {}, + "source": [ + "## Load Open Assistant dataset\n", + "\n", + "We use the [`timdettmers/openassistant-guanaco`](https://huggingface.co/datasets/timdettmers/openassistant-guanaco) dataset, which contains a small subset of conversations from the full Open Assistant database." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "load_dataset", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "DatasetDict({\n", + " train: Dataset({\n", + " features: ['text'],\n", + " num_rows: 9846\n", + " })\n", + " test: Dataset({\n", + " features: ['text'],\n", + " num_rows: 518\n", + " })\n", + "})" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from datasets import load_dataset\n", + "\n", + "dataset = load_dataset(\"timdettmers/openassistant-guanaco\")\n", + "dataset" + ] + }, + { + "cell_type": "markdown", + "id": "model_setup", + "metadata": {}, + "source": [ + "## Load and prepare model\n", + "\n", + "We download the official Llama-2 7B/ Llama-3 8B checkpoint from the HuggingFace Hub. Note that you must request access to this model on the HuggingFace website and use an API token to download it.\n", + "\n", + "The key difference in this notebook is that we'll enable gradient checkpointing to reduce memory usage during training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "load_model", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "83e60dee3c434bb3a2bc656bd7f4b667", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/4 [00:00\"\n", + "\n", + "modelpath=\"meta-llama/Meta-Llama-3-8B\"\n", + "\n", + "# Load model with gradient checkpointing enabled\n", + "model = AutoModelForCausalLM.from_pretrained(\n", + " modelpath, \n", + " device_map=\"auto\",\n", + " torch_dtype=torch.bfloat16,\n", + " token=HUGGINGFACE_ACCESS_TOKEN,\n", + ")\n", + "model.config.use_cache = False\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(modelpath, token=HUGGINGFACE_ACCESS_TOKEN)\n", + "tokenizer.pad_token = tokenizer.eos_token" + ] + }, + { + "cell_type": "markdown", + "id": "5cd73b7d", + "metadata": {}, + "source": [ + "If you get a message similar to `WARNING:accelerate.big_modeling:Some parameters are on the meta device because they were offloaded to the cpu and disk.`, then the model itself is too big for your GPU. If you don't have a bigger / additional GPU at hand, you can use a quantization method like we show in the [QLoRA notebook](QLoRA_Llama_Finetuning.ipynb). Adding the quantization_config when loading the model and choosing a quantized `LoRAConfig` in the next step will enable quantized training." + ] + }, + { + "cell_type": "markdown", + "id": "adapter_setup", + "metadata": {}, + "source": [ + "## Initialize adapter\n", + "\n", + "We initialize the adapter functionality and add a LoRA adapter. When using gradient checkpointing with adapters, we need to enable input gradients explicitly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "init_adapter", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "================================================================================\n", + "Name Architecture #Param %Param Active Train\n", + "--------------------------------------------------------------------------------\n", + "lora_adapter lora 3,407,872 0.085 1 1\n", + "--------------------------------------------------------------------------------\n", + "Full model 4,015,263,744 100.000 0\n", + "================================================================================\n" + ] + } + ], + "source": [ + "import adapters\n", + "from adapters import LoRAConfig\n", + "\n", + "adapters.init(model)\n", + "\n", + "config = LoRAConfig()\n", + "model.add_adapter(\"lora_adapter\", config=config)\n", + "model.train_adapter(\"lora_adapter\")\n", + "\n", + "# Activate gradient checkpointing\n", + "model.gradient_checkpointing_enable()\n", + "\n", + "# For gradient checkpointing with adapters, it is beneficial to set enable_input_require_grads.\n", + "model.enable_input_require_grads()\n", + "\n", + "print(model.adapter_summary())" + ] + }, + { + "cell_type": "markdown", + "id": "data_prep", + "metadata": {}, + "source": [ + "## Prepare data for training\n", + "\n", + "The dataset is tokenized and truncated." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "tokenize", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "def tokenize(element):\n", + " return tokenizer(\n", + " element[\"text\"],\n", + " truncation=True,\n", + " max_length=512,\n", + " add_special_tokens=False,\n", + " )\n", + "\n", + "dataset_tokenized = dataset.map(\n", + " tokenize, \n", + " batched=True, \n", + " num_proc=os.cpu_count(),\n", + " remove_columns=[\"text\"]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "training", + "metadata": {}, + "source": [ + "## Training\n", + "\n", + "We specify training hyperparameters and train the model using the `AdapterTrainer` class. With gradient checkpointing enabled, we can use larger batch sizes than would otherwise be possible." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "training_args", + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import TrainingArguments\n", + "\n", + "args = TrainingArguments(\n", + " output_dir=\"output/llama_gradient_checkpointing\",\n", + " per_device_train_batch_size=1,\n", + " per_device_eval_batch_size=1,\n", + " evaluation_strategy=\"steps\",\n", + " logging_steps=10,\n", + " save_steps=500,\n", + " eval_steps=187,\n", + " save_total_limit=3,\n", + " gradient_accumulation_steps=16,\n", + " max_steps=1875,\n", + " learning_rate=0.0002,\n", + " bf16=True,\n", + " warmup_ratio=0.03,\n", + " group_by_length=True,\n", + " lr_scheduler_type=\"constant\",\n", + " optim=\"adamw_torch\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "train", + "metadata": {}, + "outputs": [], + "source": [ + "from adapters import AdapterTrainer\n", + "from transformers import DataCollatorForLanguageModeling\n", + "\n", + "trainer = AdapterTrainer(\n", + " model=model,\n", + " tokenizer=tokenizer,\n", + " data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),\n", + " train_dataset=dataset_tokenized[\"train\"],\n", + " eval_dataset=dataset_tokenized[\"test\"],\n", + " args=args,\n", + ")\n", + "\n", + "trainer.train()" + ] + }, + { + "cell_type": "markdown", + "id": "inference", + "metadata": {}, + "source": [ + "## Inference\n", + "\n", + "For inference, we can disable gradient checkpointing since we don't need gradients:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "inference_setup", + "metadata": {}, + "outputs": [], + "source": [ + "# Disable gradient checkpointing for inference\n", + "model.gradient_checkpointing_disable()\n", + "model.config.use_cache = True\n", + "\n", + "def prompt_model(model, text: str):\n", + " batch = tokenizer(f\"### Human: {text}\\n### Assistant:\", return_tensors=\"pt\")\n", + " batch = batch.to(model.device)\n", + " \n", + " model.eval()\n", + " with torch.inference_mode():\n", + " output_tokens = model.generate(**batch, max_new_tokens=50)\n", + "\n", + " return tokenizer.decode(output_tokens[0], skip_special_tokens=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "test_inference", + "metadata": {}, + "outputs": [], + "source": [ + "print(prompt_model(model, \"Explain gradient checkpointing in simple terms\"))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/README.md b/notebooks/README.md index 052cdafe4d..4766baca7e 100644 --- a/notebooks/README.md +++ b/notebooks/README.md @@ -28,7 +28,6 @@ As adapters is fully compatible with HuggingFace's Transformers, you can also us | Notebook | Description | | |:----------------|:---------------------|--:| | [Text Generation](https://github.com/Adapter-Hub/adapters/blob/main/notebooks/Text_Generation_Training.ipynb) | How to train an adapter for language generation. | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/Text_Generation_Training.ipynb) | -| [QLoRA LLama Finetuning](https://github.com/Adapter-Hub/adapters/blob/main/notebooks/QLoRA_Llama_Finetuning.ipynb) | How to finetune a quantized Llama model for using QLoRA. | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/QLoRA_Llama_Finetuning.ipynb) | | [Training a NER Adapter](https://github.com/Adapter-Hub/adapters/blob/main/notebooks/Adapter_train_NER_with_id2label.ipynb) | How to train an adapter on a named entity recoginition task. | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/Adapter_train_NER_with_id2label.ipynb) | | [Adapter Drop Training](https://github.com/Adapter-Hub/adapters/blob/main/notebooks/Adapter_Drop_Training.ipynb) | How to train an adapter using AdapterDrop | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/Adapter_Drop_Training.ipynb) | | [Inference example for id2label](https://github.com/Adapter-Hub/adapters/blob/main/notebooks/Adapter_train_NER_with_id2label.ipynb) | How to use the id2label dictionary for inference | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/Adapter_id2label_inference.ipynb) | @@ -36,3 +35,9 @@ As adapters is fully compatible with HuggingFace's Transformers, you can also us | [Finetuning Whisper with Adapters](https://github.com/Adapter-Hub/adapters/blob/main/notebooks/Adapter_Whisper_Audio_FineTuning.ipynb) | Fine Tuning Whisper using LoRA | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/Adapter_Whisper_Audio_FineTuning.ipynb) | | [Adapter Training with ReFT](https://github.com/Adapter-Hub/adapters/blob/main/notebooks/ReFT_Adapters_Finetuning.ipynb) | Fine Tuning using ReFT Adapters | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/ReFT_Adapters_Finetuning.ipynb) | | [ViT Fine-Tuning with AdapterPlus](https://github.com/Adapter-Hub/adapters/blob/main/notebooks/ViT_AdapterPlus_FineTuning.ipynb) | ViT Fine-Tuning with AdapterPlus | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/ViT_AdapterPlus_FineTuning.ipynb) | + +### Memory Efficient Training +| Notebook | Description | | +|:----------------|:---------------------|--:| +| [QLoRA LLama Finetuning](https://github.com/Adapter-Hub/adapters/blob/main/notebooks/QLoRA_Llama_Finetuning.ipynb) | How to finetune a quantized Llama model for using QLoRA. | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/QLoRA_Llama_Finetuning.ipynb) | +| [Gradient Checkpointing](https://github.com/Adapter-Hub/adapters/blob/main/notebooks/Gradient_Checkpointing_Llama.ipynb) | How to finetune a quantized Llama model for using QLoRA. | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/QLoRA_Llama_Finetuning.ipynb) | From b89b941b88926456cb9fc91994489844ceada40d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leon=20Engl=C3=A4nder?= Date: Tue, 14 Jan 2025 10:15:21 +0100 Subject: [PATCH 07/10] Fix all remaining bugs (T5 and BeIT HF bug) --- notebooks/Gradient_Checkpointing_Llama.ipynb | 3 --- src/adapters/model_mixin.py | 10 ++++------ src/adapters/models/beit/adapter_model.py | 14 ++++++++++++++ src/adapters/models/t5/modeling_t5.py | 14 +++++++++++++- tests/methods/base.py | 6 +++++- 5 files changed, 36 insertions(+), 11 deletions(-) diff --git a/notebooks/Gradient_Checkpointing_Llama.ipynb b/notebooks/Gradient_Checkpointing_Llama.ipynb index dee4d3abd3..b48390d846 100644 --- a/notebooks/Gradient_Checkpointing_Llama.ipynb +++ b/notebooks/Gradient_Checkpointing_Llama.ipynb @@ -185,9 +185,6 @@ "# Activate gradient checkpointing\n", "model.gradient_checkpointing_enable()\n", "\n", - "# For gradient checkpointing with adapters, it is beneficial to set enable_input_require_grads.\n", - "model.enable_input_require_grads()\n", - "\n", "print(model.adapter_summary())" ] }, diff --git a/src/adapters/model_mixin.py b/src/adapters/model_mixin.py index 8e07559713..00714ad873 100644 --- a/src/adapters/model_mixin.py +++ b/src/adapters/model_mixin.py @@ -1674,12 +1674,10 @@ def gradient_checkpointing_function(function, *args, **kwargs): "Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model." ) - if getattr(self, "_hf_peft_config_loaded", False): - # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True - # we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334 - # When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate - # the gradients to make sure the gradient flows. - self.enable_input_require_grads() + # >>> START AH Changes <<< + # For adapter training, we always require requires_grad=True for the input embeddings. + self.enable_input_require_grads() + # >>> END AH Changes <<< @inherit_doc diff --git a/src/adapters/models/beit/adapter_model.py b/src/adapters/models/beit/adapter_model.py index 5667fa098d..578142ea11 100644 --- a/src/adapters/models/beit/adapter_model.py +++ b/src/adapters/models/beit/adapter_model.py @@ -36,6 +36,20 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() + # Overwrites the function from: transformers.modeling_utils.PreTrainedModel + def enable_input_require_grads(self): + """ + Enables the gradients for the input embeddings specifically for BEiT's tuple output format. + """ + + def make_inputs_require_grads(module, input, output): + # >>> START AH Changes <<< + # Handle BEiT's specific tuple output format. Hugging Face's implementation is buggy and doesn't work for BEiT. + output[0].requires_grad_(True) + # >>> END AH Changes <<< + + self._require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads) + @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING) def forward( self, diff --git a/src/adapters/models/t5/modeling_t5.py b/src/adapters/models/t5/modeling_t5.py index 09b969bb1b..e4af5f04c4 100644 --- a/src/adapters/models/t5/modeling_t5.py +++ b/src/adapters/models/t5/modeling_t5.py @@ -419,8 +419,20 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: + # >>> START AH Changes <<< + # Without this change, T5 training with gradient checkpointing will fail for reft. + def create_custom_forward(module): + def custom_forward(*inputs): + # Ensure all inputs are on the same device + inputs = tuple(x.to(inputs[0].device) if isinstance(x, torch.Tensor) else x for x in inputs) + return module(*inputs) + + return custom_forward + + # >>> END AH Changes <<< + layer_outputs = self._gradient_checkpointing_func( - layer_module.forward, + create_custom_forward(layer_module), hidden_states, causal_mask, position_bias, diff --git a/tests/methods/base.py b/tests/methods/base.py index 46df3fdc2a..e0e3abce19 100644 --- a/tests/methods/base.py +++ b/tests/methods/base.py @@ -394,13 +394,17 @@ def _run_gradient_checkpointing_test_helper(self, adapter_setup_fn: Callable[[ad # Initialize model model = adapters.AutoAdapterModel.from_config(config) + + # if model doesn't support gradient checkpointing, skip the test + if not model.supports_gradient_checkpointing: + self.skipTest("Model does not support gradient checkpointing") + model.to(torch_device) adapter_setup_fn(model) # Enable gradient checkpointing if train_with_checkpointing: model.gradient_checkpointing_enable() - model.enable_input_require_grads() # Train & store state dict self.trainings_run(model, batch_size=1, gradient_accumulation_steps=2) From e7de20b168fdf3a7751ec5b2ed389c8c7924b4de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leon=20Engl=C3=A4nder?= Date: Tue, 14 Jan 2025 10:45:22 +0100 Subject: [PATCH 08/10] fix handling of CLIP and fix MT5 like we did for T5 --- src/adapters/model_mixin.py | 13 +++++++++++-- src/adapters/models/mt5/modeling_mt5.py | 16 +++++++++++++++- src/adapters/models/t5/modeling_t5.py | 2 ++ 3 files changed, 28 insertions(+), 3 deletions(-) diff --git a/src/adapters/model_mixin.py b/src/adapters/model_mixin.py index 00714ad873..3042faec42 100644 --- a/src/adapters/model_mixin.py +++ b/src/adapters/model_mixin.py @@ -1675,8 +1675,17 @@ def gradient_checkpointing_function(function, *args, **kwargs): ) # >>> START AH Changes <<< - # For adapter training, we always require requires_grad=True for the input embeddings. - self.enable_input_require_grads() + # For adapter training, we set requires_grad=True for the input embeddings. Just like Hugging Face does for training with PEFT. + try: + self.enable_input_require_grads() + except NotImplementedError: + # Some models (CLIP) don't have input embeddings, so Hugging Face's implementation raises a NotImplementedError. + logger.warning( + "Model does not have input embeddings. Hugging Face didn't implement the model.enable_input_require_grads() method. But Gradient Checkpointing should nevertheless work. If you, however, encounter errors / weird behaviour, this might be the reason. In this case, please implement the method in the model yourself / open an issue on our GitHub." + ) + except Exception as e: + # Every other exception is unexpected and should be raised. + raise e # >>> END AH Changes <<< diff --git a/src/adapters/models/mt5/modeling_mt5.py b/src/adapters/models/mt5/modeling_mt5.py index 05141a08cf..b317823335 100644 --- a/src/adapters/models/mt5/modeling_mt5.py +++ b/src/adapters/models/mt5/modeling_mt5.py @@ -419,8 +419,22 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: + # >>> START AH Changes <<< + # Without this change, T5 training with gradient checkpointing will fail for reft. + def create_custom_forward(module): + def custom_forward(*inputs): + # Ensure all inputs are on the same device + inputs = tuple(x.to(inputs[0].device) if isinstance(x, torch.Tensor) else x for x in inputs) + return module(*inputs) + + return custom_forward + + # >>> END AH Changes <<< + layer_outputs = self._gradient_checkpointing_func( - layer_module.forward, + # >>> START AH Changes <<< + create_custom_forward(layer_module), + # >>> END AH Changes <<< hidden_states, causal_mask, position_bias, diff --git a/src/adapters/models/t5/modeling_t5.py b/src/adapters/models/t5/modeling_t5.py index e4af5f04c4..e401a2b840 100644 --- a/src/adapters/models/t5/modeling_t5.py +++ b/src/adapters/models/t5/modeling_t5.py @@ -432,7 +432,9 @@ def custom_forward(*inputs): # >>> END AH Changes <<< layer_outputs = self._gradient_checkpointing_func( + # >>> START AH Changes <<< create_custom_forward(layer_module), + # >>> END AH Changes <<< hidden_states, causal_mask, position_bias, From 0d3f0a1629b69e766b35bbade6317a2479a7a21e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leon=20Engl=C3=A4nder?= Date: Tue, 14 Jan 2025 11:32:59 +0100 Subject: [PATCH 09/10] fix CLIP --- src/adapters/model_mixin.py | 9 +++------ tests/models/test_clip.py | 5 +++++ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/adapters/model_mixin.py b/src/adapters/model_mixin.py index 3042faec42..68971bef1b 100644 --- a/src/adapters/model_mixin.py +++ b/src/adapters/model_mixin.py @@ -1679,13 +1679,10 @@ def gradient_checkpointing_function(function, *args, **kwargs): try: self.enable_input_require_grads() except NotImplementedError: - # Some models (CLIP) don't have input embeddings, so Hugging Face's implementation raises a NotImplementedError. - logger.warning( - "Model does not have input embeddings. Hugging Face didn't implement the model.enable_input_require_grads() method. But Gradient Checkpointing should nevertheless work. If you, however, encounter errors / weird behaviour, this might be the reason. In this case, please implement the method in the model yourself / open an issue on our GitHub." + # Some models (CLIP) don't have input embeddings, so Hugging Face's implementation raises a NotImplementedError. We provide the user with some more information. + raise NotImplementedError( + "Model has no enable_input_require_grads method implementation by Hugging Face. Parameter efficient fine-tuning however needs gradients for embeddings. This model therefore doesn't support gradient checkpointing with Adapters nor Hugging Face's PEFT library." ) - except Exception as e: - # Every other exception is unexpected and should be raised. - raise e # >>> END AH Changes <<< diff --git a/tests/models/test_clip.py b/tests/models/test_clip.py index 921e0668f5..cf1297b693 100644 --- a/tests/models/test_clip.py +++ b/tests/models/test_clip.py @@ -37,3 +37,8 @@ def test_initialization(self): [0.0, 1.0], msg=f"Parameter {name} of model {model_class} seems not properly initialized", ) + + def test_gradient_checkpointing_enable_disable(self): + # CLIPAdapterModel does not support gradient checkpointing (because enable_input_require_grads is not implemented by Hugging Face, + # which is required for gradient checkpointing with parameter efficient fine-tuning methods). + self.skipTest("CLIPAdapterModel does not support gradient checkpointing") From c7e178b515846bd28ece75474ce4d88e021648ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leon=20Engl=C3=A4nder?= Date: Mon, 20 Jan 2025 23:48:49 +0100 Subject: [PATCH 10/10] Add comments --- src/adapters/model_mixin.py | 1 + tests/methods/base.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/adapters/model_mixin.py b/src/adapters/model_mixin.py index d18b8b36e5..1895671f8d 100644 --- a/src/adapters/model_mixin.py +++ b/src/adapters/model_mixin.py @@ -1621,6 +1621,7 @@ def save_pretrained( # Remove adapters config del self.config.adapters + # Override PreTrainedModel.gradient_checkpointing_enable(...) method from transformers/modeling_utils.py to support gradient checkpointing for adapter training. def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): """ Activates gradient checkpointing for the current model. diff --git a/tests/methods/base.py b/tests/methods/base.py index 12abcffbf0..55389b7052 100644 --- a/tests/methods/base.py +++ b/tests/methods/base.py @@ -388,6 +388,7 @@ def _run_gradient_checkpointing_test_helper(self, adapter_setup_fn: Callable[[ad config = self.config() state_dict_after_training = {} + # Run training twice (with & without gradient checkpointing) to verify both produce identical results (i.e. the same state dict) for train_with_checkpointing in [True, False]: # Set random seed torch.manual_seed(42)