diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 786bc618e..16ab85b76 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -341,7 +341,7 @@ def get_trainer_kwargs( ChainConfigModifier.default_config().set( config_modifiers=[ MeshShapeModifier.default_config().set( - mesh_shape=mesh_shape_from_axes(data=-1, fsdp=256) + mesh_shape=mesh_shape_from_axes(data=1, fsdp=512) ), RematSpecModifier.default_config().set( remat_policies={ @@ -359,7 +359,7 @@ def get_trainer_kwargs( ChainConfigModifier.default_config().set( config_modifiers=[ MeshShapeModifier.default_config().set( - mesh_shape=mesh_shape_from_axes(data=-1, fsdp=256) + mesh_shape=mesh_shape_from_axes(data=1, fsdp=512, model=2) ), RematSpecModifier.default_config().set( remat_policies={ @@ -371,7 +371,7 @@ def get_trainer_kwargs( ], ), ), - ("tpu-v5p-.*", mesh_shape_from_axes(data=-1, fsdp=8)), + ("tpu-v5p-.*", mesh_shape_from_axes(data=1, fsdp=512, model=2)), ( "gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)", mesh_shape_from_axes(data=-1, fsdp=8),