From 8e9623ac8173f83b6ab0ee0d86546730c70211bf Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Fri, 5 Jul 2024 10:50:32 -0700 Subject: [PATCH] delete Float8DynamicLinear (#304) Summary: Pull Request resolved: https://github.com/pytorch-labs/float8_experimental/pull/304 We are standardizing on `Float8Linear` as the only float8 linear object: 1. the stack ending with https://github.com/pytorch-labs/float8_experimental/pull/300 moved all of the functionality of `Float8DynamicLinear` to `Float8Linear`. The default settings of `Float8Linear` are to use dynamic scaling. 2. this PR deletes `Float8DynamicLinear` from the codebase and patches the relevant callsites in fbsource. Reviewed By: drisspg Differential Revision: D59342767 fbshipit-source-id: cfb09dd5f6517cfbf41d8b46eb6d7d6a5266006a --- benchmarks/bench_linear_float8.py | 65 ++------ benchmarks/bench_multi_gpu.py | 10 +- benchmarks/profile_linear_float8.py | 49 +++--- ...amic_linear.py => float8_dynamic_utils.py} | 58 ------- float8_experimental/float8_linear.py | 6 +- float8_experimental/float8_linear_utils.py | 63 +------ float8_experimental/float8_tensor_parallel.py | 22 +-- test/test_base.py | 154 +++++++----------- test/test_compile.py | 58 +------ test/test_dtensor.py | 66 ++------ test/test_fsdp.py | 3 - test/test_fsdp2/test_fsdp2_common.py | 9 +- test/test_fsdp2/test_fsdp2_eager.py | 123 ++++---------- test/test_fsdp_compile.py | 1 - test/test_inference_flows.py | 2 - test/test_numerics_integration.py | 29 +--- 16 files changed, 182 insertions(+), 536 deletions(-) rename float8_experimental/{float8_dynamic_linear.py => float8_dynamic_utils.py} (72%) diff --git a/benchmarks/bench_linear_float8.py b/benchmarks/bench_linear_float8.py index 8020ccc..5f8e4f9 100644 --- a/benchmarks/bench_linear_float8.py +++ b/benchmarks/bench_linear_float8.py @@ -14,11 +14,9 @@ import torch import torch.utils.benchmark as benchmark -from float8_experimental.float8_linear import TensorScalingType +from float8_experimental.float8_linear import Float8Linear, TensorScalingType from float8_experimental.float8_linear_utils import ( - get_float8_linear, linear_requires_sync, - LinearType, sync_float8_amax_and_scale_history, ) from float8_experimental.float8_tensor import ScaledMMConfig @@ -69,7 +67,6 @@ class Experiment: dtype: torch.dtype compiled: bool use_fast_accum: bool - linear_type: str scaling_repr: str # 3 Times since we are calculating forward backward @@ -98,7 +95,6 @@ def main( n_limit: Optional[int] = None, fast_accum_filter: Optional[bool] = None, shape_name_filter: Optional[str] = None, - linear_type_filter: Optional[str] = None, scaling_type_x: str = "delayed", scaling_type_w: str = "delayed", scaling_type_dL_dY: str = "delayed", @@ -123,44 +119,28 @@ def main( use_fast_accum = [fast_accum_filter] else: use_fast_accum = [True, False] - if linear_type_filter is not None: - linear_types = [linear_type_filter] - else: - linear_types = ["delayed", "dynamic"] if shape_name_filter is not None: k = shape_name_filter name_to_shapes_70b = {k: name_to_shapes_70b[k]} experiment_list: List[Experiment] = [] dtype = torch.bfloat16 - for idx, (fast_accum, (name, (K, N)), linear_type) in enumerate( - tqdm(list(product(use_fast_accum, name_to_shapes_70b.items(), linear_types))) + for idx, (fast_accum, (name, (K, N))) in enumerate( + tqdm(list(product(use_fast_accum, name_to_shapes_70b.items()))) ): if n_limit is not None and idx >= n_limit: break linear_ref = torch.nn.Linear(K, N, bias=input_bias).to( device=device, dtype=dtype ) - linear_type_enum = ( - LinearType.DELAYED if linear_type == "delayed" else LinearType.DYNAMIC - ) - if linear_type == "delayed": - linear_float8 = get_float8_linear( - linear_type_enum, - copy.deepcopy(linear_ref), - emulate=False, - scaling_type_x=scaling_type_x, - scaling_type_w=scaling_type_w, - scaling_type_dL_dY=scaling_type_dL_dY, - ) - scaling_repr = linear_float8.scaling_repr() - else: - linear_float8 = get_float8_linear( - linear_type_enum, - copy.deepcopy(linear_ref), - emulate=False, - ) - scaling_repr = None + linear_float8 = Float8Linear.from_float( + copy.deepcopy(linear_ref), + emulate=False, + scaling_type_x=scaling_type_x, + scaling_type_w=scaling_type_w, + scaling_type_dL_dY=scaling_type_dL_dY, + ) + scaling_repr = linear_float8.scaling_repr() if fast_accum: linear_float8.forward_config = ScaledMMConfig(False, True, False) @@ -172,19 +152,10 @@ def main( input_tensor = torch.randn(M, K, device=device, dtype=dtype, requires_grad=True) ref_forw_backward = lambda: linear_ref(input_tensor).sum().backward() - if linear_type_enum == LinearType.DELAYED: - - def float8_forw_backward(): - if linear_requires_sync( - linear_type_enum, scaling_type_x, scaling_type_w, scaling_type_dL_dY - ): - sync_float8_amax_and_scale_history(linear_float8) - linear_float8(input_tensor).sum().backward() - - else: - - def float8_forw_backward(): - linear_float8(input_tensor).sum().backward() + def float8_forw_backward(): + if linear_requires_sync(scaling_type_x, scaling_type_w, scaling_type_dL_dY): + sync_float8_amax_and_scale_history(linear_float8) + linear_float8(input_tensor).sum().backward() def n_times(n, fn, *args, **kwargs): def wrapper(*args, **kwargs): @@ -224,7 +195,6 @@ def wrapper(*args, **kwargs): dtype, compile, use_fast_accum=fast_accum, - linear_type=linear_type, scaling_repr=scaling_repr, ) print(experiment) @@ -237,7 +207,6 @@ def wrapper(*args, **kwargs): "M", "K", "N", - "linear_type", "scaling_repr", "ref_dtype", "compiled", @@ -257,7 +226,6 @@ def wrapper(*args, **kwargs): experiment.shape[0], experiment.shape[1], experiment.shape[2], - experiment.linear_type, experiment.scaling_repr, experiment.dtype, experiment.compiled, @@ -287,7 +255,6 @@ def wrapper(*args, **kwargs): [ "name", "shape", - "linear_type", "scaling_repr", "compiled", "use_fast_accum", @@ -311,7 +278,6 @@ def invoke_main() -> None: parser.add_argument("-n", "--n_limit", type=int, required=False) parser.add_argument("--fast_accum_filter", type=bool, required=False) parser.add_argument("--shape_name_filter", type=str, required=False) - parser.add_argument("--linear_type_filter", type=str, required=False) parser.add_argument("--scaling_type_x", type=str, required=False) parser.add_argument("--scaling_type_w", type=str, required=False) parser.add_argument("--scaling_type_dL_dY", type=str, required=False) @@ -330,7 +296,6 @@ def invoke_main() -> None: args.n_limit, args.fast_accum_filter, args.shape_name_filter, - args.linear_type_filter, **kwargs, ) diff --git a/benchmarks/bench_multi_gpu.py b/benchmarks/bench_multi_gpu.py index 8c54b3d..12a1ddb 100644 --- a/benchmarks/bench_multi_gpu.py +++ b/benchmarks/bench_multi_gpu.py @@ -14,7 +14,7 @@ import torch.multiprocessing as mp import torch.nn as nn import torch.utils.benchmark as benchmark -from float8_experimental.float8_linear import Float8Linear +from float8_experimental.float8_linear import Float8Linear, TensorScalingType from float8_experimental.float8_linear_utils import ( swap_linear_with_float8_linear, sync_float8_amax_and_scale_history, @@ -65,7 +65,13 @@ def get_model(K, N, is_fp8, base_dtype=torch.float32): modules.append(nn.ReLU()) m = nn.Sequential(*modules) if is_fp8: - swap_linear_with_float8_linear(m, Float8Linear, emulate=False) + swap_linear_with_float8_linear( + m, + emulate=False, + scaling_type_x=TensorScalingType.DELAYED, + scaling_type_w=TensorScalingType.DELAYED, + scaling_type_dL_dY=TensorScalingType.DELAYED, + ) return m diff --git a/benchmarks/profile_linear_float8.py b/benchmarks/profile_linear_float8.py index c5d7c44..1ef5478 100644 --- a/benchmarks/profile_linear_float8.py +++ b/benchmarks/profile_linear_float8.py @@ -18,11 +18,9 @@ import torch import torch.nn as nn import torch.nn.functional as F -from float8_experimental.float8_dynamic_linear import Float8DynamicLinear from float8_experimental.float8_linear import Float8Linear, TensorScalingType from float8_experimental.float8_linear_utils import ( linear_requires_sync, - LinearType, swap_linear_with_float8_linear, sync_float8_amax_and_scale_history, ) @@ -206,19 +204,25 @@ def profile_function( def main( profile_path_prefix: Path, compile: bool = True, - linear_type: str = "dynamic", - scaling_type_x: str = "delayed", - scaling_type_w: str = "delayed", - scaling_type_dL_dY: str = "delayed", + scaling_type_x: str = "dynamic", + scaling_type_w: str = "dynamic", + scaling_type_dL_dY: str = "dynamic", model_type: str = "linear", dtype_filter: str = "both", ): assert model_type in ("linear", "ln_linear", "norm_ffn_norm"), "unsupported" assert dtype_filter in ("both", "float8", "bfloat16") - print(f"Compile is set to | {compile}") - print(f"Using Linear type: | {linear_type}") - print(f"model_type is set to | {model_type}") + scaling_type_x = TensorScalingType(scaling_type_x) + scaling_type_w = TensorScalingType(scaling_type_w) + scaling_type_dL_dY = TensorScalingType(scaling_type_dL_dY) + scaling_repr = "_".join( + [s.short_str() for s in (scaling_type_x, scaling_type_w, scaling_type_dL_dY)] + ) + + print(f"Compile is set to | {compile}") + print(f"model_type is set to | {model_type}") + print(f"scaling_repr is set to | {scaling_repr}") device = "cuda" ref_dtype = torch.bfloat16 @@ -249,21 +253,14 @@ def main( m_ref = m_ref.to(device).to(ref_dtype) - linear_type = LinearType[linear_type.upper()] - linear_cls = ( - Float8Linear if linear_type is LinearType.DELAYED else Float8DynamicLinear - ) - extra_kwargs = {} - scaling_type_x = TensorScalingType(scaling_type_x) - scaling_type_w = TensorScalingType(scaling_type_w) - scaling_type_dL_dY = TensorScalingType(scaling_type_dL_dY) - if linear_type is LinearType.DELAYED: - extra_kwargs["scaling_type_x"] = scaling_type_x - extra_kwargs["scaling_type_w"] = scaling_type_w - extra_kwargs["scaling_type_dL_dY"] = scaling_type_dL_dY + extra_kwargs = { + "scaling_type_x": scaling_type_x, + "scaling_type_w": scaling_type_w, + "scaling_type_dL_dY": scaling_type_dL_dY, + } m_float8 = copy.deepcopy(m_ref) - swap_linear_with_float8_linear(m_float8, linear_cls, **extra_kwargs) + swap_linear_with_float8_linear(m_float8, **extra_kwargs) def ref_forw_backward(x): out = m_ref(x) @@ -281,9 +278,7 @@ def float8_forw_backward_wrapper(x): # inspection of the fw+bw torch.compile without the scale # syncing code # TODO(future): make this better - if linear_requires_sync( - linear_type, scaling_type_x, scaling_type_w, scaling_type_dL_dY - ): + if linear_requires_sync(scaling_type_x, scaling_type_w, scaling_type_dL_dY): with record_function("scale_amax_and_scales"): sync_amax_history(m_float8) out = float8_forw(x) @@ -345,7 +340,9 @@ def float8_forw_backward_wrapper(x): if dtype_filter != "bfloat16": # Profile Float8 Model print("profiling float8") - float8_suffix = f"_{model_type}_float8_compile_{compile}_{linear_type}.json" + float8_suffix = ( + f"_{model_type}_float8_compile_{compile}_{scaling_repr}.json" + ) float8_path = profile_path_prefix + float8_suffix profile_config = ProfileConfig( float8_path, diff --git a/float8_experimental/float8_dynamic_linear.py b/float8_experimental/float8_dynamic_utils.py similarity index 72% rename from float8_experimental/float8_dynamic_linear.py rename to float8_experimental/float8_dynamic_utils.py index 763a521..f48424c 100644 --- a/float8_experimental/float8_dynamic_linear.py +++ b/float8_experimental/float8_dynamic_utils.py @@ -53,64 +53,6 @@ def backward(ctx, gradY): return fp8_tensor, None -class Float8DynamicLinear(torch.nn.Linear): - """ - A wrapper around a `torch.nn.Linear` module which does fp8 compute. By on the fly - conversion to fp8 of the input and weight tensors. - """ - - def __init__(self, **super_kwargs): - super().__init__(**super_kwargs) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - x_fp8 = cast_to_float8_e4m3_dynamic(input, self.forward_config) - if isinstance(self.weight, Float8Tensor): # cast by FSDP - w_fp8 = self.weight - else: - w_fp8 = cast_to_float8_e4m3_dynamic(self.weight, self.forward_config) - y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias) - y = cast_to_float8_e5m2_dynamic_bw(y, self.backward_config) - return y - - @classmethod - def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear": - """ - Create an nn.Linear with fp8 compute from a regular nn.Linear - - Args: - mod (torch.nn.Linear): nn.Linear to convert - emulate (bool): whether to emulate fp8 matmul logic in float32 - """ - with torch.device("meta"): - super_kwargs = { - "in_features": mod.in_features, - "out_features": mod.out_features, - "bias": False, - } - new_mod = cls(**super_kwargs) - - new_mod.forward_config = ScaledMMConfig( - emulate=emulate, - use_fast_accum=not bool(emulate), - fp8_output=False, - pad_inner_dim=config.pad_inner_dim, - ) - new_mod.backward_config = ScaledMMConfig( - emulate=emulate, - use_fast_accum=False, - fp8_output=False, - pad_inner_dim=config.pad_inner_dim, - ) - if config.enable_fsdp_fp8_all_gather: - new_mod.weight = nn.Parameter( - WeightWithDynamicFloat8CastTensor(mod.weight, new_mod.forward_config) - ) - else: - new_mod.weight = mod.weight - new_mod.bias = mod.bias - return new_mod - - def cast_to_float8_e4m3_dynamic( inpt_tensor: torch.Tensor, mm_config: ScaledMMConfig, reduce_amax: bool = False ) -> Float8Tensor: diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 90c207f..664a03a 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -16,7 +16,7 @@ import torch -from float8_experimental.float8_dynamic_linear import ( +from float8_experimental.float8_dynamic_utils import ( cast_to_float8_e4m3_dynamic, cast_to_float8_e5m2_dynamic_bw, WeightWithDynamicFloat8CastTensor, @@ -402,8 +402,8 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: def scaling_repr(self): # add scaling settings without using too many characters - # example: "x:del,w:del,dldy:dyn" - return f"x:{self.scaling_type_x.short_str()},w:{self.scaling_type_w.short_str()},dldy:{self.scaling_type_dL_dY.short_str()}" + # example: "x_del_w_del_dldy_dyn" + return f"x_{self.scaling_type_x.short_str()}_w_{self.scaling_type_w.short_str()}_dldy_{self.scaling_type_dL_dY.short_str()}" def extra_repr(self): s = f'{super().extra_repr()}, scaling="{self.scaling_repr()}"' diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index b1a17e4..945f7a6 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -11,7 +11,6 @@ import torch import torch.distributed as dist import torch.nn as nn -from float8_experimental.float8_dynamic_linear import Float8DynamicLinear from float8_experimental.float8_linear import Float8Linear, TensorScalingType from float8_experimental.float8_utils import ( @@ -25,51 +24,13 @@ log.addHandler(logging.NullHandler()) -class LinearType(Enum): - DELAYED = auto() - DYNAMIC = auto() - - -def get_float8_linear( - linear_type: LinearType, - linear_ref: torch.nn.Linear, - emulate: bool = False, - scaling_type_x: TensorScalingType = TensorScalingType.DELAYED, - scaling_type_w: TensorScalingType = TensorScalingType.DELAYED, - scaling_type_dL_dY: TensorScalingType = TensorScalingType.DELAYED, -): - """Returns a Float8Linear module of the given type, initialized from linear_ref. - Args: - linear_type: The type of Float8Linear to return. - linear_ref: The linear module to initialize from. - emulate: Whether to emulate the fp8 matmul logic in float32. - scaling_type_x: delayed vs dynamic scaling for `x`. - scaling_type_w: delayed vs dynamic scaling for `w`. - scaling_type_dL_dY: delayed vs dynamic scaling for `dL_dY`. - """ - if linear_type is LinearType.DYNAMIC: - return Float8DynamicLinear.from_float( - copy.deepcopy(linear_ref), emulate=emulate - ) - else: - assert linear_type is LinearType.DELAYED - return Float8Linear.from_float( - copy.deepcopy(linear_ref), - emulate=emulate, - scaling_type_x=scaling_type_x, - scaling_type_w=scaling_type_w, - scaling_type_dL_dY=scaling_type_dL_dY, - ) - - def linear_requires_sync( - linear_type: LinearType, scaling_type_x: TensorScalingType = TensorScalingType.DELAYED, scaling_type_w: TensorScalingType = TensorScalingType.DELAYED, scaling_type_dL_dY: TensorScalingType = TensorScalingType.DELAYED, ): """Returns whether the given linear_type requires sync before forward.""" - return linear_type is LinearType.DELAYED and any( + return any( [ scaling_type_x is TensorScalingType.DELAYED, scaling_type_w is TensorScalingType.DELAYED, @@ -186,7 +147,6 @@ def post_order_traversal( def swap_linear_with_float8_linear( module: nn.Module, - module_cls: Union[Type[Float8Linear], Type[Float8DynamicLinear]], *, skip_fqn_list: Optional[List[str]] = None, emulate: bool = False, @@ -196,12 +156,10 @@ def swap_linear_with_float8_linear( scaling_type_dL_dY: TensorScalingType = TensorScalingType.DYNAMIC, ) -> Optional[nn.Module]: """ - Swaps `torch.nn.Linear` in `module` with `Float8Linear` or `Float8DynamicLinear`. + Swaps `torch.nn.Linear` in `module` with `Float8Linear`. Args: module: Module to modify. - module_cls: `Float8Linear` or `Float8DynamicLinear`. - from_float_func: Function that accepts a linear layer and returns a new type of linear layer. skip_fqn_list: If specified, a list of module FQNs to skip. emulate: If True, emulation is used instead of hardware accelerated gemm linear_layer_filter: If specified, only the linear layers @@ -213,16 +171,13 @@ def swap_linear_with_float8_linear( Returns: nn.Module: The modified module with swapped linear layers. """ - if module_cls is Float8DynamicLinear: - from_float = lambda m: module_cls.from_float(m, emulate=emulate) - else: - from_float = lambda m: module_cls.from_float( - m, - emulate=emulate, - scaling_type_x=scaling_type_x, - scaling_type_w=scaling_type_w, - scaling_type_dL_dY=scaling_type_dL_dY, - ) + from_float = lambda m: Float8Linear.from_float( + m, + emulate=emulate, + scaling_type_x=scaling_type_x, + scaling_type_w=scaling_type_w, + scaling_type_dL_dY=scaling_type_dL_dY, + ) return swap_linear_layers( module, from_float, diff --git a/float8_experimental/float8_tensor_parallel.py b/float8_experimental/float8_tensor_parallel.py index fac0201..7c012f6 100644 --- a/float8_experimental/float8_tensor_parallel.py +++ b/float8_experimental/float8_tensor_parallel.py @@ -1,6 +1,6 @@ import torch import torch.nn as nn -from float8_experimental.float8_dynamic_linear import ( +from float8_experimental.float8_dynamic_utils import ( cast_to_float8_e4m3_dynamic, cast_to_float8_e5m2_dynamic_bw, ) @@ -21,7 +21,6 @@ # creating the DTensor. # NOTE: This only works and tested with the dynamic scaling -# (Float8DynamicLinear and Float8Linear with dynamic scaling for all tensors) def _float8_linear_supports_float8_allgather(m): @@ -71,12 +70,11 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me return outputs.to_local() if use_local_output else outputs def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: - from float8_experimental.float8_dynamic_linear import Float8DynamicLinear from float8_experimental.float8_linear import Float8Linear - if not isinstance(module, (Float8DynamicLinear, Float8Linear)): + if not isinstance(module, Float8Linear): raise ValueError( - f"Expecting module to be Float8DynamicLinear or Float8Linear but found {type(module)}" + f"Expecting module to be Float8Linear but found {type(module)}" ) elif isinstance( module, Float8Linear @@ -122,12 +120,11 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me return outputs.to_local() if use_local_output else outputs def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: - from float8_experimental.float8_dynamic_linear import Float8DynamicLinear from float8_experimental.float8_linear import Float8Linear - if not isinstance(module, (Float8DynamicLinear, Float8Linear)): + if not isinstance(module, Float8Linear): raise ValueError( - f"Expecting module to be Float8DynamicLinear or Float8Linear but found {type(module)}" + f"Expecting module to be Float8Linear but found {type(module)}" ) elif isinstance( module, Float8Linear @@ -147,7 +144,7 @@ class PrepareFloat8ModuleInput(PrepareModuleInput): # float8_dtype (torch.dtype, optional): control what float8 dtype to cast to when prepare the module input, # we currently only support torch.float8_e4m3fn. default: torch.float8_e4m3fn # fwd_config_submodule_fqn (str, optional): the fqn of the submodule that contains the forward config used - # for the float8 cast. If not specified, we will search for the Float8DynamicLinear in the submodules + # for the float8 cast. If not specified, we will search for the Float8Linear in the submodules # and use the forward config from that module, in this case all module's forward config must be # the same. @@ -204,24 +201,23 @@ def _prepare_input_arg(self, input, mesh, input_layout, desired_layout): return input def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: - from float8_experimental.float8_dynamic_linear import Float8DynamicLinear from float8_experimental.float8_linear import Float8Linear fwd_linear_config = None if self.fwd_config_submodule_fqn is not None: fwd_linear = module.get_submodule(self.fwd_config_submodule_fqn) - assert isinstance(fwd_linear, (Float8DynamicLinear, Float8Linear)) + assert isinstance(fwd_linear, Float8Linear) fwd_linear_config = fwd_linear.forward_config else: # search for ScaledMM configs for all the submodules and make sure they are the same for mod in module.modules(): - if isinstance(mod, (Float8DynamicLinear, Float8Linear)): + if isinstance(mod, Float8Linear): if fwd_linear_config is None: fwd_linear_config = mod.forward_config else: assert ( fwd_linear_config == mod.forward_config - ), "All the Float8DynamicLinear and Float8Linear modules should have same forward config!" + ), "All the Float8Linear modules should have same forward config!" self.fwd_linear_config = fwd_linear_config super()._apply(module, device_mesh) diff --git a/test/test_base.py b/test/test_base.py index 754e656..b2ee071 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import copy import io import itertools import random @@ -15,13 +16,10 @@ import torch import torch.nn as nn -from float8_experimental.float8_dynamic_linear import Float8DynamicLinear from float8_experimental.float8_linear import Float8Linear, TensorScalingType from float8_experimental.float8_linear_utils import ( filter_out_small_unaligned_layers, - get_float8_linear, linear_requires_sync, - LinearType, swap_linear_with_float8_linear, sync_float8_amax_and_scale_history, ) @@ -149,24 +147,20 @@ def _test_linear_impl( self, x, m_ref, - linear_type: LinearType, emulate: bool, scaling_type_x: TensorScalingType = TensorScalingType.DELAYED, scaling_type_w: TensorScalingType = TensorScalingType.DELAYED, scaling_type_dL_dY: TensorScalingType = TensorScalingType.DELAYED, ): - m_fp8 = get_float8_linear( - linear_type, - m_ref, + m_fp8 = Float8Linear.from_float( + copy.deepcopy(m_ref), emulate, scaling_type_x, scaling_type_w, scaling_type_dL_dY, ) for _ in range(2): - if linear_requires_sync( - linear_type, scaling_type_x, scaling_type_w, scaling_type_dL_dY - ): + if linear_requires_sync(scaling_type_x, scaling_type_w, scaling_type_dL_dY): sync_float8_amax_and_scale_history(m_fp8) y_fp8 = m_fp8(x) y_fp8.sum().backward() @@ -184,9 +178,7 @@ def _test_linear_impl( torch.testing.assert_close(m_ref.bias.grad, m_fp8.bias.grad) # verify all of the amax buffers got updated - if linear_requires_sync( - linear_type, scaling_type_x, scaling_type_w, scaling_type_dL_dY - ): + if linear_requires_sync(scaling_type_x, scaling_type_w, scaling_type_dL_dY): # only check buffers that are actually used, based on per-tensor # scaling settings amax_buffer_names = [] @@ -231,7 +223,6 @@ def _test_linear_impl( @pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True]) @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) - @pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]) @pytest.mark.parametrize( "scaling_type_x", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] ) @@ -245,7 +236,6 @@ def _test_linear_impl( def test_linear_nobias( self, x_shape, - linear_type: LinearType, emulate: bool, scaling_type_x: TensorScalingType, scaling_type_w: TensorScalingType, @@ -260,25 +250,11 @@ def test_linear_nobias( f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)" ) pytest.skip() - if linear_type is LinearType.DYNAMIC: - # Only test one combination of scaling types, as they are a no-op - # for Float8DynamicLinear. It would be cleaner to split into two - # tests, but IMO not worth it since Float8DynamicLinear will be - # deleted soon - is_all_dynamic = ( - scaling_type_x is TensorScalingType.DYNAMIC - and scaling_type_w is TensorScalingType.DYNAMIC - and scaling_type_dL_dY is TensorScalingType.DYNAMIC - ) - if not is_all_dynamic: - pytest.skip() - x = torch.randn(*x_shape, device="cuda") m_ref = nn.Linear(16, 32, bias=False, device="cuda") self._test_linear_impl( x, m_ref, - linear_type, emulate, scaling_type_x, scaling_type_w, @@ -287,7 +263,6 @@ def test_linear_nobias( @pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True]) @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) - @pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]) @pytest.mark.parametrize( "scaling_type_x", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] ) @@ -304,7 +279,6 @@ def test_linear_nobias( def test_linear_bias( self, x_shape, - linear_type: LinearType, scaling_type_x: TensorScalingType, scaling_type_w: TensorScalingType, scaling_type_dL_dY: TensorScalingType, @@ -320,25 +294,11 @@ def test_linear_bias( f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)" ) pytest.skip() - if linear_type is LinearType.DYNAMIC: - # Only test one combination of scaling types, as they are a no-op - # for Float8DynamicLinear. It would be cleaner to split into two - # tests, but IMO not worth it since Float8DynamicLinear will be - # deleted soon - is_all_dynamic = ( - scaling_type_x is TensorScalingType.DYNAMIC - and scaling_type_w is TensorScalingType.DYNAMIC - and scaling_type_dL_dY is TensorScalingType.DYNAMIC - ) - if not is_all_dynamic: - pytest.skip() - x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype) self._test_linear_impl( x, m_ref, - linear_type, emulate, scaling_type_x, scaling_type_w, @@ -346,14 +306,12 @@ def test_linear_bias( ) @pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True]) - @pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]) @pytest.mark.parametrize( "linear_dtype", [torch.float16, torch.bfloat16, torch.float32] ) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_autocast_outputs( self, - linear_type: LinearType, emulate: bool, linear_dtype: torch.dtype, ): @@ -368,49 +326,56 @@ def test_autocast_outputs( pytest.skip() m_ref = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) - m = get_float8_linear(linear_type, m_ref, emulate) + kwargs = { + "scaling_type_x": TensorScalingType.DELAYED, + "scaling_type_w": TensorScalingType.DELAYED, + "scaling_type_dL_dY": TensorScalingType.DELAYED, + } + m = Float8Linear.from_float(copy.deepcopy(m_ref), emulate, **kwargs) # autocast off x = torch.randn(16, 32, device="cuda", dtype=linear_dtype) - if linear_requires_sync(linear_type): + if linear_requires_sync(**kwargs): sync_float8_amax_and_scale_history(m) y = m(x) assert y.dtype == linear_dtype, f"y.dtype is {y.dtype}, expected {linear_dtype}" # autocast on with torch.autocast("cuda"): - if linear_requires_sync(linear_type): + if linear_requires_sync(**kwargs): sync_float8_amax_and_scale_history(m) y = m(x) assert y.dtype == torch.half, f"y.dtype is {y.dtype}, expected {torch.half}" with torch.autocast("cuda", dtype=torch.bfloat16): - if linear_requires_sync(linear_type): + if linear_requires_sync(**kwargs): sync_float8_amax_and_scale_history(m) y = m(x) assert ( y.dtype == torch.bfloat16 ), f"y.dtype is {y.dtype}, expected {torch.bfloat16}" - @pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]) @pytest.mark.parametrize( "linear_dtype", [torch.float16, torch.bfloat16, torch.float32] ) @pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") - def test_type_cast( - self, linear_type: LinearType, linear_dtype: torch.dtype, emulate: bool - ): + def test_type_cast(self, linear_dtype: torch.dtype, emulate: bool): emulate = ( not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0) ) m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) - m = get_float8_linear(linear_type, m, emulate) + kwargs = { + "scaling_type_x": TensorScalingType.DYNAMIC, + "scaling_type_w": TensorScalingType.DYNAMIC, + "scaling_type_dL_dY": TensorScalingType.DYNAMIC, + } + m = Float8Linear.from_float(copy.deepcopy(m), emulate, **kwargs) # Cast the module to dtype m = m.to(dtype=linear_dtype) - if linear_requires_sync(linear_type): + if linear_requires_sync(**kwargs): # Check amax buffer types for key in [ "fp8_amax_x", @@ -429,18 +394,21 @@ def test_type_cast( # autocast off x = torch.randn(16, 32, device="cuda", dtype=linear_dtype) - sync_float8_amax_and_scale_history(m) + if linear_requires_sync(**kwargs): + sync_float8_amax_and_scale_history(m) y = m(x) assert y.dtype == linear_dtype, f"y.dtype is {y.dtype}, expected {linear_dtype}" # autocast on with torch.autocast("cuda"): - sync_float8_amax_and_scale_history(m) + if linear_requires_sync(**kwargs): + sync_float8_amax_and_scale_history(m) y = m(x) assert y.dtype == torch.half, f"y.dtype is {y.dtype}, expected {torch.half}" with torch.autocast("cuda", dtype=torch.bfloat16): - sync_float8_amax_and_scale_history(m) + if linear_requires_sync(**kwargs): + sync_float8_amax_and_scale_history(m) y = m(x) assert ( y.dtype == torch.bfloat16 @@ -448,9 +416,8 @@ def test_type_cast( def test_repr(self): m = nn.Linear(32, 16) - m = get_float8_linear( - LinearType.DELAYED, - m, + m = Float8Linear.from_float( + copy.deepcopy(m), emulate=True, scaling_type_x=TensorScalingType.DYNAMIC, scaling_type_w=TensorScalingType.DELAYED, @@ -633,26 +600,22 @@ def test_small_amax_float16(self, float8_dtype): class TestFloat8LinearUtils(unittest.TestCase): def test_swap_root_linear(self): - for module_cls, emulate in itertools.product( - [Float8Linear, Float8DynamicLinear], [True, False] - ): + for emulate in [True, False]: module = nn.Linear(3, 3) - module = swap_linear_with_float8_linear(module, module_cls, emulate=emulate) - self.assertIsInstance(module, module_cls) + module = swap_linear_with_float8_linear(module, emulate=emulate) + self.assertIsInstance(module, Float8Linear) self.assertEqual(module.forward_config.emulate, emulate) self.assertEqual(module.backward_config.emulate, emulate) def test_swap_root_linear_with_children_raises(self): - for module_cls, emulate in itertools.product( - [Float8Linear, Float8DynamicLinear], [True, False] - ): + for emulate in [True, False]: module = nn.Linear(3, 3) module.child = nn.Sequential(nn.Linear(3, 3)) with self.assertRaisesRegex( AssertionError, "Does not support a root nn.Linear with children", ): - swap_linear_with_float8_linear(module, module_cls, emulate=emulate) + swap_linear_with_float8_linear(module, emulate=emulate) def test_swap_submodule_linears(self): class MLP(nn.Module): @@ -661,16 +624,14 @@ def __init__(self, dim: int): self.lin1 = nn.Linear(dim, 4 * dim) self.lin2 = nn.Linear(4 * dim, dim) - for module_cls, emulate in itertools.product( - [Float8Linear, Float8DynamicLinear], [True, False] - ): + for emulate in [True, False]: model = nn.Sequential(MLP(3), nn.Linear(3, 3), MLP(3)) - model = swap_linear_with_float8_linear(model, module_cls, emulate=emulate) - self.assertIsInstance(model[0].lin1, module_cls) - self.assertIsInstance(model[0].lin2, module_cls) - self.assertIsInstance(model[1], module_cls) - self.assertIsInstance(model[2].lin1, module_cls) - self.assertIsInstance(model[2].lin2, module_cls) + model = swap_linear_with_float8_linear(model, emulate=emulate) + self.assertIsInstance(model[0].lin1, Float8Linear) + self.assertIsInstance(model[0].lin2, Float8Linear) + self.assertIsInstance(model[1], Float8Linear) + self.assertIsInstance(model[2].lin1, Float8Linear) + self.assertIsInstance(model[2].lin2, Float8Linear) def test_swap_linears_with_filters(self): class MLP(nn.Module): @@ -679,27 +640,24 @@ def __init__(self, dim: int): self.lin1 = nn.Linear(dim, 4 * dim) self.lin2 = nn.Linear(4 * dim, 4 * dim) - for module_cls, emulate in itertools.product( - [Float8Linear, Float8DynamicLinear], [True, False] - ): + for emulate in [True, False]: model = nn.Sequential(MLP(8), nn.Linear(32, 32), MLP(40)) # filter out the linear layers whose shape is smaller than 32 or non-divisible by 16. model = swap_linear_with_float8_linear( model, - module_cls, emulate=emulate, linear_layer_filter=filter_out_small_unaligned_layers(32), ) # in_features=8, out_features=32, 8 is less than 32. - self.assertNotIsInstance(model[0].lin1, module_cls) + self.assertNotIsInstance(model[0].lin1, Float8Linear) # in_features=32, out_features=32, - self.assertIsInstance(model[0].lin2, module_cls) + self.assertIsInstance(model[0].lin2, Float8Linear) # in_features=32, out_features=32, - self.assertIsInstance(model[1], module_cls) + self.assertIsInstance(model[1], Float8Linear) # in_features=40, out_features=160, 40 is not divisible by 16. - self.assertNotIsInstance(model[2].lin1, module_cls) + self.assertNotIsInstance(model[2].lin1, Float8Linear) # in_features=160, out_features=160, - self.assertIsInstance(model[2].lin2, module_cls) + self.assertIsInstance(model[2].lin2, Float8Linear) def test_swap_submodule_linears_with_skip(self): class MLP(nn.Module): @@ -708,20 +666,18 @@ def __init__(self, dim: int): self.lin1 = nn.Linear(dim, 4 * dim) self.lin2 = nn.Linear(4 * dim, dim) - for module_cls, emulate in itertools.product( - [Float8Linear, Float8DynamicLinear], [True, False] - ): + for emulate in [True, False]: model = nn.Sequential(MLP(3), nn.Linear(3, 3), MLP(3)) skip_fqn_list = ["2", "0.lin2"] model = swap_linear_with_float8_linear( - model, module_cls, emulate=emulate, skip_fqn_list=skip_fqn_list + model, emulate=emulate, skip_fqn_list=skip_fqn_list ) - self.assertIsInstance(model[0].lin1, module_cls) - self.assertNotIsInstance(model[0].lin2, module_cls) + self.assertIsInstance(model[0].lin1, Float8Linear) + self.assertNotIsInstance(model[0].lin2, Float8Linear) self.assertIsInstance(model[0].lin2, nn.Linear) - self.assertIsInstance(model[1], module_cls) - self.assertNotIsInstance(model[2].lin2, module_cls) - self.assertNotIsInstance(model[2].lin2, module_cls) + self.assertIsInstance(model[1], Float8Linear) + self.assertNotIsInstance(model[2].lin2, Float8Linear) + self.assertNotIsInstance(model[2].lin2, Float8Linear) self.assertIsInstance(model[2].lin1, nn.Linear) self.assertIsInstance(model[2].lin2, nn.Linear) diff --git a/test/test_compile.py b/test/test_compile.py index 834d126..5a6e003 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -16,8 +16,6 @@ from float8_experimental.float8_linear import Float8Linear, TensorScalingType from float8_experimental.float8_linear_utils import ( get_float8_layers, - get_float8_linear, - LinearType, swap_linear_with_float8_linear, sync_float8_amax_and_scale_history, ) @@ -34,7 +32,6 @@ def _test_compile_base( backend: str, fullgraph: bool, emulate: bool, - linear_type: LinearType, scaling_type_x, scaling_type_w, scaling_type_dL_dY, @@ -48,8 +45,12 @@ def _test_compile_base( x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype) - m_fp8 = get_float8_linear( - linear_type, m_ref, emulate, scaling_type_x, scaling_type_w, scaling_type_dL_dY + m_fp8 = Float8Linear.from_float( + copy.deepcopy(m_ref), + emulate, + scaling_type_x, + scaling_type_w, + scaling_type_dL_dY, ) m_fp8 = torch.compile(m_fp8, backend=backend, fullgraph=fullgraph) @@ -66,7 +67,6 @@ def _test_compile_base( @pytest.mark.parametrize("fullgraph", [True]) -@pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]) @pytest.mark.parametrize( "scaling_type_x", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] ) @@ -82,30 +82,16 @@ def _test_compile_base( def test_eager_only( fullgraph, emulate: bool, - linear_type: bool, scaling_type_x: TensorScalingType, scaling_type_w: TensorScalingType, scaling_type_dL_dY: TensorScalingType, dtype: torch.dtype, ): - if linear_type is LinearType.DYNAMIC: - # Only test one combination of scaling types, as they are a no-op - # for Float8DynamicLinear. It would be cleaner to split into two - # tests, but IMO not worth it since Float8DynamicLinear will be - # deleted soon - is_all_dynamic = ( - scaling_type_x is TensorScalingType.DYNAMIC - and scaling_type_w is TensorScalingType.DYNAMIC - and scaling_type_dL_dY is TensorScalingType.DYNAMIC - ) - if not is_all_dynamic: - pytest.skip() torch._dynamo.reset() _test_compile_base( "eager", fullgraph, emulate, - linear_type, scaling_type_x, scaling_type_w, scaling_type_dL_dY, @@ -115,7 +101,6 @@ def test_eager_only( @pytest.mark.parametrize("fullgraph", [True]) @pytest.mark.parametrize("emulate", [False, True] if is_H100 else [True]) -@pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]) @pytest.mark.parametrize( "scaling_type_x", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] ) @@ -130,30 +115,16 @@ def test_eager_only( def test_aot_eager( fullgraph, emulate: bool, - linear_type: bool, scaling_type_x: TensorScalingType, scaling_type_w: TensorScalingType, scaling_type_dL_dY: TensorScalingType, dtype: torch.dtype, ): - if linear_type is LinearType.DYNAMIC: - # Only test one combination of scaling types, as they are a no-op - # for Float8DynamicLinear. It would be cleaner to split into two - # tests, but IMO not worth it since Float8DynamicLinear will be - # deleted soon - is_all_dynamic = ( - scaling_type_x is TensorScalingType.DYNAMIC - and scaling_type_w is TensorScalingType.DYNAMIC - and scaling_type_dL_dY is TensorScalingType.DYNAMIC - ) - if not is_all_dynamic: - pytest.skip() torch._dynamo.reset() _test_compile_base( "aot_eager", fullgraph, emulate, - linear_type, scaling_type_x, scaling_type_w, scaling_type_dL_dY, @@ -163,7 +134,6 @@ def test_aot_eager( @pytest.mark.parametrize("fullgraph", [True]) @pytest.mark.parametrize("emulate", [False]) -@pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]) @pytest.mark.parametrize( "scaling_type_x", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] ) @@ -178,30 +148,16 @@ def test_aot_eager( def test_inductor( fullgraph, emulate: bool, - linear_type: bool, scaling_type_x: TensorScalingType, scaling_type_w: TensorScalingType, scaling_type_dL_dY: TensorScalingType, dtype: torch.dtype, ): - if linear_type is LinearType.DYNAMIC: - # Only test one combination of scaling types, as they are a no-op - # for Float8DynamicLinear. It would be cleaner to split into two - # tests, but IMO not worth it since Float8DynamicLinear will be - # deleted soon - is_all_dynamic = ( - scaling_type_x is TensorScalingType.DYNAMIC - and scaling_type_w is TensorScalingType.DYNAMIC - and scaling_type_dL_dY is TensorScalingType.DYNAMIC - ) - if not is_all_dynamic: - pytest.skip() torch._dynamo.reset() _test_compile_base( "inductor", fullgraph, emulate, - linear_type, scaling_type_x, scaling_type_w, scaling_type_dL_dY, @@ -301,7 +257,6 @@ def test_sync_amax_func(): ) float8_mod = swap_linear_with_float8_linear( module, - Float8Linear, scaling_type_x=TensorScalingType.DELAYED, scaling_type_w=TensorScalingType.DELAYED, scaling_type_dL_dY=TensorScalingType.DELAYED, @@ -337,7 +292,6 @@ def test_sync_amax_func_cuda_graph_success(): ).to("cuda") swap_linear_with_float8_linear( my_module, - Float8Linear, scaling_type_x=TensorScalingType.DELAYED, scaling_type_w=TensorScalingType.DELAYED, scaling_type_dL_dY=TensorScalingType.DELAYED, diff --git a/test/test_dtensor.py b/test/test_dtensor.py index 24a5e58..6506ee7 100644 --- a/test/test_dtensor.py +++ b/test/test_dtensor.py @@ -14,10 +14,7 @@ import torch.nn as nn import torch.nn.functional as F -from float8_experimental.float8_dynamic_linear import ( - Float8DynamicLinear, - NoopFwToFloat8E5M2Bw, -) +from float8_experimental.float8_dynamic_utils import NoopFwToFloat8E5M2Bw from float8_experimental.float8_linear import Float8Linear, TensorScalingType from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig @@ -171,37 +168,28 @@ def test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16): def _test_fp8_mlp_tensor_parallelism_base( - mesh: DeviceMesh, size=16, compile: bool = False, use_float8_linear: bool = False + mesh: DeviceMesh, size=16, compile: bool = False ): device = mesh.device_type - # TODO(future): delete Float8DynamicLinear from this test once all the - # code is unified - float8_cls = Float8Linear if use_float8_linear else Float8DynamicLinear - extra_kwargs = {} - if use_float8_linear: - # For now, just use Float8Linear with dynamic scaling, which is the - # same behavior as Float8Linear. - # TODO(future): add support for float8 all-gather with delayed scaling - # for activations and gradients. - extra_kwargs = { - "scaling_type_x": TensorScalingType.DYNAMIC, - "scaling_type_w": TensorScalingType.DYNAMIC, - "scaling_type_dL_dY": TensorScalingType.DYNAMIC, - } + # For now, just use Float8Linear with dynamic scaling, which is the + # same behavior as Float8Linear. + # TODO(future): add support for float8 all-gather with delayed scaling + # for activations and gradients. + extra_kwargs = { + "scaling_type_x": TensorScalingType.DYNAMIC, + "scaling_type_w": TensorScalingType.DYNAMIC, + "scaling_type_dL_dY": TensorScalingType.DYNAMIC, + } toy_model = ToyModel().to(device) toy_model_fp8 = swap_linear_with_float8_linear( - toy_model, float8_cls, emulate=True, **extra_kwargs + toy_model, emulate=True, **extra_kwargs ) tp_model = copy.deepcopy(toy_model) - tp_model = swap_linear_with_float8_linear( - tp_model, float8_cls, emulate=True, **extra_kwargs - ) + tp_model = swap_linear_with_float8_linear(tp_model, emulate=True, **extra_kwargs) sp_model = copy.deepcopy(toy_model) - sp_model = swap_linear_with_float8_linear( - sp_model, float8_cls, emulate=True, **extra_kwargs - ) + sp_model = swap_linear_with_float8_linear(sp_model, emulate=True, **extra_kwargs) # vanilla TP tp_model = parallelize_module( @@ -232,9 +220,7 @@ def _test_fp8_mlp_tensor_parallelism_base( # PrepareFloat8ModuleInput with specific submodule fqn sp_model2 = copy.deepcopy(toy_model) - sp_model2 = swap_linear_with_float8_linear( - sp_model2, Float8DynamicLinear, emulate=True, **extra_kwargs - ) + sp_model2 = swap_linear_with_float8_linear(sp_model2, emulate=True, **extra_kwargs) sp_model2 = parallelize_module( sp_model2, @@ -287,27 +273,11 @@ def _test_fp8_mlp_tensor_parallelism_base( def test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16): - _test_fp8_mlp_tensor_parallelism_base( - mesh, size, compile=False, use_float8_linear=False - ) - - -def test_fp8_mlp_tensor_parallelism_eager_float8_linear(mesh: DeviceMesh, size=16): - _test_fp8_mlp_tensor_parallelism_base( - mesh, size, compile=False, use_float8_linear=True - ) + _test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=False) def test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16): - _test_fp8_mlp_tensor_parallelism_base( - mesh, size, compile=True, use_float8_linear=False - ) - - -def test_fp8_mlp_tensor_parallelism_compile_float8_linear(mesh: DeviceMesh, size=16): - _test_fp8_mlp_tensor_parallelism_base( - mesh, size, compile=True, use_float8_linear=True - ) + _test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True) if __name__ == "__main__": @@ -321,9 +291,7 @@ def test_fp8_mlp_tensor_parallelism_compile_float8_linear(mesh: DeviceMesh, size test_dtensor_cast_to_fp8, test_dtensor_fp8_autograd, test_fp8_mlp_tensor_parallelism_eager, - test_fp8_mlp_tensor_parallelism_eager_float8_linear, test_fp8_mlp_tensor_parallelism_compile, - test_fp8_mlp_tensor_parallelism_compile_float8_linear, ] for test in tqdm(tests, desc="Running tests"): diff --git a/test/test_fsdp.py b/test/test_fsdp.py index 031b40d..79bba19 100644 --- a/test/test_fsdp.py +++ b/test/test_fsdp.py @@ -24,7 +24,6 @@ from float8_experimental.float8_linear import Float8Linear, TensorScalingType from float8_experimental.float8_linear_utils import ( linear_requires_sync, - LinearType, swap_linear_with_float8_linear, sync_float8_amax_and_scale_history, ) @@ -84,7 +83,6 @@ def fsdp_main(rank, world_size, args): # with weights. swap_linear_with_float8_linear( model_fp8, - Float8Linear, emulate=False, scaling_type_w=scaling_type_w, ) @@ -133,7 +131,6 @@ def forward_backward(model, optim, is_fp8, i): y_local = model(ref_input_local[i]) y_local.backward(ref_grad_local[i]) if is_fp8 and linear_requires_sync( - LinearType.DELAYED, TensorScalingType.DYNAMIC, scaling_type_w, TensorScalingType.DYNAMIC, diff --git a/test/test_fsdp2/test_fsdp2_common.py b/test/test_fsdp2/test_fsdp2_common.py index 4f20fd5..c20e8cc 100644 --- a/test/test_fsdp2/test_fsdp2_common.py +++ b/test/test_fsdp2/test_fsdp2_common.py @@ -7,7 +7,6 @@ import torch.distributed as dist import torch.nn as nn from float8_experimental.float8_linear import Float8Linear -from float8_experimental.float8_linear_utils import sync_float8_amax_and_scale_history def check_parity_no_mp( @@ -17,7 +16,6 @@ def check_parity_no_mp( fsdp_model: nn.Module, fsdp_optim: torch.optim.Optimizer, local_inp: torch.Tensor, - module_cls: Type, ): for iter_idx in range(10): losses: List[torch.Tensor] = [] @@ -29,8 +27,7 @@ def check_parity_no_mp( for param in model.parameters(): dist.all_reduce(param.grad) param.grad.div_(dist.get_world_size()) - if module_cls is Float8Linear: - sync_float8_amax_and_scale_history(model) + # TODO(future): add amax syncing once delayed scaling is supported optim.step() test_cls.assertEqual(losses[0], losses[1]) @@ -43,7 +40,6 @@ def check_parity_bf16_mp( fsdp_model: nn.Module, fsdp_optim: torch.optim.Optimizer, local_inp: torch.Tensor, - module_cls: Type, ): for iter_idx in range(10): losses: List[torch.Tensor] = [] @@ -62,8 +58,7 @@ def check_parity_bf16_mp( param_bf16.grad.div_(dist.get_world_size()) param_fp32.grad = param_bf16.grad.float() param_bf16.grad = None - if module_cls is Float8Linear: - sync_float8_amax_and_scale_history(model) + # TODO(future): add amax syncing once delayed scaling is supported optim.step() for param_fp32, param_bf16 in zip( ref_model.parameters(), ref_model_bf16.parameters() diff --git a/test/test_fsdp2/test_fsdp2_eager.py b/test/test_fsdp2/test_fsdp2_eager.py index 5e4dc8f..57123cd 100644 --- a/test/test_fsdp2/test_fsdp2_eager.py +++ b/test/test_fsdp2/test_fsdp2_eager.py @@ -8,10 +8,7 @@ import torch._dynamo.testing import torch.distributed as dist import torch.nn as nn -from float8_experimental.float8_dynamic_linear import ( - Float8DynamicLinear, - WeightWithDynamicFloat8CastTensor, -) +from float8_experimental.float8_dynamic_utils import WeightWithDynamicFloat8CastTensor from float8_experimental.float8_linear import Float8Linear, TensorScalingType from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear from test_fsdp2_common import ( @@ -76,16 +73,11 @@ def get_local_inp(self, dtype: torch.dtype = torch.float32): dist.broadcast(global_inp, src=0) return global_inp.view(self.world_size, -1)[self.rank].view(16, 16) - def swap_linear_with_dynamic( - self, module: nn.Module, use_float8_linear=False, **kwargs: Any - ) -> nn.Module: - if use_float8_linear: - kwargs["scaling_type_x"] = TensorScalingType.DYNAMIC - kwargs["scaling_type_w"] = TensorScalingType.DYNAMIC - kwargs["scaling_type_dL_dY"] = TensorScalingType.DYNAMIC - return swap_linear_with_float8_linear(module, Float8Linear, **kwargs) - else: - return swap_linear_with_float8_linear(module, Float8DynamicLinear, **kwargs) + def swap_linear_with_dynamic(self, module: nn.Module, **kwargs: Any) -> nn.Module: + kwargs["scaling_type_x"] = TensorScalingType.DYNAMIC + kwargs["scaling_type_w"] = TensorScalingType.DYNAMIC + kwargs["scaling_type_dL_dY"] = TensorScalingType.DYNAMIC + return swap_linear_with_float8_linear(module, **kwargs) class TestFloat8MultiProcess(FSDPTest, TestFloat8Common): @@ -95,16 +87,10 @@ def world_size(self) -> int: @skip_if_lt_x_gpu(2) def test_transformer_parity_dynamic(self): - for enable_fsdp_fp8_all_gather, use_float8_linear in itertools.product( - [False, True], [False, True] - ): - self._test_transformer_parity_dynamic( - enable_fsdp_fp8_all_gather, use_float8_linear - ) + for enable_fsdp_fp8_all_gather in [False, True]: + self._test_transformer_parity_dynamic(enable_fsdp_fp8_all_gather) - def _test_transformer_parity_dynamic( - self, enable_fsdp_fp8_all_gather: bool, use_float8_linear: bool - ): + def _test_transformer_parity_dynamic(self, enable_fsdp_fp8_all_gather: bool): # NOTE: Weight-tying does not compose with fp8 all-gather because the # embedding weight and output linear weight are tied but only the # latter uses fp8 compute. With fp8 all-gather, FSDP would pre-cast to @@ -112,9 +98,9 @@ def _test_transformer_parity_dynamic( weight_tying = not enable_fsdp_fp8_all_gather module = self.init_transformer(weight_tying=weight_tying) ref_module = copy.deepcopy(module) - ref_module = self.swap_linear_with_dynamic(ref_module, use_float8_linear).cuda() + ref_module = self.swap_linear_with_dynamic(ref_module).cuda() with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): - module = self.swap_linear_with_dynamic(module, use_float8_linear) + module = self.swap_linear_with_dynamic(module) for submodule in module.modules(): if isinstance(submodule, TransformerBlock): fully_shard(submodule) @@ -124,24 +110,15 @@ def _test_transformer_parity_dynamic( local_inp = torch.randint( 0, ref_module.tok_embeddings.weight.size(0), (16, 16), device="cuda" ) - # TODO(future): change Float8DynamicLinear to module_cls below, and - # ensure there is no amax syncing for all-dynamic - check_parity_no_mp( - self, ref_module, ref_optim, module, optim, local_inp, Float8DynamicLinear - ) + check_parity_no_mp(self, ref_module, ref_optim, module, optim, local_inp) @skip_if_lt_x_gpu(2) def test_transformer_memory(self): """Tests peak active memory in the forward and backward passes.""" - # for enable_fsdp_fp8_all_gather in [False, True]: - for enable_fsdp_fp8_all_gather, use_float8_linear in itertools.product( - [False, True], [False, True] - ): - self._test_transformer_memory(enable_fsdp_fp8_all_gather, use_float8_linear) - - def _test_transformer_memory( - self, enable_fsdp_fp8_all_gather: bool, use_float8_linear: bool - ): + for enable_fsdp_fp8_all_gather in [False, True]: + self._test_transformer_memory(enable_fsdp_fp8_all_gather) + + def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool): torch.manual_seed(42) # Pre-run a linear forward (gemm and bias) and backward (gemm) to # allocate the cuBLAS workspaces before measuring the memory usage @@ -164,9 +141,7 @@ def _test_transformer_memory( # Emulate the fp8 matmul to bypass the scaled matmul op's divisibility # requirement to use a smaller activation size with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): - model = self.swap_linear_with_dynamic( - model, emulate=True, use_float8_linear=use_float8_linear - ) + model = self.swap_linear_with_dynamic(model, emulate=True) model_unsharded_numel = sum(p.numel() for p in model.parameters()) model_sharded_numel = (model_unsharded_numel + 1) // 2 block_lin_weight_numel = 0 @@ -267,21 +242,19 @@ class TestFloat8MultiThread(FSDPTestMultiThread, TestFloat8Common): def world_size(self) -> int: return 2 - def _test_weight_subclass_dynamic(self, use_float8_linear): - float8_cls = Float8Linear if use_float8_linear else Float8DynamicLinear - extra_kwargs = {} - if use_float8_linear: - extra_kwargs["scaling_type_x"] = TensorScalingType.DYNAMIC - extra_kwargs["scaling_type_w"] = TensorScalingType.DYNAMIC - extra_kwargs["scaling_type_dL_dY"] = TensorScalingType.DYNAMIC - pass + @unittest.skipIf(not TEST_CUDA, "no cuda") + def test_weight_subclass_dynamic(self): + extra_kwargs = { + "scaling_type_x": TensorScalingType.DYNAMIC, + "scaling_type_w": TensorScalingType.DYNAMIC, + "scaling_type_dL_dY": TensorScalingType.DYNAMIC, + } tensor_cls = WeightWithDynamicFloat8CastTensor # Check for a single FSDP paramter group module_fp32 = self.init_single_module() with set_enable_fsdp_fp8_all_gather(True): module = swap_linear_with_float8_linear( module_fp32, - float8_cls, emulate=True, **extra_kwargs, ) @@ -297,7 +270,6 @@ def _test_weight_subclass_dynamic(self, use_float8_linear): with set_enable_fsdp_fp8_all_gather(True): module = swap_linear_with_float8_linear( module, - float8_cls, emulate=True, **extra_kwargs, ) @@ -313,14 +285,7 @@ def _test_weight_subclass_dynamic(self, use_float8_linear): self.assertIsInstance(param.to_local(), tensor_cls) @unittest.skipIf(not TEST_CUDA, "no cuda") - def test_weight_subclass_float8_dynamic_linear(self): - self._test_weight_subclass_dynamic(use_float8_linear=False) - - @unittest.skipIf(not TEST_CUDA, "no cuda") - def test_weight_subclass_float8_linear(self): - self._test_weight_subclass_dynamic(use_float8_linear=True) - - def _test_fp8_fp32_all_gather_dynamic_comm_size(self, use_float8_linear): + def test_fp8_fp32_all_gather_dynamic_comm_size(self): """ Tests that fp8 all-gather with dynamic scaling communicates the expected number of bytes. @@ -354,7 +319,7 @@ def get_expected_all_gather_size(module: nn.Module): module_fp32 = self.init_single_module() ref_module = copy.deepcopy(module_fp32) with set_enable_fsdp_fp8_all_gather(True): - module = self.swap_linear_with_dynamic(module_fp32, use_float8_linear) + module = self.swap_linear_with_dynamic(module_fp32) fully_shard(module) local_inp = self.get_local_inp() expected_all_gather_size = get_expected_all_gather_size(ref_module) @@ -398,30 +363,18 @@ def get_expected_all_gather_size(module: nn.Module): [s for s in expected_all_gather_sizes for _ in range(self.world_size)], ) - @unittest.skipIf(not TEST_CUDA, "no cuda") - def test_fp8_fp32_all_gather_float8_dynamic_linear_comm_size(self): - self._test_fp8_fp32_all_gather_dynamic_comm_size(use_float8_linear=False) - - @unittest.skipIf(not TEST_CUDA, "no cuda") - def test_fp8_fp32_all_gather_float8_linear_comm_size(self): - self._test_fp8_fp32_all_gather_dynamic_comm_size(use_float8_linear=True) - @unittest.skipIf(not TEST_CUDA, "no cuda") def test_fp32_fp8_single_module_parity(self): """ Tests numeric parity for fp32 parameters with fp8 computation with a single module/FSDP communication group. """ - for enable_fsdp_fp8_all_gather, use_float8_linear in itertools.product( - [False, True], [False, True] - ): + for enable_fsdp_fp8_all_gather in [False, True]: module_fp32 = self.init_single_module() - ref_module = self.swap_linear_with_dynamic( - copy.deepcopy(module_fp32), use_float8_linear - ) + ref_module = self.swap_linear_with_dynamic(copy.deepcopy(module_fp32)) ref_module = ref_module.cuda() with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): - module = self.swap_linear_with_dynamic(module_fp32, use_float8_linear) + module = self.swap_linear_with_dynamic(module_fp32) fully_shard(module) ref_optim = torch.optim.Adam(ref_module.parameters(), lr=1e-2) optim = torch.optim.Adam(module.parameters(), lr=1e-2, foreach=True) @@ -433,7 +386,6 @@ def test_fp32_fp8_single_module_parity(self): module, optim, local_inp, - Float8DynamicLinear, ) @unittest.skipIf(not TEST_CUDA, "no cuda") @@ -442,16 +394,12 @@ def test_fp32_fp8_multi_module_parity(self): Tests numeric parity for fp32 parameters with fp8 computation with multiple modules/FSDP communication groups. """ - for enable_fsdp_fp8_all_gather, use_float8_linear in itertools.product( - [False, True], [False, True] - ): + for enable_fsdp_fp8_all_gather in [False, True]: module = self.init_multi_module() ref_module = copy.deepcopy(module) - ref_module = self.swap_linear_with_dynamic( - ref_module, use_float8_linear - ).cuda() + ref_module = self.swap_linear_with_dynamic(ref_module).cuda() with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): - module = self.swap_linear_with_dynamic(module, use_float8_linear) + module = self.swap_linear_with_dynamic(module) for submodule in module: fully_shard(submodule) fully_shard(module) @@ -465,7 +413,6 @@ def test_fp32_fp8_multi_module_parity(self): module, optim, local_inp, - Float8DynamicLinear, ) @unittest.skipIf(not TEST_CUDA, "no cuda") @@ -482,13 +429,10 @@ def test_bf16_mp_fp8_dynamic_multi_parity(self): ref_module_bf16 = copy.deepcopy(module).to(torch.bfloat16) ref_module_bf16 = swap_linear_with_float8_linear( ref_module_bf16, - Float8DynamicLinear, emulate=True, ) ref_module_fp32 = copy.deepcopy(module).cuda() - module = swap_linear_with_float8_linear( - module, Float8DynamicLinear, emulate=True - ) + module = swap_linear_with_float8_linear(module, emulate=True) mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16) for mlp in module: fully_shard(mlp, mp_policy=mp_policy) @@ -501,7 +445,6 @@ def test_bf16_mp_fp8_dynamic_multi_parity(self): module, torch.optim.Adam(module.parameters(), lr=1e-2, foreach=True), self.get_local_inp(torch.bfloat16), - Float8DynamicLinear, ) diff --git a/test/test_fsdp_compile.py b/test/test_fsdp_compile.py index cc44934..715db29 100644 --- a/test/test_fsdp_compile.py +++ b/test/test_fsdp_compile.py @@ -51,7 +51,6 @@ def get_model(K, N, is_fp8, emulate, base_dtype=torch.float32): ) swap_linear_with_float8_linear( m, - Float8Linear, emulate=emulate, scaling_type_x=TensorScalingType.DELAYED, scaling_type_w=TensorScalingType.DELAYED, diff --git a/test/test_inference_flows.py b/test/test_inference_flows.py index c92b752..1dd09d9 100644 --- a/test/test_inference_flows.py +++ b/test/test_inference_flows.py @@ -193,7 +193,6 @@ def test_fp8_save_and_load(self, dtype: torch.dtype): fp8_mlp.reset_parameters() swap_linear_with_float8_linear( fp8_mlp, - Float8Linear, scaling_type_x=TensorScalingType.DYNAMIC, scaling_type_w=TensorScalingType.DYNAMIC, scaling_type_dL_dY=TensorScalingType.DYNAMIC, @@ -218,7 +217,6 @@ def test_fp8_save_and_load(self, dtype: torch.dtype): new_fp8_mlp = FeedForward().to(dtype=dtype) swap_linear_with_float8_linear( new_fp8_mlp, - Float8Linear, scaling_type_x=TensorScalingType.DYNAMIC, scaling_type_w=TensorScalingType.DYNAMIC, scaling_type_dL_dY=TensorScalingType.DYNAMIC, diff --git a/test/test_numerics_integration.py b/test/test_numerics_integration.py index 1d571de..401d0fd 100644 --- a/test/test_numerics_integration.py +++ b/test/test_numerics_integration.py @@ -14,11 +14,9 @@ import torch import torch.nn as nn import torch.nn.functional as F -from float8_experimental.float8_dynamic_linear import Float8DynamicLinear from float8_experimental.float8_linear import Float8Linear, TensorScalingType from float8_experimental.float8_linear_utils import ( linear_requires_sync, - LinearType, swap_linear_with_float8_linear, sync_float8_amax_and_scale_history, ) @@ -85,32 +83,14 @@ class TestFloat8NumericsIntegrationTest: @pytest.mark.parametrize( "scaling_type_dL_dY", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] ) - @pytest.mark.parametrize("linear_cls", [Float8Linear, Float8DynamicLinear]) @pytest.mark.skipif(not is_H100, reason="requires H100 GPU") @pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack") def test_encoder_fw_bw( self, - linear_cls, scaling_type_x: TensorScalingType, scaling_type_w: TensorScalingType, scaling_type_dL_dY: TensorScalingType, ): - linear_type = ( - LinearType.DELAYED if linear_cls == Float8Linear else LinearType.DYNAMIC - ) - if linear_type is LinearType.DYNAMIC: - # Only test one combination of scaling types, as they are a no-op - # for Float8DynamicLinear. It would be cleaner to split into two - # tests, but IMO not worth it since Float8DynamicLinear will be - # deleted soon - is_all_dynamic = ( - scaling_type_x is TensorScalingType.DYNAMIC - and scaling_type_w is TensorScalingType.DYNAMIC - and scaling_type_dL_dY is TensorScalingType.DYNAMIC - ) - if not is_all_dynamic: - pytest.skip() - # TODO(later): maybe add float16 back if it becomes important data_dtype = torch.bfloat16 @@ -130,7 +110,6 @@ def test_encoder_fw_bw( model_fp8 = copy.deepcopy(model_ref) swap_linear_with_float8_linear( model_fp8, - linear_cls, emulate=False, scaling_type_x=scaling_type_x, scaling_type_w=scaling_type_w, @@ -156,17 +135,13 @@ def test_encoder_fw_bw( model_ref_out = model_ref(data2) model_ref_out.sum().backward() - if linear_requires_sync( - linear_type, scaling_type_x, scaling_type_w, scaling_type_dL_dY - ): + if linear_requires_sync(scaling_type_x, scaling_type_w, scaling_type_dL_dY): sync_float8_amax_and_scale_history(model_fp8) model_fp8(data1).sum().backward() # zero out grads without stepping, since we just want to compare grads # of the second datum optim_fp8.zero_grad() - if linear_requires_sync( - linear_type, scaling_type_x, scaling_type_w, scaling_type_dL_dY - ): + if linear_requires_sync(scaling_type_x, scaling_type_w, scaling_type_dL_dY): sync_float8_amax_and_scale_history(model_fp8) model_fp8_out = model_fp8(data2) model_fp8_out.sum().backward()