Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
added fp8 to fp8 training flow tests
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Jun 22, 2024
1 parent 7983b78 commit 0bd148f
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 2 deletions.
15 changes: 13 additions & 2 deletions float8_experimental/float8_dynamic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,20 @@ def forward(self, x):
y = cast_to_float8_e5m2_bw(y, self.backward_config)
return y

def quantize_weight(self, dtype: torch.dtype = torch.float8_e4m3fn) -> None:
"""Used to perform static_quantization, useful for inference where weights are not updated."""
def quantize_weight(self, dtype: torch.dtype = e4m3_dtype) -> None:
"""This functions converts the weight to a Float8Tensor and sets its requires_grad to False.
Args:
dtype: The dtype to quantize the weight to. Default is e4m3_dtype.
Note:
This function is typically called during inference to quantize the weight once since
the weight is not updated during inference.
"""
assert not isinstance(
self.weight, Float8Tensor
), "Weight has already been quantized, cannot quantize again."
scale = tensor_to_scale(self.weight, dtype)
quantized_weight = to_fp8_no_autograd(
self.weight,
Expand Down
72 changes: 72 additions & 0 deletions test/test_inference_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import copy
import io
import random
import unittest

Expand Down Expand Up @@ -123,5 +124,76 @@ def test_static_fp8_mlp(self, compile_backend, dtype):
)


class TestFP8TrainToFP8:
def train(self, model: nn.Module, dtype: torch.dtype):
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()
target_tensor = torch.randn(4, 4096, device="cuda", dtype=dtype)
for _ in range(10):
input_tensor = torch.randn(4, 1024, 4096, device="cuda", dtype=dtype)
optimizer.zero_grad()
output = model(input_tensor)
loss = criterion(output, target_tensor)
loss.backward()
optimizer.step()
model.eval()
return model

@pytest.mark.parametrize("compile_backend", ["eager", "inductor"])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
@unittest.skipIf(
not torch.cuda.is_available() or not is_H100,
"CUDA not available or on non H100 machine",
)
def test_fp8_save_and_load(self, compile_backend: str, dtype: torch.dtype):
# Initialize FP8 model
fp8_mlp = FeedForward().to("cuda", dtype=torch.float32)
fp8_mlp.reset_parameters()
swap_linear_with_float8_linear(
fp8_mlp,
Float8DynamicLinear,
)

# Train the model
self.train(fp8_mlp, dtype)

# Generate input tensor and original out
input_tensor = torch.randn(4, 1024, 4096, device="cuda", dtype=dtype)
og_out = fp8_mlp(input_tensor)

# Save model state dict
buffer = io.BytesIO()
torch.save(fp8_mlp.state_dict(), buffer)

# Reset buffer position to the beginning
buffer.seek(0)

# Later on you load the model, will be w/ Float8DynamicLinear on meta device
with torch.device("meta"):
new_fp8_mlp = FeedForward().to(dtype=dtype)

# Load the actual data
new_fp8_mlp.load_state_dict(torch.load(buffer), strict=True, assign=True)

# Dynamic Activations + Quantized Weights
def quantize_dynamic_linear(x: nn.Module):
if isinstance(x, Float8DynamicLinear):
x.set_quantization_scales(True)
return x

new_fp8_mlp.apply(quantize_dynamic_linear)

for module in new_fp8_mlp.modules():
if isinstance(module, Float8DynamicLinear):
assert isinstance(module.weight, Float8DynamicLinear)
assert module.weight.requires_grad is False

new_out = new_fp8_mlp(input_tensor)

# Assert exact equality
assert torch.all(og_out == new_out).item()


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 0bd148f

Please sign in to comment.