From 09ca532009822f109b6dc6deb1472d156a5bf2f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leon=20Engl=C3=A4nder?= Date: Thu, 12 Oct 2023 14:32:27 +0200 Subject: [PATCH 01/13] WIP prompt tuning --- src/adapters/__init__.py | 2 + src/adapters/configuration/adapter_config.py | 27 +++ src/adapters/model_mixin.py | 73 ++++-- src/adapters/models/bert/mixin_bert.py | 25 +- src/adapters/prompt_tuning.py | 219 ++++++++++++++++++ .../composition/test_adapter_composition.py | 2 +- tests_adapters/methods/__init__.py | 1 + tests_adapters/methods/test_prompt_tuning.py | 44 ++++ tests_adapters/test_adapter.py | 19 ++ tests_adapters/test_adapter_custom_head.py | 3 +- tests_adapters/test_adapter_hub.py | 3 +- 11 files changed, 398 insertions(+), 20 deletions(-) create mode 100644 src/adapters/prompt_tuning.py create mode 100644 tests_adapters/methods/test_prompt_tuning.py diff --git a/src/adapters/__init__.py b/src/adapters/__init__.py index bd78dee5de..8c32027aa7 100644 --- a/src/adapters/__init__.py +++ b/src/adapters/__init__.py @@ -54,6 +54,7 @@ "ModelAdaptersConfig", "ParBnConfig", "PrefixTuningConfig", + "PromptTuningConfig", "SeqBnConfig", "SeqBnInvConfig", "StaticAdapterFusionConfig", @@ -161,6 +162,7 @@ ModelAdaptersConfig, ParBnConfig, PrefixTuningConfig, + PromptTuningConfig, SeqBnConfig, SeqBnInvConfig, StaticAdapterFusionConfig, diff --git a/src/adapters/configuration/adapter_config.py b/src/adapters/configuration/adapter_config.py index 0d72f08377..765295282a 100644 --- a/src/adapters/configuration/adapter_config.py +++ b/src/adapters/configuration/adapter_config.py @@ -84,6 +84,8 @@ def _get_config_class(config_dict): cls_new = LoRAConfig elif architecture == "union": cls_new = ConfigUnion + elif architecture == "prompt_tuning": + cls_new = PromptTuningConfig else: cls_new = BnConfig @@ -395,6 +397,30 @@ class PrefixTuningConfig(AdapterConfigBase): shared_gating: bool = True +@dataclass(eq=False) +class PromptTuningConfig(AdapterConfigBase): + # TODO: add config + """ + The Prompt Tuning architecture proposed by Lester et al. (2021). See https://arxiv.org/pdf/2104.08691.pdf + + Args: + + """ + + prompt_length: int + + prompt_init_text: Optional[str] = None # only necessary when using prompt_init="from_string" + architecture: Optional[str] = "prompt_tuning" + + prompt_init: str = ( # random_uniform, from_string, from_array, TODO: add more from https://github.com/google-research/prompt-tuning/blob/main/prompt_tuning/prompts.py + "random_uniform" + ) + combine: str = "prefix" # prefix, prefix_after_bos, suffix + + # TODO: add a parameter for the random uniform scale + # TODO: add more params if necessary + + @dataclass(eq=False) class LoRAConfig(AdapterConfigBase): """ @@ -612,6 +638,7 @@ def __init__( "compacter": CompacterConfig(), "prefix_tuning": PrefixTuningConfig(), "prefix_tuning_flat": PrefixTuningConfig(flat=True), + "prompt_tuning": PromptTuningConfig(prompt_length=10), # TODO: is that alright? "lora": LoRAConfig(), "ia3": IA3Config(), "mam": MAMConfig(), diff --git a/src/adapters/model_mixin.py b/src/adapters/model_mixin.py index be5e30cd93..7b0a22d2a6 100644 --- a/src/adapters/model_mixin.py +++ b/src/adapters/model_mixin.py @@ -10,10 +10,11 @@ import torch from torch import nn +from transformers.configuration_utils import PretrainedConfig from transformers.modeling_outputs import ModelOutput from .composition import AdapterCompositionBlock, Fuse, Stack, parse_composition -from .configuration import ADAPTER_CONFIG_MAP, AdapterConfigBase, AdapterFusionConfig, BnConfig +from .configuration import ADAPTER_CONFIG_MAP, AdapterConfigBase, AdapterFusionConfig, BnConfig, ModelAdaptersConfig from .context import AdapterSetup, ForwardContext from .hub_mixin import PushAdapterToHubMixin from .loading import AdapterFusionLoader, AdapterLoader, PredictionHeadLoader, WeightsLoader @@ -22,6 +23,7 @@ from .methods.lora import LoRALayer from .methods.modeling import Adapter, GLOWCouplingBlock, NICECouplingBlock, init_shared_parameters from .methods.prefix_tuning import PrefixTuningLayer, PrefixTuningPool +from .prompt_tuning import PromptTuningLayer from .utils import EMBEDDING_FILE, TOKENIZER_PATH, get_adapter_config_hash, inherit_doc from .wrappers.configuration import SUBMODEL_NAMES, init_adapters_config @@ -40,21 +42,20 @@ def init_adapters(self, model_config, adapters_config, **kwargs): if hasattr(super(), "init_adapters"): super().init_adapters(self.config, self.adapters_config, **kwargs) - self.hook_after_embeddings(self._hook_fn) + # self.hook_after_embeddings(self._hook_fn) - def _hook_fn(self, module, args, output): - new_output = self.invertible_adapters_forward(output) - return new_output + # def _hook_fn(self, module, args, output): + # new_output = self.invertible_adapters_forward(output) + # return new_output - def hook_after_embeddings(self, hook_fn: Callable): - """ - Hook a function to be called after the embeddings have been computed. The default implementation does nothing. - Override this method to add a hook. + # TODO: entfernen + # def hook_after_embeddings(self, hook_fn: Callable): + # """ + # Hook a function to be called after the embeddings have been computed. The default implementation does nothing. # + Override this method to add a hook. - Args: - hook_fn (Callable): The function to be called after the embeddings have been computed. - """ - pass + # Args: # hook_fn (Callable): The function to be called after the embeddings have been computed. #""" + # pass def add_invertible_adapter(self, adapter_name: str) -> bool: """ @@ -361,9 +362,28 @@ def loaded_embeddings(self): return self.base_model.loaded_embeddings -class ModelAdaptersMixin(PushAdapterToHubMixin, ABC): +class PromptTuningMixin: + def init_prompt_tuning( + self, model_config: PretrainedConfig, adapters_config: ModelAdaptersConfig, base_model_embeddings: nn.Module + ): + self.prompt_tuning = PromptTuningLayer(model_config, adapters_config, base_model_embeddings) + + def add_prompt_tuning(self, adapter_name: str) -> bool: + return self.prompt_tuning.add_adapter(adapter_name=adapter_name, layer_idx=-1) + + # TODO: delete, etc .... prompt tuning + # TODO: can probably be merged into whereever this gets called + + +class ModelAdaptersMixin(PushAdapterToHubMixin, PromptTuningMixin, ABC): """Mixin for transformer models adding support for loading/ saving adapters.""" + # Setting this to True will automatically add a prompt tuning layer to the model + # This prompt tuning layer is stoed in self.prompt_tuning + # Since the correct position to call the prompt tuning forward depends on the model type, this has to be called in the post_embedding_forward method + supports_prompt_tuning = False + prefix_attention_mask = None + def __init__(self, config, *args, **kwargs): super().__init__(config, *args, **kwargs) @@ -371,6 +391,13 @@ def _link_prefix_to_pool(self, layer): if isinstance(layer, PrefixTuningLayer): layer.set_pool(self.base_model.prefix_tuning) + # TODO: provide documentation & maybe better fitting name & move to more fitting place in this class + # @abstractmethod + def post_embedding_forward(self, module, args, embedding_output): + # def post_embedding_forward(self, *args, **kwargs): + """function to pass the embedding layer output through""" + raise NotImplementedError + @property def model_name(self): return self.config.name_or_path @@ -400,6 +427,10 @@ def init_adapters(self, model_config, adapters_config, add_prefix_tuning_pool=Tr self.base_model.prefix_tuning = PrefixTuningPool(self.config, self.adapters_config) self.apply_to_adapter_layers(lambda i, layer: self._link_prefix_to_pool(layer)) + # Add Prompt Tuning + if self.supports_prompt_tuning: + self.init_prompt_tuning(self.config, self.adapters_config, self.get_input_embeddings()) + # Initialize adapters from config for adapter_name in self.adapters_config: self._add_adapter_weights(adapter_name) @@ -577,6 +608,10 @@ def _add_adapter_weights(self, adapter_name: str): if isinstance(self, InvertibleAdaptersMixin) or isinstance(self, InvertibleAdaptersWrapperMixin): self.add_invertible_adapter(adapter_name) + # Prompt Tuning + if self.supports_prompt_tuning: + self.add_prompt_tuning(adapter_name) + def add_fusion(self, adapter_names: Union[Fuse, list], adapter_fusion_config=None, override_kwargs=None): warnings.warn( "add_fusion() has been deprecated in favor of add_adapter_fusion(). Please use the newer method instead.", @@ -1228,6 +1263,16 @@ def save_pretrained( class ModelBaseAdaptersMixin(ModelAdaptersMixin): @ForwardContext.wrap def forward(self, *args, **kwargs): + # TODO: adapt attention mask here + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None and self.prefix_attention_mask is not None: + self.prefix_attention_mask = self.prefix_attention_mask.to(attention_mask.device) + attention_mask = torch.cat((self.prefix_attention_mask, attention_mask), dim=1) + kwargs.update( + { + "attention_mask": attention_mask, + } + ) return super().forward(*args, **kwargs) diff --git a/src/adapters/models/bert/mixin_bert.py b/src/adapters/models/bert/mixin_bert.py index e97c9dd988..b752e3a3f6 100644 --- a/src/adapters/models/bert/mixin_bert.py +++ b/src/adapters/models/bert/mixin_bert.py @@ -1,5 +1,5 @@ import logging -from typing import Callable, Iterable, Tuple +from typing import Iterable, Tuple import torch.nn as nn @@ -70,13 +70,22 @@ def init_adapters(self, model_config, adapters_config): class BertModelAdaptersMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin): """Adds adapters to the BertModel module.""" + supports_prompt_tuning = True + def init_adapters(self, model_config, adapters_config): super().init_adapters(model_config, adapters_config) + # TODO: move somewhere else? e.g. in ModelBaseAdaptersMixin as we can add this and just don't use it in post_embedding_forward if we don't want to support it + # wie machen wir das mit dem Hinzufügen der Adaoter? Wenn wir es nicht benutzen soll man auch keine hinzufügen können + # self.prompt_tuning = PromptTuningLayer(adapters_config) + # Set hook for parallel composition for _, layer in self.iter_layers(): self._set_layer_hook_for_parallel(layer) + # Register hook for post embedding forward + self.embeddings.register_forward_hook(self.post_embedding_forward) + def _set_layer_hook_for_parallel(self, layer: nn.Module): def hook(module, input): adjust_tensors_for_parallel_(input[0], input[1]) @@ -88,5 +97,15 @@ def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: for i, layer in enumerate(self.encoder.layer): yield i, layer - def hook_after_embeddings(self, hook_fn: Callable): - return self.embeddings.register_forward_hook(hook_fn) + def post_embedding_forward(self, module, args, embedding_output): + print(f"args: {module}") + + new_output = self.invertible_adapters_forward(embedding_output) + new_output, prefix_attention_mask = self.prompt_tuning.forward(embedding_output) + + # TODO: das funktioniert so nicht alleine. Wir müssen die attention mask in die BertSelfAttention bekommen. + # Das ist aber nicht so einfach, weil wir prefix_attention_mask nicht direkt BertSelfAttention mitgeben können, da wir die mitgegebenen Parameter nicht verändern können. + # Wir können das mittels BertSelfAttentionAdaptersMixin machen, indem: + self.prefix_attention_mask = prefix_attention_mask + + return new_output diff --git a/src/adapters/prompt_tuning.py b/src/adapters/prompt_tuning.py new file mode 100644 index 0000000000..b50c127ae1 --- /dev/null +++ b/src/adapters/prompt_tuning.py @@ -0,0 +1,219 @@ +# https://github.com/google-research/prompt-tuning/blob/main/prompt_tuning/train/prompts.py + +import logging +import math +from typing import Callable, Dict, List, Optional, Protocol, Sequence, Tuple, Union + +import numpy as np +import torch +from torch import nn + +from transformers import AutoTokenizer +from transformers.configuration_utils import PretrainedConfig + +from .composition import AdapterCompositionBlock, BatchSplit, Parallel, Stack, adjust_tensors_for_parallel +from .configuration import ModelAdaptersConfig, PromptTuningConfig +from .layer import AdapterLayerBase + + +logger = logging.getLogger(__name__) + + +Initializer = Callable[[torch.Tensor, Sequence[int]], torch.Tensor] + + +class PromptTuning(nn.Module): + """Generate a Prompt and concatenate it with the input. + + This is the training time version of prompting a model. Calling the injected `prompt` module will generate your + unbatched prompt. This model then replicates it for the batched input and concatenates them together. + + Attributes: + prompt: The module that actually generates the unbatched prompt. + combine: A function that combines the prompt and the embedded input. + """ + + prompt: nn.Module + combination_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] + + def __init__( + self, + adapter_name: str, + prompt_tuning_config: PromptTuningConfig, + model_config: PretrainedConfig, + base_model_embeddings: nn.Module, + ): + super().__init__() + + self.name = adapter_name + self.model_config = model_config + self.prompt_tuning_config = prompt_tuning_config + self.base_model_embeddings = base_model_embeddings + + embedding_size = getattr(model_config, "embedding_size", model_config.hidden_size) + + self.prompt_embedding = nn.Embedding( + num_embeddings=prompt_tuning_config.prompt_length, embedding_dim=embedding_size + ) + # Initialize prompt tokens + self.prompt_tokens = torch.arange(prompt_tuning_config.prompt_length).long() + + self._init_prompt_embedding() + + if prompt_tuning_config.combine == "prefix": + self.combination_fn = lambda prompt, embedded_input: torch.cat([prompt, embedded_input], dim=1) + elif prompt_tuning_config.combine == "prefix_after_bos": + self.combination_fn = lambda prompt, embedded_input: torch.cat( + [embedded_input[:, 0, np.newaxis], prompt, embedded_input[:, 1:]], dim=1 + ) + else: + raise ValueError( + f"Unknown combination function: {prompt_tuning_config.combine}. " + "Must be one of 'prefix' or 'prefix_after_bos'." + ) + + def _init_prompt_embedding(self) -> None: + if self.prompt_tuning_config.prompt_init == "random_uniform": + # Embedding was created using torch.nn.Embedding which already uses a random uniform distribution for initialization + pass + + elif self.prompt_tuning_config.prompt_init == "from_string": + tokenizer = AutoTokenizer.from_pretrained(self.model_config.tokenizer_name_or_path) + prompt_length = self.prompt_tuning_config.prompt_length + prompt_text = self.prompt_tuning_config.prompt_init_text + if prompt_text is None: + raise ValueError("Prompt text must be provided when using prompt_init='from_string'.") + + tokenized_prompt_text: list[int] = tokenizer(prompt_text)["input_ids"] # type: ignore + + # If the prompt text tokens are shorter than the prompt length, we repeat the prompt text tokens until we reach the prompt length + if len(tokenized_prompt_text) < prompt_length: + num_reps = math.ceil(prompt_length / len(tokenized_prompt_text)) + tokenized_prompt_text = tokenized_prompt_text * num_reps + + # Adjust length of prompt text tokens to match prompt_length + tokenized_prompt_text = tokenized_prompt_text[:prompt_length] + + # Initialize prompt embedding with tokenized prompt text + word_embedding_weights = ( + self.base_model_embeddings(torch.LongTensor(tokenized_prompt_text)).detach().clone() + ) + word_embedding_weights = word_embedding_weights.to(torch.float32) + self.prompt_embedding.weight = nn.Parameter(word_embedding_weights) + + else: + raise ValueError(f"Unknown prompt initialization: {self.prompt_tuning_config.prompt_init}") + + def forward(self, embedded_input): + # Compute prompt embedding + self.prompt_tokens = self.prompt_tokens.to(embedded_input.device) + self.prompt_embedding = self.prompt_embedding.to(embedded_input.device) + prompt = self.prompt_embedding(self.prompt_tokens) + + # Prompt to batch size + batch_size = embedded_input.shape[0] + prompt = torch.tile(torch.unsqueeze(prompt, dim=0), [batch_size] + [1 for _ in prompt.shape]) + + # Merge prompt and input + output = self.combination_fn(prompt, embedded_input) + + # Adapt attention mask + prefix_attention_mask = torch.ones(batch_size, self.prompt_tuning_config.prompt_length) + + return output, prefix_attention_mask + + +class PromptTuningLayer(AdapterLayerBase, nn.Module): + # TODO: add documentation + + def __init__( + self, + model_config: PretrainedConfig, + adapters_config: ModelAdaptersConfig, + base_model_embeddings: nn.Module, + ): + super().__init__() + self.model_config = model_config + self.adapters_config = adapters_config + self.base_model_embeddings = base_model_embeddings + self.prompt_tunings = nn.ModuleDict() + + def forward(self, hidden_states: torch.Tensor): + # TODO: Takes currently only very first prompt tuning adapter + if self.adapters_config.active_setup is not None and len(self.adapters_config.active_setup) > 0: + first_adapter = self.adapters_config.active_setup.first() + if first_adapter in self.prompt_tunings: + hidden_states = self.prompt_tunings[first_adapter](hidden_states) + + return hidden_states + + def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: + # ignore layer_idx as prompt tunings are only added after the embedding layer + prompt_tuning_config = self.adapters_config.match( + adapter_name, + config_type=PromptTuningConfig, + # layer_idx=self.layer_idx, + # location_key="prompt_tuning", + ) + + if prompt_tuning_config is not None: + adapter = PromptTuning( + adapter_name=adapter_name, + prompt_tuning_config=prompt_tuning_config, # type: ignore + model_config=self.model_config, + base_model_embeddings=self.base_model_embeddings, + ) + adapter.train(self.training) # make sure training mode is consistent + self.prompt_tunings[adapter_name] = adapter + return True + + return False + + def delete_adapter(self, adapter_name: str): + if adapter_name in self.prompt_tunings: + del self.prompt_tunings[adapter_name] + + def enable_adapters(self, adapter_setup: AdapterCompositionBlock, unfreeze_adapters: bool, unfreeze_fusion: bool): + pass + # TODO + # if unfreeze_adapters: + # for prefix_tuning_name in adapter_setup.flatten(): + # self.pool.enable_prefix(prefix_tuning_name) + # if prefix_tuning_name in self.prefix_gates: + # for param in self.prefix_gates[prefix_tuning_name].parameters(): + # param.requires_grad = unfreeze_adapters + + def freeze_adapter(self, adapter_name: str, freeze: bool = True): + pass + # TODO + # if adapter_name in self.prefixes: + # self.pool.get_prefix(adapter_name)[self.location_key].train(not freeze) + # for param in self.pool.get_prefix(adapter_name)[self.location_key].parameters(): + # param.requires_grad = not freeze + # if adapter_name in self.prefix_gates: + # for param in self.prefix_gates[adapter_name].parameters(): + # param.requires_grad = not freeze + + def get_adapter(self, adapter_name): + # TODO + # return_dict = nn.ModuleDict() + # # Make sure to only return params once + # if adapter_name in self.prefixes and self.prefixes[adapter_name] == 0: + # prefix_module = self.pool.get_prefix(adapter_name) + # if prefix_module is not None: + # return_dict["prefix"] = prefix_module[self.location_key] + # if adapter_name in self.prefix_gates: + # return_dict["gate"] = self.prefix_gates[adapter_name] + # if len(return_dict) > 0: + # return return_dict + + return None + + def average_adapter(self, adapter_name: str, input_adapters: Dict[str, float]) -> bool: + raise NotImplementedError() + + def add_fusion_layer(self, adapter_names: Union[List, str]): + raise NotImplementedError() + + def delete_fusion_layer(self, adapter_names: Union[List, str]): + raise NotImplementedError() diff --git a/tests_adapters/composition/test_adapter_composition.py b/tests_adapters/composition/test_adapter_composition.py index 2670488cb9..4865f32b14 100644 --- a/tests_adapters/composition/test_adapter_composition.py +++ b/tests_adapters/composition/test_adapter_composition.py @@ -5,7 +5,7 @@ import adapters from adapters import PrefixTuningConfig, SeqBnConfig from adapters.composition import Average, BatchSplit, Fuse, Parallel, Split, Stack, parse_composition -from tests.test_modeling_common import ids_tensor +from tests_adapters.test_adapter import ids_tensor from transformers import BertConfig, BertForSequenceClassification from transformers.testing_utils import require_torch, torch_device diff --git a/tests_adapters/methods/__init__.py b/tests_adapters/methods/__init__.py index f40a688e58..b1cbe52de4 100644 --- a/tests_adapters/methods/__init__.py +++ b/tests_adapters/methods/__init__.py @@ -22,4 +22,5 @@ from .test_ia3 import IA3TestMixin from .test_lora import LoRATestMixin from .test_prefix_tuning import PrefixTuningTestMixin +from .test_prompt_tuning import PromptTuningTestMixin from .test_unipelt import UniPELTTestMixin diff --git a/tests_adapters/methods/test_prompt_tuning.py b/tests_adapters/methods/test_prompt_tuning.py new file mode 100644 index 0000000000..a9e441aef2 --- /dev/null +++ b/tests_adapters/methods/test_prompt_tuning.py @@ -0,0 +1,44 @@ +import torch + +from adapters import ADAPTER_MODEL_MAPPING, AutoAdapterModel, PrefixTuningConfig, PromptTuningConfig +from transformers.testing_utils import require_torch, torch_device + +from .base import AdapterMethodBaseTestMixin + + +@require_torch +class PromptTuningTestMixin(AdapterMethodBaseTestMixin): + def test_add_prompt_tuning(self): + model = self.get_model() + self.run_add_test( + model, PromptTuningConfig(prompt_length=10), ["prompt_tunings.{name}."] + ) # TODO: provide parameters in PromptTuningConfig(...) ? + + # TODO: add tests to add different configs (like initialization [random_uniform, from_array, ...] or prefix_prompt vs prefix_prompt_after_bos + + # def test_average_prompt_tuning(self): + # model = self.get_model() + # self.run_average_test(model, PromptTuningConfig(), ["prompt_tunings.{name}."]) + + def test_delete_prompt_tuning(self): + model = self.get_model() + self.run_delete_test(model, PromptTuningConfig(prompt_length=10), ["prompt_tunings.{name}."]) + + def test_get_prompt_tuning(self): + model = self.get_model() + self.run_get_test( + model, PromptTuningConfig(prompt_length=10), 1 + ) # TODO: last number is number of layers. Is this really 1? + + def test_forward_prompt_tuning(self): + model = self.get_model() + self.run_forward_test(model, PromptTuningConfig(prompt_length=10)) + + def test_load_prompt_tuning(self): + self.run_load_test(PromptTuningConfig(prompt_length=10)) + + def test_load_full_model_prefix_tuning(self): + self.run_full_model_load_test(PromptTuningConfig(prompt_length=10)) + + def test_train_prefix_tuning(self): + self.run_train_test(PromptTuningConfig(prompt_length=10), ["prompt_tunings.{name}."]) diff --git a/tests_adapters/test_adapter.py b/tests_adapters/test_adapter.py index d5e8d2754c..84c45c67dc 100644 --- a/tests_adapters/test_adapter.py +++ b/tests_adapters/test_adapter.py @@ -9,10 +9,29 @@ from transformers.testing_utils import torch_device +global_rng = random.Random() + + def make_config(config_class, **kwargs): return staticmethod(lambda: config_class(**kwargs)) +def ids_tensor(shape, vocab_size, rng=None, name=None): + # Creates a random int32 tensor of the shape within the vocab size + if rng is None: + rng = global_rng + + total_dims = 1 + for dim in shape: + total_dims *= dim + + values = [] + for _ in range(total_dims): + values.append(rng.randint(0, vocab_size - 1)) + + return torch.tensor(data=values, dtype=torch.long, device=torch_device).view(shape).contiguous() + + class AdapterTestBase: # If not overriden by subclass, AutoModel should be used. model_class = AutoAdapterModel diff --git a/tests_adapters/test_adapter_custom_head.py b/tests_adapters/test_adapter_custom_head.py index ea37b52325..b68662bfc6 100644 --- a/tests_adapters/test_adapter_custom_head.py +++ b/tests_adapters/test_adapter_custom_head.py @@ -5,10 +5,11 @@ from adapters import AutoAdapterModel from adapters.heads import ClassificationHead, PredictionHead -from tests.test_modeling_common import ids_tensor from transformers import AutoConfig from transformers.testing_utils import require_torch, torch_device +from .test_adapter import ids_tensor + class CustomHead(PredictionHead): def __init__( diff --git a/tests_adapters/test_adapter_hub.py b/tests_adapters/test_adapter_hub.py index cf07bd6e0a..24ac67c8b8 100644 --- a/tests_adapters/test_adapter_hub.py +++ b/tests_adapters/test_adapter_hub.py @@ -7,7 +7,6 @@ from adapters import ADAPTER_CONFIG_MAP, AdapterConfigBase, BertAdapterModel, get_adapter_config_hash from adapters.trainer import AdapterTrainer as Trainer from adapters.utils import find_in_index -from tests.test_modeling_common import ids_tensor from transformers import ( # get_adapter_config_hash, AutoModel, AutoTokenizer, @@ -19,6 +18,8 @@ ) from transformers.testing_utils import require_torch, torch_device +from .test_adapter import ids_tensor + SAMPLE_INDEX = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/hub-index.sample.json") From c94e59570e7b9eccc858a210e01e445c9f28e26e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leon=20Engl=C3=A4nder?= Date: Wed, 18 Oct 2023 13:29:09 +0200 Subject: [PATCH 02/13] bert prompt tuning running, rest WIP --- .gitignore | 4 +- src/adapters/configuration/adapter_config.py | 9 +- .../configuration/model_adapters_config.py | 2 + src/adapters/{ => methods}/prompt_tuning.py | 99 ++++++++----------- src/adapters/model_mixin.py | 70 ++++++------- src/adapters/models/bert/mixin_bert.py | 18 +--- src/adapters/models/bert/modeling_bert.py | 11 +++ tests_adapters/methods/test_adapter_common.py | 9 +- tests_adapters/methods/test_prompt_tuning.py | 6 +- tests_adapters/test_albert.py | 2 + tests_adapters/test_bart.py | 2 + tests_adapters/test_beit.py | 2 + tests_adapters/test_bert.py | 2 + tests_adapters/test_bert_generation.py | 2 + tests_adapters/test_clip.py | 2 + tests_adapters/test_deberta.py | 2 + tests_adapters/test_debertaV2.py | 2 + tests_adapters/test_distilbert.py | 2 + tests_adapters/test_electra.py | 2 + tests_adapters/test_encoder_decoder.py | 2 + tests_adapters/test_gpt2.py | 2 + tests_adapters/test_gptj.py | 2 + tests_adapters/test_llama.py | 2 + tests_adapters/test_mbart.py | 2 + tests_adapters/test_roberta.py | 2 + tests_adapters/test_t5.py | 2 + tests_adapters/test_vit.py | 2 + tests_adapters/test_xlm_roberta.py | 2 + tests_adapters/test_xmod.py | 2 + 29 files changed, 147 insertions(+), 121 deletions(-) rename src/adapters/{ => methods}/prompt_tuning.py (73%) diff --git a/.gitignore b/.gitignore index e0d065aeab..e3e8f2a311 100644 --- a/.gitignore +++ b/.gitignore @@ -73,8 +73,8 @@ instance/ .scrapy # Sphinx documentation -docs/_build/ -docs/_build/ +adapter_docs/_build/ +adapter_docs/_build/ # PyBuilder target/ diff --git a/src/adapters/configuration/adapter_config.py b/src/adapters/configuration/adapter_config.py index 765295282a..0c089dd28e 100644 --- a/src/adapters/configuration/adapter_config.py +++ b/src/adapters/configuration/adapter_config.py @@ -399,26 +399,23 @@ class PrefixTuningConfig(AdapterConfigBase): @dataclass(eq=False) class PromptTuningConfig(AdapterConfigBase): - # TODO: add config + # TODO: documentation """ The Prompt Tuning architecture proposed by Lester et al. (2021). See https://arxiv.org/pdf/2104.08691.pdf - Args: - - """ + Args:""" prompt_length: int prompt_init_text: Optional[str] = None # only necessary when using prompt_init="from_string" architecture: Optional[str] = "prompt_tuning" - prompt_init: str = ( # random_uniform, from_string, from_array, TODO: add more from https://github.com/google-research/prompt-tuning/blob/main/prompt_tuning/prompts.py + prompt_init: str = ( # random_uniform, from_string, from_array, TODO: ? add more from https://github.com/google-research/prompt-tuning/blob/main/prompt_tuning/prompts.py "random_uniform" ) combine: str = "prefix" # prefix, prefix_after_bos, suffix # TODO: add a parameter for the random uniform scale - # TODO: add more params if necessary @dataclass(eq=False) diff --git a/src/adapters/configuration/model_adapters_config.py b/src/adapters/configuration/model_adapters_config.py index aeb89d493f..329c529765 100644 --- a/src/adapters/configuration/model_adapters_config.py +++ b/src/adapters/configuration/model_adapters_config.py @@ -32,6 +32,8 @@ def __init__(self, **kwargs): self.active_setup: Optional[AdapterCompositionBlock] = None self.skip_layers = None + self.prefix_attention_mask_length = None + def __contains__(self, item): return item in self.adapters.keys() diff --git a/src/adapters/prompt_tuning.py b/src/adapters/methods/prompt_tuning.py similarity index 73% rename from src/adapters/prompt_tuning.py rename to src/adapters/methods/prompt_tuning.py index b50c127ae1..1e83b60746 100644 --- a/src/adapters/prompt_tuning.py +++ b/src/adapters/methods/prompt_tuning.py @@ -11,9 +11,9 @@ from transformers import AutoTokenizer from transformers.configuration_utils import PretrainedConfig -from .composition import AdapterCompositionBlock, BatchSplit, Parallel, Stack, adjust_tensors_for_parallel -from .configuration import ModelAdaptersConfig, PromptTuningConfig -from .layer import AdapterLayerBase +from ..composition import AdapterCompositionBlock, BatchSplit, Parallel, Stack, adjust_tensors_for_parallel +from ..configuration import ModelAdaptersConfig, PromptTuningConfig +from .adapter_layer_base import AdapterLayerBase logger = logging.getLogger(__name__) @@ -48,7 +48,6 @@ def __init__( self.name = adapter_name self.model_config = model_config self.prompt_tuning_config = prompt_tuning_config - self.base_model_embeddings = base_model_embeddings embedding_size = getattr(model_config, "embedding_size", model_config.hidden_size) @@ -58,7 +57,7 @@ def __init__( # Initialize prompt tokens self.prompt_tokens = torch.arange(prompt_tuning_config.prompt_length).long() - self._init_prompt_embedding() + self._init_prompt_embedding(base_model_embeddings) if prompt_tuning_config.combine == "prefix": self.combination_fn = lambda prompt, embedded_input: torch.cat([prompt, embedded_input], dim=1) @@ -72,7 +71,7 @@ def __init__( "Must be one of 'prefix' or 'prefix_after_bos'." ) - def _init_prompt_embedding(self) -> None: + def _init_prompt_embedding(self, base_model_embeddings: nn.Module) -> None: if self.prompt_tuning_config.prompt_init == "random_uniform": # Embedding was created using torch.nn.Embedding which already uses a random uniform distribution for initialization pass @@ -95,9 +94,7 @@ def _init_prompt_embedding(self) -> None: tokenized_prompt_text = tokenized_prompt_text[:prompt_length] # Initialize prompt embedding with tokenized prompt text - word_embedding_weights = ( - self.base_model_embeddings(torch.LongTensor(tokenized_prompt_text)).detach().clone() - ) + word_embedding_weights = base_model_embeddings(torch.LongTensor(tokenized_prompt_text)).detach().clone() word_embedding_weights = word_embedding_weights.to(torch.float32) self.prompt_embedding.weight = nn.Parameter(word_embedding_weights) @@ -118,14 +115,16 @@ def forward(self, embedded_input): output = self.combination_fn(prompt, embedded_input) # Adapt attention mask - prefix_attention_mask = torch.ones(batch_size, self.prompt_tuning_config.prompt_length) + prefix_attention_mask_length = self.prompt_tuning_config.prompt_length - return output, prefix_attention_mask + return output, prefix_attention_mask_length class PromptTuningLayer(AdapterLayerBase, nn.Module): # TODO: add documentation + adapter_modules_name = "prompt_tunings" + def __init__( self, model_config: PretrainedConfig, @@ -138,22 +137,11 @@ def __init__( self.base_model_embeddings = base_model_embeddings self.prompt_tunings = nn.ModuleDict() - def forward(self, hidden_states: torch.Tensor): - # TODO: Takes currently only very first prompt tuning adapter - if self.adapters_config.active_setup is not None and len(self.adapters_config.active_setup) > 0: - first_adapter = self.adapters_config.active_setup.first() - if first_adapter in self.prompt_tunings: - hidden_states = self.prompt_tunings[first_adapter](hidden_states) - - return hidden_states - def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: # ignore layer_idx as prompt tunings are only added after the embedding layer prompt_tuning_config = self.adapters_config.match( adapter_name, config_type=PromptTuningConfig, - # layer_idx=self.layer_idx, - # location_key="prompt_tuning", ) if prompt_tuning_config is not None: @@ -169,51 +157,46 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: return False + def average_adapter(self, adapter_name: str, input_adapters: Dict[str, float]) -> bool: + raise NotImplementedError() + def delete_adapter(self, adapter_name: str): if adapter_name in self.prompt_tunings: del self.prompt_tunings[adapter_name] + def add_fusion_layer(self, adapter_names: Union[List, str]): + raise NotImplementedError() + + def delete_fusion_layer(self, adapter_names: Union[List, str]): + raise NotImplementedError() + def enable_adapters(self, adapter_setup: AdapterCompositionBlock, unfreeze_adapters: bool, unfreeze_fusion: bool): - pass - # TODO - # if unfreeze_adapters: - # for prefix_tuning_name in adapter_setup.flatten(): - # self.pool.enable_prefix(prefix_tuning_name) - # if prefix_tuning_name in self.prefix_gates: - # for param in self.prefix_gates[prefix_tuning_name].parameters(): - # param.requires_grad = unfreeze_adapters + if unfreeze_adapters: + for prompt_tuning_name in adapter_setup.flatten(): + if prompt_tuning_name in self.prompt_tunings: + for param in self.prompt_tunings[prompt_tuning_name].parameters(): + param.requires_grad = True def freeze_adapter(self, adapter_name: str, freeze: bool = True): - pass - # TODO - # if adapter_name in self.prefixes: - # self.pool.get_prefix(adapter_name)[self.location_key].train(not freeze) - # for param in self.pool.get_prefix(adapter_name)[self.location_key].parameters(): - # param.requires_grad = not freeze - # if adapter_name in self.prefix_gates: - # for param in self.prefix_gates[adapter_name].parameters(): - # param.requires_grad = not freeze + if adapter_name in self.prompt_tunings: + self.prompt_tunings[adapter_name].train(not freeze) + for param in self.prompt_tunings[adapter_name].parameters(): + param.requires_grad = not freeze def get_adapter(self, adapter_name): - # TODO - # return_dict = nn.ModuleDict() - # # Make sure to only return params once - # if adapter_name in self.prefixes and self.prefixes[adapter_name] == 0: - # prefix_module = self.pool.get_prefix(adapter_name) - # if prefix_module is not None: - # return_dict["prefix"] = prefix_module[self.location_key] - # if adapter_name in self.prefix_gates: - # return_dict["gate"] = self.prefix_gates[adapter_name] - # if len(return_dict) > 0: - # return return_dict - - return None + if adapter_name in self.prompt_tunings: + return self.prompt_tunings[adapter_name] + else: + return None - def average_adapter(self, adapter_name: str, input_adapters: Dict[str, float]) -> bool: - raise NotImplementedError() + def forward(self, hidden_states: torch.Tensor): + prefix_attention_mask_length = None + adapter_setup = self.get_active_setup() + if adapter_setup is not None and len(adapter_setup) > 0: + first_adapter = adapter_setup.first() + if first_adapter in self.prompt_tunings: + hidden_states, prefix_attention_mask_length = self.prompt_tunings[first_adapter](hidden_states) - def add_fusion_layer(self, adapter_names: Union[List, str]): - raise NotImplementedError() + self.adapters_config.prefix_attention_mask_length = prefix_attention_mask_length - def delete_fusion_layer(self, adapter_names: Union[List, str]): - raise NotImplementedError() + return hidden_states diff --git a/src/adapters/model_mixin.py b/src/adapters/model_mixin.py index 7b0a22d2a6..efe3f62ea5 100644 --- a/src/adapters/model_mixin.py +++ b/src/adapters/model_mixin.py @@ -42,21 +42,6 @@ def init_adapters(self, model_config, adapters_config, **kwargs): if hasattr(super(), "init_adapters"): super().init_adapters(self.config, self.adapters_config, **kwargs) - # self.hook_after_embeddings(self._hook_fn) - - # def _hook_fn(self, module, args, output): - # new_output = self.invertible_adapters_forward(output) - # return new_output - - # TODO: entfernen - # def hook_after_embeddings(self, hook_fn: Callable): - # """ - # Hook a function to be called after the embeddings have been computed. The default implementation does nothing. # - Override this method to add a hook. - - # Args: # hook_fn (Callable): The function to be called after the embeddings have been computed. #""" - # pass - def add_invertible_adapter(self, adapter_name: str) -> bool: """ Adds an invertible adapter module for the adapter with the given name. If the given adapter does not specify an @@ -133,13 +118,31 @@ def enable_invertible_adapters(self, adapter_names): def invertible_adapters_forward(self, hidden_states, rev=False): # TODO: Currently no fusion over invertible adapters, takes only very first language adapter position - if self.adapters_config.active_setup is not None and len(self.adapters_config.active_setup) > 0: - first_adapter = self.adapters_config.active_setup.first() + adapter_setup = self._get_active_setup() + if adapter_setup is not None and len(adapter_setup) > 0: + first_adapter = adapter_setup.first() if first_adapter in self.invertible_adapters: hidden_states = self.invertible_adapters[first_adapter](hidden_states, rev=rev) - return hidden_states + def _get_active_setup(self): + if hasattr(self, "adapters_config"): + # First check current context before falling back to defined setup + context = AdapterSetup.get_context() + if context is not None: + adapter_setup = context.adapter_setup + else: + adapter_setup = self.adapters_config.active_setup + else: + adapter_setup = None + skip_adapters = adapter_setup is None or ( + self.adapters_config.skip_layers is not None and self.layer_idx in self.adapters_config.skip_layers + ) + if not skip_adapters and (len(adapter_setup.flatten()) > 0): + return adapter_setup + else: + return None + class InvertibleAdaptersWrapperMixin: """ @@ -382,7 +385,6 @@ class ModelAdaptersMixin(PushAdapterToHubMixin, PromptTuningMixin, ABC): # This prompt tuning layer is stoed in self.prompt_tuning # Since the correct position to call the prompt tuning forward depends on the model type, this has to be called in the post_embedding_forward method supports_prompt_tuning = False - prefix_attention_mask = None def __init__(self, config, *args, **kwargs): super().__init__(config, *args, **kwargs) @@ -391,12 +393,14 @@ def _link_prefix_to_pool(self, layer): if isinstance(layer, PrefixTuningLayer): layer.set_pool(self.base_model.prefix_tuning) - # TODO: provide documentation & maybe better fitting name & move to more fitting place in this class - # @abstractmethod def post_embedding_forward(self, module, args, embedding_output): - # def post_embedding_forward(self, *args, **kwargs): - """function to pass the embedding layer output through""" - raise NotImplementedError + if isinstance(self, InvertibleAdaptersMixin) or isinstance(self, InvertibleAdaptersWrapperMixin): + embedding_output = self.invertible_adapters_forward(embedding_output) + + if self.supports_prompt_tuning: + embedding_output = self.prompt_tuning.forward(embedding_output) + + return embedding_output @property def model_name(self): @@ -681,6 +685,9 @@ def delete_adapter(self, adapter_name: str): del self.base_model.shared_parameters[adapter_name] if isinstance(self, InvertibleAdaptersMixin) or isinstance(self, InvertibleAdaptersWrapperMixin): self.delete_invertible_adapter(adapter_name) + if self.supports_prompt_tuning: + self.prompt_tuning.delete_adapter(adapter_name) + # Reset active adapters if this was the only active adapter if self.active_adapters == Stack(adapter_name): self.active_adapters = None @@ -1001,6 +1008,11 @@ def get_adapter(self, name) -> dict: ) and name in self.invertible_adapters: destination[-1]["invertible"] = self.invertible_adapters[name] + if self.supports_prompt_tuning: + prompt_tuning = self.prompt_tuning.get_adapter(name) + if prompt_tuning is not None: + destination[-1]["prompt"] = prompt_tuning + # use a custom index to ensure numbering is from 0 to N layers for i, (_, layer) in enumerate(self.iter_layers()): for module in layer.modules(): @@ -1263,16 +1275,6 @@ def save_pretrained( class ModelBaseAdaptersMixin(ModelAdaptersMixin): @ForwardContext.wrap def forward(self, *args, **kwargs): - # TODO: adapt attention mask here - attention_mask = kwargs.get("attention_mask", None) - if attention_mask is not None and self.prefix_attention_mask is not None: - self.prefix_attention_mask = self.prefix_attention_mask.to(attention_mask.device) - attention_mask = torch.cat((self.prefix_attention_mask, attention_mask), dim=1) - kwargs.update( - { - "attention_mask": attention_mask, - } - ) return super().forward(*args, **kwargs) diff --git a/src/adapters/models/bert/mixin_bert.py b/src/adapters/models/bert/mixin_bert.py index b752e3a3f6..486162602c 100644 --- a/src/adapters/models/bert/mixin_bert.py +++ b/src/adapters/models/bert/mixin_bert.py @@ -17,6 +17,7 @@ class BertSelfAttentionAdaptersMixin: """Adds adapters to the BertSelfAttention module.""" def init_adapters(self, model_config, adapters_config): + self.adapters_config = adapters_config # Wrap layers for LoRA self.query = LoRALinear.wrap(self.query, "selfattn", model_config, adapters_config, attn_key="q") self.key = LoRALinear.wrap(self.key, "selfattn", model_config, adapters_config, attn_key="k") @@ -75,10 +76,6 @@ class BertModelAdaptersMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, Mo def init_adapters(self, model_config, adapters_config): super().init_adapters(model_config, adapters_config) - # TODO: move somewhere else? e.g. in ModelBaseAdaptersMixin as we can add this and just don't use it in post_embedding_forward if we don't want to support it - # wie machen wir das mit dem Hinzufügen der Adaoter? Wenn wir es nicht benutzen soll man auch keine hinzufügen können - # self.prompt_tuning = PromptTuningLayer(adapters_config) - # Set hook for parallel composition for _, layer in self.iter_layers(): self._set_layer_hook_for_parallel(layer) @@ -96,16 +93,3 @@ def hook(module, input): def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: for i, layer in enumerate(self.encoder.layer): yield i, layer - - def post_embedding_forward(self, module, args, embedding_output): - print(f"args: {module}") - - new_output = self.invertible_adapters_forward(embedding_output) - new_output, prefix_attention_mask = self.prompt_tuning.forward(embedding_output) - - # TODO: das funktioniert so nicht alleine. Wir müssen die attention mask in die BertSelfAttention bekommen. - # Das ist aber nicht so einfach, weil wir prefix_attention_mask nicht direkt BertSelfAttention mitgeben können, da wir die mitgegebenen Parameter nicht verändern können. - # Wir können das mittels BertSelfAttentionAdaptersMixin machen, indem: - self.prefix_attention_mask = prefix_attention_mask - - return new_output diff --git a/src/adapters/models/bert/modeling_bert.py b/src/adapters/models/bert/modeling_bert.py index 539dc74ebf..e334cf587c 100644 --- a/src/adapters/models/bert/modeling_bert.py +++ b/src/adapters/models/bert/modeling_bert.py @@ -40,6 +40,17 @@ def forward( past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor]: + if attention_mask is not None and self.adapters_config.prefix_attention_mask_length is not None: + prefix_attention_mask = torch.ones( + attention_mask.shape[0], + attention_mask.shape[1], + attention_mask.shape[2], + self.adapters_config.prefix_attention_mask_length, + dtype=torch.float32, + ).to(attention_mask.device) + + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3) + mixed_query_layer = self.query(hidden_states) # If this is instantiated as a cross-attention module, the keys diff --git a/tests_adapters/methods/test_adapter_common.py b/tests_adapters/methods/test_adapter_common.py index 616e6a99e8..5c543dadca 100644 --- a/tests_adapters/methods/test_adapter_common.py +++ b/tests_adapters/methods/test_adapter_common.py @@ -27,7 +27,6 @@ @require_torch class BottleneckAdapterTestMixin(AdapterMethodBaseTestMixin): - adapter_configs_to_test = [ (SeqBnConfig(), ["adapters.{name}."]), (MAMConfig(), ["adapters.{name}.", "prefix_tunings.{name}."]), @@ -211,6 +210,14 @@ def test_adapter_forward(self): with self.subTest(model_class=model.__class__.__name__, config=adapter_config.__class__.__name__): self.run_forward_test(model, adapter_config) + def test_invertible_adapter_forward(self): + model = self.get_model() + model.eval() + + for adapter_config, _ in self.inv_adapter_configs_to_test: + with self.subTest(model_class=model.__class__.__name__, config=adapter_config.__class__.__name__): + self.run_forward_test(model, adapter_config) + def test_load_adapter(self): self.run_load_test(SeqBnConfig()) diff --git a/tests_adapters/methods/test_prompt_tuning.py b/tests_adapters/methods/test_prompt_tuning.py index a9e441aef2..764e1980fb 100644 --- a/tests_adapters/methods/test_prompt_tuning.py +++ b/tests_adapters/methods/test_prompt_tuning.py @@ -16,9 +16,9 @@ def test_add_prompt_tuning(self): # TODO: add tests to add different configs (like initialization [random_uniform, from_array, ...] or prefix_prompt vs prefix_prompt_after_bos - # def test_average_prompt_tuning(self): - # model = self.get_model() - # self.run_average_test(model, PromptTuningConfig(), ["prompt_tunings.{name}."]) + def test_average_prompt_tuning(self): + model = self.get_model() + self.run_average_test(model, PromptTuningConfig(prompt_length=10), ["prompt_tunings.{name}."]) def test_delete_prompt_tuning(self): model = self.get_model() diff --git a/tests_adapters/test_albert.py b/tests_adapters/test_albert.py index 29f8a2b583..054dd31278 100644 --- a/tests_adapters/test_albert.py +++ b/tests_adapters/test_albert.py @@ -11,6 +11,7 @@ IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, ) from .test_adapter import AdapterTestBase, make_config @@ -42,6 +43,7 @@ class AlbertAdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, EmbeddingTestMixin, AdapterFusionModelTestMixin, diff --git a/tests_adapters/test_bart.py b/tests_adapters/test_bart.py index e40c9df521..f9fcfbb83a 100644 --- a/tests_adapters/test_bart.py +++ b/tests_adapters/test_bart.py @@ -11,6 +11,7 @@ IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, ) from .test_adapter import AdapterTestBase, make_config @@ -43,6 +44,7 @@ class BartAdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, AdapterFusionModelTestMixin, CompabilityTestMixin, diff --git a/tests_adapters/test_beit.py b/tests_adapters/test_beit.py index b2014c6c00..1e83c9e529 100644 --- a/tests_adapters/test_beit.py +++ b/tests_adapters/test_beit.py @@ -9,6 +9,7 @@ IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, ) from .test_adapter import VisionAdapterTestBase, make_config @@ -38,6 +39,7 @@ class BeitAdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, AdapterFusionModelTestMixin, CompabilityTestMixin, diff --git a/tests_adapters/test_bert.py b/tests_adapters/test_bert.py index 702e68ab9d..b4e67bc811 100644 --- a/tests_adapters/test_bert.py +++ b/tests_adapters/test_bert.py @@ -11,6 +11,7 @@ IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, ) from .test_adapter import AdapterTestBase, make_config @@ -40,6 +41,7 @@ class BertAdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, EmbeddingTestMixin, AdapterFusionModelTestMixin, diff --git a/tests_adapters/test_bert_generation.py b/tests_adapters/test_bert_generation.py index feb821ca0b..44cbd25f8e 100644 --- a/tests_adapters/test_bert_generation.py +++ b/tests_adapters/test_bert_generation.py @@ -12,6 +12,7 @@ IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, ) from .test_adapter import AdapterTestBase, make_config @@ -84,6 +85,7 @@ class BertGenerationAdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, EmbeddingTestMixin, AdapterFusionModelTestMixin, diff --git a/tests_adapters/test_clip.py b/tests_adapters/test_clip.py index 2ed57268e4..eea58262f8 100644 --- a/tests_adapters/test_clip.py +++ b/tests_adapters/test_clip.py @@ -20,6 +20,7 @@ IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, ) from .test_adapter import AdapterTestBase, VisionAdapterTestBase, make_config @@ -78,6 +79,7 @@ class CLIPVisionWithProjectionAdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, AdapterFusionModelTestMixin, CompabilityTestMixin, diff --git a/tests_adapters/test_deberta.py b/tests_adapters/test_deberta.py index 96be88d26a..7fd80322ec 100644 --- a/tests_adapters/test_deberta.py +++ b/tests_adapters/test_deberta.py @@ -11,6 +11,7 @@ IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, ) from .test_adapter import AdapterTestBase, make_config @@ -47,6 +48,7 @@ class DebertaAdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, EmbeddingTestMixin, ParallelTrainingMixin, diff --git a/tests_adapters/test_debertaV2.py b/tests_adapters/test_debertaV2.py index bc436d996d..b2d564c2e4 100644 --- a/tests_adapters/test_debertaV2.py +++ b/tests_adapters/test_debertaV2.py @@ -11,6 +11,7 @@ IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, ) from .test_adapter import AdapterTestBase, make_config @@ -47,6 +48,7 @@ class DebertaV2AdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, EmbeddingTestMixin, ParallelTrainingMixin, diff --git a/tests_adapters/test_distilbert.py b/tests_adapters/test_distilbert.py index d401fe220b..2634b390a5 100644 --- a/tests_adapters/test_distilbert.py +++ b/tests_adapters/test_distilbert.py @@ -11,6 +11,7 @@ IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, ) from .test_adapter import AdapterTestBase, make_config @@ -40,6 +41,7 @@ class DistilBertAdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, EmbeddingTestMixin, CompabilityTestMixin, diff --git a/tests_adapters/test_electra.py b/tests_adapters/test_electra.py index 5566e7be0d..a5c005f509 100644 --- a/tests_adapters/test_electra.py +++ b/tests_adapters/test_electra.py @@ -11,6 +11,7 @@ IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, ) from .test_adapter import AdapterTestBase, make_config @@ -41,6 +42,7 @@ class ElectraAdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, EmbeddingTestMixin, AdapterFusionModelTestMixin, diff --git a/tests_adapters/test_encoder_decoder.py b/tests_adapters/test_encoder_decoder.py index 708a6bfbb2..1a13d8110f 100644 --- a/tests_adapters/test_encoder_decoder.py +++ b/tests_adapters/test_encoder_decoder.py @@ -12,6 +12,7 @@ IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, ) from .test_adapter import AdapterTestBase @@ -50,6 +51,7 @@ class EncoderDecoderAdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, AdapterFusionModelTestMixin, EncoderDecoderAdapterTestBase, diff --git a/tests_adapters/test_gpt2.py b/tests_adapters/test_gpt2.py index 620435e532..3ca6783920 100644 --- a/tests_adapters/test_gpt2.py +++ b/tests_adapters/test_gpt2.py @@ -11,6 +11,7 @@ IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, ) from .test_adapter import AdapterTestBase, make_config @@ -41,6 +42,7 @@ class GPT2AdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, EmbeddingTestMixin, CompabilityTestMixin, diff --git a/tests_adapters/test_gptj.py b/tests_adapters/test_gptj.py index 2c2de0dc0c..5d3bce3dd9 100644 --- a/tests_adapters/test_gptj.py +++ b/tests_adapters/test_gptj.py @@ -11,6 +11,7 @@ IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, ) from .test_adapter import AdapterTestBase, make_config @@ -44,6 +45,7 @@ class GPTJAdapterTest( LoRATestMixin, UniPELTTestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, EmbeddingTestMixin, CompabilityTestMixin, AdapterFusionModelTestMixin, diff --git a/tests_adapters/test_llama.py b/tests_adapters/test_llama.py index 2fd455c174..9b3ab488cd 100644 --- a/tests_adapters/test_llama.py +++ b/tests_adapters/test_llama.py @@ -10,6 +10,7 @@ IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, ) from .test_adapter import AdapterTestBase, make_config @@ -42,6 +43,7 @@ class LlamaAdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, EmbeddingTestMixin, AdapterFusionModelTestMixin, diff --git a/tests_adapters/test_mbart.py b/tests_adapters/test_mbart.py index 775e1fdebb..5f726aa9be 100644 --- a/tests_adapters/test_mbart.py +++ b/tests_adapters/test_mbart.py @@ -10,6 +10,7 @@ IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, ) from .test_adapter import AdapterTestBase, make_config @@ -41,6 +42,7 @@ class MBartAdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, AdapterFusionModelTestMixin, PredictionHeadModelTestMixin, diff --git a/tests_adapters/test_roberta.py b/tests_adapters/test_roberta.py index 5ccbb53eee..6d105ceaec 100644 --- a/tests_adapters/test_roberta.py +++ b/tests_adapters/test_roberta.py @@ -11,6 +11,7 @@ IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, ) from .test_adapter import AdapterTestBase, make_config @@ -40,6 +41,7 @@ class RobertaAdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, AdapterFusionModelTestMixin, CompabilityTestMixin, diff --git a/tests_adapters/test_t5.py b/tests_adapters/test_t5.py index c8717d8b54..b2981b6110 100644 --- a/tests_adapters/test_t5.py +++ b/tests_adapters/test_t5.py @@ -10,6 +10,7 @@ IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, ) from .test_adapter import AdapterTestBase, make_config @@ -44,6 +45,7 @@ class T5AdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, EmbeddingTestMixin, CompabilityTestMixin, diff --git a/tests_adapters/test_vit.py b/tests_adapters/test_vit.py index d84d8523ca..2de1b34300 100644 --- a/tests_adapters/test_vit.py +++ b/tests_adapters/test_vit.py @@ -10,6 +10,7 @@ IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, ) from .test_adapter import VisionAdapterTestBase, make_config @@ -39,6 +40,7 @@ class ViTAdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, AdapterFusionModelTestMixin, CompabilityTestMixin, diff --git a/tests_adapters/test_xlm_roberta.py b/tests_adapters/test_xlm_roberta.py index f46d9543ac..96268302f7 100644 --- a/tests_adapters/test_xlm_roberta.py +++ b/tests_adapters/test_xlm_roberta.py @@ -9,6 +9,7 @@ IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, ) from .test_adapter import AdapterTestBase, make_config @@ -36,6 +37,7 @@ class XLMRobertaAdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, AdapterFusionModelTestMixin, XLMRobertaAdapterTestBase, diff --git a/tests_adapters/test_xmod.py b/tests_adapters/test_xmod.py index 450c84231d..a8cd02c4d9 100644 --- a/tests_adapters/test_xmod.py +++ b/tests_adapters/test_xmod.py @@ -10,6 +10,7 @@ IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, ) from .test_adapter import AdapterTestBase, make_config @@ -41,6 +42,7 @@ class XmodAdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, + PromptTuningTestMixin, UniPELTTestMixin, AdapterFusionModelTestMixin, CompabilityTestMixin, From d97e423cdfc7db0b374b3e4323b4156c0c193fbe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leon=20Engl=C3=A4nder?= Date: Wed, 25 Oct 2023 01:26:14 +0200 Subject: [PATCH 03/13] WIP. Still missing: CLIP, GPT2, GPTJ, EncoderDecoder, T5 --- src/adapters/composition.py | 22 ++++++ src/adapters/loading.py | 1 + src/adapters/methods/adapter_layer_base.py | 2 +- src/adapters/methods/prompt_tuning.py | 28 +++---- src/adapters/model_mixin.py | 74 +++++++++---------- src/adapters/models/albert/mixin_albert.py | 9 ++- src/adapters/models/albert/modeling_albert.py | 4 +- src/adapters/models/bart/mixin_bart.py | 12 ++- src/adapters/models/bart/modeling_bart.py | 3 +- src/adapters/models/beit/mixin_beit.py | 2 + src/adapters/models/bert/mixin_bert.py | 2 - src/adapters/models/bert/modeling_bert.py | 13 +--- .../modeling_bert_generation.py | 4 +- src/adapters/models/deberta/mixin_deberta.py | 2 + .../models/deberta/modeling_deberta.py | 5 +- .../models/deberta_v2/mixin_deberta_v2.py | 1 + .../models/deberta_v2/modeling_deberta_v2.py | 7 +- .../models/distilbert/mixin_distilbert.py | 16 ++-- .../models/distilbert/modeling_distilbert.py | 4 +- .../models/electra/modeling_electra.py | 4 +- src/adapters/models/llama/mixin_llama.py | 12 ++- src/adapters/models/llama/modeling_llama.py | 6 +- src/adapters/models/mbart/modeling_mbart.py | 3 +- .../models/roberta/modeling_roberta.py | 4 +- src/adapters/models/vit/mixin_vit.py | 3 + .../xlm_roberta/modeling_xlm_roberta.py | 4 +- src/adapters/models/xmod/mixin_xmod.py | 8 +- src/adapters/models/xmod/modeling_xmod.py | 4 +- 28 files changed, 154 insertions(+), 105 deletions(-) diff --git a/src/adapters/composition.py b/src/adapters/composition.py index 5899b113d6..4a451503ea 100644 --- a/src/adapters/composition.py +++ b/src/adapters/composition.py @@ -2,6 +2,8 @@ from collections.abc import Sequence from typing import List, Optional, Set, Union +import torch + class AdapterCompositionBlock(Sequence): def __init__(self, *children): @@ -242,3 +244,23 @@ def adjust_tensors_for_parallel_(hidden_states, *tensors): repeats[0] = hidden_states.shape[0] // tensor.shape[0] new_tensor = tensor.repeat(*repeats) tensor.set_(new_tensor) + + +def prefix_attention_mask(adapters_config, attention_mask=None, dim=3): + """ + Prepends a given attention mask with a tensor of ones of the length specified in the adapter configuration + `prefix_attention_mask_length`. `prefix_attention_mask_length` is set e.g. by prompt tuning. + """ + if attention_mask is not None and adapters_config.prefix_attention_mask_length is not None: + # Create a tensor of ones with the desired shape + ones_shape = list(attention_mask.shape) + ones_shape[dim] = adapters_config.prefix_attention_mask_length + prefix_attention_mask = torch.ones( + ones_shape, + dtype=attention_mask.dtype, + ).to(attention_mask.device) + + # Concatenate the prefix_attention_mask along the specified dimension + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=dim) + + return attention_mask diff --git a/src/adapters/loading.py b/src/adapters/loading.py index 57d03ea62f..ddd15c1724 100644 --- a/src/adapters/loading.py +++ b/src/adapters/loading.py @@ -307,6 +307,7 @@ def filter_func(self, adapter_name): or ".prefix_tunings.{}.".format(adapter_name) in x or ".prefix_gates.{}.".format(adapter_name) in x or ".loras.{}.".format(adapter_name) in x + or ".prompt_tunings.{}.".format(adapter_name) in x ) # This dict maps the original weight names to the currently used equivalents. diff --git a/src/adapters/methods/adapter_layer_base.py b/src/adapters/methods/adapter_layer_base.py index b89b75cb14..2590257324 100644 --- a/src/adapters/methods/adapter_layer_base.py +++ b/src/adapters/methods/adapter_layer_base.py @@ -343,7 +343,7 @@ def compose_batch_split(self, adapter_setup: BatchSplit, state: NamedTuple, lvl: # sequentially feed different parts of the blown-up batch into different adapters children_states = [] for i, child in enumerate(adapter_setup): - # compute ids of sequences thet should be passed to the ith adapter + # compute ids of sequences that should be passed to the ith adapter batch_idx = ( sum(adapter_setup.batch_sizes[:i]), sum(adapter_setup.batch_sizes[: i + 1]), diff --git a/src/adapters/methods/prompt_tuning.py b/src/adapters/methods/prompt_tuning.py index 1e83b60746..fcf35e5c95 100644 --- a/src/adapters/methods/prompt_tuning.py +++ b/src/adapters/methods/prompt_tuning.py @@ -1,8 +1,7 @@ # https://github.com/google-research/prompt-tuning/blob/main/prompt_tuning/train/prompts.py -import logging import math -from typing import Callable, Dict, List, Optional, Protocol, Sequence, Tuple, Union +from typing import Callable, Dict, List, Union import numpy as np import torch @@ -11,17 +10,11 @@ from transformers import AutoTokenizer from transformers.configuration_utils import PretrainedConfig -from ..composition import AdapterCompositionBlock, BatchSplit, Parallel, Stack, adjust_tensors_for_parallel +from ..composition import AdapterCompositionBlock from ..configuration import ModelAdaptersConfig, PromptTuningConfig from .adapter_layer_base import AdapterLayerBase -logger = logging.getLogger(__name__) - - -Initializer = Callable[[torch.Tensor, Sequence[int]], torch.Tensor] - - class PromptTuning(nn.Module): """Generate a Prompt and concatenate it with the input. @@ -121,7 +114,16 @@ def forward(self, embedded_input): class PromptTuningLayer(AdapterLayerBase, nn.Module): - # TODO: add documentation + """ + Prompt Tuning implementation. + + Args: + model_config: The model configuration. + adapters_config: The adapter configuration. + base_model_embeddings: + The embedding layer of the base model (used to initialize the prompt embedding if + prompt_init='from_string'). + """ adapter_modules_name = "prompt_tunings" @@ -158,17 +160,17 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: return False def average_adapter(self, adapter_name: str, input_adapters: Dict[str, float]) -> bool: - raise NotImplementedError() + pass # TODO: implement def delete_adapter(self, adapter_name: str): if adapter_name in self.prompt_tunings: del self.prompt_tunings[adapter_name] def add_fusion_layer(self, adapter_names: Union[List, str]): - raise NotImplementedError() + pass # not applicable to prompt tuning def delete_fusion_layer(self, adapter_names: Union[List, str]): - raise NotImplementedError() + pass # not applicable to prompt tuning def enable_adapters(self, adapter_setup: AdapterCompositionBlock, unfreeze_adapters: bool, unfreeze_fusion: bool): if unfreeze_adapters: diff --git a/src/adapters/model_mixin.py b/src/adapters/model_mixin.py index efe3f62ea5..1010120a51 100644 --- a/src/adapters/model_mixin.py +++ b/src/adapters/model_mixin.py @@ -365,26 +365,10 @@ def loaded_embeddings(self): return self.base_model.loaded_embeddings -class PromptTuningMixin: - def init_prompt_tuning( - self, model_config: PretrainedConfig, adapters_config: ModelAdaptersConfig, base_model_embeddings: nn.Module - ): - self.prompt_tuning = PromptTuningLayer(model_config, adapters_config, base_model_embeddings) - - def add_prompt_tuning(self, adapter_name: str) -> bool: - return self.prompt_tuning.add_adapter(adapter_name=adapter_name, layer_idx=-1) - - # TODO: delete, etc .... prompt tuning - # TODO: can probably be merged into whereever this gets called - - -class ModelAdaptersMixin(PushAdapterToHubMixin, PromptTuningMixin, ABC): +class ModelAdaptersMixin(PushAdapterToHubMixin, ABC): """Mixin for transformer models adding support for loading/ saving adapters.""" - # Setting this to True will automatically add a prompt tuning layer to the model - # This prompt tuning layer is stoed in self.prompt_tuning - # Since the correct position to call the prompt tuning forward depends on the model type, this has to be called in the post_embedding_forward method - supports_prompt_tuning = False + add_base_adapters = False def __init__(self, config, *args, **kwargs): super().__init__(config, *args, **kwargs) @@ -393,15 +377,6 @@ def _link_prefix_to_pool(self, layer): if isinstance(layer, PrefixTuningLayer): layer.set_pool(self.base_model.prefix_tuning) - def post_embedding_forward(self, module, args, embedding_output): - if isinstance(self, InvertibleAdaptersMixin) or isinstance(self, InvertibleAdaptersWrapperMixin): - embedding_output = self.invertible_adapters_forward(embedding_output) - - if self.supports_prompt_tuning: - embedding_output = self.prompt_tuning.forward(embedding_output) - - return embedding_output - @property def model_name(self): return self.config.name_or_path @@ -432,8 +407,8 @@ def init_adapters(self, model_config, adapters_config, add_prefix_tuning_pool=Tr self.apply_to_adapter_layers(lambda i, layer: self._link_prefix_to_pool(layer)) # Add Prompt Tuning - if self.supports_prompt_tuning: - self.init_prompt_tuning(self.config, self.adapters_config, self.get_input_embeddings()) + if self.add_base_adapters: + self.prompt_tuning = PromptTuningLayer(model_config, self.adapters_config, self.get_input_embeddings()) # Initialize adapters from config for adapter_name in self.adapters_config: @@ -465,12 +440,23 @@ def apply_to_adapter_layers(self, fn): if isinstance(module, AdapterLayerBase): fn(i, module) + def apply_to_basemodel_childs(self, fn): + """ + Applies a function to all direct childs of the model if they are a instance of AdapterLayerBase. + """ + if self.add_base_adapters: + for module in self.base_model.children(): + if isinstance(module, AdapterLayerBase): + # These childs don't have a layer index so we pass -1 + fn(-1, module) + def train_adapter(self, adapter_setup: Union[list, AdapterCompositionBlock], train_embeddings=False): """Sets the model into mode for training the given adapters.""" self.train() self.freeze_model(True) adapter_setup = parse_composition(adapter_setup) self.apply_to_adapter_layers(lambda i, layer: layer.enable_adapters(adapter_setup, True, False)) + self.apply_to_basemodel_childs(lambda i, child: child.enable_adapters(adapter_setup, True, False)) for adapter_name in adapter_setup: if adapter_name in self.base_model.shared_parameters: for param in self.base_model.shared_parameters[adapter_name].values(): @@ -497,6 +483,7 @@ def train_adapter_fusion(self, adapter_setup: Union[list, AdapterCompositionBloc self.freeze_model(True) adapter_setup = parse_composition(adapter_setup) self.apply_to_adapter_layers(lambda i, layer: layer.enable_adapters(adapter_setup, unfreeze_adapters, True)) + self.apply_to_basemodel_childs(lambda i, child: child.enable_adapters(adapter_setup, unfreeze_adapters, True)) # use the adapters to be trained by default in every forward pass self.set_active_adapters(adapter_setup) # TODO implement fusion for invertible adapters @@ -580,6 +567,8 @@ def add_adapter(self, adapter_name: str, config=None, overwrite_ok: bool = False def _add_adapter_weights(self, adapter_name: str): """Helper method that performs the actual parameter additions when adding a new adapter.""" self.apply_to_adapter_layers(lambda i, layer: layer.add_adapter(adapter_name, i)) + self.apply_to_basemodel_childs(lambda i, child: child.add_adapter(adapter_name, i)) + # PHM Layer if self.adapters_config.match(adapter_name, BnConfig, location_key="phm_layer"): adapter_module = list(self.get_adapter(adapter_name)[0].values())[0] @@ -612,10 +601,6 @@ def _add_adapter_weights(self, adapter_name: str): if isinstance(self, InvertibleAdaptersMixin) or isinstance(self, InvertibleAdaptersWrapperMixin): self.add_invertible_adapter(adapter_name) - # Prompt Tuning - if self.supports_prompt_tuning: - self.add_prompt_tuning(adapter_name) - def add_fusion(self, adapter_names: Union[Fuse, list], adapter_fusion_config=None, override_kwargs=None): warnings.warn( "add_fusion() has been deprecated in favor of add_adapter_fusion(). Please use the newer method instead.", @@ -663,6 +648,7 @@ def add_adapter_fusion( self.delete_adapter_fusion(adapter_names) self.adapters_config.add_fusion(adapter_names, config=config) self.apply_to_adapter_layers(lambda i, layer: layer.add_fusion_layer(adapter_names)) + self.apply_to_basemodel_childs(lambda i, child: child.add_fusion_layer(adapter_names)) if set_active: if not isinstance(adapter_names, list): adapter_names = adapter_names.split(",") @@ -680,13 +666,12 @@ def delete_adapter(self, adapter_name: str): return del self.adapters_config.adapters[adapter_name] self.apply_to_adapter_layers(lambda i, layer: layer.delete_adapter(adapter_name)) + self.apply_to_basemodel_childs(lambda i, child: child.delete_adapter(adapter_name)) # PHM Layer if adapter_name in self.base_model.shared_parameters: del self.base_model.shared_parameters[adapter_name] if isinstance(self, InvertibleAdaptersMixin) or isinstance(self, InvertibleAdaptersWrapperMixin): self.delete_invertible_adapter(adapter_name) - if self.supports_prompt_tuning: - self.prompt_tuning.delete_adapter(adapter_name) # Reset active adapters if this was the only active adapter if self.active_adapters == Stack(adapter_name): @@ -713,6 +698,7 @@ def delete_adapter_fusion(self, adapter_names: Union[Fuse, list, str]): return del self.adapters_config.fusions[adapter_fusion_name] self.apply_to_adapter_layers(lambda i, layer: layer.delete_fusion_layer(adapter_fusion_name)) + self.apply_to_basemodel_childs(lambda i, child: child.delete_fusion_layer(adapter_fusion_name)) # Reset active adapters if this was the active setup if self.active_adapters == adapter_names: self.active_adapters = None @@ -1008,10 +994,9 @@ def get_adapter(self, name) -> dict: ) and name in self.invertible_adapters: destination[-1]["invertible"] = self.invertible_adapters[name] - if self.supports_prompt_tuning: - prompt_tuning = self.prompt_tuning.get_adapter(name) - if prompt_tuning is not None: - destination[-1]["prompt"] = prompt_tuning + prompt_tuning = self.prompt_tuning.get_adapter(name) + if prompt_tuning is not None: + destination[-1]["prompt"] = prompt_tuning # use a custom index to ensure numbering is from 0 to N layers for i, (_, layer) in enumerate(self.iter_layers()): @@ -1162,6 +1147,7 @@ def average_adapter( input_adapters = {name: weight / sum_weights for name, weight in zip(adapter_list, weights)} try: self.apply_to_adapter_layers(lambda i, layer: layer.average_adapter(adapter_name, input_adapters)) + self.apply_to_basemodel_childs(lambda i, child: child.average_adapter(adapter_name, input_adapters)) # PHM Layer if self.adapters_config.match(adapter_name, BnConfig, location_key="phm_layer"): self._average_shared_parameters(adapter_name, input_adapters) @@ -1273,6 +1259,16 @@ def save_pretrained( @inherit_doc class ModelBaseAdaptersMixin(ModelAdaptersMixin): + add_base_adapters = True + + def post_embedding_forward(self, module, args, embedding_output): + if isinstance(self, InvertibleAdaptersMixin) or isinstance(self, InvertibleAdaptersWrapperMixin): + embedding_output = self.invertible_adapters_forward(embedding_output) + + embedding_output = self.prompt_tuning.forward(embedding_output) + + return embedding_output + @ForwardContext.wrap def forward(self, *args, **kwargs): return super().forward(*args, **kwargs) diff --git a/src/adapters/models/albert/mixin_albert.py b/src/adapters/models/albert/mixin_albert.py index 21534980af..f21d4170b6 100644 --- a/src/adapters/models/albert/mixin_albert.py +++ b/src/adapters/models/albert/mixin_albert.py @@ -1,4 +1,4 @@ -from typing import Callable, Iterable, Tuple +from typing import Iterable, Tuple import torch.nn as nn @@ -13,6 +13,8 @@ class AlbertAttentionAdaptersMixin: """Adds adapters to the AlbertAttention module of ALBERT.""" def init_adapters(self, model_config, adapters_config): + self.adapters_config = adapters_config + # Wrap layers for LoRA self.query = LoRALinear.wrap(self.query, "selfattn", model_config, adapters_config, attn_key="q") self.key = LoRALinear.wrap(self.key, "selfattn", model_config, adapters_config, attn_key="k") @@ -51,6 +53,8 @@ def init_adapters(self, model_config, adapters_config): for _, layer in self.iter_layers(): self._set_layer_hook_for_parallel(layer) + self.embeddings.register_forward_hook(self.post_embedding_forward) + def _set_layer_hook_for_parallel(self, layer: nn.Module): def hook(module, input): adjust_tensors_for_parallel_(input[0], input[1]) @@ -64,6 +68,3 @@ def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: for albertLayer in albertLayerGroup.albert_layers: yield i, albertLayer i += 1 - - def hook_after_embeddings(self, hook_fn: Callable): - return self.embeddings.register_forward_hook(hook_fn) diff --git a/src/adapters/models/albert/modeling_albert.py b/src/adapters/models/albert/modeling_albert.py index df3e7523f0..6ecd02e257 100644 --- a/src/adapters/models/albert/modeling_albert.py +++ b/src/adapters/models/albert/modeling_albert.py @@ -23,7 +23,7 @@ from transformers.models.albert.modeling_albert import AlbertAttention, AlbertLayer from transformers.pytorch_utils import apply_chunking_to_forward -from ...composition import adjust_tensors_for_parallel +from ...composition import adjust_tensors_for_parallel, prefix_attention_mask from .mixin_albert import AlbertAttentionAdaptersMixin, AlbertEncoderLayerAdaptersMixin @@ -35,6 +35,8 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, output_attentions: bool = False, ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: + attention_mask = prefix_attention_mask(self.adapters_config, attention_mask) # type: ignore + mixed_query_layer = self.query(hidden_states) mixed_key_layer = self.key(hidden_states) mixed_value_layer = self.value(hidden_states) diff --git a/src/adapters/models/bart/mixin_bart.py b/src/adapters/models/bart/mixin_bart.py index 5ef20aaa86..bd094549a1 100644 --- a/src/adapters/models/bart/mixin_bart.py +++ b/src/adapters/models/bart/mixin_bart.py @@ -1,4 +1,4 @@ -from typing import Callable, Iterable, Optional, Tuple +from typing import Iterable, Optional, Tuple import torch import torch.nn as nn @@ -20,6 +20,8 @@ class BartAttentionAdaptersMixin: """Adds adapters to the BartAttention module.""" def init_adapters(self, model_config, adapters_config): + self.adapters_config = adapters_config + # Wrap layers for LoRA self.k_proj = LoRALinear.wrap(self.k_proj, "selfattn", model_config, adapters_config, attn_key="k") self.v_proj = LoRALinear.wrap(self.v_proj, "selfattn", model_config, adapters_config, attn_key="v") @@ -34,6 +36,7 @@ class BartEncoderLayerAdaptersMixin: """Adds adapters to the BartEncoderLayer module of BART.""" def init_adapters(self, model_config, adapters_config): + self.adapters_config = adapters_config # Wrap layers for LoRA self.fc1 = LoRALinear.wrap(self.fc1, "intermediate", model_config, adapters_config) self.fc2 = LoRALinear.wrap(self.fc2, "output", model_config, adapters_config) @@ -58,8 +61,7 @@ def init_adapters(self, model_config, adapters_config): class BartEncoderAdaptersMixin(InvertibleAdaptersMixin): """Adds adapters to the BartEncoder module of BART.""" - def hook_after_embeddings(self, hook_fn: Callable): - return self.layernorm_embedding.register_forward_hook(hook_fn) + pass class BartDecoderAdaptersMixin: @@ -77,6 +79,10 @@ class BartModelAdaptersMixin(EmbeddingAdaptersMixin, InvertibleAdaptersWrapperMi invertible_adapters_base_name = "encoder" + def init_adapters(self, model_config, adapters_config): + super().init_adapters(model_config, adapters_config) + self.encoder.layernorm_embedding.register_forward_hook(self.post_embedding_forward) + def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: if hasattr(self, "encoder"): for i, layer in enumerate(self.encoder.layers): diff --git a/src/adapters/models/bart/modeling_bart.py b/src/adapters/models/bart/modeling_bart.py index cb15b385bd..cc50b76c28 100644 --- a/src/adapters/models/bart/modeling_bart.py +++ b/src/adapters/models/bart/modeling_bart.py @@ -21,7 +21,7 @@ from transformers.models.bart.modeling_bart import BartAttention, BartDecoderLayer, BartEncoderLayer -from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_ +from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_, prefix_attention_mask from .mixin_bart import BartAttentionAdaptersMixin, BartDecoderLayerAdaptersMixin, BartEncoderLayerAdaptersMixin @@ -176,6 +176,7 @@ def forward( returned tensors for more detail. """ adjust_tensors_for_parallel_(hidden_states, attention_mask) + attention_mask = prefix_attention_mask(self.adapters_config, attention_mask) # type: ignore residual = hidden_states hidden_states, attn_weights, _ = self.self_attn( diff --git a/src/adapters/models/beit/mixin_beit.py b/src/adapters/models/beit/mixin_beit.py index 2c129f085c..f4d82672a4 100644 --- a/src/adapters/models/beit/mixin_beit.py +++ b/src/adapters/models/beit/mixin_beit.py @@ -11,6 +11,7 @@ class BeitSelfAttentionAdaptersMixin: def init_adapters(self, model_config, adapters_config): self.location_key = "self" + self.adapters_config = adapters_config # Wrap layers for LoRA self.query = LoRALinear.wrap(self.query, "selfattn", model_config, adapters_config, attn_key="q") @@ -47,6 +48,7 @@ class BeitModelAdaptersMixin(ModelBaseAdaptersMixin): def init_adapters(self, model_config, adapters_config): super().init_adapters(model_config, adapters_config) + self.embeddings.register_forward_hook(self.post_embedding_forward) def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: for i, layer in enumerate(self.encoder.layer): diff --git a/src/adapters/models/bert/mixin_bert.py b/src/adapters/models/bert/mixin_bert.py index 486162602c..657598f1cc 100644 --- a/src/adapters/models/bert/mixin_bert.py +++ b/src/adapters/models/bert/mixin_bert.py @@ -71,8 +71,6 @@ def init_adapters(self, model_config, adapters_config): class BertModelAdaptersMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin): """Adds adapters to the BertModel module.""" - supports_prompt_tuning = True - def init_adapters(self, model_config, adapters_config): super().init_adapters(model_config, adapters_config) diff --git a/src/adapters/models/bert/modeling_bert.py b/src/adapters/models/bert/modeling_bert.py index e334cf587c..0f459e3dbf 100644 --- a/src/adapters/models/bert/modeling_bert.py +++ b/src/adapters/models/bert/modeling_bert.py @@ -25,7 +25,7 @@ from transformers.models.bert.modeling_bert import BertOutput, BertSelfAttention, BertSelfOutput -from ...composition import adjust_tensors_for_parallel +from ...composition import adjust_tensors_for_parallel, prefix_attention_mask from .mixin_bert import BertOutputAdaptersMixin, BertSelfAttentionAdaptersMixin, BertSelfOutputAdaptersMixin @@ -40,16 +40,7 @@ def forward( past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor]: - if attention_mask is not None and self.adapters_config.prefix_attention_mask_length is not None: - prefix_attention_mask = torch.ones( - attention_mask.shape[0], - attention_mask.shape[1], - attention_mask.shape[2], - self.adapters_config.prefix_attention_mask_length, - dtype=torch.float32, - ).to(attention_mask.device) - - attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3) + attention_mask = prefix_attention_mask(self.adapters_config, attention_mask) # type: ignore mixed_query_layer = self.query(hidden_states) diff --git a/src/adapters/models/bert_generation/modeling_bert_generation.py b/src/adapters/models/bert_generation/modeling_bert_generation.py index 8f083fe295..ce85a2c5af 100644 --- a/src/adapters/models/bert_generation/modeling_bert_generation.py +++ b/src/adapters/models/bert_generation/modeling_bert_generation.py @@ -27,7 +27,7 @@ BertGenerationSelfOutput, ) -from ...composition import adjust_tensors_for_parallel +from ...composition import adjust_tensors_for_parallel, prefix_attention_mask from ..bert.mixin_bert import BertOutputAdaptersMixin, BertSelfAttentionAdaptersMixin, BertSelfOutputAdaptersMixin @@ -52,6 +52,8 @@ def forward( past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor]: + attention_mask = prefix_attention_mask(self.adapters_config, attention_mask) # type: ignore + mixed_query_layer = self.query(hidden_states) # If this is instantiated as a cross-attention module, the keys diff --git a/src/adapters/models/deberta/mixin_deberta.py b/src/adapters/models/deberta/mixin_deberta.py index cee8530f02..04d02242dd 100644 --- a/src/adapters/models/deberta/mixin_deberta.py +++ b/src/adapters/models/deberta/mixin_deberta.py @@ -6,6 +6,8 @@ class DebertaSelfAttentionAdaptersMixin: """Adds adapters to the BertSelfAttention module.""" def init_adapters(self, model_config, adapters_config): + self.adapters_config = adapters_config + # Wrap layers for LoRA self.in_proj = LoRAMergedLinear.wrap(self.in_proj, "selfattn", model_config, adapters_config) diff --git a/src/adapters/models/deberta/modeling_deberta.py b/src/adapters/models/deberta/modeling_deberta.py index 8197c19fb6..49a493ead1 100644 --- a/src/adapters/models/deberta/modeling_deberta.py +++ b/src/adapters/models/deberta/modeling_deberta.py @@ -24,7 +24,7 @@ XSoftmax, ) -from ...composition import adjust_tensors_for_parallel +from ...composition import adjust_tensors_for_parallel, prefix_attention_mask from ..bert.mixin_bert import BertOutputAdaptersMixin, BertSelfOutputAdaptersMixin from .mixin_deberta import DebertaSelfAttentionAdaptersMixin @@ -94,6 +94,9 @@ def forward( """ + attention_mask = prefix_attention_mask(self.adapters_config, attention_mask, dim=3) # type: ignore + attention_mask = prefix_attention_mask(self.adapters_config, attention_mask, dim=2) # type: ignore + if query_states is None: qp = self.in_proj(hidden_states) # .split(self.all_head_size, dim=-1) query_layer, key_layer, value_layer = self.transpose_for_scores(qp).chunk(3, dim=-1) diff --git a/src/adapters/models/deberta_v2/mixin_deberta_v2.py b/src/adapters/models/deberta_v2/mixin_deberta_v2.py index f60e8788fb..f400f46795 100644 --- a/src/adapters/models/deberta_v2/mixin_deberta_v2.py +++ b/src/adapters/models/deberta_v2/mixin_deberta_v2.py @@ -6,6 +6,7 @@ class DebertaV2SelfAttentionAdaptersMixin: """Adds adapters to the BertSelfAttention module.""" def init_adapters(self, model_config, adapters_config): + self.adapters_config = adapters_config # Wrap layers for LoRA self.query_proj = LoRALinear.wrap(self.query_proj, "selfattn", model_config, adapters_config, attn_key="q") self.key_proj = LoRALinear.wrap(self.key_proj, "selfattn", model_config, adapters_config, attn_key="k") diff --git a/src/adapters/models/deberta_v2/modeling_deberta_v2.py b/src/adapters/models/deberta_v2/modeling_deberta_v2.py index 082e77a721..3bfa909bdf 100644 --- a/src/adapters/models/deberta_v2/modeling_deberta_v2.py +++ b/src/adapters/models/deberta_v2/modeling_deberta_v2.py @@ -24,7 +24,7 @@ XSoftmax, ) -from ...composition import adjust_tensors_for_parallel +from ...composition import adjust_tensors_for_parallel, prefix_attention_mask from ..bert.mixin_bert import BertOutputAdaptersMixin, BertSelfOutputAdaptersMixin from .mixin_deberta_v2 import DebertaV2SelfAttentionAdaptersMixin @@ -88,9 +88,10 @@ def forward( rel_embeddings (`torch.FloatTensor`): The embedding of relative distances. It's a tensor of shape [\\(2 \\times \\text{max_relative_positions}\\), *hidden_size*]. - - """ + attention_mask = prefix_attention_mask(self.adapters_config, attention_mask, dim=3) # type: ignore + attention_mask = prefix_attention_mask(self.adapters_config, attention_mask, dim=2) # type: ignore + if query_states is None: query_states = hidden_states query_layer = self.transpose_for_scores_extended(self.query_proj(query_states), self.num_attention_heads) diff --git a/src/adapters/models/distilbert/mixin_distilbert.py b/src/adapters/models/distilbert/mixin_distilbert.py index 44bcbb0b16..d7a0f549bc 100644 --- a/src/adapters/models/distilbert/mixin_distilbert.py +++ b/src/adapters/models/distilbert/mixin_distilbert.py @@ -1,4 +1,4 @@ -from typing import Callable, Iterable, Tuple +from typing import Iterable, Tuple import torch.nn as nn @@ -24,6 +24,7 @@ class DistilBertTransfomerBlockAdaptersMixin: """Adds adapters to the TransformerBlock module of DistilBert.""" def init_adapters(self, model_config, adapters_config): + self.adapters_config = adapters_config # Wrap layers for LoRA self.ffn.lin1 = LoRALinear.wrap(self.ffn.lin1, "intermediate", model_config, adapters_config) self.ffn.lin2 = LoRALinear.wrap(self.ffn.lin2, "output", model_config, adapters_config) @@ -44,15 +45,10 @@ def forward(self, *args, **kwargs): class DistilBertModelAdaptersMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin): """Adds adapters to the DistilBert module.""" + def init_adapters(self, model_config, adapters_config): + super().init_adapters(model_config, adapters_config) + self.embeddings.register_forward_hook(self.post_embedding_forward) + def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: for i, layer in enumerate(self.transformer.layer): yield i, layer - - def _hook_fn(self, module, input): - new_input = self.invertible_adapters_forward(input) - return new_input - - def hook_after_embeddings(self, hook_fn: Callable): - # PyTorch's built-in pre-forward hook does not pass the input ids. - # Therefore, we need to use a custom hook. - self.transformer.pre_forward_fn = hook_fn diff --git a/src/adapters/models/distilbert/modeling_distilbert.py b/src/adapters/models/distilbert/modeling_distilbert.py index 6dfb62eb1c..cb688c46be 100644 --- a/src/adapters/models/distilbert/modeling_distilbert.py +++ b/src/adapters/models/distilbert/modeling_distilbert.py @@ -27,7 +27,7 @@ from transformers.models.distilbert.modeling_distilbert import MultiHeadSelfAttention, TransformerBlock -from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_ +from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_, prefix_attention_mask from .mixin_distilbert import DistilBertMultiHeadSelfAttentionMixin, DistilBertTransfomerBlockAdaptersMixin @@ -118,6 +118,8 @@ def forward( torch.tensor(bs, seq_length, dim) The output of the transformer block contextualization. """ adjust_tensors_for_parallel_(x, attn_mask) + attn_mask = prefix_attention_mask(self.adapters_config, attn_mask, dim=1) # type: ignore + # Self-Attention sa_output = self.attention( query=x, diff --git a/src/adapters/models/electra/modeling_electra.py b/src/adapters/models/electra/modeling_electra.py index 35552782ce..2570c5a3e0 100644 --- a/src/adapters/models/electra/modeling_electra.py +++ b/src/adapters/models/electra/modeling_electra.py @@ -6,7 +6,7 @@ from transformers.models.electra.modeling_electra import ElectraOutput, ElectraSelfAttention, ElectraSelfOutput -from ...composition import adjust_tensors_for_parallel +from ...composition import adjust_tensors_for_parallel, prefix_attention_mask from ..bert.mixin_bert import BertOutputAdaptersMixin, BertSelfAttentionAdaptersMixin, BertSelfOutputAdaptersMixin @@ -21,6 +21,8 @@ def forward( past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor]: + attention_mask = prefix_attention_mask(self.adapters_config, attention_mask) # type: ignore + mixed_query_layer = self.query(hidden_states) # If this is instantiated as a cross-attention module, the keys diff --git a/src/adapters/models/llama/mixin_llama.py b/src/adapters/models/llama/mixin_llama.py index 22223edaf4..d7a37665f4 100644 --- a/src/adapters/models/llama/mixin_llama.py +++ b/src/adapters/models/llama/mixin_llama.py @@ -1,4 +1,4 @@ -from typing import Callable, Iterable, Tuple +from typing import Iterable, Tuple import torch.nn as nn @@ -10,6 +10,7 @@ class LlamaAttentionMixin: def init_adapters(self, model_config, adapters_config): + self.adapters_config = adapters_config self.q_proj = LoRALinear.wrap(self.q_proj, "selfattn", model_config, adapters_config, attn_key="q") self.k_proj = LoRALinear.wrap(self.k_proj, "selfattn", model_config, adapters_config, attn_key="k") self.v_proj = LoRALinear.wrap(self.v_proj, "selfattn", model_config, adapters_config, attn_key="v") @@ -28,9 +29,12 @@ def init_adapters(self, model_config, adapters_config): class LlamaModelAdapterMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin): + def init_adapters(self, model_config, adapters_config): + super().init_adapters(model_config, adapters_config) + + # Register hook for post embedding forward + self.embed_tokens.register_forward_hook(self.post_embedding_forward) + def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: for i, layer in enumerate(self.layers): yield i, layer - - def hook_after_embeddings(self, hook_fn: Callable): - return self.embed_tokens.register_forward_hook(hook_fn) diff --git a/src/adapters/models/llama/modeling_llama.py b/src/adapters/models/llama/modeling_llama.py index f16b65e9c6..0cc6ec2953 100644 --- a/src/adapters/models/llama/modeling_llama.py +++ b/src/adapters/models/llama/modeling_llama.py @@ -25,7 +25,7 @@ import torch.utils.checkpoint from torch import nn -from adapters.composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_ +from adapters.composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_, prefix_attention_mask from transformers.models.llama.modeling_llama import apply_rotary_pos_emb from transformers.utils import logging @@ -47,6 +47,10 @@ def forward( output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + attention_mask = prefix_attention_mask(self.adapters_config, attention_mask, dim=3) # type: ignore + attention_mask = prefix_attention_mask(self.adapters_config, attention_mask, dim=2) # type: ignore + position_ids = prefix_attention_mask(self.adapters_config, position_ids, dim=1) # type: ignore + bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) diff --git a/src/adapters/models/mbart/modeling_mbart.py b/src/adapters/models/mbart/modeling_mbart.py index 5c43212a28..1cfbcadcb6 100644 --- a/src/adapters/models/mbart/modeling_mbart.py +++ b/src/adapters/models/mbart/modeling_mbart.py @@ -21,7 +21,7 @@ from transformers.models.mbart.modeling_mbart import MBartAttention, MBartDecoderLayer, MBartEncoderLayer -from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_ +from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_, prefix_attention_mask from ..bart.mixin_bart import BartAttentionAdaptersMixin, BartDecoderLayerAdaptersMixin, BartEncoderLayerAdaptersMixin @@ -176,6 +176,7 @@ def forward( returned tensors for more detail. """ adjust_tensors_for_parallel_(hidden_states, attention_mask) + attention_mask = prefix_attention_mask(self.adapters_config, attention_mask, dim=3) # type: ignore residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) diff --git a/src/adapters/models/roberta/modeling_roberta.py b/src/adapters/models/roberta/modeling_roberta.py index 47a8ed35a9..867a1df8cf 100644 --- a/src/adapters/models/roberta/modeling_roberta.py +++ b/src/adapters/models/roberta/modeling_roberta.py @@ -24,7 +24,7 @@ from transformers.models.roberta.modeling_roberta import RobertaOutput, RobertaSelfAttention, RobertaSelfOutput -from ...composition import adjust_tensors_for_parallel +from ...composition import adjust_tensors_for_parallel, prefix_attention_mask from ..bert.mixin_bert import BertOutputAdaptersMixin, BertSelfAttentionAdaptersMixin, BertSelfOutputAdaptersMixin @@ -40,6 +40,8 @@ def forward( past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor]: + attention_mask = prefix_attention_mask(self.adapters_config, attention_mask) # type: ignore + mixed_query_layer = self.query(hidden_states) # If this is instantiated as a cross-attention module, the keys diff --git a/src/adapters/models/vit/mixin_vit.py b/src/adapters/models/vit/mixin_vit.py index 07598ad8ae..945fbd4d2e 100644 --- a/src/adapters/models/vit/mixin_vit.py +++ b/src/adapters/models/vit/mixin_vit.py @@ -52,6 +52,9 @@ class ViTModelAdaptersMixin(ModelBaseAdaptersMixin): def init_adapters(self, model_config, adapters_config): super().init_adapters(model_config, adapters_config) + # Register hook for post embedding forward + self.embeddings.register_forward_hook(self.post_embedding_forward) + def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: for i, layer in enumerate(self.encoder.layer): yield i, layer diff --git a/src/adapters/models/xlm_roberta/modeling_xlm_roberta.py b/src/adapters/models/xlm_roberta/modeling_xlm_roberta.py index a8d22284b7..2cd75394d5 100644 --- a/src/adapters/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/adapters/models/xlm_roberta/modeling_xlm_roberta.py @@ -28,7 +28,7 @@ XLMRobertaSelfOutput, ) -from ...composition import adjust_tensors_for_parallel +from ...composition import adjust_tensors_for_parallel, prefix_attention_mask from ..bert.mixin_bert import BertOutputAdaptersMixin, BertSelfAttentionAdaptersMixin, BertSelfOutputAdaptersMixin @@ -44,6 +44,8 @@ def forward( past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor]: + attention_mask = prefix_attention_mask(self.adapters_config, attention_mask) # type: ignore + mixed_query_layer = self.query(hidden_states) # If this is instantiated as a cross-attention module, the keys diff --git a/src/adapters/models/xmod/mixin_xmod.py b/src/adapters/models/xmod/mixin_xmod.py index eac7e4b418..bef4371f34 100644 --- a/src/adapters/models/xmod/mixin_xmod.py +++ b/src/adapters/models/xmod/mixin_xmod.py @@ -1,4 +1,4 @@ -from typing import Callable, Iterable, Tuple +from typing import Iterable, Tuple import torch.nn as nn @@ -25,6 +25,9 @@ def init_adapters(self, model_config, adapters_config): for _, layer in self.iter_layers(): del layer.output.adapter_modules + # Register hook for post embedding forward + self.embeddings.register_forward_hook(self.post_embedding_forward) + def _set_layer_hook_for_parallel(self, layer: nn.Module): def hook(module, input): # hook[1] is lang_ids tensor @@ -37,9 +40,6 @@ def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: for i, layer in enumerate(self.encoder.layer): yield i, layer - def hook_after_embeddings(self, hook_fn: Callable): - return self.embeddings.register_forward_hook(hook_fn) - def forward(self, *args, **kwargs): if "lang_ids" in kwargs and kwargs["lang_ids"] is not None: raise ValueError( diff --git a/src/adapters/models/xmod/modeling_xmod.py b/src/adapters/models/xmod/modeling_xmod.py index b772321667..c4f43efce7 100644 --- a/src/adapters/models/xmod/modeling_xmod.py +++ b/src/adapters/models/xmod/modeling_xmod.py @@ -23,7 +23,7 @@ from transformers.models.xmod.modeling_xmod import XmodOutput, XmodSelfAttention, XmodSelfOutput -from ...composition import adjust_tensors_for_parallel +from ...composition import adjust_tensors_for_parallel, prefix_attention_mask from ..bert.mixin_bert import BertOutputAdaptersMixin, BertSelfAttentionAdaptersMixin, BertSelfOutputAdaptersMixin @@ -39,6 +39,8 @@ def forward( past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor]: + attention_mask = prefix_attention_mask(self.adapters_config, attention_mask) # type: ignore + mixed_query_layer = self.query(hidden_states) # If this is instantiated as a cross-attention module, the keys From 87d2fe3a80f5c90f071e8cb858de70651617bd2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leon=20Engl=C3=A4nder?= Date: Sun, 29 Oct 2023 19:03:56 +0100 Subject: [PATCH 04/13] Add documentation, fix prefix attention mask value - move prefix_attention_mask_length in ForwardContext --- .gitignore | 2 +- docs/classes/adapter_config.rst | 7 +++ docs/methods.md | 28 +++++++++++ docs/model_overview.md | 44 ++++++++--------- src/adapters/composition.py | 22 --------- .../configuration/model_adapters_config.py | 2 - src/adapters/context.py | 7 ++- src/adapters/methods/prompt_tuning.py | 6 ++- src/adapters/model_mixin.py | 1 + src/adapters/models/albert/mixin_albert.py | 2 - src/adapters/models/albert/modeling_albert.py | 5 +- src/adapters/models/bart/mixin_bart.py | 2 - src/adapters/models/bart/modeling_bart.py | 5 +- src/adapters/models/beit/mixin_beit.py | 1 - src/adapters/models/bert/mixin_bert.py | 1 - src/adapters/models/bert/modeling_bert.py | 5 +- .../modeling_bert_generation.py | 5 +- src/adapters/models/deberta/mixin_deberta.py | 2 - .../models/deberta/modeling_deberta.py | 7 +-- .../models/deberta_v2/mixin_deberta_v2.py | 1 - .../models/deberta_v2/modeling_deberta_v2.py | 7 +-- .../models/distilbert/mixin_distilbert.py | 1 - .../models/distilbert/modeling_distilbert.py | 5 +- .../models/electra/modeling_electra.py | 5 +- src/adapters/models/llama/mixin_llama.py | 1 - src/adapters/models/llama/modeling_llama.py | 6 +-- src/adapters/models/mbart/modeling_mbart.py | 5 +- .../models/roberta/modeling_roberta.py | 5 +- .../xlm_roberta/modeling_xlm_roberta.py | 5 +- src/adapters/models/xmod/modeling_xmod.py | 5 +- src/adapters/utils.py | 48 ++++++++++++++++++- 31 files changed, 155 insertions(+), 93 deletions(-) diff --git a/.gitignore b/.gitignore index e3e8f2a311..5e4361fa20 100644 --- a/.gitignore +++ b/.gitignore @@ -73,7 +73,7 @@ instance/ .scrapy # Sphinx documentation -adapter_docs/_build/ +docs/_build/ adapter_docs/_build/ # PyBuilder diff --git a/docs/classes/adapter_config.rst b/docs/classes/adapter_config.rst index 44fff5a0f2..d817aad7c5 100644 --- a/docs/classes/adapter_config.rst +++ b/docs/classes/adapter_config.rst @@ -55,6 +55,13 @@ IA3Config :members: :inherited-members: Mapping +PromptTuningConfig +~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: adapters.PromptTuningConfig + :members: + :inherited-members: Mapping + Combined configurations ~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/methods.md b/docs/methods.md index 04cebbaaac..06ad700e68 100644 --- a/docs/methods.md +++ b/docs/methods.md @@ -267,3 +267,31 @@ model.reset_adapter() _Papers:_ - [Few-Shot Parameter-Efficient Fine-Tuning is Better and Cheaper than In-Context Learning](https://arxiv.org/pdf/2205.05638.pdf) (Liu et al., 2022) + +## Prompt Tuning +Prompt Tuning is an efficient fine-tuning technique proposed by Lester et al. (2021). Prompt tuning adds tunable tokens, called soft-prompts, that are prepended to the input text. +First, the input sequence ${x_1, x_2, \dots, x_n }$ gets embedded, resulting in the matrix $X_e \in \mathbb{R}^{n \times e}$ where $e$ is the dimension of +the embedding space. The soft-prompts with length $p$ are represented as $P_e \in \mathbb{R}^{p \times e}$. +$P_e$ and $X_e$ get concatenated, forming the input of the following encoder or decoder: + +$$ +\left[P_e; X_e\right] \in \mathbb{R}^{\left(p + n\right) \times e} +$$ + +The `PromptTuningConfig` has the properties: +- `prompt_length`: to set the soft-prompts length $p$ +- `prompt_init`: to set the weight initialisation method, which is either "random_uniform" or "from_string" to initialize each prompt token with an embedding drawn from the model’s vocabulary. + - `prompt_init_text` as the text use for initialisation if `prompt_init="from_string"` +- `combine`: To define if the prefix should be added before the embedded input sequence or after the BOS token + +To add Prompt Tuning to your model, you can use the predefined configs: +```python +from adapters import PromptTuningConfig + +config = PromptTuningConfig(prompt_length=10) +model.add_adapter("dummy", config=config) +``` + +_Papers:_ +- [The Power of Scale for Parameter-Efficient Prompt Tuning](https://aclanthology.org/2021.emnlp-main.243/) (Lester et al., 2021) + diff --git a/docs/model_overview.md b/docs/model_overview.md index 8198ea64d0..f2e7712e08 100644 --- a/docs/model_overview.md +++ b/docs/model_overview.md @@ -10,28 +10,28 @@ The table below further shows which model architectures support which adaptation E.g., for BERT, this means adapters provides a ``BertAdapterModel`` class, but you can also use ``BertModel``, ``BertForSequenceClassification`` etc. together with adapters. ``` -| Model | (Bottleneck)
Adapters | Prefix
Tuning | LoRA | Compacter | Adapter
Fusion | Invertible
Adapters | Parallel
block | -| --------------------------------------- | -| - | - | - | - | - | - | -| [ALBERT](classes/models/albert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [BART](classes/models/bart.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [BEIT](classes/models/beit.html) | ✅ | ✅ | ✅ | ✅ | ✅ | | | -| [BERT-Generation](classes/models/bert-generation.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [BERT](classes/models/bert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [CLIP](classes/models/clip.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | -| [DeBERTa](classes/models/deberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [DeBERTa-v2](classes/models/debertaV2.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [DistilBERT](classes/models/distilbert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [Electra](classes/models/electra.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [Encoder Decoder](classes/models/encoderdecoder.html) | (*) | (*) | (*) | (*) | (*) | (*) | | -| [GPT-2](classes/models/gpt2.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [GPT-J](classes/models/gptj.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [Llama](classes/models/llama.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [MBart](classes/models/mbart.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [RoBERTa](classes/models/roberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [T5](classes/models/t5.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [ViT](classes/models/vit.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [XLM-RoBERTa](classes/models/xlmroberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [X-MOD](classes/models/xmod.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| Model | (Bottleneck)
Adapters | Prefix
Tuning | LoRA | Compacter | Adapter
Fusion | Invertible
Adapters | Parallel
block | Prompt
Tuning | +| --------------------------------------- | -| - | - | - | - | - | - |- | +| [ALBERT](classes/models/albert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [BART](classes/models/bart.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [BEIT](classes/models/beit.html) | ✅ | ✅ | ✅ | ✅ | ✅ | | | ✅ | +| [BERT-Generation](classes/models/bert-generation.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [BERT](classes/models/bert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [CLIP](classes/models/clip.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | +| [DeBERTa](classes/models/deberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [DeBERTa-v2](classes/models/debertaV2.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [DistilBERT](classes/models/distilbert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [Electra](classes/models/electra.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [Encoder Decoder](classes/models/encoderdecoder.html) | (*) | (*) | (*) | (*) | (*) | (*) | | | +| [GPT-2](classes/models/gpt2.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | +| [GPT-J](classes/models/gptj.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | +| [Llama](classes/models/llama.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | +| [MBart](classes/models/mbart.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [RoBERTa](classes/models/roberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [T5](classes/models/t5.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | +| [ViT](classes/models/vit.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [XLM-RoBERTa](classes/models/xlmroberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [X-MOD](classes/models/xmod.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | (*) If the used encoder and decoder model class are supported. diff --git a/src/adapters/composition.py b/src/adapters/composition.py index 4a451503ea..5899b113d6 100644 --- a/src/adapters/composition.py +++ b/src/adapters/composition.py @@ -2,8 +2,6 @@ from collections.abc import Sequence from typing import List, Optional, Set, Union -import torch - class AdapterCompositionBlock(Sequence): def __init__(self, *children): @@ -244,23 +242,3 @@ def adjust_tensors_for_parallel_(hidden_states, *tensors): repeats[0] = hidden_states.shape[0] // tensor.shape[0] new_tensor = tensor.repeat(*repeats) tensor.set_(new_tensor) - - -def prefix_attention_mask(adapters_config, attention_mask=None, dim=3): - """ - Prepends a given attention mask with a tensor of ones of the length specified in the adapter configuration - `prefix_attention_mask_length`. `prefix_attention_mask_length` is set e.g. by prompt tuning. - """ - if attention_mask is not None and adapters_config.prefix_attention_mask_length is not None: - # Create a tensor of ones with the desired shape - ones_shape = list(attention_mask.shape) - ones_shape[dim] = adapters_config.prefix_attention_mask_length - prefix_attention_mask = torch.ones( - ones_shape, - dtype=attention_mask.dtype, - ).to(attention_mask.device) - - # Concatenate the prefix_attention_mask along the specified dimension - attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=dim) - - return attention_mask diff --git a/src/adapters/configuration/model_adapters_config.py b/src/adapters/configuration/model_adapters_config.py index 329c529765..aeb89d493f 100644 --- a/src/adapters/configuration/model_adapters_config.py +++ b/src/adapters/configuration/model_adapters_config.py @@ -32,8 +32,6 @@ def __init__(self, **kwargs): self.active_setup: Optional[AdapterCompositionBlock] = None self.skip_layers = None - self.prefix_attention_mask_length = None - def __contains__(self, item): return item in self.adapters.keys() diff --git a/src/adapters/context.py b/src/adapters/context.py index 784ed579ea..6b04d3b74f 100644 --- a/src/adapters/context.py +++ b/src/adapters/context.py @@ -78,7 +78,12 @@ class ForwardContext: # thread-local storage that holds a stack of active contexts storage = threading.local() - context_attributes = ["adapter_gating_scores", "adapter_fusion_attentions", "adapter_input_parallelized"] + context_attributes = [ + "adapter_gating_scores", + "adapter_fusion_attentions", + "adapter_input_parallelized", + "prefix_attention_mask_length", + ] def __init__(self, model, *args, **kwargs): # If the model has a method ``forward_context()``, use it to create the context. diff --git a/src/adapters/methods/prompt_tuning.py b/src/adapters/methods/prompt_tuning.py index fcf35e5c95..0a2a8af392 100644 --- a/src/adapters/methods/prompt_tuning.py +++ b/src/adapters/methods/prompt_tuning.py @@ -12,6 +12,7 @@ from ..composition import AdapterCompositionBlock from ..configuration import ModelAdaptersConfig, PromptTuningConfig +from ..context import ForwardContext from .adapter_layer_base import AdapterLayerBase @@ -97,7 +98,6 @@ def _init_prompt_embedding(self, base_model_embeddings: nn.Module) -> None: def forward(self, embedded_input): # Compute prompt embedding self.prompt_tokens = self.prompt_tokens.to(embedded_input.device) - self.prompt_embedding = self.prompt_embedding.to(embedded_input.device) prompt = self.prompt_embedding(self.prompt_tokens) # Prompt to batch size @@ -199,6 +199,8 @@ def forward(self, hidden_states: torch.Tensor): if first_adapter in self.prompt_tunings: hidden_states, prefix_attention_mask_length = self.prompt_tunings[first_adapter](hidden_states) - self.adapters_config.prefix_attention_mask_length = prefix_attention_mask_length + context = ForwardContext.get_context() + if context is not None: + context.prefix_attention_mask_length = prefix_attention_mask_length return hidden_states diff --git a/src/adapters/model_mixin.py b/src/adapters/model_mixin.py index 1010120a51..ae9678ddbc 100644 --- a/src/adapters/model_mixin.py +++ b/src/adapters/model_mixin.py @@ -955,6 +955,7 @@ def forward_context(self, context: ForwardContext, *args, **kwargs): context.output_adapter_fusion_attentions = kwargs.get("output_adapter_fusion_attentions", False) context.adapter_gating_scores = defaultdict(dict) context.adapter_fusion_attentions = defaultdict(dict) + context.prefix_attention_mask_length = kwargs.get("output_prefix_attention_mask_length", None) def get_fusion_regularization_loss(self): reg_loss = None diff --git a/src/adapters/models/albert/mixin_albert.py b/src/adapters/models/albert/mixin_albert.py index f21d4170b6..985f92ec4a 100644 --- a/src/adapters/models/albert/mixin_albert.py +++ b/src/adapters/models/albert/mixin_albert.py @@ -13,8 +13,6 @@ class AlbertAttentionAdaptersMixin: """Adds adapters to the AlbertAttention module of ALBERT.""" def init_adapters(self, model_config, adapters_config): - self.adapters_config = adapters_config - # Wrap layers for LoRA self.query = LoRALinear.wrap(self.query, "selfattn", model_config, adapters_config, attn_key="q") self.key = LoRALinear.wrap(self.key, "selfattn", model_config, adapters_config, attn_key="k") diff --git a/src/adapters/models/albert/modeling_albert.py b/src/adapters/models/albert/modeling_albert.py index 6ecd02e257..a9688ca7a3 100644 --- a/src/adapters/models/albert/modeling_albert.py +++ b/src/adapters/models/albert/modeling_albert.py @@ -23,7 +23,8 @@ from transformers.models.albert.modeling_albert import AlbertAttention, AlbertLayer from transformers.pytorch_utils import apply_chunking_to_forward -from ...composition import adjust_tensors_for_parallel, prefix_attention_mask +from ...composition import adjust_tensors_for_parallel +from ...utils import prefix_attention_mask from .mixin_albert import AlbertAttentionAdaptersMixin, AlbertEncoderLayerAdaptersMixin @@ -35,7 +36,7 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, output_attentions: bool = False, ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: - attention_mask = prefix_attention_mask(self.adapters_config, attention_mask) # type: ignore + attention_mask = prefix_attention_mask(attention_mask) # type: ignore mixed_query_layer = self.query(hidden_states) mixed_key_layer = self.key(hidden_states) diff --git a/src/adapters/models/bart/mixin_bart.py b/src/adapters/models/bart/mixin_bart.py index bd094549a1..25016d2cae 100644 --- a/src/adapters/models/bart/mixin_bart.py +++ b/src/adapters/models/bart/mixin_bart.py @@ -20,8 +20,6 @@ class BartAttentionAdaptersMixin: """Adds adapters to the BartAttention module.""" def init_adapters(self, model_config, adapters_config): - self.adapters_config = adapters_config - # Wrap layers for LoRA self.k_proj = LoRALinear.wrap(self.k_proj, "selfattn", model_config, adapters_config, attn_key="k") self.v_proj = LoRALinear.wrap(self.v_proj, "selfattn", model_config, adapters_config, attn_key="v") diff --git a/src/adapters/models/bart/modeling_bart.py b/src/adapters/models/bart/modeling_bart.py index cc50b76c28..10ba279849 100644 --- a/src/adapters/models/bart/modeling_bart.py +++ b/src/adapters/models/bart/modeling_bart.py @@ -21,7 +21,8 @@ from transformers.models.bart.modeling_bart import BartAttention, BartDecoderLayer, BartEncoderLayer -from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_, prefix_attention_mask +from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_ +from ...utils import prefix_attention_mask from .mixin_bart import BartAttentionAdaptersMixin, BartDecoderLayerAdaptersMixin, BartEncoderLayerAdaptersMixin @@ -176,7 +177,7 @@ def forward( returned tensors for more detail. """ adjust_tensors_for_parallel_(hidden_states, attention_mask) - attention_mask = prefix_attention_mask(self.adapters_config, attention_mask) # type: ignore + attention_mask = prefix_attention_mask(attention_mask, prefix_value=1) # type: ignore residual = hidden_states hidden_states, attn_weights, _ = self.self_attn( diff --git a/src/adapters/models/beit/mixin_beit.py b/src/adapters/models/beit/mixin_beit.py index f4d82672a4..e44fea0b82 100644 --- a/src/adapters/models/beit/mixin_beit.py +++ b/src/adapters/models/beit/mixin_beit.py @@ -11,7 +11,6 @@ class BeitSelfAttentionAdaptersMixin: def init_adapters(self, model_config, adapters_config): self.location_key = "self" - self.adapters_config = adapters_config # Wrap layers for LoRA self.query = LoRALinear.wrap(self.query, "selfattn", model_config, adapters_config, attn_key="q") diff --git a/src/adapters/models/bert/mixin_bert.py b/src/adapters/models/bert/mixin_bert.py index 657598f1cc..2ba92e2db8 100644 --- a/src/adapters/models/bert/mixin_bert.py +++ b/src/adapters/models/bert/mixin_bert.py @@ -17,7 +17,6 @@ class BertSelfAttentionAdaptersMixin: """Adds adapters to the BertSelfAttention module.""" def init_adapters(self, model_config, adapters_config): - self.adapters_config = adapters_config # Wrap layers for LoRA self.query = LoRALinear.wrap(self.query, "selfattn", model_config, adapters_config, attn_key="q") self.key = LoRALinear.wrap(self.key, "selfattn", model_config, adapters_config, attn_key="k") diff --git a/src/adapters/models/bert/modeling_bert.py b/src/adapters/models/bert/modeling_bert.py index 0f459e3dbf..1410e8fdd4 100644 --- a/src/adapters/models/bert/modeling_bert.py +++ b/src/adapters/models/bert/modeling_bert.py @@ -25,7 +25,8 @@ from transformers.models.bert.modeling_bert import BertOutput, BertSelfAttention, BertSelfOutput -from ...composition import adjust_tensors_for_parallel, prefix_attention_mask +from ...composition import adjust_tensors_for_parallel +from ...utils import prefix_attention_mask from .mixin_bert import BertOutputAdaptersMixin, BertSelfAttentionAdaptersMixin, BertSelfOutputAdaptersMixin @@ -40,7 +41,7 @@ def forward( past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor]: - attention_mask = prefix_attention_mask(self.adapters_config, attention_mask) # type: ignore + attention_mask = prefix_attention_mask(attention_mask) # type: ignore mixed_query_layer = self.query(hidden_states) diff --git a/src/adapters/models/bert_generation/modeling_bert_generation.py b/src/adapters/models/bert_generation/modeling_bert_generation.py index ce85a2c5af..0d51385031 100644 --- a/src/adapters/models/bert_generation/modeling_bert_generation.py +++ b/src/adapters/models/bert_generation/modeling_bert_generation.py @@ -27,7 +27,8 @@ BertGenerationSelfOutput, ) -from ...composition import adjust_tensors_for_parallel, prefix_attention_mask +from ...composition import adjust_tensors_for_parallel +from ...utils import prefix_attention_mask from ..bert.mixin_bert import BertOutputAdaptersMixin, BertSelfAttentionAdaptersMixin, BertSelfOutputAdaptersMixin @@ -52,7 +53,7 @@ def forward( past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor]: - attention_mask = prefix_attention_mask(self.adapters_config, attention_mask) # type: ignore + attention_mask = prefix_attention_mask(attention_mask) # type: ignore mixed_query_layer = self.query(hidden_states) diff --git a/src/adapters/models/deberta/mixin_deberta.py b/src/adapters/models/deberta/mixin_deberta.py index 04d02242dd..cee8530f02 100644 --- a/src/adapters/models/deberta/mixin_deberta.py +++ b/src/adapters/models/deberta/mixin_deberta.py @@ -6,8 +6,6 @@ class DebertaSelfAttentionAdaptersMixin: """Adds adapters to the BertSelfAttention module.""" def init_adapters(self, model_config, adapters_config): - self.adapters_config = adapters_config - # Wrap layers for LoRA self.in_proj = LoRAMergedLinear.wrap(self.in_proj, "selfattn", model_config, adapters_config) diff --git a/src/adapters/models/deberta/modeling_deberta.py b/src/adapters/models/deberta/modeling_deberta.py index 49a493ead1..9d0d9ac760 100644 --- a/src/adapters/models/deberta/modeling_deberta.py +++ b/src/adapters/models/deberta/modeling_deberta.py @@ -24,7 +24,8 @@ XSoftmax, ) -from ...composition import adjust_tensors_for_parallel, prefix_attention_mask +from ...composition import adjust_tensors_for_parallel +from ...utils import prefix_attention_mask from ..bert.mixin_bert import BertOutputAdaptersMixin, BertSelfOutputAdaptersMixin from .mixin_deberta import DebertaSelfAttentionAdaptersMixin @@ -94,8 +95,8 @@ def forward( """ - attention_mask = prefix_attention_mask(self.adapters_config, attention_mask, dim=3) # type: ignore - attention_mask = prefix_attention_mask(self.adapters_config, attention_mask, dim=2) # type: ignore + attention_mask = prefix_attention_mask(attention_mask, dim=3, prefix_value=1) # type: ignore + attention_mask = prefix_attention_mask(attention_mask, dim=2, prefix_value=1) # type: ignore if query_states is None: qp = self.in_proj(hidden_states) # .split(self.all_head_size, dim=-1) diff --git a/src/adapters/models/deberta_v2/mixin_deberta_v2.py b/src/adapters/models/deberta_v2/mixin_deberta_v2.py index f400f46795..f60e8788fb 100644 --- a/src/adapters/models/deberta_v2/mixin_deberta_v2.py +++ b/src/adapters/models/deberta_v2/mixin_deberta_v2.py @@ -6,7 +6,6 @@ class DebertaV2SelfAttentionAdaptersMixin: """Adds adapters to the BertSelfAttention module.""" def init_adapters(self, model_config, adapters_config): - self.adapters_config = adapters_config # Wrap layers for LoRA self.query_proj = LoRALinear.wrap(self.query_proj, "selfattn", model_config, adapters_config, attn_key="q") self.key_proj = LoRALinear.wrap(self.key_proj, "selfattn", model_config, adapters_config, attn_key="k") diff --git a/src/adapters/models/deberta_v2/modeling_deberta_v2.py b/src/adapters/models/deberta_v2/modeling_deberta_v2.py index 3bfa909bdf..c180f38c98 100644 --- a/src/adapters/models/deberta_v2/modeling_deberta_v2.py +++ b/src/adapters/models/deberta_v2/modeling_deberta_v2.py @@ -24,7 +24,8 @@ XSoftmax, ) -from ...composition import adjust_tensors_for_parallel, prefix_attention_mask +from ...composition import adjust_tensors_for_parallel +from ...utils import prefix_attention_mask from ..bert.mixin_bert import BertOutputAdaptersMixin, BertSelfOutputAdaptersMixin from .mixin_deberta_v2 import DebertaV2SelfAttentionAdaptersMixin @@ -89,8 +90,8 @@ def forward( The embedding of relative distances. It's a tensor of shape [\\(2 \\times \\text{max_relative_positions}\\), *hidden_size*]. """ - attention_mask = prefix_attention_mask(self.adapters_config, attention_mask, dim=3) # type: ignore - attention_mask = prefix_attention_mask(self.adapters_config, attention_mask, dim=2) # type: ignore + attention_mask = prefix_attention_mask(attention_mask, dim=3, prefix_value=1) # type: ignore + attention_mask = prefix_attention_mask(attention_mask, dim=2, prefix_value=1) # type: ignore if query_states is None: query_states = hidden_states diff --git a/src/adapters/models/distilbert/mixin_distilbert.py b/src/adapters/models/distilbert/mixin_distilbert.py index d7a0f549bc..4694b0ca63 100644 --- a/src/adapters/models/distilbert/mixin_distilbert.py +++ b/src/adapters/models/distilbert/mixin_distilbert.py @@ -24,7 +24,6 @@ class DistilBertTransfomerBlockAdaptersMixin: """Adds adapters to the TransformerBlock module of DistilBert.""" def init_adapters(self, model_config, adapters_config): - self.adapters_config = adapters_config # Wrap layers for LoRA self.ffn.lin1 = LoRALinear.wrap(self.ffn.lin1, "intermediate", model_config, adapters_config) self.ffn.lin2 = LoRALinear.wrap(self.ffn.lin2, "output", model_config, adapters_config) diff --git a/src/adapters/models/distilbert/modeling_distilbert.py b/src/adapters/models/distilbert/modeling_distilbert.py index cb688c46be..495e344275 100644 --- a/src/adapters/models/distilbert/modeling_distilbert.py +++ b/src/adapters/models/distilbert/modeling_distilbert.py @@ -27,7 +27,8 @@ from transformers.models.distilbert.modeling_distilbert import MultiHeadSelfAttention, TransformerBlock -from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_, prefix_attention_mask +from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_ +from ...utils import prefix_attention_mask from .mixin_distilbert import DistilBertMultiHeadSelfAttentionMixin, DistilBertTransfomerBlockAdaptersMixin @@ -118,7 +119,7 @@ def forward( torch.tensor(bs, seq_length, dim) The output of the transformer block contextualization. """ adjust_tensors_for_parallel_(x, attn_mask) - attn_mask = prefix_attention_mask(self.adapters_config, attn_mask, dim=1) # type: ignore + attn_mask = prefix_attention_mask(attn_mask, dim=1, prefix_value=1) # type: ignore # Self-Attention sa_output = self.attention( diff --git a/src/adapters/models/electra/modeling_electra.py b/src/adapters/models/electra/modeling_electra.py index 2570c5a3e0..ea8a07ac84 100644 --- a/src/adapters/models/electra/modeling_electra.py +++ b/src/adapters/models/electra/modeling_electra.py @@ -6,7 +6,8 @@ from transformers.models.electra.modeling_electra import ElectraOutput, ElectraSelfAttention, ElectraSelfOutput -from ...composition import adjust_tensors_for_parallel, prefix_attention_mask +from ...composition import adjust_tensors_for_parallel +from ...utils import prefix_attention_mask from ..bert.mixin_bert import BertOutputAdaptersMixin, BertSelfAttentionAdaptersMixin, BertSelfOutputAdaptersMixin @@ -21,7 +22,7 @@ def forward( past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor]: - attention_mask = prefix_attention_mask(self.adapters_config, attention_mask) # type: ignore + attention_mask = prefix_attention_mask(attention_mask) # type: ignore mixed_query_layer = self.query(hidden_states) diff --git a/src/adapters/models/llama/mixin_llama.py b/src/adapters/models/llama/mixin_llama.py index d7a37665f4..d1ad3ddce3 100644 --- a/src/adapters/models/llama/mixin_llama.py +++ b/src/adapters/models/llama/mixin_llama.py @@ -10,7 +10,6 @@ class LlamaAttentionMixin: def init_adapters(self, model_config, adapters_config): - self.adapters_config = adapters_config self.q_proj = LoRALinear.wrap(self.q_proj, "selfattn", model_config, adapters_config, attn_key="q") self.k_proj = LoRALinear.wrap(self.k_proj, "selfattn", model_config, adapters_config, attn_key="k") self.v_proj = LoRALinear.wrap(self.v_proj, "selfattn", model_config, adapters_config, attn_key="v") diff --git a/src/adapters/models/llama/modeling_llama.py b/src/adapters/models/llama/modeling_llama.py index 0cc6ec2953..f16b65e9c6 100644 --- a/src/adapters/models/llama/modeling_llama.py +++ b/src/adapters/models/llama/modeling_llama.py @@ -25,7 +25,7 @@ import torch.utils.checkpoint from torch import nn -from adapters.composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_, prefix_attention_mask +from adapters.composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_ from transformers.models.llama.modeling_llama import apply_rotary_pos_emb from transformers.utils import logging @@ -47,10 +47,6 @@ def forward( output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - attention_mask = prefix_attention_mask(self.adapters_config, attention_mask, dim=3) # type: ignore - attention_mask = prefix_attention_mask(self.adapters_config, attention_mask, dim=2) # type: ignore - position_ids = prefix_attention_mask(self.adapters_config, position_ids, dim=1) # type: ignore - bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) diff --git a/src/adapters/models/mbart/modeling_mbart.py b/src/adapters/models/mbart/modeling_mbart.py index 1cfbcadcb6..91535fdae1 100644 --- a/src/adapters/models/mbart/modeling_mbart.py +++ b/src/adapters/models/mbart/modeling_mbart.py @@ -21,7 +21,8 @@ from transformers.models.mbart.modeling_mbart import MBartAttention, MBartDecoderLayer, MBartEncoderLayer -from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_, prefix_attention_mask +from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_ +from ...utils import prefix_attention_mask from ..bart.mixin_bart import BartAttentionAdaptersMixin, BartDecoderLayerAdaptersMixin, BartEncoderLayerAdaptersMixin @@ -176,7 +177,7 @@ def forward( returned tensors for more detail. """ adjust_tensors_for_parallel_(hidden_states, attention_mask) - attention_mask = prefix_attention_mask(self.adapters_config, attention_mask, dim=3) # type: ignore + attention_mask = prefix_attention_mask(attention_mask, prefix_value=1) # type: ignore residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) diff --git a/src/adapters/models/roberta/modeling_roberta.py b/src/adapters/models/roberta/modeling_roberta.py index 867a1df8cf..8946ac2743 100644 --- a/src/adapters/models/roberta/modeling_roberta.py +++ b/src/adapters/models/roberta/modeling_roberta.py @@ -24,7 +24,8 @@ from transformers.models.roberta.modeling_roberta import RobertaOutput, RobertaSelfAttention, RobertaSelfOutput -from ...composition import adjust_tensors_for_parallel, prefix_attention_mask +from ...composition import adjust_tensors_for_parallel +from ...utils import prefix_attention_mask from ..bert.mixin_bert import BertOutputAdaptersMixin, BertSelfAttentionAdaptersMixin, BertSelfOutputAdaptersMixin @@ -40,7 +41,7 @@ def forward( past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor]: - attention_mask = prefix_attention_mask(self.adapters_config, attention_mask) # type: ignore + attention_mask = prefix_attention_mask(attention_mask) # type: ignore mixed_query_layer = self.query(hidden_states) diff --git a/src/adapters/models/xlm_roberta/modeling_xlm_roberta.py b/src/adapters/models/xlm_roberta/modeling_xlm_roberta.py index 2cd75394d5..90f0a2ebfc 100644 --- a/src/adapters/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/adapters/models/xlm_roberta/modeling_xlm_roberta.py @@ -28,7 +28,8 @@ XLMRobertaSelfOutput, ) -from ...composition import adjust_tensors_for_parallel, prefix_attention_mask +from ...composition import adjust_tensors_for_parallel +from ...utils import prefix_attention_mask from ..bert.mixin_bert import BertOutputAdaptersMixin, BertSelfAttentionAdaptersMixin, BertSelfOutputAdaptersMixin @@ -44,7 +45,7 @@ def forward( past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor]: - attention_mask = prefix_attention_mask(self.adapters_config, attention_mask) # type: ignore + attention_mask = prefix_attention_mask(attention_mask) # type: ignore mixed_query_layer = self.query(hidden_states) diff --git a/src/adapters/models/xmod/modeling_xmod.py b/src/adapters/models/xmod/modeling_xmod.py index c4f43efce7..91be6cc44b 100644 --- a/src/adapters/models/xmod/modeling_xmod.py +++ b/src/adapters/models/xmod/modeling_xmod.py @@ -23,7 +23,8 @@ from transformers.models.xmod.modeling_xmod import XmodOutput, XmodSelfAttention, XmodSelfOutput -from ...composition import adjust_tensors_for_parallel, prefix_attention_mask +from ...composition import adjust_tensors_for_parallel +from ...utils import prefix_attention_mask from ..bert.mixin_bert import BertOutputAdaptersMixin, BertSelfAttentionAdaptersMixin, BertSelfOutputAdaptersMixin @@ -39,7 +40,7 @@ def forward( past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor]: - attention_mask = prefix_attention_mask(self.adapters_config, attention_mask) # type: ignore + attention_mask = prefix_attention_mask(attention_mask) # type: ignore mixed_query_layer = self.query(hidden_states) diff --git a/src/adapters/utils.py b/src/adapters/utils.py index cca9ee0e6b..9f1cbfe3a8 100644 --- a/src/adapters/utils.py +++ b/src/adapters/utils.py @@ -21,6 +21,8 @@ from urllib.parse import urlparse from zipfile import ZipFile, is_zipfile +import torch + import requests from filelock import FileLock from huggingface_hub import HfApi, HfFolder, snapshot_download @@ -36,6 +38,7 @@ from transformers.utils.hub import torch_cache_home from . import __version__ +from .context import ForwardContext logger = logging.getLogger(__name__) @@ -287,7 +290,6 @@ def get_from_cache( # Prevent parallel downloads of the same file with a lock. lock_path = cache_path + ".lock" with FileLock(lock_path): - # If the download just completed while the lock was activated. if os.path.exists(cache_path) and not force_download: # Even if returning early like here, the lock will be released. @@ -819,3 +821,47 @@ def get_adapter_info(adapter_id: str, source: str = "ah") -> Optional[AdapterInf return None else: raise ValueError("Please specify either 'ah' or 'hf' as source.") + + +def prefix_attention_mask(attention_mask, dim: int = 3, prefix_value: int = 0): + """ + Adds a prefix to an attention mask. The length of the prefix is determined by the `prefix_attention_mask_length` + attribute in the ForwardContext. + + Args: + attention_mask: + The attention mask to add the prefix to. + dim (int): + The dimension along which to concatenate the prefix_attention_mask. Defaults to 3. + prefix_value (int): + The value to use for the prefix_attention_mask. Defaults to 0, however some models, e.g. DistilBert, use + different values. BERT like models invert their extended_attention_mask, hence they use 0 as value for not + masked tokens. This inversion is usually done in the forward method of the model in 2 different ways: + 1) by calling self.invert_attention_mask, as BERT does 2) by doing the inversion manually, e.g. ALBERT + does: `extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min` + """ + + forward_context = ForwardContext.get_context() + + print(f"In prefix_attention_mask: {attention_mask}") + + if ( + attention_mask is not None + and forward_context is not None + and forward_context.prefix_attention_mask_length is not None + ): + # Create a tensor of ones with the desired shape + ones_shape = list(attention_mask.shape) + ones_shape[dim] = forward_context.prefix_attention_mask_length + + prefix_attention_mask = torch.full( + ones_shape, + prefix_value, + dtype=attention_mask.dtype, + ).to(attention_mask.device) + + # Concatenate the prefix_attention_mask along the specified dimension + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=dim) + print(f"attention_mask after concat: {attention_mask}") + + return attention_mask From 6613b343a1997f0afb8be15a433e0e7e995d2c33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leon=20Engl=C3=A4nder?= Date: Sun, 29 Oct 2023 19:26:58 +0100 Subject: [PATCH 05/13] remove left over prints --- src/adapters/utils.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/adapters/utils.py b/src/adapters/utils.py index 9f1cbfe3a8..7c52f90ad2 100644 --- a/src/adapters/utils.py +++ b/src/adapters/utils.py @@ -843,8 +843,6 @@ def prefix_attention_mask(attention_mask, dim: int = 3, prefix_value: int = 0): forward_context = ForwardContext.get_context() - print(f"In prefix_attention_mask: {attention_mask}") - if ( attention_mask is not None and forward_context is not None @@ -862,6 +860,5 @@ def prefix_attention_mask(attention_mask, dim: int = 3, prefix_value: int = 0): # Concatenate the prefix_attention_mask along the specified dimension attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=dim) - print(f"attention_mask after concat: {attention_mask}") return attention_mask From 5ec546c740540e996c10fa56ca5ade5e0951f71a Mon Sep 17 00:00:00 2001 From: calpt Date: Thu, 2 Nov 2023 21:51:21 +0100 Subject: [PATCH 06/13] import fix --- src/adapters/model_mixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/adapters/model_mixin.py b/src/adapters/model_mixin.py index ae9678ddbc..dccb9e95ce 100644 --- a/src/adapters/model_mixin.py +++ b/src/adapters/model_mixin.py @@ -23,7 +23,7 @@ from .methods.lora import LoRALayer from .methods.modeling import Adapter, GLOWCouplingBlock, NICECouplingBlock, init_shared_parameters from .methods.prefix_tuning import PrefixTuningLayer, PrefixTuningPool -from .prompt_tuning import PromptTuningLayer +from .methods.prompt_tuning import PromptTuningLayer from .utils import EMBEDDING_FILE, TOKENIZER_PATH, get_adapter_config_hash, inherit_doc from .wrappers.configuration import SUBMODEL_NAMES, init_adapters_config From 291c265b4fc2a490c1bfaff20e2f13aa895604de Mon Sep 17 00:00:00 2001 From: calpt Date: Mon, 6 Nov 2023 14:35:04 +0100 Subject: [PATCH 07/13] Fix prompt tuning training for LM & tagging flex heads --- src/adapters/context.py | 14 ++++++++-- src/adapters/heads/base.py | 28 ++++++++++++++++++- src/adapters/heads/language_modeling.py | 10 +++++++ src/adapters/methods/prompt_tuning.py | 2 +- src/adapters/model_mixin.py | 6 ++-- src/adapters/models/albert/adapter_model.py | 5 +++- src/adapters/models/bart/adapter_model.py | 5 +++- src/adapters/models/beit/adapter_model.py | 5 +++- src/adapters/models/bert/adapter_model.py | 5 +++- .../models/bert_generation/adapter_model.py | 5 +++- src/adapters/models/clip/adapter_model.py | 5 +++- src/adapters/models/deberta/adapter_model.py | 5 +++- .../models/deberta_v2/adapter_model.py | 5 +++- .../models/distilbert/adapter_model.py | 5 +++- src/adapters/models/electra/adapter_model.py | 5 +++- src/adapters/models/gpt2/adapter_model.py | 5 +++- src/adapters/models/gptj/adapter_model.py | 5 +++- src/adapters/models/llama/adapter_model.py | 5 +++- src/adapters/models/mbart/adapter_model.py | 5 +++- src/adapters/models/roberta/adapter_model.py | 5 +++- src/adapters/models/t5/adapter_model.py | 5 +++- src/adapters/models/vit/adapter_model.py | 5 +++- .../models/xlm_roberta/adapter_model.py | 5 +++- src/adapters/models/xmod/adapter_model.py | 5 +++- src/adapters/utils.py | 4 +-- tests_adapters/methods/test_prompt_tuning.py | 6 ++-- 26 files changed, 132 insertions(+), 33 deletions(-) diff --git a/src/adapters/context.py b/src/adapters/context.py index 6b04d3b74f..70e685d037 100644 --- a/src/adapters/context.py +++ b/src/adapters/context.py @@ -82,8 +82,9 @@ class ForwardContext: "adapter_gating_scores", "adapter_fusion_attentions", "adapter_input_parallelized", - "prefix_attention_mask_length", ] + # Additional used attributes not exposed to the user + # - prompt_tokens_length: length of the prompt tokens def __init__(self, model, *args, **kwargs): # If the model has a method ``forward_context()``, use it to create the context. @@ -107,6 +108,8 @@ def wrap(cls, f): def wrapper_func(self, *args, **kwargs): if self.adapters_config is not None: with cls(self, *args, **kwargs) as ctx: + # whether to output the context attributes + output_context = kwargs.pop("output_context", False) kwargs = { k: v for k, v in kwargs.items() if k.replace("output_", "") not in cls.context_attributes } @@ -121,7 +124,14 @@ def wrapper_func(self, *args, **kwargs): for attr in cls.context_attributes: if getattr(ctx, "output_" + attr, False): results[attr] = dict(getattr(ctx, attr)) - return results + + if output_context: + context_dict = ctx.__dict__ + + if output_context: + return results, context_dict + else: + return results else: return f(self, *args, **kwargs) diff --git a/src/adapters/heads/base.py b/src/adapters/heads/base.py index d9c7386fec..d907aba600 100644 --- a/src/adapters/heads/base.py +++ b/src/adapters/heads/base.py @@ -325,6 +325,19 @@ def forward(self, outputs, cls_output=None, attention_mask=None, return_dict=Fal labels = kwargs.pop("labels", None) if labels is not None: loss_fct = CrossEntropyLoss() + # adjust labels for prompt tuning + if kwargs.get("prompt_tokens_length", 0) > 0: + prompt_length = kwargs.get("prompt_tokens_length") + prompt_labels = torch.full( + (labels.shape[0], prompt_length), loss_fct.ignore_index, dtype=torch.long, device=labels.device + ) + labels = torch.cat((prompt_labels, labels), dim=-1) + if attention_mask is not None: + attention_mask = torch.cat( + (torch.ones_like(prompt_labels, dtype=torch.long, device=labels.device), attention_mask), + dim=-1, + ) + # Only keep active parts of the loss if attention_mask is not None: active_loss = attention_mask.view(-1) == 1 @@ -752,7 +765,14 @@ def _get_used_heads(self, head_name: str = None): return head_modules def forward_head( - self, all_outputs, head_name=None, cls_output=None, attention_mask=None, return_dict=False, **kwargs + self, + all_outputs, + head_name=None, + cls_output=None, + attention_mask=None, + return_dict=False, + context=None, + **kwargs ): """ The forward pass through a prediction head configuration. There are three ways to specify the used prediction @@ -800,6 +820,12 @@ def _get_head_input(outputs, cls_out, batch): if inv_adapter: kwargs["invertible_adapter"] = inv_adapter + # Set prompt tokens length + if context is not None: + prompt_tokens_length = context.get("prompt_tokens_length", None) + if prompt_tokens_length is not None: + kwargs["prompt_tokens_length"] = prompt_tokens_length + if isinstance(self.active_head, BatchSplit): if sum(self.active_head.batch_sizes) != all_outputs[0].size()[0]: raise ValueError( diff --git a/src/adapters/heads/language_modeling.py b/src/adapters/heads/language_modeling.py index 3e0cda610a..bf91e5be08 100644 --- a/src/adapters/heads/language_modeling.py +++ b/src/adapters/heads/language_modeling.py @@ -1,3 +1,4 @@ +import torch import torch.nn as nn from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast, MaskedLMOutput, Seq2SeqLMOutput @@ -118,6 +119,15 @@ def forward(self, outputs, cls_output=None, attention_mask=None, return_dict=Fal labels = labels[..., 1:].contiguous() else: logits_for_loss = lm_logits + + # adjust labels for prompt tuning + if kwargs.get("prompt_tokens_length", 0) > 0: + prompt_length = kwargs.get("prompt_tokens_length") + prompt_labels = torch.full( + (labels.shape[0], prompt_length), loss_fct.ignore_index, dtype=torch.long, device=labels.device + ) + labels = torch.cat((prompt_labels, labels), dim=-1) + loss = loss_fct(logits_for_loss.view(-1, self.config["vocab_size"]), labels.view(-1)) if return_dict: diff --git a/src/adapters/methods/prompt_tuning.py b/src/adapters/methods/prompt_tuning.py index 0a2a8af392..5fd008e133 100644 --- a/src/adapters/methods/prompt_tuning.py +++ b/src/adapters/methods/prompt_tuning.py @@ -201,6 +201,6 @@ def forward(self, hidden_states: torch.Tensor): context = ForwardContext.get_context() if context is not None: - context.prefix_attention_mask_length = prefix_attention_mask_length + context.prompt_tokens_length = prefix_attention_mask_length return hidden_states diff --git a/src/adapters/model_mixin.py b/src/adapters/model_mixin.py index dccb9e95ce..dda4f6dcec 100644 --- a/src/adapters/model_mixin.py +++ b/src/adapters/model_mixin.py @@ -5,16 +5,15 @@ from abc import ABC, abstractmethod from collections import defaultdict from os.path import join -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch from torch import nn -from transformers.configuration_utils import PretrainedConfig from transformers.modeling_outputs import ModelOutput from .composition import AdapterCompositionBlock, Fuse, Stack, parse_composition -from .configuration import ADAPTER_CONFIG_MAP, AdapterConfigBase, AdapterFusionConfig, BnConfig, ModelAdaptersConfig +from .configuration import ADAPTER_CONFIG_MAP, AdapterConfigBase, AdapterFusionConfig, BnConfig from .context import AdapterSetup, ForwardContext from .hub_mixin import PushAdapterToHubMixin from .loading import AdapterFusionLoader, AdapterLoader, PredictionHeadLoader, WeightsLoader @@ -955,7 +954,6 @@ def forward_context(self, context: ForwardContext, *args, **kwargs): context.output_adapter_fusion_attentions = kwargs.get("output_adapter_fusion_attentions", False) context.adapter_gating_scores = defaultdict(dict) context.adapter_fusion_attentions = defaultdict(dict) - context.prefix_attention_mask_length = kwargs.get("output_prefix_attention_mask_length", None) def get_fusion_regularization_loss(self): reg_loss = None diff --git a/src/adapters/models/albert/adapter_model.py b/src/adapters/models/albert/adapter_model.py index 9a1f45ed2e..8261e68760 100644 --- a/src/adapters/models/albert/adapter_model.py +++ b/src/adapters/models/albert/adapter_model.py @@ -64,7 +64,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.albert( + outputs, context = self.albert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -77,7 +77,10 @@ def forward( output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + output_context=True, ) + # required e.g. for prompt tuning in all models + kwargs["context"] = context # BERT & RoBERTa & ALBERT return the pooled output as second item, we don't need that in these heads if not return_dict: diff --git a/src/adapters/models/bart/adapter_model.py b/src/adapters/models/bart/adapter_model.py index 3fc3dfd73c..f6fc12a587 100644 --- a/src/adapters/models/bart/adapter_model.py +++ b/src/adapters/models/bart/adapter_model.py @@ -76,7 +76,7 @@ def forward( if "labels" in kwargs or "start_positions" in kwargs and "end_positions" in kwargs: use_cache = False - outputs = self.model( + outputs, context = self.model( input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, @@ -95,7 +95,10 @@ def forward( output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + output_context=True, ) + # required e.g. for prompt tuning in all models + kwargs["context"] = context # sequence classification based on last token in sequence x = outputs[0] # last hidden state if input_ids is not None and x.shape[1] == input_ids.shape[1]: diff --git a/src/adapters/models/beit/adapter_model.py b/src/adapters/models/beit/adapter_model.py index 22d3e6aa4d..ceeda7b82c 100644 --- a/src/adapters/models/beit/adapter_model.py +++ b/src/adapters/models/beit/adapter_model.py @@ -47,7 +47,7 @@ def forward( ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.beit( + outputs, context = self.beit( pixel_values, bool_masked_pos=bool_masked_pos, head_mask=head_mask, @@ -57,7 +57,10 @@ def forward( output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + output_context=True, ) + # required e.g. for prompt tuning in all models + kwargs["context"] = context # BERT & RoBERTa return the pooled output as second item, we don't need that in these heads if not return_dict: diff --git a/src/adapters/models/bert/adapter_model.py b/src/adapters/models/bert/adapter_model.py index 4ff1aaf61a..02ad9411c4 100644 --- a/src/adapters/models/bert/adapter_model.py +++ b/src/adapters/models/bert/adapter_model.py @@ -66,7 +66,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.bert( + outputs, context = self.bert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -79,7 +79,10 @@ def forward( output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + output_context=True, ) + # required e.g. for prompt tuning in all models + kwargs["context"] = context # BERT & RoBERTa return the pooled output as second item, we don't need that in these heads if not return_dict: head_inputs = (outputs[0],) + outputs[2:] diff --git a/src/adapters/models/bert_generation/adapter_model.py b/src/adapters/models/bert_generation/adapter_model.py index c251af1517..1fe0152a6a 100644 --- a/src/adapters/models/bert_generation/adapter_model.py +++ b/src/adapters/models/bert_generation/adapter_model.py @@ -62,7 +62,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.bert( + outputs, context = self.bert( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -78,7 +78,10 @@ def forward( output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + output_context=True, ) + # required e.g. for prompt tuning in all models + kwargs["context"] = context # BERT & RoBERTa return the pooled output as second item, we don't need that in these heads if not return_dict: head_inputs = (outputs[0],) + outputs[2:] diff --git a/src/adapters/models/clip/adapter_model.py b/src/adapters/models/clip/adapter_model.py index 6191cd3001..39382757e5 100644 --- a/src/adapters/models/clip/adapter_model.py +++ b/src/adapters/models/clip/adapter_model.py @@ -44,7 +44,7 @@ def forward( output_adapter_fusion_attentions=False, **kwargs ): - outputs = self.clip( + outputs, context = self.clip( input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, @@ -56,7 +56,10 @@ def forward( output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + output_context=True, ) + # required e.g. for prompt tuning in all models + kwargs["context"] = context if head or AdapterSetup.get_context_head_setup() or self.active_head: head_outputs = self.forward_head( diff --git a/src/adapters/models/deberta/adapter_model.py b/src/adapters/models/deberta/adapter_model.py index 0d74e58593..4b44991d66 100644 --- a/src/adapters/models/deberta/adapter_model.py +++ b/src/adapters/models/deberta/adapter_model.py @@ -57,7 +57,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.deberta( + outputs, context = self.deberta( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -69,7 +69,10 @@ def forward( output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + output_context=True, ) + # required e.g. for prompt tuning in all models + kwargs["context"] = context # BERT & RoBERTa return the pooled output as second item, we don't need that in these heads if not return_dict: head_inputs = (outputs[0],) + outputs[2:] diff --git a/src/adapters/models/deberta_v2/adapter_model.py b/src/adapters/models/deberta_v2/adapter_model.py index bc2f6e6ed2..a980d99177 100644 --- a/src/adapters/models/deberta_v2/adapter_model.py +++ b/src/adapters/models/deberta_v2/adapter_model.py @@ -60,7 +60,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.deberta( + outputs, context = self.deberta( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -72,7 +72,10 @@ def forward( output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + output_context=True, ) + # required e.g. for prompt tuning in all models + kwargs["context"] = context # BERT & RoBERTa return the pooled output as second item, we don't need that in these heads if not return_dict: head_inputs = (outputs[0],) + outputs[2:] diff --git a/src/adapters/models/distilbert/adapter_model.py b/src/adapters/models/distilbert/adapter_model.py index 9a2294ac89..ee7fb57bf9 100644 --- a/src/adapters/models/distilbert/adapter_model.py +++ b/src/adapters/models/distilbert/adapter_model.py @@ -85,7 +85,7 @@ def forward( else None ) - distilbert_output = self.distilbert( + distilbert_output, context = self.distilbert( input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, @@ -96,7 +96,10 @@ def forward( output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + output_context=True, ) + # required e.g. for prompt tuning in all models + kwargs["context"] = context outputs = self.forward_head( distilbert_output, head_name=head, attention_mask=attention_mask, return_dict=return_dict, **kwargs diff --git a/src/adapters/models/electra/adapter_model.py b/src/adapters/models/electra/adapter_model.py index 2d7994d3a4..6dbd02569f 100644 --- a/src/adapters/models/electra/adapter_model.py +++ b/src/adapters/models/electra/adapter_model.py @@ -66,7 +66,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.electra( + outputs, context = self.electra( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -79,7 +79,10 @@ def forward( output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + output_context=True, ) + # required e.g. for prompt tuning in all models + kwargs["context"] = context head_inputs = outputs diff --git a/src/adapters/models/gpt2/adapter_model.py b/src/adapters/models/gpt2/adapter_model.py index b4e1b53d54..15501a9315 100644 --- a/src/adapters/models/gpt2/adapter_model.py +++ b/src/adapters/models/gpt2/adapter_model.py @@ -68,7 +68,7 @@ def forward( ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.transformer( + outputs, context = self.transformer( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, @@ -85,7 +85,10 @@ def forward( output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + output_context=True, ) + # required e.g. for prompt tuning in all models + kwargs["context"] = context batch_size = outputs[0].shape[0] diff --git a/src/adapters/models/gptj/adapter_model.py b/src/adapters/models/gptj/adapter_model.py index 625cd0febc..46b5844171 100644 --- a/src/adapters/models/gptj/adapter_model.py +++ b/src/adapters/models/gptj/adapter_model.py @@ -66,7 +66,7 @@ def forward( ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.transformer( + outputs, context = self.transformer( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, @@ -81,7 +81,10 @@ def forward( output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + output_context=True, ) + # required e.g. for prompt tuning in all models + kwargs["context"] = context batch_size = outputs[0].shape[0] diff --git a/src/adapters/models/llama/adapter_model.py b/src/adapters/models/llama/adapter_model.py index 43cec8abbf..1ff84e9c66 100644 --- a/src/adapters/models/llama/adapter_model.py +++ b/src/adapters/models/llama/adapter_model.py @@ -68,7 +68,7 @@ def forward( ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.model( + outputs, context = self.model( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, @@ -81,7 +81,10 @@ def forward( output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + output_context=True, ) + # required e.g. for prompt tuning in all models + kwargs["context"] = context batch_size = outputs[0].shape[0] diff --git a/src/adapters/models/mbart/adapter_model.py b/src/adapters/models/mbart/adapter_model.py index ae86c35c17..58bfc7bdb9 100644 --- a/src/adapters/models/mbart/adapter_model.py +++ b/src/adapters/models/mbart/adapter_model.py @@ -76,7 +76,7 @@ def forward( if "labels" in kwargs or "start_positions" in kwargs and "end_positions" in kwargs: use_cache = False - outputs = self.model( + outputs, context = self.model( input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, @@ -95,7 +95,10 @@ def forward( output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + output_context=True, ) + # required e.g. for prompt tuning in all models + kwargs["context"] = context # sequence classification based on last token in sequence x = outputs[0] # last hidden state if input_ids is not None and x.shape[1] == input_ids.shape[1]: diff --git a/src/adapters/models/roberta/adapter_model.py b/src/adapters/models/roberta/adapter_model.py index 13bf8b8102..3a08f33639 100644 --- a/src/adapters/models/roberta/adapter_model.py +++ b/src/adapters/models/roberta/adapter_model.py @@ -66,7 +66,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.roberta( + outputs, context = self.roberta( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -79,7 +79,10 @@ def forward( output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + output_context=True, ) + # required e.g. for prompt tuning in all models + kwargs["context"] = context # BERT & RoBERTa return the pooled output as second item, we don't need that in these heads if not return_dict: head_inputs = (outputs[0],) + outputs[2:] diff --git a/src/adapters/models/t5/adapter_model.py b/src/adapters/models/t5/adapter_model.py index 66441727c7..942b86fd3c 100644 --- a/src/adapters/models/t5/adapter_model.py +++ b/src/adapters/models/t5/adapter_model.py @@ -82,7 +82,7 @@ def forward( # decoder_input_ids from input_ids if no decoder_input_ids are provided decoder_input_ids = self._shift_right(input_ids) - model_output = self.transformer( + model_output, context = self.transformer( input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, @@ -101,7 +101,10 @@ def forward( output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + output_context=True, ) + # required e.g. for prompt tuning in all models + kwargs["context"] = context sequence_output = model_output[0] # ToDo move head to device for parallel forward pass diff --git a/src/adapters/models/vit/adapter_model.py b/src/adapters/models/vit/adapter_model.py index 33eaaf2ea0..254a5ab0d7 100644 --- a/src/adapters/models/vit/adapter_model.py +++ b/src/adapters/models/vit/adapter_model.py @@ -47,7 +47,7 @@ def forward( ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.vit( + outputs, context = self.vit( pixel_values, head_mask=head_mask, output_attentions=output_attentions, @@ -57,7 +57,10 @@ def forward( output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + output_context=True, ) + # required e.g. for prompt tuning in all models + kwargs["context"] = context # BERT & RoBERTa return the pooled output as second item, we don't need that in these heads if not return_dict: diff --git a/src/adapters/models/xlm_roberta/adapter_model.py b/src/adapters/models/xlm_roberta/adapter_model.py index ab1ca81f79..33963d5f1e 100644 --- a/src/adapters/models/xlm_roberta/adapter_model.py +++ b/src/adapters/models/xlm_roberta/adapter_model.py @@ -68,7 +68,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.roberta( + outputs, context = self.roberta( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -81,7 +81,10 @@ def forward( output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + output_context=True, ) + # required e.g. for prompt tuning in all models + kwargs["context"] = context # BERT & RoBERTa return the pooled output as second item, we don't need that in these heads if not return_dict: head_inputs = (outputs[0],) + outputs[2:] diff --git a/src/adapters/models/xmod/adapter_model.py b/src/adapters/models/xmod/adapter_model.py index 31ca7acd3b..d61578f158 100644 --- a/src/adapters/models/xmod/adapter_model.py +++ b/src/adapters/models/xmod/adapter_model.py @@ -73,7 +73,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.roberta( + outputs, context = self.roberta( input_ids, lang_ids=lang_ids, attention_mask=attention_mask, @@ -87,7 +87,10 @@ def forward( output_adapter_gating_scores=output_adapter_gating_scores, output_adapter_fusion_attentions=output_adapter_fusion_attentions, adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + output_context=True, ) + # required e.g. for prompt tuning in all models + kwargs["context"] = context # BERT & RoBERTa return the pooled output as second item, we don't need that in these heads if not return_dict: head_inputs = (outputs[0],) + outputs[2:] diff --git a/src/adapters/utils.py b/src/adapters/utils.py index 7c52f90ad2..0e3b20cabe 100644 --- a/src/adapters/utils.py +++ b/src/adapters/utils.py @@ -846,11 +846,11 @@ def prefix_attention_mask(attention_mask, dim: int = 3, prefix_value: int = 0): if ( attention_mask is not None and forward_context is not None - and forward_context.prefix_attention_mask_length is not None + and getattr(forward_context, "prompt_tokens_length", None) is not None ): # Create a tensor of ones with the desired shape ones_shape = list(attention_mask.shape) - ones_shape[dim] = forward_context.prefix_attention_mask_length + ones_shape[dim] = forward_context.prompt_tokens_length prefix_attention_mask = torch.full( ones_shape, diff --git a/tests_adapters/methods/test_prompt_tuning.py b/tests_adapters/methods/test_prompt_tuning.py index 764e1980fb..ab7b950661 100644 --- a/tests_adapters/methods/test_prompt_tuning.py +++ b/tests_adapters/methods/test_prompt_tuning.py @@ -1,7 +1,5 @@ -import torch - -from adapters import ADAPTER_MODEL_MAPPING, AutoAdapterModel, PrefixTuningConfig, PromptTuningConfig -from transformers.testing_utils import require_torch, torch_device +from adapters import PromptTuningConfig +from transformers.testing_utils import require_torch from .base import AdapterMethodBaseTestMixin From c574590689b72ef4582e34bdf9aa4abaec5f8fe3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leon=20Engl=C3=A4nder?= Date: Tue, 14 Nov 2023 13:46:31 +0100 Subject: [PATCH 08/13] Fix models that don't support prompt tuning --- src/adapters/model_mixin.py | 11 +++++++---- src/adapters/models/clip/mixin_clip.py | 13 +++++++++++-- .../encoder_decoder/mixin_encoder_decoder.py | 2 ++ src/adapters/models/gpt2/mixin_gpt2.py | 14 ++++++++++++-- src/adapters/models/gptj/mixin_gptj.py | 11 +++++++++-- src/adapters/models/llama/mixin_llama.py | 7 +++++++ src/adapters/models/t5/mixin_t5.py | 6 ++++++ src/adapters/models/t5/modeling_t5.py | 2 +- tests_adapters/test_clip.py | 2 -- tests_adapters/test_encoder_decoder.py | 2 -- tests_adapters/test_gpt2.py | 2 -- tests_adapters/test_gptj.py | 2 -- tests_adapters/test_llama.py | 2 -- tests_adapters/test_t5.py | 2 -- 14 files changed, 55 insertions(+), 23 deletions(-) diff --git a/src/adapters/model_mixin.py b/src/adapters/model_mixin.py index dda4f6dcec..1791bf7a29 100644 --- a/src/adapters/model_mixin.py +++ b/src/adapters/model_mixin.py @@ -368,6 +368,7 @@ class ModelAdaptersMixin(PushAdapterToHubMixin, ABC): """Mixin for transformer models adding support for loading/ saving adapters.""" add_base_adapters = False + support_prompt_tuning = True # If False, the prompt tuning layer is not added to the model. If True, the prompt tuning layer is added if add_base_adapters is True. def __init__(self, config, *args, **kwargs): super().__init__(config, *args, **kwargs) @@ -407,7 +408,8 @@ def init_adapters(self, model_config, adapters_config, add_prefix_tuning_pool=Tr # Add Prompt Tuning if self.add_base_adapters: - self.prompt_tuning = PromptTuningLayer(model_config, self.adapters_config, self.get_input_embeddings()) + if self.support_prompt_tuning: + self.prompt_tuning = PromptTuningLayer(model_config, self.adapters_config, self.get_input_embeddings()) # Initialize adapters from config for adapter_name in self.adapters_config: @@ -993,9 +995,10 @@ def get_adapter(self, name) -> dict: ) and name in self.invertible_adapters: destination[-1]["invertible"] = self.invertible_adapters[name] - prompt_tuning = self.prompt_tuning.get_adapter(name) - if prompt_tuning is not None: - destination[-1]["prompt"] = prompt_tuning + if self.support_prompt_tuning: + prompt_tuning = self.prompt_tuning.get_adapter(name) + if prompt_tuning is not None: + destination[-1]["prompt"] = prompt_tuning # use a custom index to ensure numbering is from 0 to N layers for i, (_, layer) in enumerate(self.iter_layers()): diff --git a/src/adapters/models/clip/mixin_clip.py b/src/adapters/models/clip/mixin_clip.py index 36eae84b0f..0afaf787a0 100644 --- a/src/adapters/models/clip/mixin_clip.py +++ b/src/adapters/models/clip/mixin_clip.py @@ -60,8 +60,16 @@ def hook(module, input): class CLIPTextTransformerAdaptersMixin(InvertibleAdaptersMixin): """Adds adapters to the CLIPTextTransformer module of CLIP.""" - def hook_after_embeddings(self, hook_fn: Callable): - return self.embeddings.register_forward_hook(hook_fn) + def init_adapters(self, model_config, adapters_config): + super().init_adapters(model_config, adapters_config) + + # Register hook for post embedding forward + self.embeddings.register_forward_hook(self.post_embedding_forward) + + def post_embedding_forward(self, module, args, embedding_output): + embedding_output = self.invertible_adapters_forward(embedding_output) + # Prompt tuning not yet supported + return embedding_output class CLIPTextModelAdaptersMixin(EmbeddingAdaptersMixin, InvertibleAdaptersWrapperMixin, ModelBaseAdaptersMixin): @@ -86,6 +94,7 @@ class CLIPModelAdaptersMixin(EmbeddingAdaptersWrapperMixin, InvertibleAdaptersWr """Adds adapters to the CLIPModel class.""" invertible_adapters_base_name = "text_model" + support_prompt_tuning = False def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: for i, layer in enumerate(self.text_model.encoder.layers): diff --git a/src/adapters/models/encoder_decoder/mixin_encoder_decoder.py b/src/adapters/models/encoder_decoder/mixin_encoder_decoder.py index ba2df24a84..50257d1536 100644 --- a/src/adapters/models/encoder_decoder/mixin_encoder_decoder.py +++ b/src/adapters/models/encoder_decoder/mixin_encoder_decoder.py @@ -17,6 +17,8 @@ class EncoderDecoderModelAdaptersMixin( ): """Adds adapters to the EncoderDecoderModel class.""" + support_prompt_tuning = False + def init_adapters(self, model_config, adapters_config): if not isinstance(self.encoder, ModelAdaptersMixin) or not isinstance(self.decoder, ModelAdaptersMixin): return diff --git a/src/adapters/models/gpt2/mixin_gpt2.py b/src/adapters/models/gpt2/mixin_gpt2.py index e86c2967a9..f277823621 100644 --- a/src/adapters/models/gpt2/mixin_gpt2.py +++ b/src/adapters/models/gpt2/mixin_gpt2.py @@ -55,9 +55,19 @@ def init_adapters(self, model_config, adapters_config): class GPT2ModelAdapterMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin): + support_prompt_tuning = False + + def init_adapters(self, model_config, adapters_config): + super().init_adapters(model_config, adapters_config) + + # Register hook for post embedding forward + self.drop.register_forward_hook(self.post_embedding_forward) + def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: for i, layer in enumerate(self.base_model.h): yield i, layer - def hook_after_embeddings(self, hook_fn: Callable): - return self.drop.register_forward_hook(hook_fn) + def post_embedding_forward(self, module, args, embedding_output): + embedding_output = self.invertible_adapters_forward(embedding_output) + # Prompt tuning not yet supported + return embedding_output diff --git a/src/adapters/models/gptj/mixin_gptj.py b/src/adapters/models/gptj/mixin_gptj.py index 333c1b9358..732dfd417e 100644 --- a/src/adapters/models/gptj/mixin_gptj.py +++ b/src/adapters/models/gptj/mixin_gptj.py @@ -38,12 +38,19 @@ def init_adapters(self, model_config, adapters_config): class GPTJModelAdapterMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin): + support_prompt_tuning = False + def init_adapters(self, model_config, adapters_config): super().init_adapters(model_config, adapters_config) + # Register hook for post embedding forward + self.drop.register_forward_hook(self.post_embedding_forward) + def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: for i, layer in enumerate(self.base_model.h): yield i, layer - def hook_after_embeddings(self, hook_fn: Callable): - return self.drop.register_forward_hook(hook_fn) + def post_embedding_forward(self, module, args, embedding_output): + embedding_output = self.invertible_adapters_forward(embedding_output) + # Prompt tuning not yet supported + return embedding_output diff --git a/src/adapters/models/llama/mixin_llama.py b/src/adapters/models/llama/mixin_llama.py index d1ad3ddce3..db2ea68b33 100644 --- a/src/adapters/models/llama/mixin_llama.py +++ b/src/adapters/models/llama/mixin_llama.py @@ -28,6 +28,8 @@ def init_adapters(self, model_config, adapters_config): class LlamaModelAdapterMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin): + support_prompt_tuning = False + def init_adapters(self, model_config, adapters_config): super().init_adapters(model_config, adapters_config) @@ -37,3 +39,8 @@ def init_adapters(self, model_config, adapters_config): def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: for i, layer in enumerate(self.layers): yield i, layer + + def post_embedding_forward(self, module, args, embedding_output): + embedding_output = self.invertible_adapters_forward(embedding_output) + # Prompt tuning not yet supported + return embedding_output diff --git a/src/adapters/models/t5/mixin_t5.py b/src/adapters/models/t5/mixin_t5.py index a5c39acaa6..d1bcfbac4a 100644 --- a/src/adapters/models/t5/mixin_t5.py +++ b/src/adapters/models/t5/mixin_t5.py @@ -82,11 +82,17 @@ def init_adapters(self, model_config, adapters_config): if not self.is_decoder: InvertibleAdaptersMixin.init_adapters(self, self.config, adapters_config) + def post_embedding_forward(self, embedding_output): + embedding_output = self.invertible_adapters_forward(embedding_output) + # Prompt tuning not yet supported + return embedding_output + class T5ModelAdaptersMixin(EmbeddingAdaptersMixin, InvertibleAdaptersWrapperMixin, ModelBaseAdaptersMixin): """Adds adapters to the T5Model class.""" invertible_adapters_base_name = "encoder" + support_prompt_tuning = False def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: global_i = 0 diff --git a/src/adapters/models/t5/modeling_t5.py b/src/adapters/models/t5/modeling_t5.py index 3440a4bb73..a37fed77a3 100644 --- a/src/adapters/models/t5/modeling_t5.py +++ b/src/adapters/models/t5/modeling_t5.py @@ -351,7 +351,7 @@ def forward( hidden_states = self.dropout(inputs_embeds) if not self.is_decoder: - hidden_states = self.invertible_adapters_forward(hidden_states) + hidden_states = self.post_embedding_forward(hidden_states) for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): layer_head_mask = head_mask[i] diff --git a/tests_adapters/test_clip.py b/tests_adapters/test_clip.py index eea58262f8..2ed57268e4 100644 --- a/tests_adapters/test_clip.py +++ b/tests_adapters/test_clip.py @@ -20,7 +20,6 @@ IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, - PromptTuningTestMixin, UniPELTTestMixin, ) from .test_adapter import AdapterTestBase, VisionAdapterTestBase, make_config @@ -79,7 +78,6 @@ class CLIPVisionWithProjectionAdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, - PromptTuningTestMixin, UniPELTTestMixin, AdapterFusionModelTestMixin, CompabilityTestMixin, diff --git a/tests_adapters/test_encoder_decoder.py b/tests_adapters/test_encoder_decoder.py index 1a13d8110f..708a6bfbb2 100644 --- a/tests_adapters/test_encoder_decoder.py +++ b/tests_adapters/test_encoder_decoder.py @@ -12,7 +12,6 @@ IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, - PromptTuningTestMixin, UniPELTTestMixin, ) from .test_adapter import AdapterTestBase @@ -51,7 +50,6 @@ class EncoderDecoderAdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, - PromptTuningTestMixin, UniPELTTestMixin, AdapterFusionModelTestMixin, EncoderDecoderAdapterTestBase, diff --git a/tests_adapters/test_gpt2.py b/tests_adapters/test_gpt2.py index 3ca6783920..620435e532 100644 --- a/tests_adapters/test_gpt2.py +++ b/tests_adapters/test_gpt2.py @@ -11,7 +11,6 @@ IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, - PromptTuningTestMixin, UniPELTTestMixin, ) from .test_adapter import AdapterTestBase, make_config @@ -42,7 +41,6 @@ class GPT2AdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, - PromptTuningTestMixin, UniPELTTestMixin, EmbeddingTestMixin, CompabilityTestMixin, diff --git a/tests_adapters/test_gptj.py b/tests_adapters/test_gptj.py index 5d3bce3dd9..2c2de0dc0c 100644 --- a/tests_adapters/test_gptj.py +++ b/tests_adapters/test_gptj.py @@ -11,7 +11,6 @@ IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, - PromptTuningTestMixin, UniPELTTestMixin, ) from .test_adapter import AdapterTestBase, make_config @@ -45,7 +44,6 @@ class GPTJAdapterTest( LoRATestMixin, UniPELTTestMixin, PrefixTuningTestMixin, - PromptTuningTestMixin, EmbeddingTestMixin, CompabilityTestMixin, AdapterFusionModelTestMixin, diff --git a/tests_adapters/test_llama.py b/tests_adapters/test_llama.py index 9b3ab488cd..2fd455c174 100644 --- a/tests_adapters/test_llama.py +++ b/tests_adapters/test_llama.py @@ -10,7 +10,6 @@ IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, - PromptTuningTestMixin, UniPELTTestMixin, ) from .test_adapter import AdapterTestBase, make_config @@ -43,7 +42,6 @@ class LlamaAdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, - PromptTuningTestMixin, UniPELTTestMixin, EmbeddingTestMixin, AdapterFusionModelTestMixin, diff --git a/tests_adapters/test_t5.py b/tests_adapters/test_t5.py index b2981b6110..c8717d8b54 100644 --- a/tests_adapters/test_t5.py +++ b/tests_adapters/test_t5.py @@ -10,7 +10,6 @@ IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, - PromptTuningTestMixin, UniPELTTestMixin, ) from .test_adapter import AdapterTestBase, make_config @@ -45,7 +44,6 @@ class T5AdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, - PromptTuningTestMixin, UniPELTTestMixin, EmbeddingTestMixin, CompabilityTestMixin, From 2c5934bf6bef1318cab748e43300d6f6082bd47e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leon=20Engl=C3=A4nder?= Date: Tue, 14 Nov 2023 21:38:12 +0100 Subject: [PATCH 09/13] fix _tied_weights_keys --- src/adapters/model_mixin.py | 2 ++ src/adapters/models/bart/adapter_model.py | 6 +++++- src/adapters/models/mbart/adapter_model.py | 6 +++++- src/adapters/models/t5/adapter_model.py | 6 +++++- 4 files changed, 17 insertions(+), 3 deletions(-) diff --git a/src/adapters/model_mixin.py b/src/adapters/model_mixin.py index 1791bf7a29..73caa244d7 100644 --- a/src/adapters/model_mixin.py +++ b/src/adapters/model_mixin.py @@ -370,6 +370,8 @@ class ModelAdaptersMixin(PushAdapterToHubMixin, ABC): add_base_adapters = False support_prompt_tuning = True # If False, the prompt tuning layer is not added to the model. If True, the prompt tuning layer is added if add_base_adapters is True. + _tied_weights_keys = ["prompt_tuning.base_model_embeddings.*"] + def __init__(self, config, *args, **kwargs): super().__init__(config, *args, **kwargs) diff --git a/src/adapters/models/bart/adapter_model.py b/src/adapters/models/bart/adapter_model.py index f6fc12a587..6b3b053917 100644 --- a/src/adapters/models/bart/adapter_model.py +++ b/src/adapters/models/bart/adapter_model.py @@ -26,7 +26,11 @@ "BART Model with the option to add multiple flexible prediction heads on top.", BART_START_DOCSTRING ) class BartAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, BartPretrainedModel): - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + _tied_weights_keys = [ + "encoder.embed_tokens.weight", + "decoder.embed_tokens.weight", + "prompt_tuning.base_model_embeddings.*", + ] def __init__(self, config: BartConfig, **kwargs): super().__init__(config, **kwargs) diff --git a/src/adapters/models/mbart/adapter_model.py b/src/adapters/models/mbart/adapter_model.py index 58bfc7bdb9..79a1a746d7 100644 --- a/src/adapters/models/mbart/adapter_model.py +++ b/src/adapters/models/mbart/adapter_model.py @@ -26,7 +26,11 @@ "MBART Model with the option to add multiple flexible prediction heads on top.", MBART_START_DOCSTRING ) class MBartAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, MBartPreTrainedModel): - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + _tied_weights_keys = [ + "encoder.embed_tokens.weight", + "decoder.embed_tokens.weight", + "prompt_tuning.base_model_embeddings.*", + ] def __init__(self, config: MBartConfig, **kwargs): super().__init__(config, **kwargs) diff --git a/src/adapters/models/t5/adapter_model.py b/src/adapters/models/t5/adapter_model.py index 942b86fd3c..685046619d 100644 --- a/src/adapters/models/t5/adapter_model.py +++ b/src/adapters/models/t5/adapter_model.py @@ -22,7 +22,11 @@ @add_start_docstrings("T5 Model with the option to add multiple flexible prediction heads on top.", T5_START_DOCSTRING) class T5AdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, T5PreTrainedModel): - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + _tied_weights_keys = [ + "encoder.embed_tokens.weight", + "decoder.embed_tokens.weight", + "prompt_tuning.base_model_embeddings.*", + ] _keys_to_ignore_on_load_unexpected = [ r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", From 7368121bfa4fe867412185a5251f3843dd76fb9a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leon=20Engl=C3=A4nder?= Date: Tue, 14 Nov 2023 21:45:01 +0100 Subject: [PATCH 10/13] remove unnecessary imports --- src/adapters/models/clip/mixin_clip.py | 2 +- src/adapters/models/gpt2/mixin_gpt2.py | 2 +- src/adapters/models/gptj/mixin_gptj.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/adapters/models/clip/mixin_clip.py b/src/adapters/models/clip/mixin_clip.py index 0afaf787a0..12d48e999d 100644 --- a/src/adapters/models/clip/mixin_clip.py +++ b/src/adapters/models/clip/mixin_clip.py @@ -1,4 +1,4 @@ -from typing import Callable, Iterable, Tuple +from typing import Iterable, Tuple import torch.nn as nn diff --git a/src/adapters/models/gpt2/mixin_gpt2.py b/src/adapters/models/gpt2/mixin_gpt2.py index f277823621..00dded8961 100644 --- a/src/adapters/models/gpt2/mixin_gpt2.py +++ b/src/adapters/models/gpt2/mixin_gpt2.py @@ -1,4 +1,4 @@ -from typing import Callable, Iterable, Tuple +from typing import Iterable, Tuple import torch.nn as nn diff --git a/src/adapters/models/gptj/mixin_gptj.py b/src/adapters/models/gptj/mixin_gptj.py index 732dfd417e..2fbffc3923 100644 --- a/src/adapters/models/gptj/mixin_gptj.py +++ b/src/adapters/models/gptj/mixin_gptj.py @@ -1,4 +1,4 @@ -from typing import Callable, Iterable, Tuple +from typing import Iterable, Tuple import torch.nn as nn From 058d6b3065d984021ee0db8a23c7198ed2433ce1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leon=20Engl=C3=A4nder?= Date: Wed, 15 Nov 2023 00:17:10 +0100 Subject: [PATCH 11/13] Fix tests & add average_adapter for prompt tuning --- src/adapters/configuration/adapter_config.py | 32 ++++++++++++-------- src/adapters/methods/prompt_tuning.py | 28 +++++++++++++++-- src/adapters/model_mixin.py | 1 - src/adapters/models/clip/adapter_model.py | 2 ++ src/adapters/models/clip/mixin_clip.py | 3 ++ src/adapters/models/gpt2/adapter_model.py | 2 ++ src/adapters/models/gptj/adapter_model.py | 2 ++ src/adapters/models/llama/adapter_model.py | 2 ++ src/adapters/models/t5/adapter_model.py | 1 - tests_adapters/test_adapter_config.py | 1 + 10 files changed, 56 insertions(+), 18 deletions(-) diff --git a/src/adapters/configuration/adapter_config.py b/src/adapters/configuration/adapter_config.py index 0c089dd28e..b4aafd364f 100644 --- a/src/adapters/configuration/adapter_config.py +++ b/src/adapters/configuration/adapter_config.py @@ -399,23 +399,29 @@ class PrefixTuningConfig(AdapterConfigBase): @dataclass(eq=False) class PromptTuningConfig(AdapterConfigBase): - # TODO: documentation """ The Prompt Tuning architecture proposed by Lester et al. (2021). See https://arxiv.org/pdf/2104.08691.pdf - Args:""" - - prompt_length: int - - prompt_init_text: Optional[str] = None # only necessary when using prompt_init="from_string" - architecture: Optional[str] = "prompt_tuning" + Args: + prompt_length (int): The number of tokens in the prompt. + Defaults to 10. + prompt_init (str): The initialization method for the prompt. Can be either "random_uniform" or "from_string". + Defaults to "random_uniform". + prompt_init_text (str): The text to use for prompt initialization if prompt_init="from_string". + random_uniform_scale (float): The scale of the random uniform initialization if prompt_init="random_uniform". + Defaults to 0.5 as in the paper. + combine (str): + The method used to combine the prompt with the input. Can be either "prefix" or "prefix_after_bos". + Defaults to "prefix". + """ - prompt_init: str = ( # random_uniform, from_string, from_array, TODO: ? add more from https://github.com/google-research/prompt-tuning/blob/main/prompt_tuning/prompts.py - "random_uniform" - ) - combine: str = "prefix" # prefix, prefix_after_bos, suffix + architecture: str = "prompt_tuning" - # TODO: add a parameter for the random uniform scale + prompt_length: int = 10 + prompt_init: str = "random_uniform" + prompt_init_text: Optional[str] = None + random_uniform_scale = 0.5 + combine: str = "prefix" @dataclass(eq=False) @@ -635,7 +641,7 @@ def __init__( "compacter": CompacterConfig(), "prefix_tuning": PrefixTuningConfig(), "prefix_tuning_flat": PrefixTuningConfig(flat=True), - "prompt_tuning": PromptTuningConfig(prompt_length=10), # TODO: is that alright? + "prompt_tuning": PromptTuningConfig(), "lora": LoRAConfig(), "ia3": IA3Config(), "mam": MAMConfig(), diff --git a/src/adapters/methods/prompt_tuning.py b/src/adapters/methods/prompt_tuning.py index 5fd008e133..fd9db60e0a 100644 --- a/src/adapters/methods/prompt_tuning.py +++ b/src/adapters/methods/prompt_tuning.py @@ -67,8 +67,11 @@ def __init__( def _init_prompt_embedding(self, base_model_embeddings: nn.Module) -> None: if self.prompt_tuning_config.prompt_init == "random_uniform": - # Embedding was created using torch.nn.Embedding which already uses a random uniform distribution for initialization - pass + nn.init.uniform_( + self.prompt_embedding.weight, + a=-self.prompt_tuning_config.random_uniform_scale, + b=self.prompt_tuning_config.random_uniform_scale, + ) elif self.prompt_tuning_config.prompt_init == "from_string": tokenizer = AutoTokenizer.from_pretrained(self.model_config.tokenizer_name_or_path) @@ -160,7 +163,26 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: return False def average_adapter(self, adapter_name: str, input_adapters: Dict[str, float]) -> bool: - pass # TODO: implement + # add new adapter + if self.add_adapter(adapter_name, -1): + # average weights + avg_state_dict = {} + for name, weight in input_adapters.items(): + if name in self.prompt_tunings: + module = self.prompt_tunings[name] + for k, v in module.state_dict().items(): + if k in avg_state_dict: + avg_state_dict[k] += weight * v + else: + avg_state_dict[k] = weight * v + else: + self.delete_adapter(adapter_name) # clean up before raising error + raise ValueError("Adapter {} not found.".format(name)) + # load averaged weights + self.prompt_tunings[adapter_name].load_state_dict(avg_state_dict) + return True + + return False def delete_adapter(self, adapter_name: str): if adapter_name in self.prompt_tunings: diff --git a/src/adapters/model_mixin.py b/src/adapters/model_mixin.py index 73caa244d7..b41a75cedd 100644 --- a/src/adapters/model_mixin.py +++ b/src/adapters/model_mixin.py @@ -369,7 +369,6 @@ class ModelAdaptersMixin(PushAdapterToHubMixin, ABC): add_base_adapters = False support_prompt_tuning = True # If False, the prompt tuning layer is not added to the model. If True, the prompt tuning layer is added if add_base_adapters is True. - _tied_weights_keys = ["prompt_tuning.base_model_embeddings.*"] def __init__(self, config, *args, **kwargs): diff --git a/src/adapters/models/clip/adapter_model.py b/src/adapters/models/clip/adapter_model.py index 39382757e5..5aa15d417b 100644 --- a/src/adapters/models/clip/adapter_model.py +++ b/src/adapters/models/clip/adapter_model.py @@ -18,6 +18,8 @@ @add_start_docstrings(CLIP_START_DOCSTRING) class CLIPAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, CLIPPreTrainedModel): + _tied_weights_keys = [] # needs to be empty since CLIP does not yet support prompt tuning + def __init__(self, config): super().__init__(config) diff --git a/src/adapters/models/clip/mixin_clip.py b/src/adapters/models/clip/mixin_clip.py index 12d48e999d..8a269a590b 100644 --- a/src/adapters/models/clip/mixin_clip.py +++ b/src/adapters/models/clip/mixin_clip.py @@ -76,6 +76,7 @@ class CLIPTextModelAdaptersMixin(EmbeddingAdaptersMixin, InvertibleAdaptersWrapp """Adds adapters to the CLIPTextModel class.""" invertible_adapters_base_name = "text_model" + support_prompt_tuning = False def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: for i, layer in enumerate(self.text_model.encoder.layers): @@ -85,6 +86,8 @@ def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: class CLIPVisionModelAdaptersMixin(ModelBaseAdaptersMixin): """Adds adapters to the a CLIPVisionModel class.""" + support_prompt_tuning = False + def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: for i, layer in enumerate(self.vision_model.encoder.layers): yield i, layer diff --git a/src/adapters/models/gpt2/adapter_model.py b/src/adapters/models/gpt2/adapter_model.py index 15501a9315..cc5709b53e 100644 --- a/src/adapters/models/gpt2/adapter_model.py +++ b/src/adapters/models/gpt2/adapter_model.py @@ -33,6 +33,8 @@ GPT2_START_DOCSTRING, ) class GPT2AdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, GPT2PreTrainedModel): + _tied_weights_keys = [] # needs to be empty since GPT2 does not yet support prompt tuning + def __init__(self, config): super().__init__(config) self.transformer = GPT2Model(config) diff --git a/src/adapters/models/gptj/adapter_model.py b/src/adapters/models/gptj/adapter_model.py index 46b5844171..a4bd8f32a1 100644 --- a/src/adapters/models/gptj/adapter_model.py +++ b/src/adapters/models/gptj/adapter_model.py @@ -33,6 +33,8 @@ GPTJ_START_DOCSTRING, ) class GPTJAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, GPTJPreTrainedModel): + _tied_weights_keys = [] # needs to be empty since GPT-J does not yet support prompt tuning + def __init__(self, config): super().__init__(config) self.transformer = GPTJModel(config) diff --git a/src/adapters/models/llama/adapter_model.py b/src/adapters/models/llama/adapter_model.py index 1ff84e9c66..7b9ce69083 100644 --- a/src/adapters/models/llama/adapter_model.py +++ b/src/adapters/models/llama/adapter_model.py @@ -32,6 +32,8 @@ LLAMA_START_DOCSTRING, ) class LlamaAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, LlamaPreTrainedModel): + _tied_weights_keys = [] # needs to be empty since LLaMA does not yet support prompt tuning + def __init__(self, config): super().__init__(config) self.model = LlamaModel(config) diff --git a/src/adapters/models/t5/adapter_model.py b/src/adapters/models/t5/adapter_model.py index 685046619d..d981815bd9 100644 --- a/src/adapters/models/t5/adapter_model.py +++ b/src/adapters/models/t5/adapter_model.py @@ -25,7 +25,6 @@ class T5AdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdapte _tied_weights_keys = [ "encoder.embed_tokens.weight", "decoder.embed_tokens.weight", - "prompt_tuning.base_model_embeddings.*", ] _keys_to_ignore_on_load_unexpected = [ diff --git a/tests_adapters/test_adapter_config.py b/tests_adapters/test_adapter_config.py index fe47c0c25b..302fe8f697 100644 --- a/tests_adapters/test_adapter_config.py +++ b/tests_adapters/test_adapter_config.py @@ -33,6 +33,7 @@ def test_config_immutable(self): def set_attr(config: AdapterConfigBase): config.non_linearity = "dummy" config.r = -1 # for LoRA + config.prompt_length = -1 # for PromptTuning for config in ADAPTER_CONFIG_MAP.values(): if isinstance(config, ConfigUnion): From d4a817ca5cd4b9bc0a82f0f11a1cddedb2b25603 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leon=20Engl=C3=A4nder?= Date: Sun, 19 Nov 2023 18:08:54 +0100 Subject: [PATCH 12/13] Renamed tests and remove leftover TODOs --- tests_adapters/methods/test_prompt_tuning.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/tests_adapters/methods/test_prompt_tuning.py b/tests_adapters/methods/test_prompt_tuning.py index ab7b950661..d0b12d259c 100644 --- a/tests_adapters/methods/test_prompt_tuning.py +++ b/tests_adapters/methods/test_prompt_tuning.py @@ -8,11 +8,7 @@ class PromptTuningTestMixin(AdapterMethodBaseTestMixin): def test_add_prompt_tuning(self): model = self.get_model() - self.run_add_test( - model, PromptTuningConfig(prompt_length=10), ["prompt_tunings.{name}."] - ) # TODO: provide parameters in PromptTuningConfig(...) ? - - # TODO: add tests to add different configs (like initialization [random_uniform, from_array, ...] or prefix_prompt vs prefix_prompt_after_bos + self.run_add_test(model, PromptTuningConfig(prompt_length=10), ["prompt_tunings.{name}."]) def test_average_prompt_tuning(self): model = self.get_model() @@ -24,9 +20,7 @@ def test_delete_prompt_tuning(self): def test_get_prompt_tuning(self): model = self.get_model() - self.run_get_test( - model, PromptTuningConfig(prompt_length=10), 1 - ) # TODO: last number is number of layers. Is this really 1? + self.run_get_test(model, PromptTuningConfig(prompt_length=10), 1) def test_forward_prompt_tuning(self): model = self.get_model() @@ -35,8 +29,8 @@ def test_forward_prompt_tuning(self): def test_load_prompt_tuning(self): self.run_load_test(PromptTuningConfig(prompt_length=10)) - def test_load_full_model_prefix_tuning(self): + def test_load_full_model_prompt_tuning(self): self.run_full_model_load_test(PromptTuningConfig(prompt_length=10)) - def test_train_prefix_tuning(self): + def test_train_prompt_tuning(self): self.run_train_test(PromptTuningConfig(prompt_length=10), ["prompt_tunings.{name}."]) From fdcc9c131a77066c9a666c2fb84bd0857568c6bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leon=20Engl=C3=A4nder?= Date: Sun, 19 Nov 2023 19:35:04 +0100 Subject: [PATCH 13/13] Remove Bart and MBart support --- docs/model_overview.md | 4 ++-- src/adapters/models/bart/adapter_model.py | 1 - src/adapters/models/bart/mixin_bart.py | 6 ++++++ src/adapters/models/bart/modeling_bart.py | 2 -- src/adapters/models/mbart/adapter_model.py | 1 - src/adapters/models/mbart/modeling_mbart.py | 2 -- tests_adapters/test_bart.py | 2 -- tests_adapters/test_mbart.py | 2 -- 8 files changed, 8 insertions(+), 12 deletions(-) diff --git a/docs/model_overview.md b/docs/model_overview.md index f2e7712e08..a5ba7c4e8c 100644 --- a/docs/model_overview.md +++ b/docs/model_overview.md @@ -13,7 +13,7 @@ The table below further shows which model architectures support which adaptation | Model | (Bottleneck)
Adapters | Prefix
Tuning | LoRA | Compacter | Adapter
Fusion | Invertible
Adapters | Parallel
block | Prompt
Tuning | | --------------------------------------- | -| - | - | - | - | - | - |- | | [ALBERT](classes/models/albert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [BART](classes/models/bart.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [BART](classes/models/bart.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | [BEIT](classes/models/beit.html) | ✅ | ✅ | ✅ | ✅ | ✅ | | | ✅ | | [BERT-Generation](classes/models/bert-generation.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [BERT](classes/models/bert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | @@ -26,7 +26,7 @@ The table below further shows which model architectures support which adaptation | [GPT-2](classes/models/gpt2.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | [GPT-J](classes/models/gptj.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | [Llama](classes/models/llama.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | -| [MBart](classes/models/mbart.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [MBart](classes/models/mbart.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | [RoBERTa](classes/models/roberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [T5](classes/models/t5.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | [ViT](classes/models/vit.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | diff --git a/src/adapters/models/bart/adapter_model.py b/src/adapters/models/bart/adapter_model.py index 6b3b053917..ddb94e6fe9 100644 --- a/src/adapters/models/bart/adapter_model.py +++ b/src/adapters/models/bart/adapter_model.py @@ -29,7 +29,6 @@ class BartAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdap _tied_weights_keys = [ "encoder.embed_tokens.weight", "decoder.embed_tokens.weight", - "prompt_tuning.base_model_embeddings.*", ] def __init__(self, config: BartConfig, **kwargs): diff --git a/src/adapters/models/bart/mixin_bart.py b/src/adapters/models/bart/mixin_bart.py index b63a28233d..d269d72b43 100644 --- a/src/adapters/models/bart/mixin_bart.py +++ b/src/adapters/models/bart/mixin_bart.py @@ -76,6 +76,7 @@ class BartModelAdaptersMixin(EmbeddingAdaptersMixin, InvertibleAdaptersWrapperMi """Adds adapters to the BartModel class.""" invertible_adapters_base_name = "encoder" + support_prompt_tuning = False def init_adapters(self, model_config, adapters_config): super().init_adapters(model_config, adapters_config) @@ -91,6 +92,11 @@ def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: for i, layer in enumerate(self.decoder.layers): yield i, layer + def post_embedding_forward(self, module, args, embedding_output): + embedding_output = self.invertible_adapters_forward(embedding_output) + # Prompt tuning not yet supported + return embedding_output + class BartDecoderWrapperAdaptersMixin(EmbeddingAdaptersWrapperMixin, ModelBaseAdaptersMixin): """Adds adapters to the BartDecoderWrapper class.""" diff --git a/src/adapters/models/bart/modeling_bart.py b/src/adapters/models/bart/modeling_bart.py index 796c6c36b5..28bf37bd7c 100644 --- a/src/adapters/models/bart/modeling_bart.py +++ b/src/adapters/models/bart/modeling_bart.py @@ -22,7 +22,6 @@ from transformers.models.bart.modeling_bart import BartAttention, BartDecoderLayer, BartEncoderLayer from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_, match_attn_matrices_for_parallel -from ...utils import prefix_attention_mask from .mixin_bart import BartAttentionAdaptersMixin, BartDecoderLayerAdaptersMixin, BartEncoderLayerAdaptersMixin @@ -182,7 +181,6 @@ def forward( returned tensors for more detail. """ adjust_tensors_for_parallel_(hidden_states, attention_mask) - attention_mask = prefix_attention_mask(attention_mask, prefix_value=1) # type: ignore residual = hidden_states hidden_states, attn_weights, _ = self.self_attn( diff --git a/src/adapters/models/mbart/adapter_model.py b/src/adapters/models/mbart/adapter_model.py index 79a1a746d7..5b57eb2cb0 100644 --- a/src/adapters/models/mbart/adapter_model.py +++ b/src/adapters/models/mbart/adapter_model.py @@ -29,7 +29,6 @@ class MBartAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAda _tied_weights_keys = [ "encoder.embed_tokens.weight", "decoder.embed_tokens.weight", - "prompt_tuning.base_model_embeddings.*", ] def __init__(self, config: MBartConfig, **kwargs): diff --git a/src/adapters/models/mbart/modeling_mbart.py b/src/adapters/models/mbart/modeling_mbart.py index cf884dfa13..0f8f0d5335 100644 --- a/src/adapters/models/mbart/modeling_mbart.py +++ b/src/adapters/models/mbart/modeling_mbart.py @@ -22,7 +22,6 @@ from transformers.models.mbart.modeling_mbart import MBartAttention, MBartDecoderLayer, MBartEncoderLayer from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_, match_attn_matrices_for_parallel -from ...utils import prefix_attention_mask from ..bart.mixin_bart import BartAttentionAdaptersMixin, BartDecoderLayerAdaptersMixin, BartEncoderLayerAdaptersMixin @@ -182,7 +181,6 @@ def forward( returned tensors for more detail. """ adjust_tensors_for_parallel_(hidden_states, attention_mask) - attention_mask = prefix_attention_mask(attention_mask, prefix_value=1) # type: ignore residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) diff --git a/tests_adapters/test_bart.py b/tests_adapters/test_bart.py index f9fcfbb83a..e40c9df521 100644 --- a/tests_adapters/test_bart.py +++ b/tests_adapters/test_bart.py @@ -11,7 +11,6 @@ IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, - PromptTuningTestMixin, UniPELTTestMixin, ) from .test_adapter import AdapterTestBase, make_config @@ -44,7 +43,6 @@ class BartAdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, - PromptTuningTestMixin, UniPELTTestMixin, AdapterFusionModelTestMixin, CompabilityTestMixin, diff --git a/tests_adapters/test_mbart.py b/tests_adapters/test_mbart.py index 5f726aa9be..775e1fdebb 100644 --- a/tests_adapters/test_mbart.py +++ b/tests_adapters/test_mbart.py @@ -10,7 +10,6 @@ IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, - PromptTuningTestMixin, UniPELTTestMixin, ) from .test_adapter import AdapterTestBase, make_config @@ -42,7 +41,6 @@ class MBartAdapterTest( IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin, - PromptTuningTestMixin, UniPELTTestMixin, AdapterFusionModelTestMixin, PredictionHeadModelTestMixin,