Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

[FSDP2] precompute scale after optimizer.step for dynamic scaling #266

Closed
wants to merge 32 commits into from
Closed
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
9d5595c
[FSDP2] set vocab_size=32 to avoid must be divisible by 16 error
weifengpy May 21, 2024
e7005c2
precast after optimizer.step and dump profiler traces
weifengpy May 21, 2024
e41d589
Merge branch 'main' into fsdp2
weifengpy May 21, 2024
e0bee10
precast and preamax unit test
weifengpy May 24, 2024
c0ba5a2
remove duplicate vocab
weifengpy May 24, 2024
8da238e
fused amax
weifengpy May 30, 2024
ffff5ed
Merge branch 'main' into fsdp2
weifengpy Jun 6, 2024
aefa21b
use FP8_TYPES and max
weifengpy Jun 6, 2024
d4a1db7
commit all changes before cleaning
weifengpy Jun 6, 2024
d36e79b
pre_compute and flatten / unflatten
weifengpy Jun 6, 2024
6f244a2
remove unused constant
weifengpy Jun 6, 2024
dc5eab0
torch.compile works
weifengpy Jun 6, 2024
546e979
eager ready
weifengpy Jun 6, 2024
229ede6
linter
weifengpy Jun 6, 2024
d5b3ff6
linter
weifengpy Jun 6, 2024
4f05e04
flatten tensor
weifengpy Jun 25, 2024
3de59af
commit all changes for review before rebasing
weifengpy Jul 8, 2024
ffcd197
rebase on unified float8linear
weifengpy Jul 9, 2024
6b18947
Merge branch 'pytorch-labs:main' into fsdp2
weifengpy Jul 9, 2024
562424c
move precompute to fsdp_utils.py
weifengpy Jul 9, 2024
75e0e45
simplify amax calc
weifengpy Jul 9, 2024
fe95f8b
explain _pre_computed_amax
weifengpy Jul 9, 2024
1cbaa13
fix linter
weifengpy Jul 9, 2024
fe2e0a0
document precompute_float8_amax_for_fsdp
weifengpy Jul 9, 2024
e4eaa2a
rename pre_compute to precompute
weifengpy Jul 9, 2024
e4245e4
Merge branch 'main' into fsdp2
weifengpy Jul 10, 2024
e12c973
remove clamp_amax=True/False
weifengpy Jul 10, 2024
9ef67fb
precompute scale
weifengpy Jul 10, 2024
fa2f08a
unit test for precomputing scales
weifengpy Jul 10, 2024
ba085e5
add precompute scale in README
weifengpy Jul 10, 2024
ac0afb0
rename to precompute_float8_dynamic_scale_for_fsdp
weifengpy Jul 11, 2024
8e56dfc
rename to precompute_float8_dynamic_scale_for_fsdp
weifengpy Jul 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 45 additions & 11 deletions float8_experimental/float8_dynamic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@

from typing import Any, Optional, Tuple

import float8_experimental.config as config

import torch
import torch.nn as nn
Copy link
Contributor Author

@weifengpy weifengpy Jul 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pre-commit hook triggers linter, and cleaned unused import

import torch.utils._pytree as pytree

from float8_experimental.float8_tensor import (
Expand All @@ -22,7 +19,12 @@
tensor_already_casted_to_fp8,
to_fp8_no_autograd,
)
from float8_experimental.float8_utils import e4m3_dtype, e5m2_dtype, tensor_to_scale
from float8_experimental.float8_utils import (
amax_to_scale,
e4m3_dtype,
e5m2_dtype,
tensor_to_scale,
)
from torch._prims_common import suggest_memory_format


Expand Down Expand Up @@ -85,7 +87,12 @@ def cast_to_float8_e5m2_dynamic_bw(

class WeightWithDynamicFloat8CastTensor(torch.Tensor):
@staticmethod
def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig):
def __new__(
cls,
tensor: torch.Tensor,
mm_config: ScaledMMConfig,
amax: Optional[torch.Tensor] = None,
):
return torch.Tensor._make_wrapper_subclass(
cls,
tensor.size(),
Expand All @@ -99,9 +106,18 @@ def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig):
requires_grad=tensor.requires_grad,
)

def __init__(self, tensor: torch.Tensor, mm_config: ScaledMMConfig):
def __init__(
self,
tensor: torch.Tensor,
mm_config: ScaledMMConfig,
amax: Optional[torch.Tensor] = None,
):
self._tensor = tensor
self._mm_config = mm_config
# for dynamic scaling
# `precompute_float8_amax_for_fsdp` calculates amax
# for all float8 parameters after optimizer step
self._precomputed_amax = amax

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
Expand Down Expand Up @@ -130,20 +146,38 @@ def unwrap(t):
)

def __tensor_flatten__(self):
return ["_tensor"], self._mm_config
if self._precomputed_amax:
return ["_tensor", "_precomputed_amax"], self._mm_config
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does having Optional[torch.Tensor] as a subclass field work with torch.compile? Or do we not care about torch.compile in this code path?

Copy link
Contributor Author

@weifengpy weifengpy Jul 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.compile assumes every tensor from __tensor_flatten__ is not None. I added if-else to make torch.compile work. I verified it in pytorch/pytorch#129457

else:
return ["_tensor"], self._mm_config

@staticmethod
def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
mm_config = flatten_spec
return WeightWithDynamicFloat8CastTensor(inner_tensors["_tensor"], mm_config)
return WeightWithDynamicFloat8CastTensor(
inner_tensors["_tensor"],
mm_config,
getattr(inner_tensors, "_precomputed_amax", None),
)

def __repr__(self):
return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config})"

def fsdp_pre_all_gather(self, mesh):
float8_tensor = cast_to_float8_e4m3_dynamic(
self._tensor, self._mm_config, reduce_amax=True
)
if self._precomputed_amax is not None:
scale = amax_to_scale(
self._precomputed_amax,
torch.float8_e4m3fn,
self._precomputed_amax.dtype,
clamp_amax=False,
)
float8_tensor = Float8Tensor.to_float8(
self._tensor, scale, torch.float8_e4m3fn, mm_config=self._mm_config
)
else:
float8_tensor = cast_to_float8_e4m3_dynamic(
self._tensor, self._mm_config, reduce_amax=True
)
return (float8_tensor._data,), (float8_tensor._scale,)

def fsdp_post_all_gather(
Expand Down
5 changes: 1 addition & 4 deletions float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,13 @@
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import copy
import logging
from enum import auto, Enum
from typing import Callable, List, Optional, Type, Union
from typing import Callable, List, Optional

import torch
import torch.distributed as dist
import torch.nn as nn
from float8_experimental.float8_linear import Float8Linear, TensorScalingType

from float8_experimental.float8_utils import (
amax_history_to_scale_stack,
e4m3_dtype,
Expand Down
9 changes: 7 additions & 2 deletions float8_experimental/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,22 @@

@torch.no_grad()
def amax_to_scale(
amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
amax: torch.Tensor,
float8_dtype: torch.dtype,
orig_dtype: torch.dtype,
clamp_amax: bool = True,
):
"""Converts the amax value of a tensor to the fp8 scale.
Args:
amax: The amax value of the tensor.
float8_dtype: The float8 dtype.
orig_dtype: The original dtype of the tensor.
clamp_amax: default is True. False for FSDP fp8 all-gather since FSDP applied `torch.clamp` during pre-compute after optimizer.step
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a bit confusing. How about precomputing the scale instead so we don't have to have gotchas like this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good suggestion! I changed the API to precompute scale and it shows another 9% speed up in unit test vs precomputing amax

fsdp_pre_all_gather is also greatly simplified because of using self._precomputed_scale

"""
scale = torch.empty_like(amax, dtype=torch.float32)
if float8_dtype in FP8_TYPES:
res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS)
amax = torch.clamp(amax, min=EPS) if clamp_amax else amax
res = torch.finfo(float8_dtype).max / amax
else:
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")

Expand Down
58 changes: 58 additions & 0 deletions float8_experimental/fsdp_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import math
import warnings
from typing import List

import torch
import torch.nn as nn
from float8_experimental.float8_dynamic_utils import WeightWithDynamicFloat8CastTensor
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear_utils import linear_requires_sync
from float8_experimental.float8_utils import EPS


def precompute_float8_amax_for_fsdp(module: nn.Module) -> None:
"""
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

improve docstring with example API usage

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice! can we add this to the README?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just added API usage to README

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe we can make sure dynamic is in the name, since this is specific to dynamic scaling?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

renaming to precompute_float8_dynamic_scale_for_fsdp

Calculate amax for all float8 parameters after optimizer step
It performs a single all-reduce instead of many all-reduces for each parameter
Exmaple usage:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit (typo):

Suggested change
Exmaple usage:
Example usage:

@vkuzo I assume that there are no docs builds for float8_experimental, so this example is for users who will read the code itself?

Otherwise, we might need to check the formatting -- I recall the format for examples being a bit different.

model(input).sum().backward()
optim.step()
precompute_float8_amax_for_fsdp(model)
"""
from torch.distributed._tensor import DTensor

if any(
weifengpy marked this conversation as resolved.
Show resolved Hide resolved
isinstance(m, Float8Linear)
and linear_requires_sync(
m.scaling_type_x, m.scaling_type_w, m.scaling_type_dL_dY
)
for m in module.modules()
):
raise NotImplementedError("Only supports delayed scaling")
float8_linears: List[Float8Linear] = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this expensive for real models? if yes, maybe we can offer option to precompute this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My intuition is that this should be pretty fast as the number of nn.Modules in the model is usually at most in the thousands and this is pure Python overhead. @weifengpy you can check the traces you have if you see any noticeable gaps from this.

Copy link
Contributor Author

@weifengpy weifengpy Jul 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just checked the profiler traces. it's roughly 0.15ms cpu overhead (5% of precompute_float8_dynamic_scale_for_fsdp and is tiny portion of 1 training loop). no cuda are launched

thus I am keeping it as is now for simplicity
Screenshot 2024-07-11 at 2 45 17 PM

m
for m in module.modules()
if isinstance(m, Float8Linear)
and isinstance(m.weight, DTensor)
and isinstance(m.weight._local_tensor, WeightWithDynamicFloat8CastTensor)
]
weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears]

