-
Notifications
You must be signed in to change notification settings - Fork 20
[FSDP2] precompute scale after optimizer.step for dynamic scaling #266
Changes from 15 commits
9d5595c
e7005c2
e41d589
e0bee10
c0ba5a2
8da238e
ffff5ed
aefa21b
d4a1db7
d36e79b
6f244a2
dc5eab0
546e979
229ede6
d5b3ff6
4f05e04
3de59af
ffcd197
6b18947
562424c
75e0e45
fe95f8b
1cbaa13
fe2e0a0
e4eaa2a
e4245e4
e12c973
9ef67fb
fa2f08a
ba085e5
ac0afb0
8e56dfc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,7 +22,7 @@ | |
tensor_already_casted_to_fp8, | ||
to_fp8_no_autograd, | ||
) | ||
from float8_experimental.float8_utils import tensor_to_scale | ||
from float8_experimental.float8_utils import amax_to_scale, tensor_to_scale | ||
from torch._prims_common import suggest_memory_format | ||
|
||
|
||
|
@@ -151,6 +151,7 @@ def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig): | |
def __init__(self, tensor: torch.Tensor, mm_config: ScaledMMConfig): | ||
self._tensor = tensor | ||
self._mm_config = mm_config | ||
self._pre_computed_amax = None | ||
|
||
@classmethod | ||
def __torch_dispatch__(cls, func, types, args, kwargs=None): | ||
|
@@ -190,9 +191,20 @@ def __repr__(self): | |
return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config})" | ||
|
||
def fsdp_pre_all_gather(self, mesh): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if |
||
float8_tensor = cast_to_float8_e4m3fn( | ||
self._tensor, self._mm_config, reduce_amax=True | ||
) | ||
if self._pre_computed_amax is not None: | ||
scale = amax_to_scale( | ||
self._pre_computed_amax, | ||
torch.float8_e4m3fn, | ||
self._pre_computed_amax.dtype, | ||
clamp_amax=False, | ||
) | ||
float8_tensor = Float8Tensor.to_float8( | ||
self._tensor, scale, torch.float8_e4m3fn, mm_config=self._mm_config | ||
) | ||
else: | ||
float8_tensor = cast_to_float8_e4m3fn( | ||
self._tensor, self._mm_config, reduce_amax=True | ||
) | ||
return (float8_tensor._data,), (float8_tensor._scale,) | ||
|
||
def fsdp_post_all_gather( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,16 +5,22 @@ | |
# LICENSE file in the root directory of this source tree. | ||
import copy | ||
import logging | ||
|
||
import math | ||
import warnings | ||
from enum import auto, Enum | ||
from typing import Callable, List, Optional, Type | ||
|
||
import torch | ||
import torch.distributed as dist | ||
import torch.nn as nn | ||
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear | ||
from float8_experimental.float8_dynamic_linear import ( | ||
Float8DynamicLinear, | ||
WeightWithDynamicFloat8CastTensor, | ||
) | ||
from float8_experimental.float8_linear import Float8Linear | ||
|
||
from float8_experimental.float8_utils import amax_history_to_scale_stack | ||
from float8_experimental.float8_utils import amax_history_to_scale_stack, EPS | ||
from torch.distributed._functional_collectives import all_reduce, AsyncCollectiveTensor | ||
|
||
log = logging.getLogger(__name__) | ||
|
@@ -322,3 +328,34 @@ def inner_func(): | |
for child in fp8_layers: | ||
# Set a flag to signal amaxes/scales are ready | ||
child.amax_and_scale_synced = True | ||
|
||
|
||
def precompute_float8_amax(module: nn.Module) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we put this in I think the function name should include that this is intended for FSDP2 with float8 all-gather There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. moving to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. indicating fsdp by renaming to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @weifengpy do you plan / want to use compile on this, and are there any gaps around here that you think would be good to prioritize on the compile side? This is mostly just me remembering @awgu mention a while ago that he thought compile added noticeable runtime overhead, and I can't remember if it was for this specific case. If it is, and we think compiling this code would be useful, I can prioritize looking into the runtime overhead. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @bdhirsh, I plan to polish and land this PR without compile next week to conclude H1. most importantly add Reducing runtime overhead from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you have a mini repro showing bad runtime overheads with compile, that would be great! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @bdhirsh , I have created a repro pytorch/pytorch#129457 . I highlighted extra cpu overhead and gpu time for torch.compile(mode="reduce-overhead") |
||
from torch.distributed._tensor import DTensor | ||
|
||
if any(isinstance(m, Float8Linear) for m in module.modules()): | ||
raise NotImplementedError("Only supports Float8DynamicLinear, not Float8Linear") | ||
float8_linears: List[Float8DynamicLinear] = [ | ||
m | ||
for m in module.modules() | ||
if isinstance(m, Float8DynamicLinear) | ||
and isinstance(m.weight, DTensor) | ||
and isinstance(m.weight._local_tensor, WeightWithDynamicFloat8CastTensor) | ||
] | ||
weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears] | ||
|
||
def compute_amaxes(weights: List[DTensor]): | ||
max_weights = torch._foreach_norm(weights, ord=math.inf) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe add a comment that this is equivalent to max(abs(w))? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
amax_tensor = torch.vstack(max_weights) | ||
amax_tensor = torch.clamp(amax_tensor, EPS) # R | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So you are relying on If this fragments the code, could we just all-reduce the amax tensor and then leave the clamp to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks for the suggestions. I can collect feedback from float8 folks if they have a preference There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we just comment with what is going on? I think it's fine as long as the code is easy to understand and there is no magic. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. agreed |
||
amaxes = torch.split(amax_tensor, 1) # R | ||
return amaxes | ||
|
||
if weights: | ||
amaxes = compute_amaxes(weights) | ||
for amax, float8_linear in zip(amaxes, float8_linears): | ||
float8_linear.weight._local_tensor._pre_computed_amax = amax._local_tensor | ||
else: | ||
warnings.warn( | ||
"Calling precompute_float8_weights without any weights using FSDP fp8 all-gather!" | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,17 +27,24 @@ | |
|
||
@torch.no_grad() | ||
def amax_to_scale( | ||
amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype | ||
amax: torch.Tensor, | ||
float8_dtype: torch.dtype, | ||
orig_dtype: torch.dtype, | ||
clamp_amax: bool = True, | ||
): | ||
"""Converts the amax value of a tensor to the fp8 scale. | ||
Args: | ||
amax: The amax value of the tensor. | ||
float8_dtype: The float8 dtype. | ||
orig_dtype: The original dtype of the tensor. | ||
clamp_amax: default is True. False for FSDP fp8 all-gather since FSDP applied `torch.clamp` during pre-compute after optimizer.step | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is a bit confusing. How about precomputing the scale instead so we don't have to have gotchas like this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good suggestion! I changed the API to precompute scale and it shows another 9% speed up in unit test vs precomputing amax
|
||
""" | ||
scale = torch.empty_like(amax, dtype=torch.float32) | ||
if float8_dtype in FP8_TYPES: | ||
res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS) | ||
if clamp_amax: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: I think if you have this on a seperate line makes the logic a lil easier to follow |
||
res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS) | ||
else: | ||
res = torch.finfo(float8_dtype).max / amax | ||
else: | ||
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this need to be added to
__tensor_flatten__
?can we add some comments on intended usage of this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 on adding to flatten/unflatten and comments/ intended usage
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done