diff --git a/float8_experimental/fsdp_utils.py b/float8_experimental/fsdp_utils.py index 04cd797..876ae93 100644 --- a/float8_experimental/fsdp_utils.py +++ b/float8_experimental/fsdp_utils.py @@ -81,6 +81,8 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: torch.ops.aten.as_strided.default, torch.ops.aten._to_copy.default, torch.ops.aten._pin_memory.default, + torch.ops.aten.split.Tensor, + torch.ops.aten.clone.default, } @@ -188,12 +190,22 @@ def fsdp_post_all_gather( *, out: Optional[torch.Tensor] = None, ): + from torch.distributed._tensor import DTensor + (data,) = all_gather_outputs (scale,) = metadata if out is not None: - assert isinstance(out, Float8Tensor), f"{type(out)}" - out._scale = scale - return + if isinstance(out, Float8Tensor): + out._scale = scale + elif isinstance(out, DTensor) and isinstance( + out._local_tensor, Float8Tensor + ): + out._local_tensor._scale = scale + else: + raise RuntimeError( + f"out must be a Float8Tensor or DTensor with Float8Tensor local tensor, but got {type(out)}" + ) + return out return Float8Tensor( data, scale, diff --git a/test/test_fsdp2/test_fsdp2.py b/test/test_fsdp2/test_fsdp2.py index 1cbec77..b846fae 100644 --- a/test/test_fsdp2/test_fsdp2.py +++ b/test/test_fsdp2/test_fsdp2.py @@ -17,7 +17,13 @@ set_enable_fsdp_fp8_all_gather, ) from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy -from torch.distributed._tensor import DTensor +from torch.distributed._tensor import ( + distribute_tensor, + DTensor, + init_device_mesh, + Shard, +) +from torch.distributed.device_mesh import DeviceMesh from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import ( @@ -516,5 +522,50 @@ def test_delayed_scaling_inplace_update(self): self.assertNotEqual(fp8_amax_w_old.item(), m_fp8.fp8_amax_w.item()) +class Test2DFloat8MultiProcess(FSDPTest, TestFloat8Common): + @property + def world_size(self) -> int: + return min(torch.cuda.device_count(), 4) + + def init_global_mesh(self) -> DeviceMesh: + dp_size = 2 if self.world_size > 2 else 1 + return init_device_mesh( + "cuda", (dp_size, self.world_size // dp_size), mesh_dim_names=("dp", "tp") + ) + + @skip_if_lt_x_gpu(4) + def test_fsdp_tp( + self, + ): + enable_fsdp_fp8_all_gather = True + scaling_type_w = TensorScalingType.DYNAMIC + global_mesh = self.init_global_mesh() + _, tp_mesh = global_mesh["dp"], global_mesh["tp"] + module = self.init_transformer(weight_tying=False).cuda() + with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): + swap_linear_with_float8_linear(module, scaling_type_w=scaling_type_w) + + # "attention.wq": Float8ColwiseParallel + colwise_param = distribute_tensor( + module.layers[0].attention.wq.weight, tp_mesh, [Shard(0)] + ) + self.assertTrue( + isinstance(colwise_param, DTensor) + and isinstance( + colwise_param._local_tensor, WeightWithDynamicFloat8CastTensor + ) + ) + # "attention.wo": Float8RowwiseParallel(output_layouts=Shard(1)), + rowwise_param = distribute_tensor( + module.layers[0].attention.wo.weight, tp_mesh, [Shard(1)] + ) + self.assertTrue( + isinstance(rowwise_param, DTensor) + and isinstance( + rowwise_param._local_tensor, WeightWithDynamicFloat8CastTensor + ) + ) + + if __name__ == "__main__": run_tests()