Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
FIX: Adding 2 adapters when target_modules is a str fails (#1111)
* Fix adding 2 adapters when target_modules is a str Problem description Adding two adapters (e.g. LoRA) when using a list for `target_mdules` works but passing a str fails. The issue is that for str, we do a `re.fullmatch`, whereas for list, we just check `endswith`. After adding the first adapter, though, the naming pattern of the modules changes. In the example above, the name for the linear layer changes from `"lin0"` to `"base_model.model.lin0"`, which is why the `fullmatch` fails but the `endswith` still works. Reproduction from peft import LoraConfig, get_peft_model from torch import nn class MLP(nn.Module): def __init__(self, bias=True): super().__init__() self.lin0 = nn.Linear(10, 20, bias=bias) def test_target_modules_list(): config = LoraConfig(target_modules=["lin0"]) test_it(config) print("Adding two adapters with target_module being a list works") def test_target_modules_str(): config = LoraConfig(target_modules="lin0") test_it(config) def test_it(config): model = MLP() model = get_peft_model(model, config, "adapter0") model.add_adapter("adapter1", config) print("Adding two adapters with target_module being a str works") if __name__ == "__main__": # works test_target_modules_list() # ValueError: Target modules lin0 not found in the base model test_target_modules_str() I think that most users would be surprised that: 1. Adding the first adapter works but adding the second fails, even though they use the same config. 2. Using `target_modules=["lin0"]` works but `target_modules="lin0"` fails for the 2nd adapter. Solution We could change the logic of not using `re.fullmatch` for str, but I think that could be tricky to achieve without breaking BC. Instead, I chose to change the inject_adapter call in add_adapter to pass the base model, not the whole peft model. This way, the naming pattern is preserved. Tests I haven't added extra tests for this. The script above could serve as a test. However, it will be sufficient to remove the guard added in #1105: if isinstance(config.target_str, modules): # TODO this should be doable self.skipTest("Multiple adapters cannot currently be added when target_modules is a string.") as that will test exactly this behavior and was how the bug was originally uncovered. Depending on what PR lands first, the guard has to removed in this PR or in #1105. * Enable tests for adding 2 adapters with str
- Loading branch information