Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
[2/x] clean up casting functions: delayed scaling
Browse files Browse the repository at this point in the history
Summary:

Removes delayed scaling from `float8_tensor.py`. After this PR, the
invariant is that everything in `float8_tensor.py` requires the scale to
be calculated elsewhere. This moves the codebase towards separation of
concerns for calculating the scale (via various scaling strategies),
separated from creating an instance of `Float8Tensor`.

Note that stateful delayed scaling is the reason we need this separation.

Test Plan:

```
./test/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: ad12f968339876f7e692a7a3ee84bfb765ccc463
Pull Request resolved: #340
  • Loading branch information
vkuzo committed Jul 26, 2024
1 parent 96162b3 commit 7d40bc6
Show file tree
Hide file tree
Showing 8 changed files with 16 additions and 31 deletions.
2 changes: 0 additions & 2 deletions benchmarks/bench_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,13 @@ def do_fp8_matmul(A, B, fp8_dtype, out_dtype):
A,
scale_a,
fp8_dtype,
None, # amax_buffer
a_config,
GemmInputRole.INPUT,
)
b_fp8 = ToFloat8ConstrFunc.apply(
B,
scale_b,
fp8_dtype,
None, # amax_buffer
b_config,
GemmInputRole.WEIGHT,
)
Expand Down
3 changes: 1 addition & 2 deletions float8_experimental/float8_scaling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def cast_to_float8_e4m3_dynamic(
inpt_tensor,
scale,
e4m3_dtype,
None, # amax_buffer
linear_mm_config,
gemm_input_role,
)
Expand All @@ -59,11 +58,11 @@ def cast_to_float8_delayed(
linear_mm_config: Optional[LinearMMConfig] = None,
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT,
):
amax_buffer.fill_(tensor_to_amax(tensor))
return ToFloat8ConstrFunc.apply(
tensor,
scale,
float8_dtype,
amax_buffer,
linear_mm_config,
gemm_input_role,
)
Expand Down
6 changes: 3 additions & 3 deletions float8_experimental/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def forward(
tensor: torch.Tensor,
scale: torch.Tensor,
float8_dtype=e4m3_dtype,
amax_buffer: Optional[torch.Tensor] = None,
# amax_buffer: Optional[torch.Tensor] = None,
linear_mm_config: Optional[LinearMMConfig] = None,
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT,
):
Expand All @@ -219,8 +219,8 @@ def forward(
amax_buffer: an Optional buffer buffer to store the amax value in prior to conversion
emulate: whether to emulate the matmuls in fp32
"""
if amax_buffer is not None:
amax_buffer.fill_(tensor_to_amax(tensor))
# if amax_buffer is not None:
# amax_buffer.fill_(tensor_to_amax(tensor))

return to_fp8_no_autograd(
tensor,
Expand Down
13 changes: 5 additions & 8 deletions float8_experimental/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
import torch
import torch.nn as nn
import torch.utils._pytree as pytree
from float8_experimental.float8_scaling_utils import cast_to_float8_e4m3_dynamic
from float8_experimental.float8_scaling_utils import (
cast_to_float8_delayed,
cast_to_float8_e4m3_dynamic,
)

from float8_experimental.float8_tensor import (
Float8Tensor,
Expand Down Expand Up @@ -168,7 +171,6 @@ def fsdp_pre_all_gather(self, mesh):
self._tensor,
self._precomputed_scale,
torch.float8_e4m3fn,
None, # amax_buffer
self._linear_mm_config,
GemmInputRole.WEIGHT,
)
Expand Down Expand Up @@ -352,12 +354,7 @@ def fsdp_pre_all_gather(self, mesh):
)
self.is_amax_initialized = True

# this will:
# 1. cast the tensor to float8 using `_scale_buffer`
# 2. populate `_amax_buffer` inplace
# TODO(future PR): clean up all the casting functions and clearly
# separate dynamic vs delayed, tech debt has accumulated
float8_tensor = ToFloat8ConstrFunc.apply(
float8_tensor = cast_to_float8_delayed(
self._tensor,
self._scale_buffer,
e4m3_dtype,
Expand Down
2 changes: 0 additions & 2 deletions float8_experimental/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ def quantize_weight(self, dtype: torch.dtype = e4m3_dtype) -> None:
self.weight,
scale,
dtype,
None, # amax_buffer
self.linear_mm_config,
GemmInputRole.WEIGHT,
)
Expand Down Expand Up @@ -205,7 +204,6 @@ def cast_to_float8_e4m3_inference(
inpt_tensor,
scale,
e4m3_dtype,
None, # amax_buffer
linear_mm_config,
GemmInputRole.INPUT,
)
Expand Down
10 changes: 2 additions & 8 deletions test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,15 +451,13 @@ def test_different_configs_error(self):
x_fp32,
x_scale,
fp8_dtype,
None, # amax_buffer
linear_config_a,
GemmInputRole.INPUT,
)
b = ToFloat8ConstrFunc.apply(
x_fp32,
x_scale,
fp8_dtype,
None, # amax_buffer
linear_config_b,
GemmInputRole.WEIGHT,
)
Expand Down Expand Up @@ -489,10 +487,10 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum):
b_scale = tensor_to_scale(b, input_dtype).float()

a_fp8 = ToFloat8ConstrFunc.apply(
a, a_scale, input_dtype, None, None, GemmInputRole.INPUT
a, a_scale, input_dtype, None, GemmInputRole.INPUT
)
b_fp8 = ToFloat8ConstrFunc.apply(
b, b_scale, input_dtype, None, None, GemmInputRole.WEIGHT
b, b_scale, input_dtype, None, GemmInputRole.WEIGHT
)

with pytest.raises(
Expand All @@ -512,15 +510,13 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum):
a,
a_scale,
input_dtype,
None, # amax_buffer
pad_config,
GemmInputRole.INPUT,
)
b_fp8 = ToFloat8ConstrFunc.apply(
b,
b_scale,
input_dtype,
None, # amax_buffer
pad_config,
GemmInputRole.WEIGHT,
)
Expand All @@ -537,15 +533,13 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum):
a,
a_scale,
input_dtype,
None, # amax_buffer
emulated_config,
GemmInputRole.INPUT,
)
b_fp8 = ToFloat8ConstrFunc.apply(
b,
b_scale,
input_dtype,
None, # amax_buffer
emulated_config,
GemmInputRole.WEIGHT,
)
Expand Down
5 changes: 3 additions & 2 deletions test/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
get_float8_layers,
sync_float8_amax_and_scale_history,
)
from float8_experimental.float8_tensor import LinearMMConfig, ToFloat8ConstrFunc
from float8_experimental.float8_tensor import LinearMMConfig
from float8_experimental.float8_utils import e4m3_dtype
from float8_experimental.float8_scaling_utils import cast_to_float8_delayed

from torch._dynamo.test_case import TestCase as DynamoTestCase
from torch._dynamo.testing import CompileCounterWithBackend
Expand Down Expand Up @@ -178,7 +179,7 @@ def __init__(self, graph_break: bool):
self.graph_break = graph_break

def forward(self, x):
x_fp8 = ToFloat8ConstrFunc.apply(
x_fp8 = cast_to_float8_delayed(
x,
self.fp8_scale_x,
e4m3_dtype,
Expand Down
6 changes: 2 additions & 4 deletions test/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,10 @@ def test_scaled_mm(mesh: DeviceMesh, size=16):
y_scale = tensor_to_scale(y_fp32, fp8_dtype).float()

x_fp8 = ToFloat8ConstrFunc.apply(
x_fp32, x_scale, fp8_dtype, None, None, GemmInputRole.INPUT
x_fp32, x_scale, fp8_dtype, None, GemmInputRole.INPUT
)
y_fp8 = ToFloat8ConstrFunc.apply(
y_fp32, y_scale, fp8_dtype, None, None, GemmInputRole.WEIGHT
y_fp32, y_scale, fp8_dtype, None, GemmInputRole.WEIGHT
)

dist_x_fp8 = DTensor.from_local(x_fp8, mesh, [lhs_placement], run_check=False)
Expand Down Expand Up @@ -169,15 +169,13 @@ def test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16):
dist_x_scale,
fp8_dtype,
None,
None,
GemmInputRole.INPUT,
)
dist_weight_fp8 = ToFloat8ConstrFunc.apply(
dist_wight_fp32,
dist_weight_scale,
fp8_dtype,
None,
None,
GemmInputRole.WEIGHT,
)

Expand Down

0 comments on commit 7d40bc6

Please sign in to comment.