Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
add precompute scale in README
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
weifengpy committed Jul 10, 2024
1 parent fa2f08a commit ba085e5
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,18 @@ model = FSDP(model, use_orig_params=True)
# optional: enable torch.compile for improved performance
m = torch.compile(m)

# train/finetune (not shown)
# toy training loop
for _ in range(N_ITER):
optimizer.zero_grad()
y = m(x)
y.sum().backward()
optimizer.step()

# specific to fsdp2 + float8 with dynamic scaling
# this method is optional but is highly recommended for performance
# it calcuclates scales for all parameters in a single all-reduce
precompute_float8_scale_for_fsdp(model)

```

## float8 linear with delayed scaling
Expand Down

0 comments on commit ba085e5

Please sign in to comment.