Skip to content

Commit

Permalink
disable reshard after forward (#56)
Browse files Browse the repository at this point in the history
Co-authored-by: Srini Iyer <[email protected]>
  • Loading branch information
sriniiyer and Srini Iyer authored Feb 13, 2025
1 parent 48e4ad0 commit 9d907fe
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions bytelatent/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,16 +146,16 @@ def build_fsdp_grouping_plan(model_args: LMTransformerArgs):
group_plan.append(("output", True))
else:
for i in range(model_args.n_layers_local_encoder):
group_plan.append((f"local_encoder.layers.{i}", True))
group_plan.append((f"local_encoder.cross_attn_layers.{i}", True))
group_plan.append((f"local_encoder.layers.{i}", False))
group_plan.append((f"local_encoder.cross_attn_layers.{i}", False))
for i in range(model_args.n_layers_local_decoder):
group_plan.append((f"local_decoder.layers.{i}", True))
group_plan.append((f"local_decoder.cross_attn_layers.{i}", True))
group_plan.append((f"local_decoder.layers.{i}", False))
group_plan.append((f"local_decoder.cross_attn_layers.{i}", False))
for i in range(model_args.n_layers_global):
group_plan.append((f"global_transformer.layers.{i}", True))
group_plan.append((f"global_transformer.layers.{i}", False))

for i in range(len(model_args.encoder_hash_byte_group_size)):
group_plan.append((f"encoder_hash_tok_embedding.{i}", True))
group_plan.append((f"encoder_hash_tok_embedding.{i}", False))

return group_plan

Expand Down

0 comments on commit 9d907fe

Please sign in to comment.