From 7b1a0906a48210ebe575275742053cf756bb795a Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Thu, 21 Mar 2024 14:06:54 +0000 Subject: [PATCH] update tests for dtype checks --- .../transformers/compression/test_bitmask.py | 9 ++++--- .../compression/test_sparse_auto.py | 24 ++++++++++++------- 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/tests/sparseml/transformers/compression/test_bitmask.py b/tests/sparseml/transformers/compression/test_bitmask.py index 84b5cc6eb8c..40d683cb468 100644 --- a/tests/sparseml/transformers/compression/test_bitmask.py +++ b/tests/sparseml/transformers/compression/test_bitmask.py @@ -73,7 +73,7 @@ def test_match(shape, sparsity, dtype): mask = (test_tensor1.abs() < (1 - sparsity)).int() test_tensor1 *= mask - test_tensor2 = torch.rand(shape, dtype=torch.float32) + test_tensor2 = torch.rand(shape, dtype=dtype) mask = (test_tensor2.abs() < (1 - sparsity)).int() test_tensor2 *= mask @@ -82,7 +82,9 @@ def test_match(shape, sparsity, dtype): for key in dense_state_dict.keys(): dense_tensor = dense_state_dict[key] sparse_tensor = BitmaskTensor.from_dense(dense_tensor) - assert torch.equal(dense_tensor, sparse_tensor.decompress()) + decompressed = sparse_tensor.decompress() + assert decompressed.dtype == dense_tensor.dtype == dtype + assert torch.equal(dense_tensor, decompressed) @pytest.mark.parametrize( @@ -99,7 +101,7 @@ def test_reload_match(sparsity, dtype, tmp_path): mask = (test_tensor1.abs() < (1 - sparsity)).int() test_tensor1 *= mask - test_tensor2 = torch.rand((360, 720), dtype=torch.float32) + test_tensor2 = torch.rand((360, 720), dtype=dtype) mask = (test_tensor2.abs() < (1 - sparsity)).int() test_tensor2 *= mask @@ -114,6 +116,7 @@ def test_reload_match(sparsity, dtype, tmp_path): for key, reconstructed_tensor in reconstructed_dense: dense_tensor = dense_state_dict[key] + assert dense_tensor.dtype == reconstructed_tensor.dtype == dtype assert torch.equal(dense_tensor, reconstructed_tensor) shutil.rmtree(tmp_path) diff --git a/tests/sparseml/transformers/compression/test_sparse_auto.py b/tests/sparseml/transformers/compression/test_sparse_auto.py index 48b72a16a6c..b88881e0dff 100644 --- a/tests/sparseml/transformers/compression/test_sparse_auto.py +++ b/tests/sparseml/transformers/compression/test_sparse_auto.py @@ -30,15 +30,15 @@ @pytest.mark.parametrize( - "compressed,config", + "compressed,config,dtype", [ - [True, None], - [False, DenseSparsityConfig()], - [True, BitmaskConfig()], - [False, BitmaskConfig()], + [True, None, torch.float32], + [False, DenseSparsityConfig(), torch.float16], + [True, BitmaskConfig(), torch.bfloat16], + [False, BitmaskConfig(), torch.float32], ], ) -def test_sparse_model_reload(compressed, config, tmp_path): +def test_sparse_model_reload(compressed, config, dtype, tmp_path): recipe_str = "tests/sparseml/transformers/obcq/test_tiny2.yaml" model_path = "Xenova/llama2.c-stories15M" device = "cuda:0" @@ -60,9 +60,12 @@ def test_sparse_model_reload(compressed, config, tmp_path): concatenate_data=concatenate_data, splits=splits, oneshot_device=device, + precision=dtype, ) - model = SparseAutoModelForCausalLM.from_pretrained(tmp_path / "oneshot_out") + model = SparseAutoModelForCausalLM.from_pretrained( + tmp_path / "oneshot_out", torch_dtype=dtype + ) inferred_global_sparsity = CompressionConfig.infer_global_sparsity(model) assert math.isclose(inferred_global_sparsity, 19.6562, rel_tol=1e-3) @@ -85,7 +88,9 @@ def test_sparse_model_reload(compressed, config, tmp_path): assert sparsity_config["global_sparsity"] == inferred_global_sparsity assert sparsity_config["sparsity_structure"] == inferred_structure - dense_model = SparseAutoModelForCausalLM.from_pretrained(tmp_path / "compress_out") + dense_model = SparseAutoModelForCausalLM.from_pretrained( + tmp_path / "compress_out", torch_dtype="auto" + ) og_state_dict = model.state_dict() reconstructed_state_dict = dense_model.state_dict() @@ -93,7 +98,8 @@ def test_sparse_model_reload(compressed, config, tmp_path): for key in og_state_dict.keys(): dense_tensor = og_state_dict[key] reconstructed_tensor = reconstructed_state_dict[key] - assert torch.equal(dense_tensor.cpu(), reconstructed_tensor.cpu()) + assert dense_tensor.dtype == reconstructed_tensor.dtype == dtype + assert torch.equal(dense_tensor, reconstructed_tensor) shutil.rmtree(tmp_path)