Skip to content

Commit

Permalink
fix add_weighted_adapter method (#1169)
Browse files Browse the repository at this point in the history
* fix `add_weighted_adapter` method

Co-Authored-By: Benjamin Bossan <[email protected]>
Co-Authored-By: jihuishan <[email protected]>

* Update testing_common.py

---------

Co-authored-by: Benjamin Bossan <[email protected]>
Co-authored-by: jihuishan <[email protected]>
  • Loading branch information
3 people authored Nov 22, 2023
1 parent b4ac2d8 commit 0432385
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import operator
import re
import warnings
Expand Down Expand Up @@ -517,8 +518,8 @@ def add_weighted_adapter(
current_adapter_lora_B = target.lora_embedding_B[adapter]
else:
continue
target_lora_A.data += current_adapter_lora_A.data * weight * target.scaling[adapter]
target_lora_B.data += current_adapter_lora_B.data
target_lora_A.data += current_adapter_lora_A.data * math.sqrt(weight) * target.scaling[adapter]
target_lora_B.data += current_adapter_lora_B.data * math.sqrt(weight)
elif combination_type == "cat":
loras_A, loras_B = [], []
for adapter, weight in zip(adapters, weights):
Expand Down
10 changes: 10 additions & 0 deletions tests/testing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ class PeftCommonTester:
transformers_class (`transformers.PreTrainedModel`):
The transformers class that is being tested.
"""

torch_device = infer_device()
transformers_class = None

Expand Down Expand Up @@ -1021,6 +1022,14 @@ def _test_weighted_combination_of_adapters(self, model_id, config_cls, config_kw
adapter_list[:2], weight_list[:2], "multi_adapter_linear_reweighting", combination_type="linear"
)

# test linear re-weighting with multiple adapters with only first adapter having non zero weight
model.add_weighted_adapter(
adapter_list[:2],
[weight_list[0], 0],
"multi_adapter_linear_reweighting_single_enabled",
combination_type="linear",
)

with self.assertRaises(ValueError):
model.add_weighted_adapter(
adapter_list[1:],
Expand All @@ -1034,6 +1043,7 @@ def _test_weighted_combination_of_adapters(self, model_id, config_cls, config_kw
"multi_adapter_svd_reweighting",
"multi_adapter_cat_reweighting",
"multi_adapter_linear_reweighting",
"multi_adapter_linear_reweighting_single_enabled",
]
for new_adapter in new_adapters:
self.assertTrue(new_adapter in model.peft_config)
Expand Down

0 comments on commit 0432385

Please sign in to comment.