diff --git a/float8_experimental/__init__.py b/float8_experimental/__init__.py index 8822796..170d05f 100644 --- a/float8_experimental/__init__.py +++ b/float8_experimental/__init__.py @@ -5,11 +5,16 @@ # LICENSE file in the root directory of this source tree. # Lets define a few top level things here from float8_experimental.float8_linear import Float8Linear -from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig +from float8_experimental.float8_tensor import ( + Float8Tensor, + GemmInputRole, + LinearMMConfig, + ScaledMMConfig, +) # Needed to load Float8Tensor with weights_only = True from torch.serialization import add_safe_globals -add_safe_globals([Float8Tensor, ScaledMMConfig]) +add_safe_globals([Float8Tensor, ScaledMMConfig, GemmInputRole, LinearMMConfig]) __all__ = ["Float8Tensor", "Float8Linear"] diff --git a/float8_experimental/float8_dynamic_utils.py b/float8_experimental/float8_dynamic_utils.py index 215a394..9fe9d17 100644 --- a/float8_experimental/float8_dynamic_utils.py +++ b/float8_experimental/float8_dynamic_utils.py @@ -8,7 +8,8 @@ from float8_experimental.float8_tensor import ( Float8Tensor, - ScaledMMConfig, + GemmInputRole, + LinearMMConfig, tensor_already_casted_to_fp8, to_fp8_no_autograd, ) @@ -26,9 +27,9 @@ class NoopFwToFloat8E5M2Bw(torch.autograd.Function): def forward( ctx, tensor, - mm_config: ScaledMMConfig, + linear_mm_config: LinearMMConfig, ): - ctx.mm_config = mm_config + ctx.linear_mm_config = linear_mm_config return tensor @staticmethod @@ -37,21 +38,34 @@ def backward(ctx, gradY): return gradY, None gradY_scale = tensor_to_scale(gradY, e5m2_dtype) fp8_tensor = to_fp8_no_autograd( - gradY, gradY_scale, e5m2_dtype, mm_config=ctx.mm_config + gradY, + gradY_scale, + e5m2_dtype, + linear_mm_config=ctx.linear_mm_config, + gemm_input_role=GemmInputRole.DL_DY, ) return fp8_tensor, None def cast_to_float8_e4m3_dynamic( - inpt_tensor: torch.Tensor, mm_config: ScaledMMConfig, reduce_amax: bool = False + inpt_tensor: torch.Tensor, + linear_mm_config: LinearMMConfig, + reduce_amax: bool = False, + gemm_input_role: GemmInputRole = GemmInputRole.X, ) -> Float8Tensor: if tensor_already_casted_to_fp8(inpt_tensor): return inpt_tensor scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax) - return Float8Tensor.to_float8(inpt_tensor, scale, e4m3_dtype, mm_config=mm_config) + return Float8Tensor.to_float8( + inpt_tensor, + scale, + e4m3_dtype, + linear_mm_config=linear_mm_config, + gemm_input_role=gemm_input_role, + ) def cast_to_float8_e5m2_dynamic_bw( - gradY: torch.Tensor, mm_config: ScaledMMConfig + gradY: torch.Tensor, linear_mm_config: LinearMMConfig ) -> torch.Tensor: - return NoopFwToFloat8E5M2Bw.apply(gradY, mm_config) + return NoopFwToFloat8E5M2Bw.apply(gradY, linear_mm_config) diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 7850738..37cf0d5 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -23,6 +23,8 @@ from float8_experimental.float8_tensor import ( Float8Tensor, + GemmInputRole, + LinearMMConfig, ScaledMMConfig, to_fp8_no_autograd, ) @@ -85,12 +87,12 @@ def forward( fp8_scale_dL_dY, scale_fn_name, is_amax_initialized, - mm_config: ScaledMMConfig, + linear_mm_config: LinearMMConfig, ): ctx.save_for_backward(fp8_amax_dL_dY, fp8_amax_history_dL_dY, fp8_scale_dL_dY) ctx.scale_fn_name = scale_fn_name ctx.is_amax_initialized = is_amax_initialized - ctx.mm_config = mm_config + ctx.linear_mm_config = linear_mm_config return tensor @staticmethod @@ -113,7 +115,11 @@ def backward(ctx, go): fp8_amax_dL_dY.fill_(tensor_to_amax(go)) res = to_fp8_no_autograd( - go, fp8_scale_dL_dY, e5m2_dtype, mm_config=ctx.mm_config + go, + fp8_scale_dL_dY, + e5m2_dtype, + linear_mm_config=ctx.linear_mm_config, + gemm_input_role=GemmInputRole.DL_DY, ) empty_grads = None, None, None, None, None, None return res, *empty_grads @@ -192,12 +198,18 @@ def __init__(self, *args, **kwargs): self.create_buffers() - # Defines the behavior of the matmul in the forward and backward pass - self.forward_config = ScaledMMConfig( - emulate, True if not emulate else False, False, config.pad_inner_dim - ) - self.backward_config = ScaledMMConfig( - emulate, False, False, config.pad_inner_dim + # TODO(future): user level configuration of gemms + self.linear_mm_config = LinearMMConfig( + # x + ScaledMMConfig( + emulate, True if not emulate else False, False, config.pad_inner_dim + ), + # w + ScaledMMConfig( + emulate, True if not emulate else False, False, config.pad_inner_dim + ), + # dL_dY + ScaledMMConfig(emulate, False, False, config.pad_inner_dim), ) # Note: is_amax_initialized is not a buffer to avoid data dependent @@ -308,11 +320,12 @@ def cast_x_to_float8( self.fp8_scale_x, e4m3_dtype, self.fp8_amax_x, - self.forward_config, + linear_mm_config=self.linear_mm_config, + gemm_input_role=GemmInputRole.X, ) else: assert self.scaling_type_x is TensorScalingType.DYNAMIC - x_fp8 = cast_to_float8_e4m3_dynamic(x, self.forward_config) + x_fp8 = cast_to_float8_e4m3_dynamic(x, self.linear_mm_config) return x_fp8 def cast_w_to_float8( @@ -339,14 +352,17 @@ def cast_w_to_float8( self.fp8_scale_w, e4m3_dtype, self.fp8_amax_w, - self.forward_config, + linear_mm_config=self.linear_mm_config, + gemm_input_role=GemmInputRole.W, ) else: assert self.scaling_type_w is TensorScalingType.DYNAMIC if isinstance(self.weight, Float8Tensor): # cast by FSDP w_fp8 = self.weight else: - w_fp8 = cast_to_float8_e4m3_dynamic(self.weight, self.forward_config) + w_fp8 = cast_to_float8_e4m3_dynamic( + self.weight, self.linear_mm_config, gemm_input_role=GemmInputRole.W + ) return w_fp8 def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor: @@ -359,11 +375,11 @@ def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor: self.fp8_scale_dL_dY, scale_fn_name, self.is_amax_initialized, - self.backward_config, + self.linear_mm_config, ) else: assert self.scaling_type_dL_dY is TensorScalingType.DYNAMIC - y = cast_to_float8_e5m2_dynamic_bw(y, self.backward_config) + y = cast_to_float8_e5m2_dynamic_bw(y, self.linear_mm_config) return y def float8_pre_forward(self, x): @@ -457,7 +473,7 @@ def from_float( new_mod.weight = torch.nn.Parameter( WeightWithDynamicFloat8CastTensor( new_mod.weight, - new_mod.forward_config, + new_mod.linear_mm_config, ) ) else: @@ -468,7 +484,7 @@ def from_float( new_mod.fp8_amax_w, new_mod.fp8_amax_history_w, new_mod.fp8_scale_w, - new_mod.forward_config, + new_mod.linear_mm_config, new_mod.is_amax_initialized, ) ) diff --git a/float8_experimental/float8_ops.py b/float8_experimental/float8_ops.py index 3a50cc8..2a11726 100644 --- a/float8_experimental/float8_ops.py +++ b/float8_experimental/float8_ops.py @@ -8,11 +8,7 @@ import torch from float8_experimental.float8_python_api import addmm_float8_unwrapped -from float8_experimental.float8_tensor import ( - Float8Tensor, - merge_mm_configs, - ScaledMMConfig, -) +from float8_experimental.float8_tensor import choose_scaled_mm_config, Float8Tensor from float8_experimental.float8_utils import is_row_major, pad_tensor_for_matmul from torch.utils._pytree import tree_map @@ -50,7 +46,11 @@ def decorator(func): def float8_desugar_op(aten_op, args, kwargs=None): new_data = aten_op(args[0]._data, *args[1:], **kwargs) return Float8Tensor( - new_data, args[0]._scale, args[0]._orig_dtype, args[0]._mm_config + new_data, + args[0]._scale, + args[0]._orig_dtype, + args[0]._linear_mm_config, + args[0]._gemm_input_role, ) @@ -60,7 +60,11 @@ def float8_split(aten_op, args, kwargs=None): def make_float8(data): return Float8Tensor( - data, args[0]._scale, args[0]._orig_dtype, args[0]._mm_config + data, + args[0]._scale, + args[0]._orig_dtype, + args[0]._linear_mm_config, + args[0]._gemm_input_role, ) out = map(make_float8, new_data_tensors) @@ -74,8 +78,9 @@ def float8_cat(aten_op, args, kwargs=None): orig_dtype = chunked_tensors[0]._orig_dtype scale = chunked_tensors[0]._scale - mm_config = chunked_tensors[0]._mm_config + mm_config = chunked_tensors[0]._linear_mm_config fp8_dtype = chunked_tensors[0]._data.dtype + gemm_input_role = chunked_tensors[0]._gemm_input_role chunk_data = [] for chunk in chunked_tensors: assert isinstance( @@ -88,16 +93,19 @@ def float8_cat(aten_op, args, kwargs=None): chunk._scale is scale ), "Expecting all chunks to have thee same scale as a result of a split" assert ( - chunk._mm_config is mm_config + chunk._linear_mm_config is mm_config ), "Expecting all chunks to have thee same mm config as a result of a split" assert ( chunk._data.dtype == fp8_dtype ), "Expecting all chunks to be of the same dtype as a result of a split" + assert ( + chunk._gemm_input_role is gemm_input_role + ), "Expecting all chunks to have the same gemm_input_role as a result of a split" chunk_data.append(chunk._data.view(torch.uint8)) new_data = aten_op(chunk_data, *args[1:], **kwargs) new_data = new_data.view(fp8_dtype) - return Float8Tensor(new_data, scale, orig_dtype, mm_config) + return Float8Tensor(new_data, scale, orig_dtype, mm_config, gemm_input_role) @implements([aten.sum.dim_IntList]) @@ -125,10 +133,14 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor): a_scale = a._scale b_data = b._data - if a._mm_config.pad_inner_dim: - assert ( - b._mm_config.pad_inner_dim - ), "Both mm configs must have pad_inner_dim set to True" + scaled_mm_config = choose_scaled_mm_config( + a._gemm_input_role, + a._linear_mm_config, + b._gemm_input_role, + b._linear_mm_config, + ) + + if scaled_mm_config.pad_inner_dim: assert a._data.size(1) == b._data.size( 0 ), f"Inner dims must match for mm, got {a._data.size(1)} and {b._data.size(0)}" @@ -155,10 +167,13 @@ def float8_mm(aten_op, args, kwargs=None): ) a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b) output_dtype = a._orig_dtype - a_mm_config: ScaledMMConfig = a._mm_config - b_mm_config: ScaledMMConfig = b._mm_config - mm_config: ScaledMMConfig = merge_mm_configs(a_mm_config, b_mm_config) - if mm_config.emulate: + scaled_mm_config = choose_scaled_mm_config( + a._gemm_input_role, + a._linear_mm_config, + b._gemm_input_role, + b._linear_mm_config, + ) + if scaled_mm_config.emulate: return torch.ops.aten.mm_float8_emulated( a._data, a._scale, b._data, b._scale, output_dtype ) @@ -170,7 +185,7 @@ def float8_mm(aten_op, args, kwargs=None): output_dtype, output_scale=None, bias=None, - use_fast_accum=mm_config.use_fast_accum, + use_fast_accum=scaled_mm_config.use_fast_accum, ) return tensor_out @@ -188,10 +203,13 @@ def float8_addmm(aten_op, args, kwargs=None): a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b) output_dtype = a._orig_dtype assert bias.dtype == output_dtype, "bias dtype must match output dtype" - a_mm_config: ScaledMMConfig = a._mm_config - b_mm_config: ScaledMMConfig = b._mm_config - mm_config: ScaledMMConfig = merge_mm_configs(a_mm_config, b_mm_config) - if mm_config.emulate: + scaled_mm_config = choose_scaled_mm_config( + a._gemm_input_role, + a._linear_mm_config, + b._gemm_input_role, + b._linear_mm_config, + ) + if scaled_mm_config.emulate: out = torch.ops.aten.mm_float8_emulated( a._data, a._scale, b._data, b._scale, output_dtype ) @@ -204,7 +222,7 @@ def float8_addmm(aten_op, args, kwargs=None): output_dtype, output_scale=None, bias=bias, - use_fast_accum=mm_config.use_fast_accum, + use_fast_accum=scaled_mm_config.use_fast_accum, ) return tensor_out @@ -229,7 +247,11 @@ def autocast_to_copy(aten_op, args, kwargs=None): torch.bfloat16, }, "Only support floating point conversion for autocast w/ Float8Tensor" return Float8Tensor( - args[0]._data, args[0]._scale, kwargs["dtype"], args[0]._mm_config + args[0]._data, + args[0]._scale, + kwargs["dtype"], + args[0]._linear_mm_config, + args[0]._gemm_input_role, ) @@ -252,7 +274,11 @@ def allgather_fp8(aten_op, args, kwargs=None): fp8_data = fp8_data.contiguous() fp8_out = aten_op(fp8_data, *args[1:], **kwargs) return Float8Tensor( - fp8_out, fp8_input._scale, fp8_input._orig_dtype, fp8_input._mm_config + fp8_out, + fp8_input._scale, + fp8_input._orig_dtype, + fp8_input._linear_mm_config, + fp8_input._gemm_input_role, ) @@ -264,7 +290,11 @@ def wait_tensor_fp8(aten_op, args, kwargs=None): fp8_data = fp8_input._data fp8_out = aten_op(fp8_data, *args[1:], **kwargs) return Float8Tensor( - fp8_out, fp8_input._scale, fp8_input._orig_dtype, fp8_input._mm_config + fp8_out, + fp8_input._scale, + fp8_input._orig_dtype, + fp8_input._linear_mm_config, + fp8_input._gemm_input_role, ) @@ -282,7 +312,11 @@ def index_put_fp8(aten_op, args, kwargs=None): fp8_values_data = fp8_values._data fp8_out = aten_op(fp8_data, args[1], fp8_values_data, *args[3:], **kwargs) return Float8Tensor( - fp8_out, fp8_self._scale, fp8_self._orig_dtype, fp8_self._mm_config + fp8_out, + fp8_self._scale, + fp8_self._orig_dtype, + fp8_self._linear_mm_config, + fp8_self._gemm_input_role, ) @@ -309,12 +343,21 @@ def copy_fp8(aten_op, args, kwargs=None): self._scale == src._scale ), "Expecting both Float8Tensors to have thee same scale" assert ( - self._mm_config == src._mm_config + self._linear_mm_config == src._linear_mm_config ), "Expecting both Float8Tensors to have thee same mm config" assert ( self._data.dtype == src._data.dtype ), "Expecting both Float8Tensors to be of the same dtypet" + assert ( + self._gemm_input_role == src._gemm_input_role + ), "Expecting both Float8Tensors to have the same gemm_input_role" fp8_out = aten_op(self._data, src._data, *args[2:], **kwargs) - return Float8Tensor(fp8_out, self._scale, self._orig_dtype, self._mm_config) + return Float8Tensor( + fp8_out, + self._scale, + self._orig_dtype, + self._linear_mm_config, + self._gemm_input_role, + ) else: raise RuntimeError("Unsupported semantics for copy_ in Float8Tensor") diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index 26d4688..475a17a 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import enum from collections import namedtuple from typing import Dict, Optional @@ -18,6 +19,31 @@ aten = torch.ops.aten +# +# A note on configuration of float8 logic in a linear +# TODO(future): move all the configs to separate file +# +# There are three gemms in a forward + backward of a Linear layer: +# +# 1. x @ w_t = y (forward pass) +# 2. dL_dY @ w = dL_dX (backward pass) +# 3. x_t @ dL_dY = dL_dW (backward pass) +# +# In the formulas above, there are: +# A. six input tensors (x, x_t, w, w_t, dL_dY, dL_dY_t). +# - Note that dL_dY_t is implied because of memory format requirements +# of float8 gemms +# B. three output tensors (y, dL_dX, dL_dW) +# +# We want each input tensor, gemm, and output tensor to be configurable. +# The state of this configuration today is: +# +# i. pairs of input tensors (non-t and t variants) have their scaling +# configurable via the scaling_type_{x_w_dL_dY} arguments to Float8Linear +# ii. each gemm + output is configurable via ScaledMMConfig, which is not user facing +# iii. LinearMMConfig is a container for the three ScaledMMConfig objects needed +# to configure all three gemms, also not user facing + # ScaledMMConfig is a namedtuple that defines the configuration for the scaled_mm in the forward and backward pass. # emulate: whether to emulate the matmuls in fp32 @@ -30,27 +56,58 @@ defaults=[False, False, False, False], ) +# The object below is not user facing and exists for convenience, +# to allow Float8Tensor to use +# the right config based on which gemm from `y`, `dL_dX`, `dL_dW` is +# being called. +LinearMMConfig = namedtuple( + "LinearMMConfig", + ["y", "dL_dX", "dL_dW"], + defaults=[ + ScaledMMConfig(False, True, False, False), + ScaledMMConfig(False, False, False, False), + ScaledMMConfig(False, False, False, False), + ], +) -def merge_mm_configs( - a_mm_config: ScaledMMConfig, b_mm_config: ScaledMMConfig -) -> ScaledMMConfig: - """Merges two mm_configs together emulate behavior must match, - However we want to use_fast_accum in forward and not in backward. - We do this by populating the fields of the backproping grad. Same applies for fp8_output. - For both use_fast_accum and fp8_output, if either config is False, the merged config will be False. +class GemmInputRole(enum.Enum): """ - assert ( - a_mm_config.emulate == b_mm_config.emulate - ), "Both mm_configs must have the same emulate value, but got {} and {}".format( - a_mm_config.emulate, b_mm_config.emulate - ) - return ScaledMMConfig( - emulate=a_mm_config.emulate, - use_fast_accum=a_mm_config.use_fast_accum and b_mm_config.use_fast_accum, - fp8_output=a_mm_config.fp8_output and b_mm_config.fp8_output, - pad_inner_dim=a_mm_config.pad_inner_dim and b_mm_config.pad_inner_dim, - ) + Given a Float8Tensor, the enum below describes the expected role of this + tensor in the three gemms present in the fw + bw pass of a Linear layer. + This is used to choose the right config for a float8 gemm when the + gemm is performed. + """ + + X = "x" + W = "w" + DL_DY = "dL_dY" + + +# choose which scaled_mm_config to use based on gemm inputs +def choose_scaled_mm_config( + a_role: GemmInputRole, + a_linear_mm_config: LinearMMConfig, + b_role: GemmInputRole, + b_linear_mm_config: LinearMMConfig, +): + if a_role is GemmInputRole.X and b_role is GemmInputRole.W: + assert ( + a_linear_mm_config.y == b_linear_mm_config.y + ), f"linear_mm_config.y mismatch: {a_linear_mm_config.y} vs {b_linear_mm_config.y}" + return a_linear_mm_config.y + elif a_role is GemmInputRole.DL_DY and b_role is GemmInputRole.W: + assert ( + a_linear_mm_config.dL_dX == b_linear_mm_config.dL_dX + ), f"linear_mm_config.dL_dX mismatch: {a_linear_mm_config.dL_dX} vs {b_linear_mm_config.dL_dX}" + return a_linear_mm_config.dL_dX + elif a_role is GemmInputRole.DL_DY and b_role is GemmInputRole.X: + assert ( + a_linear_mm_config.dL_dW == b_linear_mm_config.dL_dW + ), f"linear_mm_config.dL_dW mismatch: {a_linear_mm_config.dL_dW} vs {b_linear_mm_config.dL_dW}" + return a_linear_mm_config.dL_dW + else: + raise AssertionError(f"unexpected a_role {a_role} and b_role {b_role}") def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool: @@ -72,7 +129,8 @@ def to_fp8_no_autograd( x: torch.Tensor, x_scale: torch.Tensor, float8_dtype: torch.dtype, - mm_config: Optional[ScaledMMConfig], + linear_mm_config: Optional[LinearMMConfig], + gemm_input_role: Optional[GemmInputRole], ) -> "Float8Tensor": """Convert a tensor to float8 without autograd This is used in multiple places in the codebase to convert a tensor to float8 @@ -90,7 +148,10 @@ def to_fp8_no_autograd( x: the tensor to convert scale: the scale to use to convert the tensor float8_dtype: the float8 dtype to use - mm_config: Defines the configuration for the scaled_mm + linear_mm_config: Defines the configuration for the scaled_mm for + the 3 fwd/bwd gemms of linear + gemm_input_role: Defines the role of this tensor (x, w or dL_dY) in + the 3 fwd/bwd gemms of linear """ x_scaled = x * x_scale bits_fp8 = to_fp8_saturated(x_scaled, float8_dtype) @@ -104,7 +165,11 @@ def to_fp8_no_autograd( local_bits = bits_fp8.to_local() local_scale = x_scale.to_local() inner_float8_tensor = Float8Tensor( - local_bits, local_scale, x.dtype, mm_config=mm_config + local_bits, + local_scale, + x.dtype, + linear_mm_config=linear_mm_config, + gemm_input_role=gemm_input_role, ) return DTensor.from_local( inner_float8_tensor, @@ -115,7 +180,13 @@ def to_fp8_no_autograd( stride=bits_fp8.stride(), ) - return Float8Tensor(bits_fp8, x_scale, x.dtype, mm_config=mm_config) + return Float8Tensor( + bits_fp8, + x_scale, + x.dtype, + linear_mm_config=linear_mm_config, + gemm_input_role=gemm_input_role, + ) @torch._dynamo.allow_in_graph @@ -133,7 +204,8 @@ def forward( scale: torch.Tensor, float8_dtype=e4m3_dtype, amax_buffer: Optional[torch.Tensor] = None, - mm_config: Optional[ScaledMMConfig] = None, + linear_mm_config: Optional[LinearMMConfig] = None, + gemm_input_role: Optional[GemmInputRole] = GemmInputRole.X, ): """Autograd enabled wrapper around to_fp8_no_autograd that will also populate the amax buffer. Args @@ -146,11 +218,17 @@ def forward( if amax_buffer is not None: amax_buffer.fill_(tensor_to_amax(tensor)) - return to_fp8_no_autograd(tensor, scale, float8_dtype, mm_config=mm_config) + return to_fp8_no_autograd( + tensor, + scale, + float8_dtype, + linear_mm_config=linear_mm_config, + gemm_input_role=gemm_input_role, + ) @staticmethod def backward(ctx, g): - return g, None, None, None, None + return g, None, None, None, None, None @torch._dynamo.allow_in_graph @@ -194,15 +272,16 @@ class Float8Tensor(torch.Tensor): _data: torch.Tensor _scale: torch.Tensor _orig_dtype: torch.dtype - _mm_config: ScaledMMConfig - __slots__ = ["_data", "_scale", "_orig_dtype", "_mm_config"] + _linear_mm_config: LinearMMConfig + __slots__ = ["_data", "_scale", "_orig_dtype", "_linear_mm_config"] def __new__( cls, data: torch.Tensor, scale: torch.Tensor, orig_dtype: torch.dtype, - mm_config: Optional[ScaledMMConfig], + linear_mm_config: Optional[LinearMMConfig], + gemm_input_role: Optional[GemmInputRole] = GemmInputRole.X, ): assert ( scale.numel() == 1 @@ -223,17 +302,21 @@ def __new__( self._data = data self._scale = scale self._orig_dtype = orig_dtype - self._mm_config = mm_config if mm_config is not None else ScaledMMConfig() + self._linear_mm_config = ( + linear_mm_config if linear_mm_config is not None else LinearMMConfig() + ) + self._gemm_input_role = gemm_input_role return self def __repr__(self): - return f"Float8Tensor(dtype={self._data.dtype}, scale={self._scale}, mm_config={self._mm_config}\nas_orig_prec={self.to_original_precision()}" + return f"Float8Tensor(dtype={self._data.dtype}, scale={self._scale}, linear_mm_config={self._linear_mm_config}\ngemm_input_role={self._gemm_input_role}\nas_orig_prec={self.to_original_precision()}" def __tensor_flatten__(self): ctx = { "_orig_dtype": self._orig_dtype, - "_mm_config": self._mm_config, + "_linear_mm_config": self._linear_mm_config, + "_gemm_input_role": self._gemm_input_role, } return ["_data", "_scale"], ctx @@ -244,7 +327,8 @@ def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride inner_tensors["_data"], inner_tensors["_scale"], metadata["_orig_dtype"], - metadata["_mm_config"], + metadata["_linear_mm_config"], + metadata["_gemm_input_role"], ) def to_original_precision(self): @@ -257,7 +341,8 @@ def to_float8( scale: torch.Tensor, float8_dtype: torch.dtype, amax_buffer: Optional[torch.Tensor] = None, - mm_config: Optional[ScaledMMConfig] = None, + linear_mm_config: Optional[LinearMMConfig] = None, + gemm_input_role: Optional[GemmInputRole] = GemmInputRole.X, ): """Converts a higher precision tensor to float8 in a differentiable way. @@ -266,13 +351,18 @@ def to_float8( scale: the scale to use to convert the tensor float8_dtype: the float8 dtype to use amax_buffer: a buffer to store the amax value in prior to conversion - mm_config: Defines the configuration for the scaled_mm + linearmm_config: Defines the configuration for 3 gemms in fwd/bwd of linear Returns: Float8Tensor: a float8 tensor """ return ToFloat8ConstrFunc.apply( - tensor, scale, float8_dtype, amax_buffer, mm_config + tensor, + scale, + float8_dtype, + amax_buffer, + linear_mm_config, + gemm_input_role, ) @classmethod diff --git a/float8_experimental/float8_tensor_parallel.py b/float8_experimental/float8_tensor_parallel.py index 7c012f6..4c5297c 100644 --- a/float8_experimental/float8_tensor_parallel.py +++ b/float8_experimental/float8_tensor_parallel.py @@ -5,6 +5,7 @@ cast_to_float8_e5m2_dynamic_bw, ) from float8_experimental.float8_linear import TensorScalingType +from float8_experimental.float8_tensor import GemmInputRole from torch.distributed._tensor import DTensor from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor.parallel import ( @@ -45,7 +46,9 @@ def _prepare_input_fn( ) input_tensor = cast_to_float8_e4m3_dynamic( - input_tensor, mod.forward_config + input_tensor, + mod.linear_mm_config, + gemm_input_role=GemmInputRole.X, ) # DTensor(Float8Tensor) # transform the input layouts to the desired layouts of ColwiseParallel @@ -64,7 +67,7 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me ) # DTensor(torch.Tensor) # fwd noop bwd cast to DTensor(Float8Tensor) - outputs = cast_to_float8_e5m2_dynamic_bw(outputs, mod.backward_config) + outputs = cast_to_float8_e5m2_dynamic_bw(outputs, mod.linear_mm_config) # back to local tensor return outputs.to_local() if use_local_output else outputs @@ -96,7 +99,9 @@ def _prepare_input_fn( ) input_tensor = cast_to_float8_e4m3_dynamic( - input_tensor, mod.forward_config + input_tensor, + mod.linear_mm_config, + gemm_input_role=GemmInputRole.X, ) # DTensor(Float8Tensor) if input_layouts != desired_input_layouts: @@ -114,7 +119,7 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me outputs = outputs.redistribute(placements=output_layouts, async_op=True) # fwd noop bwd cast to DTensor(Float8Tensor) - outputs = cast_to_float8_e5m2_dynamic_bw(outputs, mod.backward_config) + outputs = cast_to_float8_e5m2_dynamic_bw(outputs, mod.linear_mm_config) # back to local tensor if use_local_output is True return outputs.to_local() if use_local_output else outputs @@ -169,6 +174,7 @@ def __init__( # fp8 specific fields self.float8_dtype = float8_dtype + self.linear_mm_config = None self.fwd_config_submodule_fqn = fwd_config_submodule_fqn if self.float8_dtype != torch.float8_e4m3fn: @@ -191,7 +197,9 @@ def _prepare_input_arg(self, input, mesh, input_layout, desired_layout): ) dt_inp = cast_to_float8_e4m3_dynamic( - dt_inp, self.fwd_linear_config + dt_inp, + self.linear_mm_config, + gemm_input_role=GemmInputRole.X, ) # DTensor(Float8Tensor) if desired_layout is not None and input_layout != desired_layout: dt_inp = dt_inp.redistribute(placements=(desired_layout,)) @@ -203,22 +211,21 @@ def _prepare_input_arg(self, input, mesh, input_layout, desired_layout): def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: from float8_experimental.float8_linear import Float8Linear - fwd_linear_config = None if self.fwd_config_submodule_fqn is not None: fwd_linear = module.get_submodule(self.fwd_config_submodule_fqn) assert isinstance(fwd_linear, Float8Linear) - fwd_linear_config = fwd_linear.forward_config + self.linear_mm_config = fwd_linear.linear_mm_config else: # search for ScaledMM configs for all the submodules and make sure they are the same for mod in module.modules(): if isinstance(mod, Float8Linear): - if fwd_linear_config is None: - fwd_linear_config = mod.forward_config + if self.linear_mm_config is None: + self.linear_mm_config = mod.linear_mm_config else: assert ( - fwd_linear_config == mod.forward_config - ), "All the Float8Linear modules should have same forward config!" + self.linear_mm_config == mod.linear_mm_config + ), "All the Float8Linear modules should have same linear_mm_config!" - self.fwd_linear_config = fwd_linear_config + assert self.linear_mm_config is not None super()._apply(module, device_mesh) return module diff --git a/float8_experimental/fsdp_utils.py b/float8_experimental/fsdp_utils.py index c7eb2c0..04cd797 100644 --- a/float8_experimental/fsdp_utils.py +++ b/float8_experimental/fsdp_utils.py @@ -14,8 +14,8 @@ from float8_experimental.float8_tensor import ( Float8Tensor, - merge_mm_configs, - ScaledMMConfig, + GemmInputRole, + LinearMMConfig, ) from float8_experimental.float8_utils import e4m3_dtype, EPS @@ -89,7 +89,7 @@ class WeightWithDynamicFloat8CastTensor(torch.Tensor): def __new__( cls, tensor: torch.Tensor, - mm_config: ScaledMMConfig, + linear_mm_config: LinearMMConfig, precomputed_scale: Optional[torch.Tensor] = None, ): return torch.Tensor._make_wrapper_subclass( @@ -108,11 +108,11 @@ def __new__( def __init__( self, tensor: torch.Tensor, - mm_config: ScaledMMConfig, + linear_mm_config: LinearMMConfig, precomputed_scale: Optional[torch.Tensor] = None, ): self._tensor = tensor - self._mm_config = mm_config + self._linear_mm_config = linear_mm_config # for dynamic scaling # `precompute_float8_dynamic_scale_for_fsdp` calculates scales # for all float8 parameters after optimizer step @@ -122,16 +122,16 @@ def __init__( def __torch_dispatch__(cls, func, types, args, kwargs=None): if func == torch.ops.aten.detach.default: return WeightWithDynamicFloat8CastTensor( - args[0]._tensor, args[0]._mm_config + args[0]._tensor, args[0]._linear_mm_config ) - mm_config: Optional[ScaledMMConfig] = None + mm_config: Optional[LinearMMConfig] = None def unwrap(t): nonlocal mm_config if mm_config is None: - mm_config = t._mm_config + mm_config = t._linear_mm_config else: - mm_config = merge_mm_configs(mm_config, t._mm_config) + assert t._linear_mm_config == mm_config return t._tensor args, kwargs = pytree.tree_map_only( @@ -146,9 +146,9 @@ def unwrap(t): def __tensor_flatten__(self): if self._precomputed_scale: - return ["_tensor", "_precomputed_scale"], self._mm_config + return ["_tensor", "_precomputed_scale"], self._linear_mm_config else: - return ["_tensor"], self._mm_config + return ["_tensor"], self._linear_mm_config @staticmethod def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): @@ -160,7 +160,7 @@ def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): ) def __repr__(self): - return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config})" + return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, linear_mm_config={self._linear_mm_config})" def fsdp_pre_all_gather(self, mesh): if self._precomputed_scale is not None: @@ -168,11 +168,15 @@ def fsdp_pre_all_gather(self, mesh): self._tensor, self._precomputed_scale, torch.float8_e4m3fn, - mm_config=self._mm_config, + linear_mm_config=self._linear_mm_config, + gemm_input_role=GemmInputRole.W, ) else: float8_tensor = cast_to_float8_e4m3_dynamic( - self._tensor, self._mm_config, reduce_amax=True + self._tensor, + self._linear_mm_config, + reduce_amax=True, + gemm_input_role=GemmInputRole.W, ) return (float8_tensor._data,), (float8_tensor._scale,) @@ -190,7 +194,13 @@ def fsdp_post_all_gather( assert isinstance(out, Float8Tensor), f"{type(out)}" out._scale = scale return - return Float8Tensor(data, scale, param_dtype, self._mm_config), (data,) + return Float8Tensor( + data, + scale, + param_dtype, + self._linear_mm_config, + gemm_input_role=GemmInputRole.W, + ), (data,) class WeightWithDelayedFloat8CastTensor(torch.Tensor): @@ -201,7 +211,7 @@ def __new__( amax_buffer: torch.Tensor, amax_history_buffer: torch.Tensor, scale_buffer: torch.Tensor, - mm_config: ScaledMMConfig, + linear_mm_config: LinearMMConfig, is_amax_initialized: bool, ): return torch.Tensor._make_wrapper_subclass( @@ -223,14 +233,14 @@ def __init__( amax_buffer: torch.Tensor, amax_history_buffer: torch.Tensor, scale_buffer: torch.Tensor, - mm_config: ScaledMMConfig, + linear_mm_config: LinearMMConfig, is_amax_initialized: bool, ): self._tensor = tensor self._amax_buffer = amax_buffer self._amax_history_buffer = amax_history_buffer self._scale_buffer = scale_buffer - self._mm_config = mm_config + self._linear_mm_config = linear_mm_config # Note: is_amax_initialized is not a buffer to avoid data dependent # control flow visible to dynamo @@ -245,10 +255,10 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): args[0]._amax_buffer, args[0]._amax_history_buffer, args[0]._scale_buffer, - args[0]._mm_config, + args[0]._linear_mm_config, args[0].is_amax_initialized, ) - mm_config: Optional[ScaledMMConfig] = None + mm_config: Optional[LinearMMConfig] = None amax_buffer: Optional[torch.Tensor] = None amax_history_buffer: Optional[torch.Tensor] = None scale_buffer: Optional[torch.Tensor] = None @@ -257,9 +267,9 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): def unwrap(t): nonlocal mm_config if mm_config is None: - mm_config = t._mm_config + mm_config = t._linear_mm_config else: - mm_config = merge_mm_configs(mm_config, t._mm_config) + assert t._linear_mm_config == mm_config nonlocal amax_buffer if amax_buffer is None: amax_buffer = t._amax_buffer @@ -302,7 +312,7 @@ def __tensor_flatten__(self): "_scale_buffer", ], { - "mm_config": self._mm_config, + "mm_config": self._linear_mm_config, "is_amax_initialized": self.is_amax_initialized, }, ) @@ -319,7 +329,7 @@ def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride): ) def __repr__(self): - return f"WeightWithDelayedFloat8CastTensor(tensor={self._tensor}, amax_buffer={self._amax_buffer}, scale_buffer={self._scale_buffer}, mm_config={self._mm_config})" + return f"WeightWithDelayedFloat8CastTensor(tensor={self._tensor}, amax_buffer={self._amax_buffer}, scale_buffer={self._scale_buffer}, mm_config={self._linear_mm_config})" def fsdp_pre_all_gather(self, mesh): # initialize if needed @@ -351,7 +361,8 @@ def fsdp_pre_all_gather(self, mesh): self._scale_buffer, e4m3_dtype, self._amax_buffer, - self._mm_config, + self._linear_mm_config, + gemm_input_role=GemmInputRole.W, ) return (float8_tensor._data,), (float8_tensor._scale,) @@ -369,4 +380,10 @@ def fsdp_post_all_gather( assert isinstance(out, Float8Tensor), f"{type(out)}" out._scale = scale return - return Float8Tensor(data, scale, param_dtype, self._mm_config), (data,) + return Float8Tensor( + data, + scale, + param_dtype, + self._linear_mm_config, + gemm_input_role=GemmInputRole.W, + ), (data,) diff --git a/float8_experimental/inference.py b/float8_experimental/inference.py index 1c931ee..2bb593b 100644 --- a/float8_experimental/inference.py +++ b/float8_experimental/inference.py @@ -20,6 +20,8 @@ from float8_experimental.float8_tensor import ( Float8Tensor, + GemmInputRole, + LinearMMConfig, ScaledMMConfig, tensor_already_casted_to_fp8, to_fp8_no_autograd, @@ -73,7 +75,7 @@ def __init__( self, # FP8 specific arguments quant_config: QuantConfig, - forward_config: ScaledMMConfig, + linear_mm_config: LinearMMConfig, # nn.Linear arguments in_features: int, out_features: int, @@ -83,7 +85,7 @@ def __init__( ) -> None: # Construct the superclass this will create dummy weights and biases super().__init__(in_features, out_features, bias, device, dtype) - self.forward_config = forward_config + self.linear_mm_config = linear_mm_config self.activation_casting = quant_config.activation_casting if self.activation_casting == ActivationCasting.STATIC: self.register_buffer( @@ -100,7 +102,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: x_fp8 = cast_to_float8_e4m3_inference( input, - self.forward_config, + self.linear_mm_config, static_quantization_scale=self.static_quantization_scale, ) return torch.nn.functional.linear(x_fp8, self.weight, self.bias) @@ -125,7 +127,8 @@ def quantize_weight(self, dtype: torch.dtype = e4m3_dtype) -> None: self.weight, scale, dtype, - self.forward_config, + self.linear_mm_config, + gemm_input_role=GemmInputRole.W, ) self.weight = nn.Parameter(quantized_weight) self.weight.requires_grad = False @@ -150,9 +153,12 @@ def from_float( forward_config = ScaledMMConfig( False, use_fast_accum, pad_inner_dim=config.pad_inner_dim ) + linear_mm_config = LinearMMConfig( + forward_config, forward_config, forward_config + ) linear = cls( quant_config, - forward_config, + linear_mm_config, module.in_features, module.out_features, False, @@ -165,7 +171,7 @@ def from_float( def cast_to_float8_e4m3_inference( inpt_tensor: torch.Tensor, - mm_config: ScaledMMConfig, + linear_mm_config: LinearMMConfig, reduce_amax: bool = False, static_quantization_scale: Optional[torch.Tensor] = None, ) -> Float8Tensor: @@ -173,7 +179,7 @@ def cast_to_float8_e4m3_inference( Args: inpt_tensor: The input tensor to be cast. - mm_config: Configuration settings for the matrix multiplication + linear_mm_config: Configuration settings for the matrix multiplication reduce_amax: Whether to reduce the amax (absolute maximum) among the local distributed group. static_quantization_scale: Optional tensor specifying the scale for activation. Default is None. @@ -190,7 +196,13 @@ def cast_to_float8_e4m3_inference( if static_quantization_scale is not None else tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax) ) - return Float8Tensor.to_float8(inpt_tensor, scale, e4m3_dtype, mm_config=mm_config) + return Float8Tensor.to_float8( + inpt_tensor, + scale, + e4m3_dtype, + linear_mm_config=linear_mm_config, + gemm_input_role=GemmInputRole.X, + ) def quantize_to_float8( diff --git a/test/test_base.py b/test/test_base.py index 2c7c3f4..381fb4e 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -26,7 +26,8 @@ from float8_experimental.float8_python_api import addmm_float8_unwrapped from float8_experimental.float8_tensor import ( Float8Tensor, - merge_mm_configs, + GemmInputRole, + LinearMMConfig, ScaledMMConfig, ) from float8_experimental.float8_utils import ( @@ -121,7 +122,7 @@ def test_copy_(self): torch.empty(16, dtype=torch.float8_e4m3fn), scale_a, torch.bfloat16, - fp8_a._mm_config, + fp8_a._linear_mm_config, ) fp8_b.copy_(fp8_a) torch.testing.assert_close(fp8_a._data, fp8_b._data) @@ -438,38 +439,36 @@ def test_different_configs_error(self): x_fp32 = torch.randn(16, 16, device="cuda") x_scale = torch.tensor(1.0, device="cuda") fp8_dtype = e4m3_dtype - a = Float8Tensor.to_float8(x_fp32, x_scale, fp8_dtype) + linear_config_a = LinearMMConfig( + ScaledMMConfig(False, True, False, False), + ScaledMMConfig(False, False, False, False), + ScaledMMConfig(False, False, False, False), + ) + linear_config_b = LinearMMConfig( + ScaledMMConfig(True, True, False, False), + ScaledMMConfig(True, False, False, False), + ScaledMMConfig(True, False, False, False), + ) + a = Float8Tensor.to_float8( + x_fp32, + x_scale, + fp8_dtype, + linear_mm_config=linear_config_a, + gemm_input_role=GemmInputRole.X, + ) b = Float8Tensor.to_float8( - x_fp32, x_scale, fp8_dtype, mm_config=ScaledMMConfig(True) + x_fp32, + x_scale, + fp8_dtype, + linear_mm_config=linear_config_b, + gemm_input_role=GemmInputRole.W, ) with pytest.raises( AssertionError, - match="Both mm_configs must have the same emulate value, but got False and True", + match="linear_mm_config.y mismatch", ): a @ b - def test_merge_configs(self): - a = ScaledMMConfig(False, True, True) - b = ScaledMMConfig(True, False, False) - with pytest.raises( - AssertionError, - match="Both mm_configs must have the same emulate value, but got False and True", - ): - merge_mm_configs(a, b) - a = ScaledMMConfig(False, True, True) - b = ScaledMMConfig(False, False, False) - c = merge_mm_configs(a, b) - assert c.emulate is False - assert c.use_fast_accum is False - assert c.fp8_output is False - - a = ScaledMMConfig(False, True, False) - b = ScaledMMConfig(False, True, False) - c = merge_mm_configs(a, b) - assert c.emulate is False - assert c.use_fast_accum is True - assert c.fp8_output is False - @unittest.skipIf( not is_H100, "CUDA not available", @@ -489,8 +488,12 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum): a_scale = tensor_to_scale(a, input_dtype).float() b_scale = tensor_to_scale(b, input_dtype).float() - a_fp8 = Float8Tensor.to_float8(a, a_scale, input_dtype) - b_fp8 = Float8Tensor.to_float8(b, b_scale, input_dtype) + a_fp8 = Float8Tensor.to_float8( + a, a_scale, input_dtype, gemm_input_role=GemmInputRole.X + ) + b_fp8 = Float8Tensor.to_float8( + b, b_scale, input_dtype, gemm_input_role=GemmInputRole.W + ) with pytest.raises( RuntimeError, @@ -500,19 +503,47 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum): ): a_fp8 @ b_fp8 - pad_config = ScaledMMConfig(False, use_fast_accum, False, True) + scaled_mm_config = ScaledMMConfig(False, use_fast_accum, False, True) + pad_config = LinearMMConfig( + scaled_mm_config, scaled_mm_config, scaled_mm_config + ) - a_fp8 = Float8Tensor.to_float8(a, a_scale, input_dtype, mm_config=pad_config) - b_fp8 = Float8Tensor.to_float8(b, b_scale, input_dtype, mm_config=pad_config) + a_fp8 = Float8Tensor.to_float8( + a, + a_scale, + input_dtype, + linear_mm_config=pad_config, + gemm_input_role=GemmInputRole.X, + ) + b_fp8 = Float8Tensor.to_float8( + b, + b_scale, + input_dtype, + linear_mm_config=pad_config, + gemm_input_role=GemmInputRole.W, + ) out_padded = a_fp8 @ b_fp8 out_padded.to(compare_type) - emulated_conifg = ScaledMMConfig(True, use_fast_accum, False, False) + emulated_scaled_mm_config = ScaledMMConfig(True, use_fast_accum, False, False) + emulated_config = LinearMMConfig( + emulated_scaled_mm_config, + emulated_scaled_mm_config, + emulated_scaled_mm_config, + ) a_fp8 = Float8Tensor.to_float8( - a, a_scale, input_dtype, mm_config=emulated_conifg + a, + a_scale, + input_dtype, + linear_mm_config=emulated_config, + gemm_input_role=GemmInputRole.X, ) b_fp8 = Float8Tensor.to_float8( - b, b_scale, input_dtype, mm_config=emulated_conifg + b, + b_scale, + input_dtype, + linear_mm_config=emulated_config, + gemm_input_role=GemmInputRole.W, ) out_emualted = a_fp8 @ b_fp8 out_emualted.to(compare_type) @@ -564,8 +595,8 @@ def test_swap_root_linear(self): module = nn.Linear(3, 3) module = swap_linear_with_float8_linear(module, emulate=emulate) self.assertIsInstance(module, Float8Linear) - self.assertEqual(module.forward_config.emulate, emulate) - self.assertEqual(module.backward_config.emulate, emulate) + self.assertEqual(module.linear_mm_config.y.emulate, emulate) + self.assertEqual(module.linear_mm_config.y.emulate, emulate) def test_swap_root_linear_with_children_raises(self): for emulate in [True, False]: diff --git a/test/test_compile.py b/test/test_compile.py index 5a6e003..f72425d 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -19,7 +19,7 @@ swap_linear_with_float8_linear, sync_float8_amax_and_scale_history, ) -from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig +from float8_experimental.float8_tensor import Float8Tensor, LinearMMConfig from float8_experimental.float8_utils import e4m3_dtype from torch._dynamo.test_case import TestCase as DynamoTestCase @@ -179,7 +179,7 @@ def forward(self, x): self.fp8_scale_x, e4m3_dtype, self.fp8_amax_x, - ScaledMMConfig(), + LinearMMConfig(), ) if self.graph_break: torch._dynamo.graph_break() @@ -242,9 +242,9 @@ def test_float8_graph_output(self): type(y_compiled._orig_dtype) ) assert isinstance( - y_compiled._mm_config.emulate, bool + y_compiled._linear_mm_config.y.emulate, bool ), "Float8Tensor._emulate should be a bool but got {}".format( - type(y_compiled._mm_config.emulate) + type(y_compiled._linear_mm_config.y.emulate) ) diff --git a/test/test_dtensor.py b/test/test_dtensor.py index 8aada4b..1cd14db 100644 --- a/test/test_dtensor.py +++ b/test/test_dtensor.py @@ -17,7 +17,11 @@ from float8_experimental.float8_dynamic_utils import NoopFwToFloat8E5M2Bw from float8_experimental.float8_linear import TensorScalingType from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear -from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig +from float8_experimental.float8_tensor import ( + Float8Tensor, + GemmInputRole, + LinearMMConfig, +) from float8_experimental.float8_tensor_parallel import ( Float8ColwiseParallel, Float8RowwiseParallel, @@ -82,8 +86,12 @@ def test_scaled_mm(mesh: DeviceMesh, size=16): x_scale = tensor_to_scale(x_fp32, fp8_dtype).float() y_scale = tensor_to_scale(y_fp32, fp8_dtype).float() - x_fp8 = Float8Tensor.to_float8(x_fp32, x_scale, fp8_dtype) - y_fp8 = Float8Tensor.to_float8(y_fp32, y_scale, fp8_dtype) + x_fp8 = Float8Tensor.to_float8( + x_fp32, x_scale, fp8_dtype, gemm_input_role=GemmInputRole.X + ) + y_fp8 = Float8Tensor.to_float8( + y_fp32, y_scale, fp8_dtype, gemm_input_role=GemmInputRole.W + ) dist_x_fp8 = DTensor.from_local(x_fp8, mesh, [lhs_placement], run_check=False) dist_y_fp8 = DTensor.from_local(y_fp8, mesh, [rhs_placement], run_check=False) @@ -155,13 +163,15 @@ def test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16): dist_weight_scale = tensor_to_scale(dist_wight_fp32, fp8_dtype).float() dist_target = distribute_tensor(target, mesh, [Shard(0)]) - dist_x_fp8 = Float8Tensor.to_float8(dist_x_fp32, dist_x_scale, fp8_dtype) + dist_x_fp8 = Float8Tensor.to_float8( + dist_x_fp32, dist_x_scale, fp8_dtype, gemm_input_role=GemmInputRole.X + ) dist_weight_fp8 = Float8Tensor.to_float8( - dist_wight_fp32, dist_weight_scale, fp8_dtype + dist_wight_fp32, dist_weight_scale, fp8_dtype, gemm_input_role=GemmInputRole.W ) out = torch.nn.functional.linear(dist_x_fp8, dist_weight_fp8) - out = NoopFwToFloat8E5M2Bw.apply(out, ScaledMMConfig()) + out = NoopFwToFloat8E5M2Bw.apply(out, LinearMMConfig()) assert isinstance(out, DTensor), f"Expected DTensor, got {type(out)}" loss = torch.sum(torch.abs(out - dist_target)) loss.backward()