diff --git a/docs/source/package_reference/vera.md b/docs/source/package_reference/vera.md index 9677df2742..9f7bb19a38 100644 --- a/docs/source/package_reference/vera.md +++ b/docs/source/package_reference/vera.md @@ -20,9 +20,10 @@ rendered properly in your Markdown viewer. When saving the adapter parameters, it's possible to eschew storing the low rank matrices by setting `save_projection=False` on the `VeraConfig`. In that case, these matrices will be restored based on the fixed random seed from the `projection_prng_key` argument. This cuts down on the size of the checkpoint, but we cannot guarantee reproducibility on all devices and for all future versions of PyTorch. If you want to ensure reproducibility, set `save_projection=True` (which is the default). +To handle different shapes of adapted layers, VeRA initializes shared A and B matrices with the largest required size for each dimension. During the forward pass, submatrices A and B for a given layer are sliced out from these shared matrices and used as described in the paper. For example, adapting two linear layers of shapes (100, 20) and (80, 50) will create A and B matrices of shapes (rank, 50) and (100, rank) respectively. Then, to adapt a layer of shape (100, 20), submatrices A and B of shapes (rank, 20) and (100, rank) will be extracted. + VeRA currently has the following constraints: -- All targeted parameters must have the same shape. - Only `nn.Linear` layers are supported. - Quantized layers are not supported. diff --git a/examples/sequence_classification/VeRA.ipynb b/examples/sequence_classification/VeRA.ipynb index b917618db3..e3786fff45 100644 --- a/examples/sequence_classification/VeRA.ipynb +++ b/examples/sequence_classification/VeRA.ipynb @@ -94,7 +94,7 @@ " task_type=\"SEQ_CLS\", \n", " r=rank,\n", " d_initial=0.1,\n", - " target_modules=[\"query\", \"value\"],\n", + " target_modules=[\"query\", \"value\", \"intermediate.dense\"],\n", " save_projection=True,\n", ")\n", "head_lr = 1e-2\n", @@ -205,7 +205,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "trainable params: 610,754 || all params: 125,257,924 || trainable%: 0.48759709605278145\n" + "trainable params: 647,714 || all params: 125,294,884 || trainable%: 0.5170\n" ] } ], @@ -255,76 +255,76 @@ "name": "stderr", "output_type": "stream", "text": [ - " 0%| | 0/29 [00:00 torch.Tensor: lambda_d = lambda_d.float() lambda_b = lambda_b.float() + sliced_A = vera_A[:, : self.in_features] + sliced_B = vera_B[: self.out_features, :] lambda_b = lambda_b.unsqueeze(-1) lambda_d = lambda_d.unsqueeze(-1) - output_tensor = transpose((lambda_b * vera_B) @ (lambda_d * vera_A), self.fan_in_fan_out) + output_tensor = transpose((lambda_b * sliced_B) @ (lambda_d * sliced_A), self.fan_in_fan_out) if cast_to_fp32: output_tensor = output_tensor.to(dtype=dtype) @@ -252,9 +254,15 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: vera_A = self.vera_A[active_adapter] vera_B = self.vera_B[active_adapter] + # As adapted layers may have different shapes and VeRA contains a single shared pair of A and B matrices, + # we initialize these matrices with the largest required size for each dimension. + # During the forward pass, required submatrices are sliced out from the shared vera_A and vera_B. + sliced_A = vera_A[:, : self.in_features] + sliced_B = vera_B[: self.out_features, :] + dropout = self.vera_dropout[active_adapter] x = x.to(lambda_d.dtype) - result = result + lambda_b * F.linear(lambda_d * F.linear(dropout(x), vera_A), vera_B) + result = result + lambda_b * F.linear(lambda_d * F.linear(dropout(x), sliced_A), sliced_B) result = result.to(previous_dtype) return result diff --git a/src/peft/tuners/vera/model.py b/src/peft/tuners/vera/model.py index 2ecd1c9ab8..a47112d94e 100644 --- a/src/peft/tuners/vera/model.py +++ b/src/peft/tuners/vera/model.py @@ -101,13 +101,11 @@ class VeraModel(BaseTuner): def __init__(self, model, config, adapter_name) -> None: super().__init__(model, config, adapter_name) - def _find_first_dim(self, config) -> tuple[int, int]: + def _find_dim(self, config) -> tuple[int, int]: """ - Finds the first linear layer that has been wrapped with Vera, and extract the input and output dimension. + Finds the largest input and output dimensions across linear layers that have been wrapped with VeRA. This will be used for determining the size of the shared vera_A and vera_B matrices. - - This will throw an error if there are multiple layers of the same type with different shapes. """ model_config = getattr(self.model, "config", {"model_type": "custom"}) if hasattr(model_config, "to_dict"): @@ -116,7 +114,7 @@ def _find_first_dim(self, config) -> tuple[int, int]: peft_config = self._prepare_adapter_config(config, model_config) peft_config = _maybe_include_all_linear_layers(peft_config, self.model) - first_shape = None + largest_shape = None for key, module in self.model.named_modules(): if not self._check_target_module_exists(peft_config, key): continue @@ -128,24 +126,21 @@ def _find_first_dim(self, config) -> tuple[int, int]: else: continue - if first_shape is None: - first_shape = module_shape + if largest_shape is None: + largest_shape = module_shape continue - if module_shape != first_shape: - raise ValueError( - "Multiple target layers with different dimensions were specified. VeRA only supports a " - f"single dimension size. Expected shape {first_shape}, got {module_shape}." - ) + if module_shape != largest_shape: + largest_shape = tuple(max(a, b) for a, b in zip(largest_shape, module_shape)) - if first_shape is None: + if largest_shape is None: msg = "No layers types compatible with VeRA were found. Please check `peft_config.target_modules`." raise ValueError(msg) - return first_shape + return largest_shape def _init_vera_A_vera_B(self, config: VeraConfig, adapter_name: str) -> None: - first_linear_out_dim, first_linear_in_dim = self._find_first_dim(config) + linear_out_dim, linear_in_dim = self._find_dim(config) # use of persistent to exclude vera_A and vera_B from the state dict if we choose not to save them. self.vera_A = BufferDict({}, persistent=config.save_projection) @@ -153,8 +148,8 @@ def _init_vera_A_vera_B(self, config: VeraConfig, adapter_name: str) -> None: # deterministic init of vera_A and vera_B if we know the key generator = torch.Generator(device="cpu").manual_seed(config.projection_prng_key) - vera_A = _kaiming_init((config.r, first_linear_in_dim), generator=generator) - vera_B = _kaiming_init((first_linear_out_dim, config.r), generator=generator) + vera_A = _kaiming_init((config.r, linear_in_dim), generator=generator) + vera_B = _kaiming_init((linear_out_dim, config.r), generator=generator) self.vera_A[adapter_name] = vera_A self.vera_B[adapter_name] = vera_B diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 9b5517f9c2..06eb097b36 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -338,6 +338,7 @@ ("Vanilla MLP 1 VeRA", "MLP", VeraConfig, {"target_modules": "lin0"}), ("Vanilla MLP 2 VeRA", "MLP", VeraConfig, {"target_modules": ["lin0"]}), ("Vanilla MLP 3 VeRA", "MLP", VeraConfig, {"target_modules": ["lin1"]}), + ("Vanilla MLP 4 VeRA", "MLP", VeraConfig, {"target_modules": ["lin0", "lin1"]}), ( "Vanilla MLP 5 VeRA", "MLP", diff --git a/tests/test_vera.py b/tests/test_vera.py index 9fd3eca71c..6dbaac6bd3 100644 --- a/tests/test_vera.py +++ b/tests/test_vera.py @@ -15,7 +15,6 @@ # This test file is for tests specific to VeRA, since VeRA has some specific challenges due to the shared weights. import os -import re import pytest import torch @@ -265,13 +264,21 @@ def test_vera_lambda_dont_share_memory(self, mlp_same_prng): != mlp_same_prng.base_model.model.lin2.vera_lambda_d["other"].data_ptr() ) - def test_vera_different_shapes_raises(self, mlp): - # It is not possible (currently) to have vera_A and vera_B for different shapes, as they cannot be shared if - # their shapes are not identical. lin0 and lin1 have different shapes. - config = VeraConfig(target_modules=["lin0", "lin1"], init_weights=False) - msg = re.escape( - "Multiple target layers with different dimensions were specified. VeRA only supports a single dimension " - "size. Expected shape (20, 10), got (20, 20)." - ) - with pytest.raises(ValueError, match=msg): - get_peft_model(mlp, config) + def test_vera_different_shapes(self, mlp): + config = VeraConfig(target_modules=["lin0", "lin3"], init_weights=False) + mlp_different_shapes = get_peft_model(mlp, config) + + vera_A = mlp_different_shapes.vera_A["default"] + vera_B = mlp_different_shapes.vera_B["default"] + + # sanity check + assert mlp.lin0.base_layer.weight.shape != mlp.lin3.base_layer.weight.shape + + # lin0 has the largest output dimension, lin3 has the largest input dimension + # vera_A should have the shape of (rank, largest_in), vera_B should have the shape of (largest_out, rank) + assert vera_A.shape == (config.r, mlp.lin3.in_features) + assert vera_B.shape == (mlp.lin0.out_features, config.r) + + # should not raise + input = torch.randn(5, 10) + mlp_different_shapes(input)