From a6055d0b2c05938e4dfcc6c9429ad19dbc670a24 Mon Sep 17 00:00:00 2001 From: calpt Date: Wed, 15 Nov 2023 17:34:28 +0100 Subject: [PATCH] Fix training T5 adapter models with Trainer (#599) --- src/adapters/models/__init__.py | 11 ++++++++--- src/adapters/models/t5/mixin_t5.py | 30 +++++++++++++++++++++++++++--- 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/src/adapters/models/__init__.py b/src/adapters/models/__init__.py index 11da5d325e..dd48552d23 100644 --- a/src/adapters/models/__init__.py +++ b/src/adapters/models/__init__.py @@ -18,7 +18,12 @@ from .gpt2.mixin_gpt2 import GPT2ModelAdapterMixin from .gptj.mixin_gptj import GPTJMLPAdaptersMixin, GPTJModelAdapterMixin from .llama.mixin_llama import LlamaModelAdapterMixin -from .t5.mixin_t5 import T5BlockAdaptersMixin, T5ModelAdaptersMixin, T5ModelAdaptersWithHeadsMixin +from .t5.mixin_t5 import ( + T5BlockAdaptersMixin, + T5ForCondiditionalGenerationWithHeadsMixin, + T5ForQuestionAnsweringWithHeadsMixin, + T5ModelAdaptersMixin, +) from .vit.mixin_vit import ViTIntermediateAdaptersMixin, ViTModelAdaptersMixin from .xmod.mixin_xmod import XmodModelAdaptersMixin @@ -57,8 +62,8 @@ "RobertaModel": BertModelAdaptersMixin, "T5Block": T5BlockAdaptersMixin, "T5Model": T5ModelAdaptersMixin, - "T5ForConditionalGeneration": T5ModelAdaptersWithHeadsMixin, - "T5ForQuestionAnswering": T5ModelAdaptersWithHeadsMixin, + "T5ForConditionalGeneration": T5ForCondiditionalGenerationWithHeadsMixin, + "T5ForQuestionAnswering": T5ForQuestionAnsweringWithHeadsMixin, "T5EncoderModel": T5ModelAdaptersMixin, "ViTIntermediate": ViTIntermediateAdaptersMixin, "ViTModel": ViTModelAdaptersMixin, diff --git a/src/adapters/models/t5/mixin_t5.py b/src/adapters/models/t5/mixin_t5.py index a5c39acaa6..832dfd185d 100644 --- a/src/adapters/models/t5/mixin_t5.py +++ b/src/adapters/models/t5/mixin_t5.py @@ -1,5 +1,6 @@ -from typing import Iterable, Tuple +from typing import Iterable, Optional, Tuple +import torch import torch.nn as nn from ...methods.bottleneck import BottleneckLayer @@ -99,5 +100,28 @@ def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: yield i, layer -class T5ModelAdaptersWithHeadsMixin(ModelWithHeadsAdaptersMixin, T5ModelAdaptersMixin): - pass +# Stating "labels" and "input_ids" explicitly is required for training using Trainer class +class T5ForCondiditionalGenerationWithHeadsMixin(ModelWithHeadsAdaptersMixin, T5ModelAdaptersMixin): + def forward( + self, + *args, + input_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs, + ): + return super().forward(*args, input_ids=input_ids, labels=labels, **kwargs) + + +# Stating "start_positions"/"end_positions" and "input_ids" explicitly is required for training using Trainer class +class T5ForQuestionAnsweringWithHeadsMixin(ModelWithHeadsAdaptersMixin, T5ModelAdaptersMixin): + def forward( + self, + *args, + input_ids: Optional[torch.LongTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + **kwargs, + ): + return super().forward( + *args, input_ids=input_ids, start_positions=start_positions, end_positions=end_positions, **kwargs + )