Skip to content

Commit

Permalink
FIX: Adding 2 adapters when target_modules is a str fails (#1111)
Browse files Browse the repository at this point in the history
* Fix adding 2 adapters when target_modules is a str

Problem description

Adding two adapters (e.g. LoRA) when using a list for `target_mdules`
works but passing a str fails. The issue is that for str, we do a
`re.fullmatch`, whereas for list, we just check `endswith`. After adding
the first adapter, though, the naming pattern of the modules changes. In
the example above, the name for the linear layer changes from `"lin0"`
to `"base_model.model.lin0"`, which is why the `fullmatch` fails but the
`endswith` still works.

Reproduction

from peft import LoraConfig, get_peft_model
from torch import nn

class MLP(nn.Module):
    def __init__(self, bias=True):
        super().__init__()
        self.lin0 = nn.Linear(10, 20, bias=bias)

def test_target_modules_list():
    config = LoraConfig(target_modules=["lin0"])
    test_it(config)
    print("Adding two adapters with target_module being a list works")

def test_target_modules_str():
    config = LoraConfig(target_modules="lin0")
    test_it(config)

def test_it(config):
    model = MLP()
    model = get_peft_model(model, config, "adapter0")
    model.add_adapter("adapter1", config)
    print("Adding two adapters with target_module being a str works")

if __name__ == "__main__":
    # works
    test_target_modules_list()
    # ValueError: Target modules lin0 not found in the base model
    test_target_modules_str()

I think that most users would be surprised that:

1. Adding the first adapter works but adding the second fails, even
   though they use the same config.
2. Using `target_modules=["lin0"]` works but `target_modules="lin0"`
   fails for the 2nd adapter.

Solution

We could change the logic of not using `re.fullmatch` for str, but I
think that could be tricky to achieve without breaking BC. Instead, I
chose to change the inject_adapter call in add_adapter to pass the base
model, not the whole peft model. This way, the naming pattern is
preserved.

Tests

I haven't added extra tests for this. The script above could serve as a
test. However, it will be sufficient to remove the guard added in #1105:

    if isinstance(config.target_str, modules):
        # TODO this should be doable
        self.skipTest("Multiple adapters cannot currently be added when target_modules is a string.")

as that will test exactly this behavior and was how the bug was
originally uncovered. Depending on what PR lands first, the guard has to
removed in this PR or in #1105.

* Enable tests for adding 2 adapters with str
  • Loading branch information
BenjaminBossan authored Nov 14, 2023
1 parent 94877b5 commit ad75617
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 9 deletions.
2 changes: 1 addition & 1 deletion src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ def add_adapter(self, adapter_name: str, peft_config: PeftConfig):
self.base_model.add_adapter(adapter_name, peft_config)
else:
self.peft_config[adapter_name] = peft_config
self.base_model.inject_adapter(self, adapter_name)
self.base_model.inject_adapter(self.base_model.model, adapter_name)
except Exception: # somthing went wrong, roll back
if adapter_name in self.peft_config:
del self.peft_config[adapter_name]
Expand Down
8 changes: 0 additions & 8 deletions tests/testing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,10 +827,6 @@ def _test_delete_adapter(self, model_id, config_cls, config_kwargs):
return

model = self.transformers_class.from_pretrained(model_id)
if isinstance(config.target_modules, str):
# TODO this should be doable
self.skipTest("Multiple adapters cannot currently be added when target_modules is a string.")

adapter_to_delete = "delete_me"
model = get_peft_model(model, config)
model.add_adapter(adapter_to_delete, config)
Expand Down Expand Up @@ -869,10 +865,6 @@ def _test_delete_inactive_adapter(self, model_id, config_cls, config_kwargs):
return

model = self.transformers_class.from_pretrained(model_id)
if isinstance(config.target_modules, str):
# TODO this should be doable
self.skipTest("Multiple adapters cannot currently be added when target_modules is a string.")

adapter_to_delete = "delete_me"
model = get_peft_model(model, config)
model.add_adapter(adapter_to_delete, config)
Expand Down

0 comments on commit ad75617

Please sign in to comment.