Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Prompt Tuning #595

Merged
merged 15 commits into from
Nov 19, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ instance/

# Sphinx documentation
docs/_build/
docs/_build/
adapter_docs/_build/

# PyBuilder
target/
Expand Down
7 changes: 7 additions & 0 deletions docs/classes/adapter_config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ IA3Config
:members:
:inherited-members: Mapping

PromptTuningConfig
~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: adapters.PromptTuningConfig
:members:
:inherited-members: Mapping

Combined configurations
~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
28 changes: 28 additions & 0 deletions docs/methods.md
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,31 @@ model.reset_adapter()

_Papers:_
- [Few-Shot Parameter-Efficient Fine-Tuning is Better and Cheaper than In-Context Learning](https://arxiv.org/pdf/2205.05638.pdf) (Liu et al., 2022)

## Prompt Tuning
Prompt Tuning is an efficient fine-tuning technique proposed by Lester et al. (2021). Prompt tuning adds tunable tokens, called soft-prompts, that are prepended to the input text.
First, the input sequence ${x_1, x_2, \dots, x_n }$ gets embedded, resulting in the matrix $X_e \in \mathbb{R}^{n \times e}$ where $e$ is the dimension of
the embedding space. The soft-prompts with length $p$ are represented as $P_e \in \mathbb{R}^{p \times e}$.
$P_e$ and $X_e$ get concatenated, forming the input of the following encoder or decoder:

$$
\left[P_e; X_e\right] \in \mathbb{R}^{\left(p + n\right) \times e}
$$

The `PromptTuningConfig` has the properties:
- `prompt_length`: to set the soft-prompts length $p$
- `prompt_init`: to set the weight initialisation method, which is either "random_uniform" or "from_string" to initialize each prompt token with an embedding drawn from the model’s vocabulary.
- `prompt_init_text` as the text use for initialisation if `prompt_init="from_string"`
- `combine`: To define if the prefix should be added before the embedded input sequence or after the BOS token

To add Prompt Tuning to your model, you can use the predefined configs:
```python
from adapters import PromptTuningConfig

config = PromptTuningConfig(prompt_length=10)
model.add_adapter("dummy", config=config)
```

_Papers:_
- [The Power of Scale for Parameter-Efficient Prompt Tuning](https://aclanthology.org/2021.emnlp-main.243/) (Lester et al., 2021)

44 changes: 22 additions & 22 deletions docs/model_overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,28 @@ The table below further shows which model architectures support which adaptation
E.g., for BERT, this means adapters provides a ``BertAdapterModel`` class, but you can also use ``BertModel``, ``BertForSequenceClassification`` etc. together with adapters.
```

| Model | (Bottleneck)<br> Adapters | Prefix<br> Tuning | LoRA | Compacter | Adapter<br> Fusion | Invertible<br> Adapters | Parallel<br> block |
| --------------------------------------- | -| - | - | - | - | - | - |
| [ALBERT](classes/models/albert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [BART](classes/models/bart.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [BEIT](classes/models/beit.html) | ✅ | ✅ | ✅ | ✅ | ✅ | | |
| [BERT-Generation](classes/models/bert-generation.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [BERT](classes/models/bert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [CLIP](classes/models/clip.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | |
| [DeBERTa](classes/models/deberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [DeBERTa-v2](classes/models/debertaV2.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [DistilBERT](classes/models/distilbert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [Electra](classes/models/electra.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [Encoder Decoder](classes/models/encoderdecoder.html) | (*) | (*) | (*) | (*) | (*) | (*) | |
| [GPT-2](classes/models/gpt2.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [GPT-J](classes/models/gptj.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [Llama](classes/models/llama.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [MBart](classes/models/mbart.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [RoBERTa](classes/models/roberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [T5](classes/models/t5.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [ViT](classes/models/vit.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [XLM-RoBERTa](classes/models/xlmroberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [X-MOD](classes/models/xmod.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| Model | (Bottleneck)<br> Adapters | Prefix<br> Tuning | LoRA | Compacter | Adapter<br> Fusion | Invertible<br> Adapters | Parallel<br> block | Prompt<br> Tuning |
| --------------------------------------- | -| - | - | - | - | - | - |- |
| [ALBERT](classes/models/albert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [BART](classes/models/bart.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [BEIT](classes/models/beit.html) | ✅ | ✅ | ✅ | ✅ | ✅ | | | ✅ |
| [BERT-Generation](classes/models/bert-generation.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [BERT](classes/models/bert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [CLIP](classes/models/clip.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | |
| [DeBERTa](classes/models/deberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [DeBERTa-v2](classes/models/debertaV2.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [DistilBERT](classes/models/distilbert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [Electra](classes/models/electra.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [Encoder Decoder](classes/models/encoderdecoder.html) | (*) | (*) | (*) | (*) | (*) | (*) | | |
| [GPT-2](classes/models/gpt2.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | |
| [GPT-J](classes/models/gptj.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | |
| [Llama](classes/models/llama.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | |
| [MBart](classes/models/mbart.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [RoBERTa](classes/models/roberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [T5](classes/models/t5.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | |
| [ViT](classes/models/vit.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [XLM-RoBERTa](classes/models/xlmroberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [X-MOD](classes/models/xmod.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |

(*) If the used encoder and decoder model class are supported.

Expand Down
2 changes: 2 additions & 0 deletions src/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
"ModelAdaptersConfig",
"ParBnConfig",
"PrefixTuningConfig",
"PromptTuningConfig",
"SeqBnConfig",
"SeqBnInvConfig",
"StaticAdapterFusionConfig",
Expand Down Expand Up @@ -161,6 +162,7 @@
ModelAdaptersConfig,
ParBnConfig,
PrefixTuningConfig,
PromptTuningConfig,
SeqBnConfig,
SeqBnInvConfig,
StaticAdapterFusionConfig,
Expand Down
30 changes: 30 additions & 0 deletions src/adapters/configuration/adapter_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ def _get_config_class(config_dict):
cls_new = LoRAConfig
elif architecture == "union":
cls_new = ConfigUnion
elif architecture == "prompt_tuning":
cls_new = PromptTuningConfig
else:
cls_new = BnConfig

Expand Down Expand Up @@ -395,6 +397,33 @@ class PrefixTuningConfig(AdapterConfigBase):
shared_gating: bool = True


@dataclass(eq=False)
class PromptTuningConfig(AdapterConfigBase):
"""
The Prompt Tuning architecture proposed by Lester et al. (2021). See https://arxiv.org/pdf/2104.08691.pdf

Args:
prompt_length (int): The number of tokens in the prompt.
Defaults to 10.
prompt_init (str): The initialization method for the prompt. Can be either "random_uniform" or "from_string".
Defaults to "random_uniform".
prompt_init_text (str): The text to use for prompt initialization if prompt_init="from_string".
random_uniform_scale (float): The scale of the random uniform initialization if prompt_init="random_uniform".
Defaults to 0.5 as in the paper.
combine (str):
The method used to combine the prompt with the input. Can be either "prefix" or "prefix_after_bos".
Defaults to "prefix".
"""

architecture: str = "prompt_tuning"

prompt_length: int = 10
prompt_init: str = "random_uniform"
prompt_init_text: Optional[str] = None
random_uniform_scale = 0.5
combine: str = "prefix"


@dataclass(eq=False)
class LoRAConfig(AdapterConfigBase):
"""
Expand Down Expand Up @@ -612,6 +641,7 @@ def __init__(
"compacter": CompacterConfig(),
"prefix_tuning": PrefixTuningConfig(),
"prefix_tuning_flat": PrefixTuningConfig(flat=True),
"prompt_tuning": PromptTuningConfig(),
"lora": LoRAConfig(),
"ia3": IA3Config(),
"mam": MAMConfig(),
Expand Down
19 changes: 17 additions & 2 deletions src/adapters/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,13 @@ class ForwardContext:
# thread-local storage that holds a stack of active contexts
storage = threading.local()

context_attributes = ["adapter_gating_scores", "adapter_fusion_attentions", "adapter_input_parallelized"]
context_attributes = [
"adapter_gating_scores",
"adapter_fusion_attentions",
"adapter_input_parallelized",
]
# Additional used attributes not exposed to the user
# - prompt_tokens_length: length of the prompt tokens

def __init__(self, model, *args, **kwargs):
# If the model has a method ``forward_context()``, use it to create the context.
Expand All @@ -102,6 +108,8 @@ def wrap(cls, f):
def wrapper_func(self, *args, **kwargs):
if self.adapters_config is not None:
with cls(self, *args, **kwargs) as ctx:
# whether to output the context attributes
output_context = kwargs.pop("output_context", False)
kwargs = {
k: v for k, v in kwargs.items() if k.replace("output_", "") not in cls.context_attributes
}
Expand All @@ -116,7 +124,14 @@ def wrapper_func(self, *args, **kwargs):
for attr in cls.context_attributes:
if getattr(ctx, "output_" + attr, False):
results[attr] = dict(getattr(ctx, attr))
return results

if output_context:
context_dict = ctx.__dict__

if output_context:
return results, context_dict
else:
return results
else:
return f(self, *args, **kwargs)

Expand Down
28 changes: 27 additions & 1 deletion src/adapters/heads/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,19 @@ def forward(self, outputs, cls_output=None, attention_mask=None, return_dict=Fal
labels = kwargs.pop("labels", None)
if labels is not None:
loss_fct = CrossEntropyLoss()
# adjust labels for prompt tuning
if kwargs.get("prompt_tokens_length", 0) > 0:
prompt_length = kwargs.get("prompt_tokens_length")
prompt_labels = torch.full(
(labels.shape[0], prompt_length), loss_fct.ignore_index, dtype=torch.long, device=labels.device
)
labels = torch.cat((prompt_labels, labels), dim=-1)
if attention_mask is not None:
attention_mask = torch.cat(
(torch.ones_like(prompt_labels, dtype=torch.long, device=labels.device), attention_mask),
dim=-1,
)

# Only keep active parts of the loss
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
Expand Down Expand Up @@ -752,7 +765,14 @@ def _get_used_heads(self, head_name: str = None):
return head_modules

def forward_head(
self, all_outputs, head_name=None, cls_output=None, attention_mask=None, return_dict=False, **kwargs
self,
all_outputs,
head_name=None,
cls_output=None,
attention_mask=None,
return_dict=False,
context=None,
**kwargs
):
"""
The forward pass through a prediction head configuration. There are three ways to specify the used prediction
Expand Down Expand Up @@ -800,6 +820,12 @@ def _get_head_input(outputs, cls_out, batch):
if inv_adapter:
kwargs["invertible_adapter"] = inv_adapter

# Set prompt tokens length
if context is not None:
prompt_tokens_length = context.get("prompt_tokens_length", None)
if prompt_tokens_length is not None:
kwargs["prompt_tokens_length"] = prompt_tokens_length

if isinstance(self.active_head, BatchSplit):
if sum(self.active_head.batch_sizes) != all_outputs[0].size()[0]:
raise ValueError(
Expand Down
10 changes: 10 additions & 0 deletions src/adapters/heads/language_modeling.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
import torch.nn as nn

from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast, MaskedLMOutput, Seq2SeqLMOutput
Expand Down Expand Up @@ -118,6 +119,15 @@ def forward(self, outputs, cls_output=None, attention_mask=None, return_dict=Fal
labels = labels[..., 1:].contiguous()
else:
logits_for_loss = lm_logits

# adjust labels for prompt tuning
if kwargs.get("prompt_tokens_length", 0) > 0:
prompt_length = kwargs.get("prompt_tokens_length")
prompt_labels = torch.full(
(labels.shape[0], prompt_length), loss_fct.ignore_index, dtype=torch.long, device=labels.device
)
labels = torch.cat((prompt_labels, labels), dim=-1)

loss = loss_fct(logits_for_loss.view(-1, self.config["vocab_size"]), labels.view(-1))

if return_dict:
Expand Down
1 change: 1 addition & 0 deletions src/adapters/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ def filter_func(self, adapter_name):
or ".prefix_tunings.{}.".format(adapter_name) in x
or ".prefix_gates.{}.".format(adapter_name) in x
or ".loras.{}.".format(adapter_name) in x
or ".prompt_tunings.{}.".format(adapter_name) in x
)

# This dict maps the original weight names to the currently used equivalents.
Expand Down
2 changes: 1 addition & 1 deletion src/adapters/methods/adapter_layer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def compose_batch_split(self, adapter_setup: BatchSplit, state: NamedTuple, lvl:
# sequentially feed different parts of the blown-up batch into different adapters
children_states = []
for i, child in enumerate(adapter_setup):
# compute ids of sequences thet should be passed to the ith adapter
# compute ids of sequences that should be passed to the ith adapter
batch_idx = (
sum(adapter_setup.batch_sizes[:i]),
sum(adapter_setup.batch_sizes[: i + 1]),
Expand Down
Loading
Loading