This repository has been archived by the owner on Aug 7, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
precompute scale after optimizer.step for dynamic scaling (#266)
Summary: Goal: improve float8 all-gather perf in FSDP2 by precomputing scales for all float8 params with a single all-reduce updated README for API usage: call `precompute_float8_scale_for_fsdp` inside the training loop after optimizer step ``` from float8_experimental.fsdp_utils import precompute_float8_amax_for_fsdp # inside the training loop model(input).sum().backward() optim.step() precompute_float8_scale_for_fsdp(model) ``` unit test `pytest -s test/test_fsdp2/test_fsdp2_eager.py -k test_transformer_parity_dynamic` **FSDP pre-forward**: shortend from 3ms to 1.8ms because of doing 1 all-reduce instead N small all-reduces <img width="703" alt="Screenshot 2024-05-30 at 12 38 24 AM" src="https://github.com/pytorch-labs/float8_experimental/assets/134637289/81361471-fde4-43e4-ad83-a8c5b39f0cf1"> <img width="720" alt="Screenshot 2024-05-30 at 12 48 14 AM" src="https://github.com/pytorch-labs/float8_experimental/assets/134637289/26202869-cf7d-4427-b87f-570e5dc39324"> **Pre-computing amax**: shortened from 5ms to 1.7ms, by switching from `torch._foreach_abs` + `torch.max(a)` to `torch._foreach_norm(weights, ord=math.inf)` <img width="1075" alt="Screenshot 2024-05-30 at 12 50 17 AM" src="https://github.com/pytorch-labs/float8_experimental/assets/134637289/823fb717-8f5b-42e9-afc8-6f6c34ab45b2"> <img width="1050" alt="Screenshot 2024-05-30 at 12 49 54 AM" src="https://github.com/pytorch-labs/float8_experimental/assets/134637289/5ea15f59-ec85-456b-a28c-3e672d2cdaae"> Pull Request resolved: #266 Reviewed By: vkuzo Differential Revision: D59562409 Pulled By: weifengpy fbshipit-source-id: 683c4719e20f6b30f39ca9109ee29e53981a2aec
- Loading branch information
1 parent
73fd168
commit 6cba2ae
Showing
5 changed files
with
122 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import math | ||
from typing import List | ||
|
||
import torch | ||
import torch.nn as nn | ||
from float8_experimental.float8_dynamic_utils import WeightWithDynamicFloat8CastTensor | ||
from float8_experimental.float8_linear import Float8Linear, TensorScalingType | ||
from float8_experimental.float8_utils import EPS | ||
|
||
|
||
@torch.no_grad() | ||
def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: | ||
""" | ||
Calculate scale dynamically for all float8 parameters. | ||
This should be run after the optimizer step. It performs a single all-reduce to compute the | ||
scales for all float8 weights. | ||
Example usage: | ||
model(input).sum().backward() | ||
optim.step() | ||
precompute_float8_dynamic_scale_for_fsdp(model) | ||
""" | ||
from torch.distributed._tensor import DTensor | ||
|
||
if any( | ||
isinstance(m, Float8Linear) and m.scaling_type_w is TensorScalingType.DELAYED | ||
for m in module.modules() | ||
): | ||
raise NotImplementedError("Only supports delayed scaling") | ||
float8_linears: List[Float8Linear] = [ | ||
m | ||
for m in module.modules() | ||
if isinstance(m, Float8Linear) | ||
and isinstance(m.weight, DTensor) | ||
and isinstance(m.weight._local_tensor, WeightWithDynamicFloat8CastTensor) | ||
] | ||
weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears] | ||
|
||
if not weights: | ||
return | ||
|
||
# inf-norm is equivalent to max(abs(w)) | ||
max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial | ||
amax_tensor = torch.vstack(max_weights) # Partial | ||
# clamp is dispatched through DTensor | ||
# it will issue a single all-reduce | ||
amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate | ||
scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate | ||
if amax_tensor.dtype is torch.float16: | ||
scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max) | ||
scales = torch.split(scale_tensor, 1) # Replicate | ||
for scale, float8_linear in zip(scales, float8_linears): | ||
float8_linear.weight._local_tensor._precomputed_scale = scale._local_tensor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters