From 5463e6822a6dd466338fac1b54207908ee2f67e7 Mon Sep 17 00:00:00 2001 From: JianyuWangV <134360816+JianyuWangV@users.noreply.github.com> Date: Wed, 8 Jan 2025 14:05:53 -0800 Subject: [PATCH 1/2] Update lora input linear adapter output dim. --- axlearn/common/lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/axlearn/common/lora.py b/axlearn/common/lora.py index 199cef603..fbd0689cb 100644 --- a/axlearn/common/lora.py +++ b/axlearn/common/lora.py @@ -501,7 +501,7 @@ def __init__(self, cfg: Config, *, parent: Module): "adapter", cfg.adapter.set( input_dim=cfg.query_dim, - output_dim=cfg.query_dim, + output_dim=cfg.num_heads * cfg.per_head_dim, num_heads=cfg.num_heads, ), ) From cfcb8fef9d8658c0b22211641372932b3a781d65 Mon Sep 17 00:00:00 2001 From: JianyuWangV <134360816+JianyuWangV@users.noreply.github.com> Date: Tue, 14 Jan 2025 11:26:25 -0800 Subject: [PATCH 2/2] Update unit tests. --- axlearn/common/lora_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/axlearn/common/lora_test.py b/axlearn/common/lora_test.py index 02cf95847..febf547f0 100644 --- a/axlearn/common/lora_test.py +++ b/axlearn/common/lora_test.py @@ -137,7 +137,7 @@ def test_alpha_is_zero(self): class LoraFusedQKVLinearTest(TestCase): def test_forward(self): - model_dim = 6 + model_dim = 16 num_heads = 2 per_head_dim = 3 seq_len = 4 @@ -197,7 +197,7 @@ def test_forward(self): ), ) def test_extend_step(self, layer): - model_dim = 8 + model_dim = 16 num_heads = 2 per_head_dim = 4 # change this to 4 to adapt the need of RoPE. seq_len = 4 @@ -267,7 +267,7 @@ def test_extend_step(self, layer): ) def test_prefill_states(self): - model_dim = 6 + model_dim = 16 num_heads = 2 per_head_dim = 3 seq_len = 4