Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
fix float8 all-gather in 2d
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
weifengpy committed Jul 24, 2024
1 parent f475c40 commit cc763ce
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 4 deletions.
18 changes: 15 additions & 3 deletions float8_experimental/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
torch.ops.aten.as_strided.default,
torch.ops.aten._to_copy.default,
torch.ops.aten._pin_memory.default,
torch.ops.aten.split.Tensor,
torch.ops.aten.clone.default,
}


Expand Down Expand Up @@ -188,12 +190,22 @@ def fsdp_post_all_gather(
*,
out: Optional[torch.Tensor] = None,
):
from torch.distributed._tensor import DTensor

(data,) = all_gather_outputs
(scale,) = metadata
if out is not None:
assert isinstance(out, Float8Tensor), f"{type(out)}"
out._scale = scale
return
if isinstance(out, Float8Tensor):
out._scale = scale
elif isinstance(out, DTensor) and isinstance(
out._local_tensor, Float8Tensor
):
out._local_tensor._scale = scale
else:
raise RuntimeError(
f"out must be a Float8Tensor or DTensor with Float8Tensor local tensor, but got {type(out)}"
)
return out
return Float8Tensor(
data,
scale,
Expand Down
53 changes: 52 additions & 1 deletion test/test_fsdp2/test_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@
set_enable_fsdp_fp8_all_gather,
)
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed._tensor import DTensor
from torch.distributed._tensor import (
distribute_tensor,
DTensor,
init_device_mesh,
Shard,
)
from torch.distributed.device_mesh import DeviceMesh
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
Expand Down Expand Up @@ -516,5 +522,50 @@ def test_delayed_scaling_inplace_update(self):
self.assertNotEqual(fp8_amax_w_old.item(), m_fp8.fp8_amax_w.item())


class Test2DFloat8MultiProcess(FSDPTest, TestFloat8Common):
@property
def world_size(self) -> int:
return min(torch.cuda.device_count(), 4)

def init_global_mesh(self) -> DeviceMesh:
dp_size = 2 if self.world_size > 2 else 1
return init_device_mesh(
"cuda", (dp_size, self.world_size // dp_size), mesh_dim_names=("dp", "tp")
)

@skip_if_lt_x_gpu(4)
def test_fsdp_tp(
self,
):
enable_fsdp_fp8_all_gather = True
scaling_type_w = TensorScalingType.DYNAMIC
global_mesh = self.init_global_mesh()
_, tp_mesh = global_mesh["dp"], global_mesh["tp"]
module = self.init_transformer(weight_tying=False).cuda()
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather):
swap_linear_with_float8_linear(module, scaling_type_w=scaling_type_w)

# "attention.wq": Float8ColwiseParallel
colwise_param = distribute_tensor(
module.layers[0].attention.wq.weight, tp_mesh, [Shard(0)]
)
self.assertTrue(
isinstance(colwise_param, DTensor)
and isinstance(
colwise_param._local_tensor, WeightWithDynamicFloat8CastTensor
)
)
# "attention.wo": Float8RowwiseParallel(output_layouts=Shard(1)),
rowwise_param = distribute_tensor(
module.layers[0].attention.wo.weight, tp_mesh, [Shard(1)]
)
self.assertTrue(
isinstance(rowwise_param, DTensor)
and isinstance(
rowwise_param._local_tensor, WeightWithDynamicFloat8CastTensor
)
)


if __name__ == "__main__":
run_tests()

0 comments on commit cc763ce

Please sign in to comment.