Skip to content

Commit

Permalink
update tests for dtype checks
Browse files Browse the repository at this point in the history
  • Loading branch information
Sara Adkins committed Mar 21, 2024
1 parent c7c8b92 commit 7b1a090
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 12 deletions.
9 changes: 6 additions & 3 deletions tests/sparseml/transformers/compression/test_bitmask.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_match(shape, sparsity, dtype):
mask = (test_tensor1.abs() < (1 - sparsity)).int()
test_tensor1 *= mask

test_tensor2 = torch.rand(shape, dtype=torch.float32)
test_tensor2 = torch.rand(shape, dtype=dtype)
mask = (test_tensor2.abs() < (1 - sparsity)).int()
test_tensor2 *= mask

Expand All @@ -82,7 +82,9 @@ def test_match(shape, sparsity, dtype):
for key in dense_state_dict.keys():
dense_tensor = dense_state_dict[key]
sparse_tensor = BitmaskTensor.from_dense(dense_tensor)
assert torch.equal(dense_tensor, sparse_tensor.decompress())
decompressed = sparse_tensor.decompress()
assert decompressed.dtype == dense_tensor.dtype == dtype
assert torch.equal(dense_tensor, decompressed)


@pytest.mark.parametrize(
Expand All @@ -99,7 +101,7 @@ def test_reload_match(sparsity, dtype, tmp_path):
mask = (test_tensor1.abs() < (1 - sparsity)).int()
test_tensor1 *= mask

test_tensor2 = torch.rand((360, 720), dtype=torch.float32)
test_tensor2 = torch.rand((360, 720), dtype=dtype)
mask = (test_tensor2.abs() < (1 - sparsity)).int()
test_tensor2 *= mask

Expand All @@ -114,6 +116,7 @@ def test_reload_match(sparsity, dtype, tmp_path):

for key, reconstructed_tensor in reconstructed_dense:
dense_tensor = dense_state_dict[key]
assert dense_tensor.dtype == reconstructed_tensor.dtype == dtype
assert torch.equal(dense_tensor, reconstructed_tensor)

shutil.rmtree(tmp_path)
24 changes: 15 additions & 9 deletions tests/sparseml/transformers/compression/test_sparse_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@


@pytest.mark.parametrize(
"compressed,config",
"compressed,config,dtype",
[
[True, None],
[False, DenseSparsityConfig()],
[True, BitmaskConfig()],
[False, BitmaskConfig()],
[True, None, torch.float32],
[False, DenseSparsityConfig(), torch.float16],
[True, BitmaskConfig(), torch.bfloat16],
[False, BitmaskConfig(), torch.float32],
],
)
def test_sparse_model_reload(compressed, config, tmp_path):
def test_sparse_model_reload(compressed, config, dtype, tmp_path):
recipe_str = "tests/sparseml/transformers/obcq/test_tiny2.yaml"
model_path = "Xenova/llama2.c-stories15M"
device = "cuda:0"
Expand All @@ -60,9 +60,12 @@ def test_sparse_model_reload(compressed, config, tmp_path):
concatenate_data=concatenate_data,
splits=splits,
oneshot_device=device,
precision=dtype,
)

model = SparseAutoModelForCausalLM.from_pretrained(tmp_path / "oneshot_out")
model = SparseAutoModelForCausalLM.from_pretrained(
tmp_path / "oneshot_out", torch_dtype=dtype
)

inferred_global_sparsity = CompressionConfig.infer_global_sparsity(model)
assert math.isclose(inferred_global_sparsity, 19.6562, rel_tol=1e-3)
Expand All @@ -85,15 +88,18 @@ def test_sparse_model_reload(compressed, config, tmp_path):
assert sparsity_config["global_sparsity"] == inferred_global_sparsity
assert sparsity_config["sparsity_structure"] == inferred_structure

dense_model = SparseAutoModelForCausalLM.from_pretrained(tmp_path / "compress_out")
dense_model = SparseAutoModelForCausalLM.from_pretrained(
tmp_path / "compress_out", torch_dtype="auto"
)

og_state_dict = model.state_dict()
reconstructed_state_dict = dense_model.state_dict()
assert len(og_state_dict) == len(reconstructed_state_dict)
for key in og_state_dict.keys():
dense_tensor = og_state_dict[key]
reconstructed_tensor = reconstructed_state_dict[key]
assert torch.equal(dense_tensor.cpu(), reconstructed_tensor.cpu())
assert dense_tensor.dtype == reconstructed_tensor.dtype == dtype
assert torch.equal(dense_tensor, reconstructed_tensor)

shutil.rmtree(tmp_path)

Expand Down

0 comments on commit 7b1a090

Please sign in to comment.