Skip to content

Commit

Permalink
change mesh shape to fit the model in memory
Browse files Browse the repository at this point in the history
  • Loading branch information
sychen52 committed Jan 7, 2025
1 parent 0cc22e0 commit f6ba43b
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions axlearn/experiments/text/gpt/fuji.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def get_trainer_kwargs(
ChainConfigModifier.default_config().set(
config_modifiers=[
MeshShapeModifier.default_config().set(
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=256)
mesh_shape=mesh_shape_from_axes(data=1, fsdp=512)
),
RematSpecModifier.default_config().set(
remat_policies={
Expand All @@ -359,7 +359,7 @@ def get_trainer_kwargs(
ChainConfigModifier.default_config().set(
config_modifiers=[
MeshShapeModifier.default_config().set(
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=256)
mesh_shape=mesh_shape_from_axes(data=1, fsdp=512, model=2)
),
RematSpecModifier.default_config().set(
remat_policies={
Expand All @@ -371,7 +371,7 @@ def get_trainer_kwargs(
],
),
),
("tpu-v5p-.*", mesh_shape_from_axes(data=-1, fsdp=8)),
("tpu-v5p-.*", mesh_shape_from_axes(data=1, fsdp=512, model=2)),
(
"gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)",
mesh_shape_from_axes(data=-1, fsdp=8),
Expand Down

0 comments on commit f6ba43b

Please sign in to comment.