Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX Bug with modules_to_save loading if substring #2334

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/peft/utils/save_and_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,10 +335,11 @@ def set_peft_model_state_dict(
state_dict = {}
if getattr(model, "modules_to_save", None) is not None:
for key, value in peft_model_state_dict.items():
if any(module_name in key for module_name in model.modules_to_save):
for module_name in model.modules_to_save:
if module_name in key:
key = key.replace(module_name, f"{module_name}.modules_to_save.{adapter_name}")
if any(f".{module_name}." in key for module_name in model.modules_to_save):
githubnemo marked this conversation as resolved.
Show resolved Hide resolved
# sort to make order deterministic, but should not affect overall logic
for module_name in sorted(model.modules_to_save):
if f".{module_name}." in key:
key = key.replace(f".{module_name}.", f".{module_name}.modules_to_save.{adapter_name}.")
break
state_dict[key] = value
else:
Expand Down
100 changes: 99 additions & 1 deletion tests/test_other.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import pytest
import torch
from torch import nn
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification

from peft import LoraConfig, get_peft_model
from peft import LoraConfig, PeftModel, get_peft_model
from peft.utils.other import ModulesToSaveWrapper


Expand Down Expand Up @@ -199,3 +200,100 @@ def test_transient_attribute_access_non_existing_adapter(self, mlp):
model.base_model.model.lin1._active_adapter = "does-not-exist"
with pytest.raises(AttributeError, match="has no attribute 'weight'"):
model.lin1.weight


class TestModulesToSaveNameSubstringBug:
"""Test a bug that could occur with multiple modules to save where one adapter's name is a substring of another
adapter's name.

This bug was the result of an error in the logic of modifying the state_dict for modules_to_save in
set_peft_model_state_dict. The error in the logic was that it was checked if an entry from modules_to_save (a set
of strings) is a substring of a key of the state_dict. If it was, a new name was assigned to that key in the
state_dict, which would allow to load the weight later.

The issue that stems from the substring check occurs if there are multiple modules_to_save, and one of them has a
name that is a substring of another. So e.g. if one is named "classifier" and the other is named "classifier2",
there could be a false match.


This bug was reported in #2289.

"""

def get_model(self):
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.lin = nn.Linear(5, 4)
# important: "classifier" is a substring of "classifier2", "classifier3", "classifier4"
self.classifier = nn.Linear(4, 2)
self.classifier2 = nn.Linear(4, 2)
self.classifier3 = nn.Linear(4, 2)
self.classifier4 = nn.Linear(4, 2)

def forward(self, x):
x = self.lin(x)
return self.classifier(x) + self.classifier2(x) + self.classifier3(x) + self.classifier4(x)

torch.manual_seed(0)
return MyModule()

@pytest.fixture
def path_merged_and_unmerged(self, tmp_path):
# Create 2 checkpoints:
# 1. merged: the model after calling merge_and_unload
# 2. unmerged: the PEFT model saved without calling merge_and_unload
path = tmp_path / "model.pt"

lora_config = LoraConfig(
target_modules=["lin"],
# important: "classifier" is a substring of "classifier2", "classifier3", "classifier4"
modules_to_save=["classifier", "classifier2", "classifier3", "classifier4"],
)
model = get_peft_model(self.get_model(), lora_config)
# mock training
for _ in range(5):
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
output = model(torch.randn(10, 5))
loss = output.sum()
loss.backward()
optimizer.step()

# save the peft model without merging
path_unmerged = tmp_path / "unmerged"
model.save_pretrained(path_unmerged)

# merge the model and save state_dict
path_merged = tmp_path / "merged"
merged = model.merge_and_unload()
state_dict = merged.state_dict()
torch.save(state_dict, path_merged)

return path_merged, path_unmerged

def test_load_merged_and_unmerged_same_weights(self, path_merged_and_unmerged):
# Note that this test is quasi flaky, it has a 1 in 4 chance of passing even without the bugfix. It passes when
# "classifier" happens to be the last element of the set model.modules_to_save. The order of the set is random.
# It is not possible just run this test multiple times to minimize the probability of this happening, because
# within the same process, the hash order is consistent. With the bug fix, this doesn't matter, as the test will
# always pass, but if there is a regression, there is a 1 in 4 chance of not catching it. Since the CI runs many
# tests, it is overall very unlikely that none will catch it though. If you see this test failing in CI, thus be
# aware that some of the passing tests may just pass owing to randomness.
path_merged, path_unmerged = path_merged_and_unmerged

# load the merged model directly
state_dict = torch.load(path_merged, weights_only=True)
model = self.get_model()
model.load_state_dict(state_dict)
sd_merged = model.state_dict()
del model

# load the unmerged model and merge it
unmerged = PeftModel.from_pretrained(self.get_model(), path_unmerged)
sd_unmerged = unmerged.merge_and_unload().state_dict()

assert sd_merged.keys() == sd_unmerged.keys()
for key in sd_merged.keys():
param_merged = sd_merged[key]
param_unmerged = sd_unmerged[key]
assert torch.allclose(param_merged, param_unmerged)
Loading