From 79fd694ad76cce5b11dc37c30e7f2f78feab6929 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Mon, 20 Jan 2025 12:00:35 +0100 Subject: [PATCH] FIX Bug with modules_to_save loading if substring Fixes #2289 This bug was the result of an error in the logic of modifying the state_dict for modules_to_save in set_peft_model_state_dict. The error in the logic was that it was checked if an entry from modules_to_save (a set of strings) is a substring of a key of the state_dict. If it was, a new name was assigned to that key in the state_dict, which would allow to load the weight later. The issue that stems from the substring check occurs if there are multiple modules_to_save, and one of them has a name that is a substring of another. So e.g. if one is named "classifier" and the other is named "classifier2", there could be a false match. This PR fixes the issue by enclosing the string with ".", i.e. we now check if ".classifier." is a substring instead, which avoid false matches. What made this bug even harder to debug was that modules_to_save is a set and therefore has no predetermined order. Therefore, the bug would be flaky. To address this, modules_to_save is now sorted before iterating over it. That doesn't contribute to resolving the bug, but it makes the bug non-flaky, allowing future debugging to be easier. --- src/peft/utils/save_and_load.py | 9 +-- tests/test_other.py | 100 +++++++++++++++++++++++++++++++- 2 files changed, 104 insertions(+), 5 deletions(-) diff --git a/src/peft/utils/save_and_load.py b/src/peft/utils/save_and_load.py index 5337afeb9f..e42e4a5e8c 100644 --- a/src/peft/utils/save_and_load.py +++ b/src/peft/utils/save_and_load.py @@ -335,10 +335,11 @@ def set_peft_model_state_dict( state_dict = {} if getattr(model, "modules_to_save", None) is not None: for key, value in peft_model_state_dict.items(): - if any(module_name in key for module_name in model.modules_to_save): - for module_name in model.modules_to_save: - if module_name in key: - key = key.replace(module_name, f"{module_name}.modules_to_save.{adapter_name}") + if any(f".{module_name}." in key for module_name in model.modules_to_save): + # sort to make order deterministic, but should not affect overall logic + for module_name in sorted(model.modules_to_save): + if f".{module_name}." in key: + key = key.replace(f".{module_name}.", f".{module_name}.modules_to_save.{adapter_name}.") break state_dict[key] = value else: diff --git a/tests/test_other.py b/tests/test_other.py index 75d8a7565c..7ee521f1c3 100644 --- a/tests/test_other.py +++ b/tests/test_other.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. + import pytest import torch from torch import nn from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification -from peft import LoraConfig, get_peft_model +from peft import LoraConfig, PeftModel, get_peft_model from peft.utils.other import ModulesToSaveWrapper @@ -199,3 +200,100 @@ def test_transient_attribute_access_non_existing_adapter(self, mlp): model.base_model.model.lin1._active_adapter = "does-not-exist" with pytest.raises(AttributeError, match="has no attribute 'weight'"): model.lin1.weight + + +class TestModulesToSaveNameSubstringBug: + """Test a bug that could occur with multiple modules to save where one adapter's name is a substring of another + adapter's name. + + This bug was the result of an error in the logic of modifying the state_dict for modules_to_save in + set_peft_model_state_dict. The error in the logic was that it was checked if an entry from modules_to_save (a set + of strings) is a substring of a key of the state_dict. If it was, a new name was assigned to that key in the + state_dict, which would allow to load the weight later. + + The issue that stems from the substring check occurs if there are multiple modules_to_save, and one of them has a + name that is a substring of another. So e.g. if one is named "classifier" and the other is named "classifier2", + there could be a false match. + + + This bug was reported in #2289. + + """ + + def get_model(self): + class MyModule(nn.Module): + def __init__(self): + super().__init__() + self.lin = nn.Linear(5, 4) + # important: "classifier" is a substring of "classifier2", "classifier3", "classifier4" + self.classifier = nn.Linear(4, 2) + self.classifier2 = nn.Linear(4, 2) + self.classifier3 = nn.Linear(4, 2) + self.classifier4 = nn.Linear(4, 2) + + def forward(self, x): + x = self.lin(x) + return self.classifier(x) + self.classifier2(x) + self.classifier3(x) + self.classifier4(x) + + torch.manual_seed(0) + return MyModule() + + @pytest.fixture + def path_merged_and_unmerged(self, tmp_path): + # Create 2 checkpoints: + # 1. merged: the model after calling merge_and_unload + # 2. unmerged: the PEFT model saved without calling merge_and_unload + path = tmp_path / "model.pt" + + lora_config = LoraConfig( + target_modules=["lin"], + # important: "classifier" is a substring of "classifier2", "classifier3", "classifier4" + modules_to_save=["classifier", "classifier2", "classifier3", "classifier4"], + ) + model = get_peft_model(self.get_model(), lora_config) + # mock training + for _ in range(5): + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + output = model(torch.randn(10, 5)) + loss = output.sum() + loss.backward() + optimizer.step() + + # save the peft model without merging + path_unmerged = tmp_path / "unmerged" + model.save_pretrained(path_unmerged) + + # merge the model and save state_dict + path_merged = tmp_path / "merged" + merged = model.merge_and_unload() + state_dict = merged.state_dict() + torch.save(state_dict, path_merged) + + return path_merged, path_unmerged + + def test_load_merged_and_unmerged_same_weights(self, path_merged_and_unmerged): + # Note that this test is quasi flaky, it has a 1 in 4 chance of passing even without the bugfix. It passes when + # "classifier" happens to be the last element of the set model.modules_to_save. The order of the set is random. + # It is not possible just run this test multiple times to minimize the probability of this happening, because + # within the same process, the hash order is consistent. With the bug fix, this doesn't matter, as the test will + # always pass, but if there is a regression, there is a 1 in 4 chance of not catching it. Since the CI runs many + # tests, it is overall very unlikely that none will catch it though. If you see this test failing in CI, thus be + # aware that some of the passing tests may just pass owing to randomness. + path_merged, path_unmerged = path_merged_and_unmerged + + # load the merged model directly + state_dict = torch.load(path_merged, weights_only=True) + model = self.get_model() + model.load_state_dict(state_dict) + sd_merged = model.state_dict() + del model + + # load the unmerged model and merge it + unmerged = PeftModel.from_pretrained(self.get_model(), path_unmerged) + sd_unmerged = unmerged.merge_and_unload().state_dict() + + assert sd_merged.keys() == sd_unmerged.keys() + for key in sd_merged.keys(): + param_merged = sd_merged[key] + param_unmerged = sd_unmerged[key] + assert torch.allclose(param_merged, param_unmerged)