From e3ad38ecc8c0e424d583ae79b19d6a633cd9e196 Mon Sep 17 00:00:00 2001 From: cadentj Date: Sun, 29 Dec 2024 19:25:33 -0500 Subject: [PATCH] Fixed batch tokenization for 2+ invokes --- src/nnsight/modeling/language.py | 2 +- tests/test_lm.py | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/nnsight/modeling/language.py b/src/nnsight/modeling/language.py index b082bfd7..3bb37ebf 100755 --- a/src/nnsight/modeling/language.py +++ b/src/nnsight/modeling/language.py @@ -276,7 +276,7 @@ def _batch( batched_labels = torch.cat((batched_labels, labels)) - batched_inputs["attention_mask"][:1, : attention_mask.shape[1]] = attention_mask + batched_inputs["attention_mask"][:-1, : attention_mask.shape[1]] = attention_mask return ((batched_inputs,), {"labels": batched_labels}) diff --git a/tests/test_lm.py b/tests/test_lm.py index 347e349b..5d406ba8 100755 --- a/tests/test_lm.py +++ b/tests/test_lm.py @@ -72,6 +72,19 @@ def test_generation(gpt2: nnsight.LanguageModel, MSG_prompt: str): ) +@torch.no_grad() +def test_invoke(gpt2: nnsight.LanguageModel, MSG_prompt: str): + hidden_states = [] + with gpt2.trace(validate=True, backend=AssertSavedLenBackend(1)) as tracer: + for _ in range(3): + with tracer.invoke(MSG_prompt, scan=True) as invoker: + hs = gpt2.transformer.h[-1].output[0].save() + hidden_states.append(hs) + _test_serialize(tracer) + + assert any(hs is not None for hs in hidden_states) + + @torch.no_grad() def test_save(gpt2: nnsight.LanguageModel): with gpt2.generate("Hello world", validate=True, scan=True, backend=AssertSavedLenBackend(2)) as tracer: