From 6da07d13c19d200393490372fc5bc4e155b47354 Mon Sep 17 00:00:00 2001 From: calpt Date: Thu, 12 Oct 2023 11:19:02 +0200 Subject: [PATCH] Fix bottleneck average composition computation (#590) --- src/adapters/layer.py | 4 ++-- tests_adapters/composition/test_adapter_composition.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/adapters/layer.py b/src/adapters/layer.py index 32b57915a5..99b2e151ac 100644 --- a/src/adapters/layer.py +++ b/src/adapters/layer.py @@ -631,8 +631,8 @@ def adapter_average_output(self, adapter_setup: Average, hidden_states, input_te ) # Case X: No adapter which is part of this module -> ignore - weights = torch.tensor(adapter_setup.weights).unsqueeze(1).unsqueeze(1).to(hidden_states.device) - hidden_states = torch.mean(torch.cat(children_hidden, 0) * weights, 0) + weights = torch.tensor(adapter_setup.weights)[:, None, None, None].to(hidden_states.device) + hidden_states = torch.mean(torch.stack(children_hidden, 0) * weights, 0) return hidden_states diff --git a/tests_adapters/composition/test_adapter_composition.py b/tests_adapters/composition/test_adapter_composition.py index d43a840552..ff30bd8a33 100644 --- a/tests_adapters/composition/test_adapter_composition.py +++ b/tests_adapters/composition/test_adapter_composition.py @@ -224,9 +224,9 @@ def test_average(self): model.set_active_adapters(Average("a", "b", "c", "d")) inputs = {} - inputs["input_ids"] = ids_tensor((1, 128), 1000) + inputs["input_ids"] = ids_tensor((2, 128), 1000) logits = model(**inputs).logits - self.assertEqual(logits.shape, (1, 2)) + self.assertEqual(logits.shape, (2, 2)) class PrefixTuningCompositionTest(AdapterCompositionTest):