diff --git a/benchmarks/bench_padding.py b/benchmarks/bench_padding.py index e02d2d5..16463d5 100644 --- a/benchmarks/bench_padding.py +++ b/benchmarks/bench_padding.py @@ -62,7 +62,6 @@ def do_fp8_matmul(A, B, fp8_dtype, out_dtype): A, scale_a, fp8_dtype, - None, # amax_buffer a_config, GemmInputRole.INPUT, ) @@ -70,7 +69,6 @@ def do_fp8_matmul(A, B, fp8_dtype, out_dtype): B, scale_b, fp8_dtype, - None, # amax_buffer b_config, GemmInputRole.WEIGHT, ) diff --git a/float8_experimental/float8_scaling_utils.py b/float8_experimental/float8_scaling_utils.py index 81910e2..c1a58fb 100644 --- a/float8_experimental/float8_scaling_utils.py +++ b/float8_experimental/float8_scaling_utils.py @@ -44,7 +44,6 @@ def cast_to_float8_e4m3_dynamic( inpt_tensor, scale, e4m3_dtype, - None, # amax_buffer linear_mm_config, gemm_input_role, ) @@ -59,11 +58,11 @@ def cast_to_float8_delayed( linear_mm_config: Optional[LinearMMConfig] = None, gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, ): + amax_buffer.fill_(tensor_to_amax(tensor)) return ToFloat8ConstrFunc.apply( tensor, scale, float8_dtype, - amax_buffer, linear_mm_config, gemm_input_role, ) diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index fd37482..29a965a 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -207,7 +207,7 @@ def forward( tensor: torch.Tensor, scale: torch.Tensor, float8_dtype=e4m3_dtype, - amax_buffer: Optional[torch.Tensor] = None, + # amax_buffer: Optional[torch.Tensor] = None, linear_mm_config: Optional[LinearMMConfig] = None, gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, ): @@ -219,8 +219,8 @@ def forward( amax_buffer: an Optional buffer buffer to store the amax value in prior to conversion emulate: whether to emulate the matmuls in fp32 """ - if amax_buffer is not None: - amax_buffer.fill_(tensor_to_amax(tensor)) + # if amax_buffer is not None: + # amax_buffer.fill_(tensor_to_amax(tensor)) return to_fp8_no_autograd( tensor, diff --git a/float8_experimental/fsdp_utils.py b/float8_experimental/fsdp_utils.py index ca0812c..1cb4788 100644 --- a/float8_experimental/fsdp_utils.py +++ b/float8_experimental/fsdp_utils.py @@ -10,7 +10,10 @@ import torch import torch.nn as nn import torch.utils._pytree as pytree -from float8_experimental.float8_scaling_utils import cast_to_float8_e4m3_dynamic +from float8_experimental.float8_scaling_utils import ( + cast_to_float8_delayed, + cast_to_float8_e4m3_dynamic, +) from float8_experimental.float8_tensor import ( Float8Tensor, @@ -168,7 +171,6 @@ def fsdp_pre_all_gather(self, mesh): self._tensor, self._precomputed_scale, torch.float8_e4m3fn, - None, # amax_buffer self._linear_mm_config, GemmInputRole.WEIGHT, ) @@ -352,12 +354,7 @@ def fsdp_pre_all_gather(self, mesh): ) self.is_amax_initialized = True - # this will: - # 1. cast the tensor to float8 using `_scale_buffer` - # 2. populate `_amax_buffer` inplace - # TODO(future PR): clean up all the casting functions and clearly - # separate dynamic vs delayed, tech debt has accumulated - float8_tensor = ToFloat8ConstrFunc.apply( + float8_tensor = cast_to_float8_delayed( self._tensor, self._scale_buffer, e4m3_dtype, diff --git a/float8_experimental/inference.py b/float8_experimental/inference.py index 4807ac6..717695f 100644 --- a/float8_experimental/inference.py +++ b/float8_experimental/inference.py @@ -131,7 +131,6 @@ def quantize_weight(self, dtype: torch.dtype = e4m3_dtype) -> None: self.weight, scale, dtype, - None, # amax_buffer self.linear_mm_config, GemmInputRole.WEIGHT, ) @@ -205,7 +204,6 @@ def cast_to_float8_e4m3_inference( inpt_tensor, scale, e4m3_dtype, - None, # amax_buffer linear_mm_config, GemmInputRole.INPUT, ) diff --git a/test/test_base.py b/test/test_base.py index b1f4781..94baec3 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -451,7 +451,6 @@ def test_different_configs_error(self): x_fp32, x_scale, fp8_dtype, - None, # amax_buffer linear_config_a, GemmInputRole.INPUT, ) @@ -459,7 +458,6 @@ def test_different_configs_error(self): x_fp32, x_scale, fp8_dtype, - None, # amax_buffer linear_config_b, GemmInputRole.WEIGHT, ) @@ -489,10 +487,10 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum): b_scale = tensor_to_scale(b, input_dtype).float() a_fp8 = ToFloat8ConstrFunc.apply( - a, a_scale, input_dtype, None, None, GemmInputRole.INPUT + a, a_scale, input_dtype, None, GemmInputRole.INPUT ) b_fp8 = ToFloat8ConstrFunc.apply( - b, b_scale, input_dtype, None, None, GemmInputRole.WEIGHT + b, b_scale, input_dtype, None, GemmInputRole.WEIGHT ) with pytest.raises( @@ -512,7 +510,6 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum): a, a_scale, input_dtype, - None, # amax_buffer pad_config, GemmInputRole.INPUT, ) @@ -520,7 +517,6 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum): b, b_scale, input_dtype, - None, # amax_buffer pad_config, GemmInputRole.WEIGHT, ) @@ -537,7 +533,6 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum): a, a_scale, input_dtype, - None, # amax_buffer emulated_config, GemmInputRole.INPUT, ) @@ -545,7 +540,6 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum): b, b_scale, input_dtype, - None, # amax_buffer emulated_config, GemmInputRole.WEIGHT, ) diff --git a/test/test_compile.py b/test/test_compile.py index 7f6ab68..05bb57a 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -20,8 +20,9 @@ get_float8_layers, sync_float8_amax_and_scale_history, ) -from float8_experimental.float8_tensor import LinearMMConfig, ToFloat8ConstrFunc +from float8_experimental.float8_tensor import LinearMMConfig from float8_experimental.float8_utils import e4m3_dtype +from float8_experimental.float8_scaling_utils import cast_to_float8_delayed from torch._dynamo.test_case import TestCase as DynamoTestCase from torch._dynamo.testing import CompileCounterWithBackend @@ -178,7 +179,7 @@ def __init__(self, graph_break: bool): self.graph_break = graph_break def forward(self, x): - x_fp8 = ToFloat8ConstrFunc.apply( + x_fp8 = cast_to_float8_delayed( x, self.fp8_scale_x, e4m3_dtype, diff --git a/test/test_dtensor.py b/test/test_dtensor.py index cfa0445..4d56a0d 100644 --- a/test/test_dtensor.py +++ b/test/test_dtensor.py @@ -88,10 +88,10 @@ def test_scaled_mm(mesh: DeviceMesh, size=16): y_scale = tensor_to_scale(y_fp32, fp8_dtype).float() x_fp8 = ToFloat8ConstrFunc.apply( - x_fp32, x_scale, fp8_dtype, None, None, GemmInputRole.INPUT + x_fp32, x_scale, fp8_dtype, None, GemmInputRole.INPUT ) y_fp8 = ToFloat8ConstrFunc.apply( - y_fp32, y_scale, fp8_dtype, None, None, GemmInputRole.WEIGHT + y_fp32, y_scale, fp8_dtype, None, GemmInputRole.WEIGHT ) dist_x_fp8 = DTensor.from_local(x_fp8, mesh, [lhs_placement], run_check=False) @@ -169,7 +169,6 @@ def test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16): dist_x_scale, fp8_dtype, None, - None, GemmInputRole.INPUT, ) dist_weight_fp8 = ToFloat8ConstrFunc.apply( @@ -177,7 +176,6 @@ def test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16): dist_weight_scale, fp8_dtype, None, - None, GemmInputRole.WEIGHT, )