From 1ae193ff7081e7631032b45469b9d116992ba37d Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 18 Jan 2024 15:25:17 -0800 Subject: [PATCH] add test for FSDP mixed precision Summary: The only restriction in this test is buffers must be in float32 precision. Test Plan: ``` ./test/test_fsdp_compile.sh ``` Reviewers: Subscribers: Tasks: Tags: --- test/test_fsdp_compile.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/test/test_fsdp_compile.py b/test/test_fsdp_compile.py index 9dc77434..81e8ee20 100644 --- a/test/test_fsdp_compile.py +++ b/test/test_fsdp_compile.py @@ -26,6 +26,7 @@ from torch.distributed.fsdp import ( FullStateDictConfig, FullyShardedDataParallel as FSDP, + MixedPrecision, StateDictType, ) @@ -35,6 +36,16 @@ lr = 0.01 N_ITER = 1 +bfSixteen = MixedPrecision( + param_dtype=torch.bfloat16, + reduce_dtype=torch.bfloat16, + # Currently buffer_dtype must be float32 because we + # want the float8 scales to stay in high precision. + # TODO(later): figure out importance of allowing more granular + # buffer precision control in FSDP. + buffer_dtype=torch.float32, +) + def setup(rank, world_size): os.environ["MASTER_ADDR"] = "localhost" @@ -86,7 +97,7 @@ def fsdp_main(rank, world_size, args): ) # To compile FSDP, we need use_orig_params to True - model = FSDP(model, use_orig_params=True) + model = FSDP(model, use_orig_params=True, mixed_precision=bfSixteen) optimizer = torch.optim.SGD(model.parameters(), lr=lr * world_size) input_local = torch.randn(B, M, K, N, device="cuda")