From 83028178ecc996573ea5fc4bc50fe663bdac4599 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Mon, 20 Jan 2025 18:28:33 +0100 Subject: [PATCH] FIX Add missing attributes to MultiheadAttention (#2335) See initial report here: https://github.com/huggingface/peft/issues/761#issuecomment-2600936330. For MHA to work in all circumstances, for instance in eval model, it requires us to expose a couple of more attributes that we have missed so far. Those were added now. --- src/peft/tuners/lora/layer.py | 27 +++++++++++++++++++++++ tests/test_initialization.py | 40 +++++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+) diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index 1a94c8cec5..557fcfd188 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -1405,6 +1405,33 @@ def batch_first(self) -> bool: def head_dim(self) -> int: return self.get_base_layer().head_dim + @property + def in_proj_weight(self) -> nn.Parameter: + return self.get_base_layer().in_proj_weight + + @property + def in_proj_bias(self) -> nn.Parameter: + return self.get_base_layer().in_proj_bias + + @property + def out_proj(self) -> nn.Module: + return self.get_base_layer().out_proj.get_base_layer() + + @property + def bias_k(self) -> Optional[nn.Parameter]: + return self.get_base_layer().bias_k + + @property + def bias_v(self) -> Optional[nn.Parameter]: + return self.get_base_layer().bias_v + + def merge_masks(self, *args, **kwargs) -> tuple[Optional[torch.Tensor], Optional[int]]: + return self.get_base_layer().merge_masks(*args, **kwargs) + + @property + def add_zero_attn(self) -> bool: + return self.get_base_layer().add_zero_attn + def update_layer(self, *args, **kwargs) -> None: super().update_layer(*args, **kwargs) # Note: LoRA is applied to both in_proj and out_proj. There is currently no way to only specify one of them. diff --git a/tests/test_initialization.py b/tests/test_initialization.py index 5842cc6225..510f12892e 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -1145,6 +1145,7 @@ def test_mha_with_dora_raises(self, mha_cls): get_peft_model(model, config) def test_mha_exposes_attributes(self, mha_cls): + # MHA requires a bunch of attributes to be exposed, try to check them exhaustively here model = mha_cls() embed_dim = model.mha.embed_dim kdim = model.mha.kdim @@ -1154,6 +1155,12 @@ def test_mha_exposes_attributes(self, mha_cls): dropout = model.mha.dropout batch_first = model.mha.batch_first head_dim = model.mha.head_dim + in_proj_weight = model.mha.in_proj_weight + in_proj_bias = model.mha.in_proj_bias + out_proj = model.mha.out_proj + bias_k = model.mha.bias_k + bias_v = model.mha.bias_v + add_zero_attn = model.mha.add_zero_attn config = LoraConfig(target_modules=["mha"]) peft_model = get_peft_model(model, config) @@ -1165,6 +1172,39 @@ def test_mha_exposes_attributes(self, mha_cls): assert peft_model.base_model.mha.dropout == dropout assert peft_model.base_model.mha.batch_first == batch_first assert peft_model.base_model.mha.head_dim == head_dim + if in_proj_weight is not None: + assert torch.allclose(peft_model.base_model.mha.in_proj_weight, in_proj_weight) + else: + assert peft_model.base_model.mha.in_proj_weight is None + if in_proj_bias is not None: + assert torch.allclose(peft_model.base_model.mha.in_proj_bias, in_proj_bias) + else: + assert peft_model.base_model.mha.in_proj_bias is None + assert peft_model.base_model.mha.out_proj is out_proj + if bias_k is not None: + assert torch.allclose(peft_model.base_model.mha.bias_k, bias_k) + else: + assert peft_model.base_model.mha.bias_k is None + if bias_v is not None: + assert torch.allclose(peft_model.base_model.mha.bias_v, bias_v) + else: + assert peft_model.base_model.mha.bias_v is None + assert peft_model.base_model.mha.add_zero_attn == add_zero_attn + + def test_mha_merge_masks_method(self, mha_cls): + # MHA requires a merge_masks method to be exposed, check that it works + model = mha_cls() + config = LoraConfig(target_modules=["mha"]) + peft_model = get_peft_model(model, config) + + attn_mask = torch.randint(0, 2, (10, 10)) + key_padding_mask = torch.randint(0, 2, (10, 10)) + query = torch.rand(10, 10, 10) + merged_mask0, mask_type0 = model.mha.merge_masks(attn_mask, key_padding_mask, query) + merged_mask1, mask_type1 = peft_model.base_model.mha.merge_masks(attn_mask, key_padding_mask, query) + + assert torch.allclose(merged_mask0, merged_mask1) + assert mask_type0 == mask_type1 def test_lora_with_bias_extra_params(self): # lora with lora_bias=True