diff --git a/test/test_fsdp_compile.py b/test/test_fsdp_compile.py index 9dc77434..81e8ee20 100644 --- a/test/test_fsdp_compile.py +++ b/test/test_fsdp_compile.py @@ -26,6 +26,7 @@ from torch.distributed.fsdp import ( FullStateDictConfig, FullyShardedDataParallel as FSDP, + MixedPrecision, StateDictType, ) @@ -35,6 +36,16 @@ lr = 0.01 N_ITER = 1 +bfSixteen = MixedPrecision( + param_dtype=torch.bfloat16, + reduce_dtype=torch.bfloat16, + # Currently buffer_dtype must be float32 because we + # want the float8 scales to stay in high precision. + # TODO(later): figure out importance of allowing more granular + # buffer precision control in FSDP. + buffer_dtype=torch.float32, +) + def setup(rank, world_size): os.environ["MASTER_ADDR"] = "localhost" @@ -86,7 +97,7 @@ def fsdp_main(rank, world_size, args): ) # To compile FSDP, we need use_orig_params to True - model = FSDP(model, use_orig_params=True) + model = FSDP(model, use_orig_params=True, mixed_precision=bfSixteen) optimizer = torch.optim.SGD(model.parameters(), lr=lr * world_size) input_local = torch.randn(B, M, K, N, device="cuda")