-
Notifications
You must be signed in to change notification settings - Fork 20
[FSDP2] precompute scale after optimizer.step for dynamic scaling #266
Changes from 25 commits
9d5595c
e7005c2
e41d589
e0bee10
c0ba5a2
8da238e
ffff5ed
aefa21b
d4a1db7
d36e79b
6f244a2
dc5eab0
546e979
229ede6
d5b3ff6
4f05e04
3de59af
ffcd197
6b18947
562424c
75e0e45
fe95f8b
1cbaa13
fe2e0a0
e4eaa2a
e4245e4
e12c973
9ef67fb
fa2f08a
ba085e5
ac0afb0
8e56dfc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,10 +9,7 @@ | |
|
||
from typing import Any, Optional, Tuple | ||
|
||
import float8_experimental.config as config | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.utils._pytree as pytree | ||
|
||
from float8_experimental.float8_tensor import ( | ||
|
@@ -22,7 +19,12 @@ | |
tensor_already_casted_to_fp8, | ||
to_fp8_no_autograd, | ||
) | ||
from float8_experimental.float8_utils import e4m3_dtype, e5m2_dtype, tensor_to_scale | ||
from float8_experimental.float8_utils import ( | ||
amax_to_scale, | ||
e4m3_dtype, | ||
e5m2_dtype, | ||
tensor_to_scale, | ||
) | ||
from torch._prims_common import suggest_memory_format | ||
|
||
|
||
|
@@ -85,7 +87,12 @@ def cast_to_float8_e5m2_dynamic_bw( | |
|
||
class WeightWithDynamicFloat8CastTensor(torch.Tensor): | ||
@staticmethod | ||
def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig): | ||
def __new__( | ||
cls, | ||
tensor: torch.Tensor, | ||
mm_config: ScaledMMConfig, | ||
amax: Optional[torch.Tensor] = None, | ||
): | ||
return torch.Tensor._make_wrapper_subclass( | ||
cls, | ||
tensor.size(), | ||
|
@@ -99,9 +106,18 @@ def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig): | |
requires_grad=tensor.requires_grad, | ||
) | ||
|
||
def __init__(self, tensor: torch.Tensor, mm_config: ScaledMMConfig): | ||
def __init__( | ||
self, | ||
tensor: torch.Tensor, | ||
mm_config: ScaledMMConfig, | ||
amax: Optional[torch.Tensor] = None, | ||
): | ||
self._tensor = tensor | ||
self._mm_config = mm_config | ||
# for dynamic scaling | ||
# `precompute_float8_amax_for_fsdp` calculates amax | ||
# for all float8 parameters after optimizer step | ||
self._precomputed_amax = amax | ||
|
||
@classmethod | ||
def __torch_dispatch__(cls, func, types, args, kwargs=None): | ||
|
@@ -130,20 +146,38 @@ def unwrap(t): | |
) | ||
|
||
def __tensor_flatten__(self): | ||
return ["_tensor"], self._mm_config | ||
if self._precomputed_amax: | ||
return ["_tensor", "_precomputed_amax"], self._mm_config | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does having There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. torch.compile assumes every tensor from |
||
else: | ||
return ["_tensor"], self._mm_config | ||
|
||
@staticmethod | ||
def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): | ||
mm_config = flatten_spec | ||
return WeightWithDynamicFloat8CastTensor(inner_tensors["_tensor"], mm_config) | ||
return WeightWithDynamicFloat8CastTensor( | ||
inner_tensors["_tensor"], | ||
mm_config, | ||
getattr(inner_tensors, "_precomputed_amax", None), | ||
) | ||
|
||
def __repr__(self): | ||
return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config})" | ||
|
||
def fsdp_pre_all_gather(self, mesh): | ||
float8_tensor = cast_to_float8_e4m3_dynamic( | ||
self._tensor, self._mm_config, reduce_amax=True | ||
) | ||
if self._precomputed_amax is not None: | ||
scale = amax_to_scale( | ||
self._precomputed_amax, | ||
torch.float8_e4m3fn, | ||
self._precomputed_amax.dtype, | ||
clamp_amax=False, | ||
) | ||
float8_tensor = Float8Tensor.to_float8( | ||
self._tensor, scale, torch.float8_e4m3fn, mm_config=self._mm_config | ||
) | ||
else: | ||
float8_tensor = cast_to_float8_e4m3_dynamic( | ||
self._tensor, self._mm_config, reduce_amax=True | ||
) | ||
return (float8_tensor._data,), (float8_tensor._scale,) | ||
|
||
def fsdp_post_all_gather( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,17 +34,22 @@ | |
|
||
@torch.no_grad() | ||
def amax_to_scale( | ||
amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype | ||
amax: torch.Tensor, | ||
float8_dtype: torch.dtype, | ||
orig_dtype: torch.dtype, | ||
clamp_amax: bool = True, | ||
): | ||
"""Converts the amax value of a tensor to the fp8 scale. | ||
Args: | ||
amax: The amax value of the tensor. | ||
float8_dtype: The float8 dtype. | ||
orig_dtype: The original dtype of the tensor. | ||
clamp_amax: default is True. False for FSDP fp8 all-gather since FSDP applied `torch.clamp` during pre-compute after optimizer.step | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is a bit confusing. How about precomputing the scale instead so we don't have to have gotchas like this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good suggestion! I changed the API to precompute scale and it shows another 9% speed up in unit test vs precomputing amax
|
||
""" | ||
scale = torch.empty_like(amax, dtype=torch.float32) | ||
if float8_dtype in FP8_TYPES: | ||
res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS) | ||
amax = torch.clamp(amax, min=EPS) if clamp_amax else amax | ||
res = torch.finfo(float8_dtype).max / amax | ||
else: | ||
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") | ||
|
||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,58 @@ | ||||||
import math | ||||||
import warnings | ||||||
from typing import List | ||||||
|
||||||
import torch | ||||||
import torch.nn as nn | ||||||
from float8_experimental.float8_dynamic_utils import WeightWithDynamicFloat8CastTensor | ||||||
from float8_experimental.float8_linear import Float8Linear | ||||||
from float8_experimental.float8_linear_utils import linear_requires_sync | ||||||
from float8_experimental.float8_utils import EPS | ||||||
|
||||||
|
||||||
def precompute_float8_amax_for_fsdp(module: nn.Module) -> None: | ||||||
""" | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. improve docstring with example API usage There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice! can we add this to the README? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just added API usage to README There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: maybe we can make sure There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. renaming to |
||||||
Calculate amax for all float8 parameters after optimizer step | ||||||
It performs a single all-reduce instead of many all-reduces for each parameter | ||||||
Exmaple usage: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit (typo):
Suggested change
@vkuzo I assume that there are no docs builds for Otherwise, we might need to check the formatting -- I recall the format for examples being a bit different. |
||||||
model(input).sum().backward() | ||||||
optim.step() | ||||||
precompute_float8_amax_for_fsdp(model) | ||||||
""" | ||||||
from torch.distributed._tensor import DTensor | ||||||
|
||||||
if any( | ||||||
weifengpy marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
isinstance(m, Float8Linear) | ||||||
and linear_requires_sync( | ||||||
m.scaling_type_x, m.scaling_type_w, m.scaling_type_dL_dY | ||||||
) | ||||||
for m in module.modules() | ||||||
): | ||||||
raise NotImplementedError("Only supports delayed scaling") | ||||||
float8_linears: List[Float8Linear] = [ | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this expensive for real models? if yes, maybe we can offer option to precompute this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My intuition is that this should be pretty fast as the number of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
m | ||||||
for m in module.modules() | ||||||
if isinstance(m, Float8Linear) | ||||||
and isinstance(m.weight, DTensor) | ||||||
and isinstance(m.weight._local_tensor, WeightWithDynamicFloat8CastTensor) | ||||||
] | ||||||
weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears] | ||||||
|
||||||
def compute_amaxes(weights: List[DTensor]): | ||||||
# inf-norm is equivalent to max(abs(w)) | ||||||
max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial | ||||||
amax_tensor = torch.vstack(max_weights) # Partial | ||||||
# clamp is dispatched through DTensor | ||||||
# it will issue a single all-reduce | ||||||
amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate | ||||||
amaxes = torch.split(amax_tensor, 1) # Replicate | ||||||
return amaxes | ||||||
|
||||||
if weights: | ||||||
amaxes = compute_amaxes(weights) | ||||||
for amax, float8_linear in zip(amaxes, float8_linears): | ||||||
float8_linear.weight._local_tensor._precomputed_amax = amax._local_tensor | ||||||
else: | ||||||
warnings.warn( | ||||||
"Calling precompute_float8_weights without any weights using FSDP fp8 all-gather!" | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. function name in the warning needs to be updated I am okay with not including this warning by the way. This was also to help debugging to make sure we actually found There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. got you. I am removing the warnings for simplicity |
||||||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,12 @@ | ||
import contextlib | ||
from typing import List, Type | ||
from typing import List | ||
|
||
import float8_experimental.config as config | ||
|
||
import torch | ||
import torch.distributed as dist | ||
import torch.nn as nn | ||
from float8_experimental.float8_linear import Float8Linear | ||
from float8_experimental.fsdp_utils import precompute_float8_amax_for_fsdp | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
just added
good suggestion. I will use |
||
|
||
|
||
def check_parity_no_mp( | ||
|
@@ -16,6 +16,7 @@ def check_parity_no_mp( | |
fsdp_model: nn.Module, | ||
fsdp_optim: torch.optim.Optimizer, | ||
local_inp: torch.Tensor, | ||
precompute: bool = False, | ||
): | ||
for iter_idx in range(10): | ||
losses: List[torch.Tensor] = [] | ||
|
@@ -29,6 +30,8 @@ def check_parity_no_mp( | |
param.grad.div_(dist.get_world_size()) | ||
# TODO(future): add amax syncing once delayed scaling is supported | ||
optim.step() | ||
if model is fsdp_model and precompute: | ||
precompute_float8_amax_for_fsdp(model) | ||
test_cls.assertEqual(losses[0], losses[1]) | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pre-commit hook triggers linter, and cleaned unused import