diff --git a/src/sparseml/transformers/utils/transformations.py b/src/sparseml/transformers/utils/transformations.py index aceff72346d..a614ee52bf0 100644 --- a/src/sparseml/transformers/utils/transformations.py +++ b/src/sparseml/transformers/utils/transformations.py @@ -234,20 +234,3 @@ def remove_unwanted_tensors(state_dict: Dict[str, Tensor]) -> Dict[str, Tensor]: if any(key.endswith(suffix) for suffix in to_delete): del state_dict[key] return state_dict - - -def check_dicts(actual, expected): - assert len(actual) == len( - expected - ), "The number of tensors in the actual and expected state dicts do not match" - - for key, value in actual.items(): - assert ( - key in expected - ), f"The key {key} is not present in the expected state dict" - assert ( - value.shape == expected[key].shape - ), f"The shape of the tensor {key} in the actual state dict does not match the shape of the tensor in the expected state dict, expected {expected[key].shape} but got {value.shape}" - assert ( - value.dtype == expected[key].dtype - ), f"The dtype of the tensor {key} in the actual state dict does not match the dtype of the tensor in the expected state dict, expected {expected[key].dtype} but got {value.dtype}"