diff --git a/float8_experimental/float8_dynamic_linear.py b/float8_experimental/float8_dynamic_linear.py index 80b0cc9..77b0bb3 100644 --- a/float8_experimental/float8_dynamic_linear.py +++ b/float8_experimental/float8_dynamic_linear.py @@ -74,9 +74,20 @@ def forward(self, x): y = cast_to_float8_e5m2_bw(y, self.backward_config) return y - def quantize_weight(self, dtype: torch.dtype = torch.float8_e4m3fn) -> None: - """Used to perform static_quantization, useful for inference where weights are not updated.""" + def quantize_weight(self, dtype: torch.dtype = e4m3_dtype) -> None: + """This functions converts the weight to a Float8Tensor and sets its requires_grad to False. + Args: + dtype: The dtype to quantize the weight to. Default is e4m3_dtype. + + Note: + This function is typically called during inference to quantize the weight once since + the weight is not updated during inference. + + """ + assert not isinstance( + self.weight, Float8Tensor + ), "Weight has already been quantized, cannot quantize again." scale = tensor_to_scale(self.weight, dtype) quantized_weight = to_fp8_no_autograd( self.weight, diff --git a/test/test_inference_flows.py b/test/test_inference_flows.py index 56856ea..42f5799 100644 --- a/test/test_inference_flows.py +++ b/test/test_inference_flows.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. import copy +import io import random import unittest @@ -123,5 +124,76 @@ def test_static_fp8_mlp(self, compile_backend, dtype): ) +class TestFP8TrainToFP8: + def train(self, model: nn.Module, dtype: torch.dtype): + model.train() + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + criterion = nn.MSELoss() + target_tensor = torch.randn(4, 4096, device="cuda", dtype=dtype) + for _ in range(10): + input_tensor = torch.randn(4, 1024, 4096, device="cuda", dtype=dtype) + optimizer.zero_grad() + output = model(input_tensor) + loss = criterion(output, target_tensor) + loss.backward() + optimizer.step() + model.eval() + return model + + @pytest.mark.parametrize("compile_backend", ["eager", "inductor"]) + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) + @unittest.skipIf( + not torch.cuda.is_available() or not is_H100, + "CUDA not available or on non H100 machine", + ) + def test_fp8_save_and_load(self, compile_backend: str, dtype: torch.dtype): + # Initialize FP8 model + fp8_mlp = FeedForward().to("cuda", dtype=torch.float32) + fp8_mlp.reset_parameters() + swap_linear_with_float8_linear( + fp8_mlp, + Float8DynamicLinear, + ) + + # Train the model + self.train(fp8_mlp, dtype) + + # Generate input tensor and original out + input_tensor = torch.randn(4, 1024, 4096, device="cuda", dtype=dtype) + og_out = fp8_mlp(input_tensor) + + # Save model state dict + buffer = io.BytesIO() + torch.save(fp8_mlp.state_dict(), buffer) + + # Reset buffer position to the beginning + buffer.seek(0) + + # Later on you load the model, will be w/ Float8DynamicLinear on meta device + with torch.device("meta"): + new_fp8_mlp = FeedForward().to(dtype=dtype) + + # Load the actual data + new_fp8_mlp.load_state_dict(torch.load(buffer), strict=True, assign=True) + + # Dynamic Activations + Quantized Weights + def quantize_dynamic_linear(x: nn.Module): + if isinstance(x, Float8DynamicLinear): + x.set_quantization_scales(True) + return x + + new_fp8_mlp.apply(quantize_dynamic_linear) + + for module in new_fp8_mlp.modules(): + if isinstance(module, Float8DynamicLinear): + assert isinstance(module.weight, Float8DynamicLinear) + assert module.weight.requires_grad is False + + new_out = new_fp8_mlp(input_tensor) + + # Assert exact equality + assert torch.all(og_out == new_out).item() + + if __name__ == "__main__": pytest.main([__file__])