From c58fb5d6ac768f2213e7f65123cfff779bef9d87 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Fri, 19 Jul 2024 15:26:01 -0700 Subject: [PATCH] make all 3 gemms in Float8Linear support configurability, not user facing (#315) Summary: Pull Request resolved: https://github.com/pytorch-labs/float8_experimental/pull/315 This PR adds some plumbing for how to eventually make all 3 gemms in a linear fwd/bwd configurable: 1. add `LinearMMConfig` to `Float8Tensor` to tie together the three `ScaledMMConfig` objects, one per gemm 2. add `GemmInputRole` to `Float8Tensor` to specify how to pick the right config 3. plumb all of these throughout the codebase Note that none of this is user facing, and there is no logic change. Planned follow-ups: * a future PR will make the per-gemm behavior configurable in a user facing way, which will hook up to the objects introduced in this PR * a future PR will update the naming from x/w/dL_dY to input/weight/grad_output throughout the codebase Reviewed By: drisspg Differential Revision: D59973551 fbshipit-source-id: c667245449628b377e9bb20dda6a76fbf8a5ef3c --- float8_experimental/__init__.py | 9 +- float8_experimental/float8_dynamic_utils.py | 30 +++- float8_experimental/float8_linear.py | 50 ++++-- float8_experimental/float8_ops.py | 103 +++++++---- float8_experimental/float8_tensor.py | 160 ++++++++++++++---- float8_experimental/float8_tensor_parallel.py | 31 ++-- float8_experimental/fsdp_utils.py | 69 +++++--- float8_experimental/inference.py | 28 ++- test/test_base.py | 105 ++++++++---- test/test_compile.py | 8 +- test/test_dtensor.py | 22 ++- 11 files changed, 430 insertions(+), 185 deletions(-) diff --git a/float8_experimental/__init__.py b/float8_experimental/__init__.py index 8822796..170d05f 100644 --- a/float8_experimental/__init__.py +++ b/float8_experimental/__init__.py @@ -5,11 +5,16 @@ # LICENSE file in the root directory of this source tree. # Lets define a few top level things here from float8_experimental.float8_linear import Float8Linear -from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig +from float8_experimental.float8_tensor import ( + Float8Tensor, + GemmInputRole, + LinearMMConfig, + ScaledMMConfig, +) # Needed to load Float8Tensor with weights_only = True from torch.serialization import add_safe_globals -add_safe_globals([Float8Tensor, ScaledMMConfig]) +add_safe_globals([Float8Tensor, ScaledMMConfig, GemmInputRole, LinearMMConfig]) __all__ = ["Float8Tensor", "Float8Linear"] diff --git a/float8_experimental/float8_dynamic_utils.py b/float8_experimental/float8_dynamic_utils.py index 215a394..9fe9d17 100644 --- a/float8_experimental/float8_dynamic_utils.py +++ b/float8_experimental/float8_dynamic_utils.py @@ -8,7 +8,8 @@ from float8_experimental.float8_tensor import ( Float8Tensor, - ScaledMMConfig, + GemmInputRole, + LinearMMConfig, tensor_already_casted_to_fp8, to_fp8_no_autograd, ) @@ -26,9 +27,9 @@ class NoopFwToFloat8E5M2Bw(torch.autograd.Function): def forward( ctx, tensor, - mm_config: ScaledMMConfig, + linear_mm_config: LinearMMConfig, ): - ctx.mm_config = mm_config + ctx.linear_mm_config = linear_mm_config return tensor @staticmethod @@ -37,21 +38,34 @@ def backward(ctx, gradY): return gradY, None gradY_scale = tensor_to_scale(gradY, e5m2_dtype) fp8_tensor = to_fp8_no_autograd( - gradY, gradY_scale, e5m2_dtype, mm_config=ctx.mm_config + gradY, + gradY_scale, + e5m2_dtype, + linear_mm_config=ctx.linear_mm_config, + gemm_input_role=GemmInputRole.DL_DY, ) return fp8_tensor, None def cast_to_float8_e4m3_dynamic( - inpt_tensor: torch.Tensor, mm_config: ScaledMMConfig, reduce_amax: bool = False + inpt_tensor: torch.Tensor, + linear_mm_config: LinearMMConfig, + reduce_amax: bool = False, + gemm_input_role: GemmInputRole = GemmInputRole.X, ) -> Float8Tensor: if tensor_already_casted_to_fp8(inpt_tensor): return inpt_tensor scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax) - return Float8Tensor.to_float8(inpt_tensor, scale, e4m3_dtype, mm_config=mm_config) + return Float8Tensor.to_float8( + inpt_tensor, + scale, + e4m3_dtype, + linear_mm_config=linear_mm_config, + gemm_input_role=gemm_input_role, + ) def cast_to_float8_e5m2_dynamic_bw( - gradY: torch.Tensor, mm_config: ScaledMMConfig + gradY: torch.Tensor, linear_mm_config: LinearMMConfig ) -> torch.Tensor: - return NoopFwToFloat8E5M2Bw.apply(gradY, mm_config) + return NoopFwToFloat8E5M2Bw.apply(gradY, linear_mm_config) diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 7850738..37cf0d5 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -23,6 +23,8 @@ from float8_experimental.float8_tensor import ( Float8Tensor, + GemmInputRole, + LinearMMConfig, ScaledMMConfig, to_fp8_no_autograd, ) @@ -85,12 +87,12 @@ def forward( fp8_scale_dL_dY, scale_fn_name, is_amax_initialized, - mm_config: ScaledMMConfig, + linear_mm_config: LinearMMConfig, ): ctx.save_for_backward(fp8_amax_dL_dY, fp8_amax_history_dL_dY, fp8_scale_dL_dY) ctx.scale_fn_name = scale_fn_name ctx.is_amax_initialized = is_amax_initialized - ctx.mm_config = mm_config + ctx.linear_mm_config = linear_mm_config return tensor @staticmethod @@ -113,7 +115,11 @@ def backward(ctx, go): fp8_amax_dL_dY.fill_(tensor_to_amax(go)) res = to_fp8_no_autograd( - go, fp8_scale_dL_dY, e5m2_dtype, mm_config=ctx.mm_config + go, + fp8_scale_dL_dY, + e5m2_dtype, + linear_mm_config=ctx.linear_mm_config, + gemm_input_role=GemmInputRole.DL_DY, ) empty_grads = None, None, None, None, None, None return res, *empty_grads @@ -192,12 +198,18 @@ def __init__(self, *args, **kwargs): self.create_buffers() - # Defines the behavior of the matmul in the forward and backward pass - self.forward_config = ScaledMMConfig( - emulate, True if not emulate else False, False, config.pad_inner_dim - ) - self.backward_config = ScaledMMConfig( - emulate, False, False, config.pad_inner_dim + # TODO(future): user level configuration of gemms + self.linear_mm_config = LinearMMConfig( + # x + ScaledMMConfig( + emulate, True if not emulate else False, False, config.pad_inner_dim + ), + # w + ScaledMMConfig( + emulate, True if not emulate else False, False, config.pad_inner_dim + ), + # dL_dY + ScaledMMConfig(emulate, False, False, config.pad_inner_dim), ) # Note: is_amax_initialized is not a buffer to avoid data dependent @@ -308,11 +320,12 @@ def cast_x_to_float8( self.fp8_scale_x, e4m3_dtype, self.fp8_amax_x, - self.forward_config, + linear_mm_config=self.linear_mm_config, + gemm_input_role=GemmInputRole.X, ) else: assert self.scaling_type_x is TensorScalingType.DYNAMIC - x_fp8 = cast_to_float8_e4m3_dynamic(x, self.forward_config) + x_fp8 = cast_to_float8_e4m3_dynamic(x, self.linear_mm_config) return x_fp8 def cast_w_to_float8( @@ -339,14 +352,17 @@ def cast_w_to_float8( self.fp8_scale_w, e4m3_dtype, self.fp8_amax_w, - self.forward_config, + linear_mm_config=self.linear_mm_config, + gemm_input_role=GemmInputRole.W, ) else: assert self.scaling_type_w is TensorScalingType.DYNAMIC if isinstance(self.weight, Float8Tensor): # cast by FSDP w_fp8 = self.weight else: - w_fp8 = cast_to_float8_e4m3_dynamic(self.weight, self.forward_config) + w_fp8 = cast_to_float8_e4m3_dynamic( + self.weight, self.linear_mm_config, gemm_input_role=GemmInputRole.W + ) return w_fp8 def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor: @@ -359,11 +375,11 @@ def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor: self.fp8_scale_dL_dY, scale_fn_name, self.is_amax_initialized, - self.backward_config, + self.linear_mm_config, ) else: assert self.scaling_type_dL_dY is TensorScalingType.DYNAMIC - y = cast_to_float8_e5m2_dynamic_bw(y, self.backward_config) + y = cast_to_float8_e5m2_dynamic_bw(y, self.linear_mm_config) return y def float8_pre_forward(self, x): @@ -457,7 +473,7 @@ def from_float( new_mod.weight = torch.nn.Parameter( WeightWithDynamicFloat8CastTensor( new_mod.weight, - new_mod.forward_config, + new_mod.linear_mm_config, ) ) else: @@ -468,7 +484,7 @@ def from_float( new_mod.fp8_amax_w, new_mod.fp8_amax_history_w, new_mod.fp8_scale_w, - new_mod.forward_config, + new_mod.linear_mm_config, new_mod.is_amax_initialized, ) ) diff --git a/float8_experimental/float8_ops.py b/float8_experimental/float8_ops.py index 3a50cc8..2a11726 100644 --- a/float8_experimental/float8_ops.py +++ b/float8_experimental/float8_ops.py @@ -8,11 +8,7 @@ import torch from float8_experimental.float8_python_api import addmm_float8_unwrapped -from float8_experimental.float8_tensor import ( - Float8Tensor, - merge_mm_configs, - ScaledMMConfig, -) +from float8_experimental.float8_tensor import choose_scaled_mm_config, Float8Tensor from float8_experimental.float8_utils import is_row_major, pad_tensor_for_matmul from torch.utils._pytree import tree_map @@ -50,7 +46,11 @@ def decorator(func): def float8_desugar_op(aten_op, args, kwargs=None): new_data = aten_op(args[0]._data, *args[1:], **kwargs) return Float8Tensor( - new_data, args[0]._scale, args[0]._orig_dtype, args[0]._mm_config + new_data, + args[0]._scale, + args[0]._orig_dtype, + args[0]._linear_mm_config, + args[0]._gemm_input_role, ) @@ -60,7 +60,11 @@ def float8_split(aten_op, args, kwargs=None): def make_float8(data): return Float8Tensor( - data, args[0]._scale, args[0]._orig_dtype, args[0]._mm_config + data, + args[0]._scale, + args[0]._orig_dtype, + args[0]._linear_mm_config, + args[0]._gemm_input_role, ) out = map(make_float8, new_data_tensors) @@ -74,8 +78,9 @@ def float8_cat(aten_op, args, kwargs=None): orig_dtype = chunked_tensors[0]._orig_dtype scale = chunked_tensors[0]._scale - mm_config = chunked_tensors[0]._mm_config + mm_config = chunked_tensors[0]._linear_mm_config fp8_dtype = chunked_tensors[0]._data.dtype + gemm_input_role = chunked_tensors[0]._gemm_input_role chunk_data = [] for chunk in chunked_tensors: assert isinstance( @@ -88,16 +93,19 @@ def float8_cat(aten_op, args, kwargs=None): chunk._scale is scale ), "Expecting all chunks to have thee same scale as a result of a split" assert ( - chunk._mm_config is mm_config + chunk._linear_mm_config is mm_config ), "Expecting all chunks to have thee same mm config as a result of a split" assert ( chunk._data.dtype == fp8_dtype ), "Expecting all chunks to be of the same dtype as a result of a split" + assert ( + chunk._gemm_input_role is gemm_input_role + ), "Expecting all chunks to have the same gemm_input_role as a result of a split" chunk_data.append(chunk._data.view(torch.uint8)) new_data = aten_op(chunk_data, *args[1:], **kwargs) new_data = new_data.view(fp8_dtype) - return Float8Tensor(new_data, scale, orig_dtype, mm_config) + return Float8Tensor(new_data, scale, orig_dtype, mm_config, gemm_input_role) @implements([aten.sum.dim_IntList]) @@ -125,10 +133,14 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor): a_scale = a._scale b_data = b._data - if a._mm_config.pad_inner_dim: - assert ( - b._mm_config.pad_inner_dim - ), "Both mm configs must have pad_inner_dim set to True" + scaled_mm_config = choose_scaled_mm_config( + a._gemm_input_role, + a._linear_mm_config, + b._gemm_input_role, + b._linear_mm_config, + ) + + if scaled_mm_config.pad_inner_dim: assert a._data.size(1) == b._data.size( 0 ), f"Inner dims must match for mm, got {a._data.size(1)} and {b._data.size(0)}" @@ -155,10 +167,13 @@ def float8_mm(aten_op, args, kwargs=None): ) a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b) output_dtype = a._orig_dtype - a_mm_config: ScaledMMConfig = a._mm_config - b_mm_config: ScaledMMConfig = b._mm_config - mm_config: ScaledMMConfig = merge_mm_configs(a_mm_config, b_mm_config) - if mm_config.emulate: + scaled_mm_config = choose_scaled_mm_config( + a._gemm_input_role, + a._linear_mm_config, + b._gemm_input_role, + b._linear_mm_config, + ) + if scaled_mm_config.emulate: return torch.ops.aten.mm_float8_emulated( a._data, a._scale, b._data, b._scale, output_dtype ) @@ -170,7 +185,7 @@ def float8_mm(aten_op, args, kwargs=None): output_dtype, output_scale=None, bias=None, - use_fast_accum=mm_config.use_fast_accum, + use_fast_accum=scaled_mm_config.use_fast_accum, ) return tensor_out @@ -188,10 +203,13 @@ def float8_addmm(aten_op, args, kwargs=None): a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b) output_dtype = a._orig_dtype assert bias.dtype == output_dtype, "bias dtype must match output dtype" - a_mm_config: ScaledMMConfig = a._mm_config - b_mm_config: ScaledMMConfig = b._mm_config - mm_config: ScaledMMConfig = merge_mm_configs(a_mm_config, b_mm_config) - if mm_config.emulate: + scaled_mm_config = choose_scaled_mm_config( + a._gemm_input_role, + a._linear_mm_config, + b._gemm_input_role, + b._linear_mm_config, + ) + if scaled_mm_config.emulate: out = torch.ops.aten.mm_float8_emulated( a._data, a._scale, b._data, b._scale, output_dtype ) @@ -204,7 +222,7 @@ def float8_addmm(aten_op, args, kwargs=None): output_dtype, output_scale=None, bias=bias, - use_fast_accum=mm_config.use_fast_accum, + use_fast_accum=scaled_mm_config.use_fast_accum, ) return tensor_out @@ -229,7 +247,11 @@ def autocast_to_copy(aten_op, args, kwargs=None): torch.bfloat16, }, "Only support floating point conversion for autocast w/ Float8Tensor" return Float8Tensor( - args[0]._data, args[0]._scale, kwargs["dtype"], args[0]._mm_config + args[0]._data, + args[0]._scale, + kwargs["dtype"], + args[0]._linear_mm_config, + args[0]._gemm_input_role, ) @@ -252,7 +274,11 @@ def allgather_fp8(aten_op, args, kwargs=None): fp8_data = fp8_data.contiguous() fp8_out = aten_op(fp8_data, *args[1:], **kwargs) return Float8Tensor( - fp8_out, fp8_input._scale, fp8_input._orig_dtype, fp8_input._mm_config + fp8_out, + fp8_input._scale, + fp8_input._orig_dtype, + fp8_input._linear_mm_config, + fp8_input._gemm_input_role, ) @@ -264,7 +290,11 @@ def wait_tensor_fp8(aten_op, args, kwargs=None): fp8_data = fp8_input._data fp8_out = aten_op(fp8_data, *args[1:], **kwargs) return Float8Tensor( - fp8_out, fp8_input._scale, fp8_input._orig_dtype, fp8_input._mm_config + fp8_out, + fp8_input._scale, + fp8_input._orig_dtype, + fp8_input._linear_mm_config, + fp8_input._gemm_input_role, ) @@ -282,7 +312,11 @@ def index_put_fp8(aten_op, args, kwargs=None): fp8_values_data = fp8_values._data fp8_out = aten_op(fp8_data, args[1], fp8_values_data, *args[3:], **kwargs) return Float8Tensor( - fp8_out, fp8_self._scale, fp8_self._orig_dtype, fp8_self._mm_config + fp8_out, + fp8_self._scale, + fp8_self._orig_dtype, + fp8_self._linear_mm_config, + fp8_self._gemm_input_role, ) @@ -309,12 +343,21 @@ def copy_fp8(aten_op, args, kwargs=None): self._scale == src._scale ), "Expecting both Float8Tensors to have thee same scale" assert ( - self._mm_config == src._mm_config + self._linear_mm_config == src._linear_mm_config ), "Expecting both Float8Tensors to have thee same mm config" assert ( self._data.dtype == src._data.dtype ), "Expecting both Float8Tensors to be of the same dtypet" + assert ( + self._gemm_input_role == src._gemm_input_role + ), "Expecting both Float8Tensors to have the same gemm_input_role" fp8_out = aten_op(self._data, src._data, *args[2:], **kwargs) - return Float8Tensor(fp8_out, self._scale, self._orig_dtype, self._mm_config) + return Float8Tensor( + fp8_out, + self._scale, + self._orig_dtype, + self._linear_mm_config, + self._gemm_input_role, + ) else: raise RuntimeError("Unsupported semantics for copy_ in Float8Tensor") diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index 26d4688..475a17a 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import enum from collections import namedtuple from typing import Dict, Optional @@ -18,6 +19,31 @@ aten = torch.ops.aten +# +# A note on configuration of float8 logic in a linear +# TODO(future): move all the configs to separate file +# +# There are three gemms in a forward + backward of a Linear layer: +# +# 1. x @ w_t = y (forward pass) +# 2. dL_dY @ w = dL_dX (backward pass) +# 3. x_t @ dL_dY = dL_dW (backward pass) +# +# In the formulas above, there are: +# A. six input tensors (x, x_t, w, w_t, dL_dY, dL_dY_t). +# - Note that dL_dY_t is implied because of memory format requirements +# of float8 gemms +# B. three output tensors (y, dL_dX, dL_dW) +# +# We want each input tensor, gemm, and output tensor to be configurable. +# The state of this configuration today is: +# +# i. pairs of input tensors (non-t and t variants) have their scaling +# configurable via the scaling_type_{x_w_dL_dY} arguments to Float8Linear +# ii. each gemm + output is configurable via ScaledMMConfig, which is not user facing +# iii. LinearMMConfig is a container for the three ScaledMMConfig objects needed +# to configure all three gemms, also not user facing + # ScaledMMConfig is a namedtuple that defines the configuration for the scaled_mm in the forward and backward pass. # emulate: whether to emulate the matmuls in fp32 @@ -30,27 +56,58 @@ defaults=[False, False, False, False], ) +# The object below is not user facing and exists for convenience, +# to allow Float8Tensor to use +# the right config based on which gemm from `y`, `dL_dX`, `dL_dW` is +# being called. +LinearMMConfig = namedtuple( + "LinearMMConfig", + ["y", "dL_dX", "dL_dW"], + defaults=[ + ScaledMMConfig(False, True, False, False), + ScaledMMConfig(False, False, False, False), + ScaledMMConfig(False, False, False, False), + ], +) -def merge_mm_configs( - a_mm_config: ScaledMMConfig, b_mm_config: ScaledMMConfig -) -> ScaledMMConfig: - """Merges two mm_configs together emulate behavior must match, - However we want to use_fast_accum in forward and not in backward. - We do this by populating the fields of the backproping grad. Same applies for fp8_output. - For both use_fast_accum and fp8_output, if either config is False, the merged config will be False. +class GemmInputRole(enum.Enum): """ - assert ( - a_mm_config.emulate == b_mm_config.emulate - ), "Both mm_configs must have the same emulate value, but got {} and {}".format( - a_mm_config.emulate, b_mm_config.emulate - ) - return ScaledMMConfig( - emulate=a_mm_config.emulate, - use_fast_accum=a_mm_config.use_fast_accum and b_mm_config.use_fast_accum, - fp8_output=a_mm_config.fp8_output and b_mm_config.fp8_output, - pad_inner_dim=a_mm_config.pad_inner_dim and b_mm_config.pad_inner_dim, - ) + Given a Float8Tensor, the enum below describes the expected role of this + tensor in the three gemms present in the fw + bw pass of a Linear layer. + This is used to choose the right config for a float8 gemm when the + gemm is performed. + """ + + X = "x" + W = "w" + DL_DY = "dL_dY" + + +# choose which scaled_mm_config to use based on gemm inputs +def choose_scaled_mm_config( + a_role: GemmInputRole, + a_linear_mm_config: LinearMMConfig, + b_role: GemmInputRole, + b_linear_mm_config: LinearMMConfig, +): + if a_role is GemmInputRole.X and b_role is GemmInputRole.W: + assert ( + a_linear_mm_config.y == b_linear_mm_config.y + ), f"linear_mm_config.y mismatch: {a_linear_mm_config.y} vs {b_linear_mm_config.y}" + return a_linear_mm_config.y + elif a_role is GemmInputRole.DL_DY and b_role is GemmInputRole.W: + assert ( + a_linear_mm_config.dL_dX == b_linear_mm_config.dL_dX + ), f"linear_mm_config.dL_dX mismatch: {a_linear_mm_config.dL_dX} vs {b_linear_mm_config.dL_dX}" + return a_linear_mm_config.dL_dX + elif a_role is GemmInputRole.DL_DY and b_role is GemmInputRole.X: + assert ( + a_linear_mm_config.dL_dW == b_linear_mm_config.dL_dW + ), f"linear_mm_config.dL_dW mismatch: {a_linear_mm_config.dL_dW} vs {b_linear_mm_config.dL_dW}" + return a_linear_mm_config.dL_dW + else: + raise AssertionError(f"unexpected a_role {a_role} and b_role {b_role}") def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool: @@ -72,7 +129,8 @@ def to_fp8_no_autograd( x: torch.Tensor, x_scale: torch.Tensor, float8_dtype: torch.dtype, - mm_config: Optional[ScaledMMConfig], + linear_mm_config: Optional[LinearMMConfig], + gemm_input_role: Optional[GemmInputRole], ) -> "Float8Tensor": """Convert a tensor to float8 without autograd This is used in multiple places in the codebase to convert a tensor to float8 @@ -90,7 +148,10 @@ def to_fp8_no_autograd( x: the tensor to convert scale: the scale to use to convert the tensor float8_dtype: the float8 dtype to use - mm_config: Defines the configuration for the scaled_mm + linear_mm_config: Defines the configuration for the scaled_mm for + the 3 fwd/bwd gemms of linear + gemm_input_role: Defines the role of this tensor (x, w or dL_dY) in + the 3 fwd/bwd gemms of linear """ x_scaled = x * x_scale bits_fp8 = to_fp8_saturated(x_scaled, float8_dtype) @@ -104,7 +165,11 @@ def to_fp8_no_autograd( local_bits = bits_fp8.to_local() local_scale = x_scale.to_local() inner_float8_tensor = Float8Tensor( - local_bits, local_scale, x.dtype, mm_config=mm_config + local_bits, + local_scale, + x.dtype, + linear_mm_config=linear_mm_config, + gemm_input_role=gemm_input_role, ) return DTensor.from_local( inner_float8_tensor, @@ -115,7 +180,13 @@ def to_fp8_no_autograd( stride=bits_fp8.stride(), ) - return Float8Tensor(bits_fp8, x_scale, x.dtype, mm_config=mm_config) + return Float8Tensor( + bits_fp8, + x_scale, + x.dtype, + linear_mm_config=linear_mm_config, + gemm_input_role=gemm_input_role, + ) @torch._dynamo.allow_in_graph @@ -133,7 +204,8 @@ def forward( scale: torch.Tensor, float8_dtype=e4m3_dtype, amax_buffer: Optional[torch.Tensor] = None, - mm_config: Optional[ScaledMMConfig] = None, + linear_mm_config: Optional[LinearMMConfig] = None, + gemm_input_role: Optional[GemmInputRole] = GemmInputRole.X, ): """Autograd enabled wrapper around to_fp8_no_autograd that will also populate the amax buffer. Args @@ -146,11 +218,17 @@ def forward( if amax_buffer is not None: amax_buffer.fill_(tensor_to_amax(tensor)) - return to_fp8_no_autograd(tensor, scale, float8_dtype, mm_config=mm_config) + return to_fp8_no_autograd( + tensor, + scale, + float8_dtype, + linear_mm_config=linear_mm_config, + gemm_input_role=gemm_input_role, + ) @staticmethod def backward(ctx, g): - return g, None, None, None, None + return g, None, None, None, None, None @torch._dynamo.allow_in_graph @@ -194,15 +272,16 @@ class Float8Tensor(torch.Tensor): _data: torch.Tensor _scale: torch.Tensor _orig_dtype: torch.dtype - _mm_config: ScaledMMConfig - __slots__ = ["_data", "_scale", "_orig_dtype", "_mm_config"] + _linear_mm_config: LinearMMConfig + __slots__ = ["_data", "_scale", "_orig_dtype", "_linear_mm_config"] def __new__( cls, data: torch.Tensor, scale: torch.Tensor, orig_dtype: torch.dtype, - mm_config: Optional[ScaledMMConfig], + linear_mm_config: Optional[LinearMMConfig], + gemm_input_role: Optional[GemmInputRole] = GemmInputRole.X, ): assert ( scale.numel() == 1 @@ -223,17 +302,21 @@ def __new__( self._data = data self._scale = scale self._orig_dtype = orig_dtype - self._mm_config = mm_config if mm_config is not None else ScaledMMConfig() + self._linear_mm_config = ( + linear_mm_config if linear_mm_config is not None else LinearMMConfig() + ) + self._gemm_input_role = gemm_input_role return self def __repr__(self): - return f"Float8Tensor(dtype={self._data.dtype}, scale={self._scale}, mm_config={self._mm_config}\nas_orig_prec={self.to_original_precision()}" + return f"Float8Tensor(dtype={self._data.dtype}, scale={self._scale}, linear_mm_config={self._linear_mm_config}\ngemm_input_role={self._gemm_input_role}\nas_orig_prec={self.to_original_precision()}" def __tensor_flatten__(self): ctx = { "_orig_dtype": self._orig_dtype, - "_mm_config": self._mm_config, + "_linear_mm_config": self._linear_mm_config, + "_gemm_input_role": self._gemm_input_role, } return ["_data", "_scale"], ctx @@ -244,7 +327,8 @@ def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride inner_tensors["_data"], inner_tensors["_scale"], metadata["_orig_dtype"], - metadata["_mm_config"], + metadata["_linear_mm_config"], + metadata["_gemm_input_role"], ) def to_original_precision(self): @@ -257,7 +341,8 @@ def to_float8( scale: torch.Tensor, float8_dtype: torch.dtype, amax_buffer: Optional[torch.Tensor] = None, - mm_config: Optional[ScaledMMConfig] = None, + linear_mm_config: Optional[LinearMMConfig] = None, + gemm_input_role: Optional[GemmInputRole] = GemmInputRole.X, ): """Converts a higher precision tensor to float8 in a differentiable way. @@ -266,13 +351,18 @@ def to_float8( scale: the scale to use to convert the tensor float8_dtype: the float8 dtype to use amax_buffer: a buffer to store the amax value in prior to conversion - mm_config: Defines the configuration for the scaled_mm + linearmm_config: Defines the configuration for 3 gemms in fwd/bwd of linear Returns: Float8Tensor: a float8 tensor """ return ToFloat8ConstrFunc.apply( - tensor, scale, float8_dtype, amax_buffer, mm_config + tensor, + scale, + float8_dtype, + amax_buffer, + linear_mm_config, + gemm_input_role, ) @classmethod diff --git a/float8_experimental/float8_tensor_parallel.py b/float8_experimental/float8_tensor_parallel.py index 7c012f6..4c5297c 100644 --- a/float8_experimental/float8_tensor_parallel.py +++ b/float8_experimental/float8_tensor_parallel.py @@ -5,6 +5,7 @@ cast_to_float8_e5m2_dynamic_bw, ) from float8_experimental.float8_linear import TensorScalingType +from float8_experimental.float8_tensor import GemmInputRole from torch.distributed._tensor import DTensor from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor.parallel import ( @@ -45,7 +46,9 @@ def _prepare_input_fn( ) input_tensor = cast_to_float8_e4m3_dynamic( - input_tensor, mod.forward_config + input_tensor, + mod.linear_mm_config, + gemm_input_role=GemmInputRole.X, ) # DTensor(Float8Tensor) # transform the input layouts to the desired layouts of ColwiseParallel @@ -64,7 +67,7 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me ) # DTensor(torch.Tensor) # fwd noop bwd cast to DTensor(Float8Tensor) - outputs = cast_to_float8_e5m2_dynamic_bw(outputs, mod.backward_config) + outputs = cast_to_float8_e5m2_dynamic_bw(outputs, mod.linear_mm_config) # back to local tensor return outputs.to_local() if use_local_output else outputs @@ -96,7 +99,9 @@ def _prepare_input_fn( ) input_tensor = cast_to_float8_e4m3_dynamic( - input_tensor, mod.forward_config + input_tensor, + mod.linear_mm_config, + gemm_input_role=GemmInputRole.X, ) # DTensor(Float8Tensor) if input_layouts != desired_input_layouts: @@ -114,7 +119,7 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me outputs = outputs.redistribute(placements=output_layouts, async_op=True) # fwd noop bwd cast to DTensor(Float8Tensor) - outputs = cast_to_float8_e5m2_dynamic_bw(outputs, mod.backward_config) + outputs = cast_to_float8_e5m2_dynamic_bw(outputs, mod.linear_mm_config) # back to local tensor if use_local_output is True return outputs.to_local() if use_local_output else outputs @@ -169,6 +174,7 @@ def __init__( # fp8 specific fields self.float8_dtype = float8_dtype + self.linear_mm_config = None self.fwd_config_submodule_fqn = fwd_config_submodule_fqn if self.float8_dtype != torch.float8_e4m3fn: @@ -191,7 +197,9 @@ def _prepare_input_arg(self, input, mesh, input_layout, desired_layout): ) dt_inp = cast_to_float8_e4m3_dynamic( - dt_inp, self.fwd_linear_config + dt_inp, + self.linear_mm_config, + gemm_input_role=GemmInputRole.X, ) # DTensor(Float8Tensor) if desired_layout is not None and input_layout != desired_layout: dt_inp = dt_inp.redistribute(placements=(desired_layout,)) @@ -203,22 +211,21 @@ def _prepare_input_arg(self, input, mesh, input_layout, desired_layout): def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: from float8_experimental.float8_linear import Float8Linear - fwd_linear_config = None if self.fwd_config_submodule_fqn is not None: fwd_linear = module.get_submodule(self.fwd_config_submodule_fqn) assert isinstance(fwd_linear, Float8Linear) - fwd_linear_config = fwd_linear.forward_config + self.linear_mm_config = fwd_linear.linear_mm_config else: # search for ScaledMM configs for all the submodules and make sure they are the same for mod in module.modules(): if isinstance(mod, Float8Linear): - if fwd_linear_config is None: - fwd_linear_config = mod.forward_config + if self.linear_mm_config is None: + self.linear_mm_config = mod.linear_mm_config else: assert ( - fwd_linear_config == mod.forward_config - ), "All the Float8Linear modules should have same forward config!" + self.linear_mm_config == mod.linear_mm_config + ), "All the Float8Linear modules should have same linear_mm_config!" - self.fwd_linear_config = fwd_linear_config + assert self.linear_mm_config is not None super()._apply(module, device_mesh) return module diff --git a/float8_experimental/fsdp_utils.py b/float8_experimental/fsdp_utils.py index c7eb2c0..04cd797 100644 --- a/float8_experimental/fsdp_utils.py +++ b/float8_experimental/fsdp_utils.py @@ -14,8 +14,8 @@ from float8_experimental.float8_tensor import ( Float8Tensor, - merge_mm_configs, - ScaledMMConfig, + GemmInputRole, + LinearMMConfig, ) from float8_experimental.float8_utils import e4m3_dtype, EPS @@ -89,7 +89,7 @@ class WeightWithDynamicFloat8CastTensor(torch.Tensor): def __new__( cls, tensor: torch.Tensor, - mm_config: ScaledMMConfig, + linear_mm_config: LinearMMConfig, precomputed_scale: Optional[torch.Tensor] = None, ): return torch.Tensor._make_wrapper_subclass( @@ -108,11 +108,11 @@ def __new__( def __init__( self, tensor: torch.Tensor, - mm_config: ScaledMMConfig, + linear_mm_config: LinearMMConfig, precomputed_scale: Optional[torch.Tensor] = None, ): self._tensor = tensor - self._mm_config = mm_config + self._linear_mm_config = linear_mm_config # for dynamic scaling # `precompute_float8_dynamic_scale_for_fsdp` calculates scales # for all float8 parameters after optimizer step @@ -122,16 +122,16 @@ def __init__( def __torch_dispatch__(cls, func, types, args, kwargs=None): if func == torch.ops.aten.detach.default: return WeightWithDynamicFloat8CastTensor( - args[0]._tensor, args[0]._mm_config + args[0]._tensor, args[0]._linear_mm_config ) - mm_config: Optional[ScaledMMConfig] = None + mm_config: Optional[LinearMMConfig] = None def unwrap(t): nonlocal mm_config if mm_config is None: - mm_config = t._mm_config + mm_config = t._linear_mm_config else: - mm_config = merge_mm_configs(mm_config, t._mm_config) + assert t._linear_mm_config == mm_config return t._tensor args, kwargs = pytree.tree_map_only( @@ -146,9 +146,9 @@ def unwrap(t): def __tensor_flatten__(self): if self._precomputed_scale: - return ["_tensor", "_precomputed_scale"], self._mm_config + return ["_tensor", "_precomputed_scale"], self._linear_mm_config else: - return ["_tensor"], self._mm_config + return ["_tensor"], self._linear_mm_config @staticmethod def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): @@ -160,7 +160,7 @@ def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): ) def __repr__(self): - return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config})" + return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, linear_mm_config={self._linear_mm_config})" def fsdp_pre_all_gather(self, mesh): if self._precomputed_scale is not None: @@ -168,11 +168,15 @@ def fsdp_pre_all_gather(self, mesh): self._tensor, self._precomputed_scale, torch.float8_e4m3fn, - mm_config=self._mm_config, + linear_mm_config=self._linear_mm_config, + gemm_input_role=GemmInputRole.W, ) else: float8_tensor = cast_to_float8_e4m3_dynamic( - self._tensor, self._mm_config, reduce_amax=True + self._tensor, + self._linear_mm_config, + reduce_amax=True, + gemm_input_role=GemmInputRole.W, ) return (float8_tensor._data,), (float8_tensor._scale,) @@ -190,7 +194,13 @@ def fsdp_post_all_gather( assert isinstance(out, Float8Tensor), f"{type(out)}" out._scale = scale return - return Float8Tensor(data, scale, param_dtype, self._mm_config), (data,) + return Float8Tensor( + data, + scale, + param_dtype, + self._linear_mm_config, + gemm_input_role=GemmInputRole.W, + ), (data,) class WeightWithDelayedFloat8CastTensor(torch.Tensor): @@ -201,7 +211,7 @@ def __new__( amax_buffer: torch.Tensor, amax_history_buffer: torch.Tensor, scale_buffer: torch.Tensor, - mm_config: ScaledMMConfig, + linear_mm_config: LinearMMConfig, is_amax_initialized: bool, ): return torch.Tensor._make_wrapper_subclass( @@ -223,14 +233,14 @@ def __init__( amax_buffer: torch.Tensor, amax_history_buffer: torch.Tensor, scale_buffer: torch.Tensor, - mm_config: ScaledMMConfig, + linear_mm_config: LinearMMConfig, is_amax_initialized: bool, ): self._tensor = tensor self._amax_buffer = amax_buffer self._amax_history_buffer = amax_history_buffer self._scale_buffer = scale_buffer - self._mm_config = mm_config + self._linear_mm_config = linear_mm_config # Note: is_amax_initialized is not a buffer to avoid data dependent # control flow visible to dynamo @@ -245,10 +255,10 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): args[0]._amax_buffer, args[0]._amax_history_buffer, args[0]._scale_buffer, - args[0]._mm_config, + args[0]._linear_mm_config, args[0].is_amax_initialized, ) - mm_config: Optional[ScaledMMConfig] = None + mm_config: Optional[LinearMMConfig] = None amax_buffer: Optional[torch.Tensor] = None amax_history_buffer: Optional[torch.Tensor] = None scale_buffer: Optional[torch.Tensor] = None @@ -257,9 +267,9 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): def unwrap(t): nonlocal mm_config if mm_config is None: - mm_config = t._mm_config + mm_config = t._linear_mm_config else: - mm_config = merge_mm_configs(mm_config, t._mm_config) + assert t._linear_mm_config == mm_config nonlocal amax_buffer if amax_buffer is None: amax_buffer = t._amax_buffer @@ -302,7 +312,7 @@ def __tensor_flatten__(self): "_scale_buffer", ], { - "mm_config": self._mm_config, + "mm_config": self._linear_mm_config, "is_amax_initialized": self.is_amax_initialized, }, ) @@ -319,7 +329,7 @@ def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride): ) def __repr__(self): - return f"WeightWithDelayedFloat8CastTensor(tensor={self._tensor}, amax_buffer={self._amax_buffer}, scale_buffer={self._scale_buffer}, mm_config={self._mm_config})" + return f"WeightWithDelayedFloat8CastTensor(tensor={self._tensor}, amax_buffer={self._amax_buffer}, scale_buffer={self._scale_buffer}, mm_config={self._linear_mm_config})" def fsdp_pre_all_gather(self, mesh): # initialize if needed @@ -351,7 +361,8 @@ def fsdp_pre_all_gather(self, mesh): self._scale_buffer, e4m3_dtype, self._amax_buffer, - self._mm_config, + self._linear_mm_config, + gemm_input_role=GemmInputRole.W, ) return (float8_tensor._data,), (float8_tensor._scale,) @@ -369,4 +380,10 @@ def fsdp_post_all_gather( assert isinstance(out, Float8Tensor), f"{type(out)}" out._scale = scale return - return Float8Tensor(data, scale, param_dtype, self._mm_config), (data,) + return Float8Tensor( + data, + scale, + param_dtype, + self._linear_mm_config, + gemm_input_role=GemmInputRole.W, + ), (data,) diff --git a/float8_experimental/inference.py b/float8_experimental/inference.py index 1c931ee..2bb593b 100644 --- a/float8_experimental/inference.py +++ b/float8_experimental/inference.py @@ -20,6 +20,8 @@ from float8_experimental.float8_tensor import ( Float8Tensor, + GemmInputRole, + LinearMMConfig, ScaledMMConfig, tensor_already_casted_to_fp8, to_fp8_no_autograd, @@ -73,7 +75,7 @@ def __init__( self, # FP8 specific arguments quant_config: QuantConfig, - forward_config: ScaledMMConfig, + linear_mm_config: LinearMMConfig, # nn.Linear arguments in_features: int, out_features: int, @@ -83,7 +85,7 @@ def __init__( ) -> None: # Construct the superclass this will create dummy weights and biases super().__init__(in_features, out_features, bias, device, dtype) - self.forward_config = forward_config + self.linear_mm_config = linear_mm_config self.activation_casting = quant_config.activation_casting if self.activation_casting == ActivationCasting.STATIC: self.register_buffer( @@ -100,7 +102,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: x_fp8 = cast_to_float8_e4m3_inference( input, - self.forward_config, + self.linear_mm_config, static_quantization_scale=self.static_quantization_scale, ) return torch.nn.functional.linear(x_fp8, self.weight, self.bias) @@ -125,7 +127,8 @@ def quantize_weight(self, dtype: torch.dtype = e4m3_dtype) -> None: self.weight, scale, dtype, - self.forward_config, + self.linear_mm_config, + gemm_input_role=GemmInputRole.W, ) self.weight = nn.Parameter(quantized_weight) self.weight.requires_grad = False @@ -150,9 +153,12 @@ def from_float( forward_config = ScaledMMConfig( False, use_fast_accum, pad_inner_dim=config.pad_inner_dim ) + linear_mm_config = LinearMMConfig( + forward_config, forward_config, forward_config + ) linear = cls( quant_config, - forward_config, + linear_mm_config, module.in_features, module.out_features, False, @@ -165,7 +171,7 @@ def from_float( def cast_to_float8_e4m3_inference( inpt_tensor: torch.Tensor, - mm_config: ScaledMMConfig, + linear_mm_config: LinearMMConfig, reduce_amax: bool = False, static_quantization_scale: Optional[torch.Tensor] = None, ) -> Float8Tensor: @@ -173,7 +179,7 @@ def cast_to_float8_e4m3_inference( Args: inpt_tensor: The input tensor to be cast. - mm_config: Configuration settings for the matrix multiplication + linear_mm_config: Configuration settings for the matrix multiplication reduce_amax: Whether to reduce the amax (absolute maximum) among the local distributed group. static_quantization_scale: Optional tensor specifying the scale for activation. Default is None. @@ -190,7 +196,13 @@ def cast_to_float8_e4m3_inference( if static_quantization_scale is not None else tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax) ) - return Float8Tensor.to_float8(inpt_tensor, scale, e4m3_dtype, mm_config=mm_config) + return Float8Tensor.to_float8( + inpt_tensor, + scale, + e4m3_dtype, + linear_mm_config=linear_mm_config, + gemm_input_role=GemmInputRole.X, + ) def quantize_to_float8( diff --git a/test/test_base.py b/test/test_base.py index 2c7c3f4..381fb4e 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -26,7 +26,8 @@ from float8_experimental.float8_python_api import addmm_float8_unwrapped from float8_experimental.float8_tensor import ( Float8Tensor, - merge_mm_configs, + GemmInputRole, + LinearMMConfig, ScaledMMConfig, ) from float8_experimental.float8_utils import ( @@ -121,7 +122,7 @@ def test_copy_(self): torch.empty(16, dtype=torch.float8_e4m3fn), scale_a, torch.bfloat16, - fp8_a._mm_config, + fp8_a._linear_mm_config, ) fp8_b.copy_(fp8_a) torch.testing.assert_close(fp8_a._data, fp8_b._data) @@ -438,38 +439,36 @@ def test_different_configs_error(self): x_fp32 = torch.randn(16, 16, device="cuda") x_scale = torch.tensor(1.0, device="cuda") fp8_dtype = e4m3_dtype - a = Float8Tensor.to_float8(x_fp32, x_scale, fp8_dtype) + linear_config_a = LinearMMConfig( + ScaledMMConfig(False, True, False, False), + ScaledMMConfig(False, False, False, False), + ScaledMMConfig(False, False, False, False), + ) + linear_config_b = LinearMMConfig( + ScaledMMConfig(True, True, False, False), + ScaledMMConfig(True, False, False, False), + ScaledMMConfig(True, False, False, False), + ) + a = Float8Tensor.to_float8( + x_fp32, + x_scale, + fp8_dtype, + linear_mm_config=linear_config_a, + gemm_input_role=GemmInputRole.X, + ) b = Float8Tensor.to_float8( - x_fp32, x_scale, fp8_dtype, mm_config=ScaledMMConfig(True) + x_fp32, + x_scale, + fp8_dtype, + linear_mm_config=linear_config_b, + gemm_input_role=GemmInputRole.W, ) with pytest.raises( AssertionError, - match="Both mm_configs must have the same emulate value, but got False and True", + match="linear_mm_config.y mismatch", ): a @ b - def test_merge_configs(self): - a = ScaledMMConfig(False, True, True) - b = ScaledMMConfig(True, False, False) - with pytest.raises( - AssertionError, - match="Both mm_configs must have the same emulate value, but got False and True", - ): - merge_mm_configs(a, b) - a = ScaledMMConfig(False, True, True) - b = ScaledMMConfig(False, False, False) - c = merge_mm_configs(a, b) - assert c.emulate is False - assert c.use_fast_accum is False - assert c.fp8_output is False - - a = ScaledMMConfig(False, True, False) - b = ScaledMMConfig(False, True, False) - c = merge_mm_configs(a, b) - assert c.emulate is False - assert c.use_fast_accum is True - assert c.fp8_output is False - @unittest.skipIf( not is_H100, "CUDA not available", @@ -489,8 +488,12 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum): a_scale = tensor_to_scale(a, input_dtype).float() b_scale = tensor_to_scale(b, input_dtype).float() - a_fp8 = Float8Tensor.to_float8(a, a_scale, input_dtype) - b_fp8 = Float8Tensor.to_float8(b, b_scale, input_dtype) + a_fp8 = Float8Tensor.to_float8( + a, a_scale, input_dtype, gemm_input_role=GemmInputRole.X + ) + b_fp8 = Float8Tensor.to_float8( + b, b_scale, input_dtype, gemm_input_role=GemmInputRole.W + ) with pytest.raises( RuntimeError, @@ -500,19 +503,47 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum): ): a_fp8 @ b_fp8 - pad_config = ScaledMMConfig(False, use_fast_accum, False, True) + scaled_mm_config = ScaledMMConfig(False, use_fast_accum, False, True) + pad_config = LinearMMConfig( + scaled_mm_config, scaled_mm_config, scaled_mm_config + ) - a_fp8 = Float8Tensor.to_float8(a, a_scale, input_dtype, mm_config=pad_config) - b_fp8 = Float8Tensor.to_float8(b, b_scale, input_dtype, mm_config=pad_config) + a_fp8 = Float8Tensor.to_float8( + a, + a_scale, + input_dtype, + linear_mm_config=pad_config, + gemm_input_role=GemmInputRole.X, + ) + b_fp8 = Float8Tensor.to_float8( + b, + b_scale, + input_dtype, + linear_mm_config=pad_config, + gemm_input_role=GemmInputRole.W, + ) out_padded = a_fp8 @ b_fp8 out_padded.to(compare_type) - emulated_conifg = ScaledMMConfig(True, use_fast_accum, False, False) + emulated_scaled_mm_config = ScaledMMConfig(True, use_fast_accum, False, False) + emulated_config = LinearMMConfig( + emulated_scaled_mm_config, + emulated_scaled_mm_config, + emulated_scaled_mm_config, + ) a_fp8 = Float8Tensor.to_float8( - a, a_scale, input_dtype, mm_config=emulated_conifg + a, + a_scale, + input_dtype, + linear_mm_config=emulated_config, + gemm_input_role=GemmInputRole.X, ) b_fp8 = Float8Tensor.to_float8( - b, b_scale, input_dtype, mm_config=emulated_conifg + b, + b_scale, + input_dtype, + linear_mm_config=emulated_config, + gemm_input_role=GemmInputRole.W, ) out_emualted = a_fp8 @ b_fp8 out_emualted.to(compare_type) @@ -564,8 +595,8 @@ def test_swap_root_linear(self): module = nn.Linear(3, 3) module = swap_linear_with_float8_linear(module, emulate=emulate) self.assertIsInstance(module, Float8Linear) - self.assertEqual(module.forward_config.emulate, emulate) - self.assertEqual(module.backward_config.emulate, emulate) + self.assertEqual(module.linear_mm_config.y.emulate, emulate) + self.assertEqual(module.linear_mm_config.y.emulate, emulate) def test_swap_root_linear_with_children_raises(self): for emulate in [True, False]: diff --git a/test/test_compile.py b/test/test_compile.py index 5a6e003..f72425d 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -19,7 +19,7 @@ swap_linear_with_float8_linear, sync_float8_amax_and_scale_history, ) -from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig +from float8_experimental.float8_tensor import Float8Tensor, LinearMMConfig from float8_experimental.float8_utils import e4m3_dtype from torch._dynamo.test_case import TestCase as DynamoTestCase @@ -179,7 +179,7 @@ def forward(self, x): self.fp8_scale_x, e4m3_dtype, self.fp8_amax_x, - ScaledMMConfig(), + LinearMMConfig(), ) if self.graph_break: torch._dynamo.graph_break() @@ -242,9 +242,9 @@ def test_float8_graph_output(self): type(y_compiled._orig_dtype) ) assert isinstance( - y_compiled._mm_config.emulate, bool + y_compiled._linear_mm_config.y.emulate, bool ), "Float8Tensor._emulate should be a bool but got {}".format( - type(y_compiled._mm_config.emulate) + type(y_compiled._linear_mm_config.y.emulate) ) diff --git a/test/test_dtensor.py b/test/test_dtensor.py index 8aada4b..1cd14db 100644 --- a/test/test_dtensor.py +++ b/test/test_dtensor.py @@ -17,7 +17,11 @@ from float8_experimental.float8_dynamic_utils import NoopFwToFloat8E5M2Bw from float8_experimental.float8_linear import TensorScalingType from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear -from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig +from float8_experimental.float8_tensor import ( + Float8Tensor, + GemmInputRole, + LinearMMConfig, +) from float8_experimental.float8_tensor_parallel import ( Float8ColwiseParallel, Float8RowwiseParallel, @@ -82,8 +86,12 @@ def test_scaled_mm(mesh: DeviceMesh, size=16): x_scale = tensor_to_scale(x_fp32, fp8_dtype).float() y_scale = tensor_to_scale(y_fp32, fp8_dtype).float() - x_fp8 = Float8Tensor.to_float8(x_fp32, x_scale, fp8_dtype) - y_fp8 = Float8Tensor.to_float8(y_fp32, y_scale, fp8_dtype) + x_fp8 = Float8Tensor.to_float8( + x_fp32, x_scale, fp8_dtype, gemm_input_role=GemmInputRole.X + ) + y_fp8 = Float8Tensor.to_float8( + y_fp32, y_scale, fp8_dtype, gemm_input_role=GemmInputRole.W + ) dist_x_fp8 = DTensor.from_local(x_fp8, mesh, [lhs_placement], run_check=False) dist_y_fp8 = DTensor.from_local(y_fp8, mesh, [rhs_placement], run_check=False) @@ -155,13 +163,15 @@ def test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16): dist_weight_scale = tensor_to_scale(dist_wight_fp32, fp8_dtype).float() dist_target = distribute_tensor(target, mesh, [Shard(0)]) - dist_x_fp8 = Float8Tensor.to_float8(dist_x_fp32, dist_x_scale, fp8_dtype) + dist_x_fp8 = Float8Tensor.to_float8( + dist_x_fp32, dist_x_scale, fp8_dtype, gemm_input_role=GemmInputRole.X + ) dist_weight_fp8 = Float8Tensor.to_float8( - dist_wight_fp32, dist_weight_scale, fp8_dtype + dist_wight_fp32, dist_weight_scale, fp8_dtype, gemm_input_role=GemmInputRole.W ) out = torch.nn.functional.linear(dist_x_fp8, dist_weight_fp8) - out = NoopFwToFloat8E5M2Bw.apply(out, ScaledMMConfig()) + out = NoopFwToFloat8E5M2Bw.apply(out, LinearMMConfig()) assert isinstance(out, DTensor), f"Expected DTensor, got {type(out)}" loss = torch.sum(torch.abs(out - dist_target)) loss.backward()