def compute_amaxes(weights: List[DTensor]):
# inf-norm is equivalent to max(abs(w))
max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial
amax_tensor = torch.vstack(max_weights) # Partial
# clamp is dispatched through DTensor
# it will issue a single all-reduce
amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate
amaxes = torch.split(amax_tensor, 1) # Replicate
return amaxes

if weights:
amaxes = compute_amaxes(weights)
for amax, float8_linear in zip(amaxes, float8_linears):
float8_linear.weight._local_tensor._precomputed_amax = amax._local_tensor
else:
warnings.warn(
"Calling precompute_float8_weights without any weights using FSDP fp8 all-gather!"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

function name in the warning needs to be updated

I am okay with not including this warning by the way. This was also to help debugging to make sure we actually found weights.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got you. I am removing the warnings for simplicity

)
7 changes: 5 additions & 2 deletions test/test_fsdp2/test_fsdp2_common.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import contextlib
from typing import List, Type
from typing import List

import float8_experimental.config as config

import torch
import torch.distributed as dist
import torch.nn as nn
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.fsdp_utils import precompute_float8_amax_for_fsdp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be very confused, but where is this fsdp_utils file?
Screenshot 2024-07-09 at 7 20 22 PM

and a nit: maybe should be consistent in whether pre_compute vs. precompute across the PR
(either way is fine to me, precompute maybe since it is more concise)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be very confused, but where is this fsdp_utils file?

just added float8_experimental/fsdp_utils.py

aybe should be consistent in whether pre_compute vs. precompute across the PR

good suggestion. I will use precompute consistently



def check_parity_no_mp(
Expand All @@ -16,6 +16,7 @@ def check_parity_no_mp(
fsdp_model: nn.Module,
fsdp_optim: torch.optim.Optimizer,
local_inp: torch.Tensor,
precompute: bool = False,
):
for iter_idx in range(10):
losses: List[torch.Tensor] = []
Expand All @@ -29,6 +30,8 @@ def check_parity_no_mp(
param.grad.div_(dist.get_world_size())
# TODO(future): add amax syncing once delayed scaling is supported
optim.step()
if model is fsdp_model and precompute:
precompute_float8_amax_for_fsdp(model)
test_cls.assertEqual(losses[0], losses[1])


Expand Down
24 changes: 18 additions & 6 deletions test/test_fsdp2/test_fsdp2_eager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import copy
import itertools
import threading
import unittest
from typing import Any, List
Expand All @@ -9,7 +8,7 @@
import torch.distributed as dist
import torch.nn as nn
from float8_experimental.float8_dynamic_utils import WeightWithDynamicFloat8CastTensor
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
from float8_experimental.float8_linear import TensorScalingType
from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear
from test_fsdp2_common import (
check_parity_bf16_mp,
Expand Down Expand Up @@ -87,10 +86,21 @@ def world_size(self) -> int:

@skip_if_lt_x_gpu(2)
def test_transformer_parity_dynamic(self):
for enable_fsdp_fp8_all_gather in [False, True]:
self._test_transformer_parity_dynamic(enable_fsdp_fp8_all_gather)
self.run_subtests(
{
"enable_fsdp_fp8_all_gather": [False, True],
"precompute": [False, True],
},
self._test_transformer_parity_dynamic,
)

def _test_transformer_parity_dynamic(self, enable_fsdp_fp8_all_gather: bool):
def _test_transformer_parity_dynamic(
self,
enable_fsdp_fp8_all_gather: bool,
precompute: bool,
):
if not enable_fsdp_fp8_all_gather and precompute:
return
# NOTE: Weight-tying does not compose with fp8 all-gather because the
# embedding weight and output linear weight are tied but only the
# latter uses fp8 compute. With fp8 all-gather, FSDP would pre-cast to
Expand All @@ -110,7 +120,9 @@ def _test_transformer_parity_dynamic(self, enable_fsdp_fp8_all_gather: bool):
local_inp = torch.randint(
0, ref_module.tok_embeddings.weight.size(0), (16, 16), device="cuda"
)
check_parity_no_mp(self, ref_module, ref_optim, module, optim, local_inp)
check_parity_no_mp(
self, ref_module, ref_optim, module, optim, local_inp, precompute
)

@skip_if_lt_x_gpu(2)
def test_transformer_memory(self):
Expand Down
Loading