From 56b9ceed88242e2691923a5dbd20262f219c78ab Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 21 Jan 2025 15:08:57 +0100 Subject: [PATCH] remove class from tests --- tests/test_modeling_common.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index cf259fabe302..3914669a6c69 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4628,33 +4628,13 @@ def test_flash_attn_2_from_config(self): dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device) dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [0, 1, 1, 1]]).to(torch_device) - fa2_correctly_converted = False - - for _, module in fa2_model.named_modules(): - if "FlashAttention" in module.__class__.__name__: - fa2_correctly_converted = True - break - - self.assertTrue(fa2_correctly_converted) - _ = fa2_model(input_ids=dummy_input, attention_mask=dummy_attention_mask) with tempfile.TemporaryDirectory() as tmpdirname: fa2_model.save_pretrained(tmpdirname) - model_from_pretrained = model_class.from_pretrained(tmpdirname) - self.assertTrue(model_from_pretrained.config._attn_implementation != "flash_attention_2") - fa2_correctly_converted = False - - for _, module in model_from_pretrained.named_modules(): - if "FlashAttention" in module.__class__.__name__: - fa2_correctly_converted = True - break - - self.assertFalse(fa2_correctly_converted) - def _get_custom_4d_mask_test_data(self): # Sequence in which all but the last token is the same input_ids = torch.tensor(