Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
delete Float8DynamicLinear (#304)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #304

We are standardizing on `Float8Linear` as the only float8 linear object:
1. the stack ending with
   #300 moved
   all of the functionality of `Float8DynamicLinear` to `Float8Linear`.
   The default settings of `Float8Linear` are to use dynamic scaling.
2. this PR deletes `Float8DynamicLinear` from the codebase and patches
   the relevant callsites in fbsource.

Reviewed By: drisspg

Differential Revision: D59342767

fbshipit-source-id: cfb09dd5f6517cfbf41d8b46eb6d7d6a5266006a
  • Loading branch information
vkuzo authored and facebook-github-bot committed Jul 5, 2024
1 parent d4cf2ad commit 8e9623a
Show file tree
Hide file tree
Showing 16 changed files with 182 additions and 536 deletions.
65 changes: 15 additions & 50 deletions benchmarks/bench_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,9 @@

import torch
import torch.utils.benchmark as benchmark
from float8_experimental.float8_linear import TensorScalingType
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
from float8_experimental.float8_linear_utils import (
get_float8_linear,
linear_requires_sync,
LinearType,
sync_float8_amax_and_scale_history,
)
from float8_experimental.float8_tensor import ScaledMMConfig
Expand Down Expand Up @@ -69,7 +67,6 @@ class Experiment:
dtype: torch.dtype
compiled: bool
use_fast_accum: bool
linear_type: str
scaling_repr: str

# 3 Times since we are calculating forward backward
Expand Down Expand Up @@ -98,7 +95,6 @@ def main(
n_limit: Optional[int] = None,
fast_accum_filter: Optional[bool] = None,
shape_name_filter: Optional[str] = None,
linear_type_filter: Optional[str] = None,
scaling_type_x: str = "delayed",
scaling_type_w: str = "delayed",
scaling_type_dL_dY: str = "delayed",
Expand All @@ -123,44 +119,28 @@ def main(
use_fast_accum = [fast_accum_filter]
else:
use_fast_accum = [True, False]
if linear_type_filter is not None:
linear_types = [linear_type_filter]
else:
linear_types = ["delayed", "dynamic"]
if shape_name_filter is not None:
k = shape_name_filter
name_to_shapes_70b = {k: name_to_shapes_70b[k]}
experiment_list: List[Experiment] = []
dtype = torch.bfloat16
for idx, (fast_accum, (name, (K, N)), linear_type) in enumerate(
tqdm(list(product(use_fast_accum, name_to_shapes_70b.items(), linear_types)))
for idx, (fast_accum, (name, (K, N))) in enumerate(
tqdm(list(product(use_fast_accum, name_to_shapes_70b.items())))
):
if n_limit is not None and idx >= n_limit:
break
linear_ref = torch.nn.Linear(K, N, bias=input_bias).to(
device=device, dtype=dtype
)
linear_type_enum = (
LinearType.DELAYED if linear_type == "delayed" else LinearType.DYNAMIC
)

if linear_type == "delayed":
linear_float8 = get_float8_linear(
linear_type_enum,
copy.deepcopy(linear_ref),
emulate=False,
scaling_type_x=scaling_type_x,
scaling_type_w=scaling_type_w,
scaling_type_dL_dY=scaling_type_dL_dY,
)
scaling_repr = linear_float8.scaling_repr()
else:
linear_float8 = get_float8_linear(
linear_type_enum,
copy.deepcopy(linear_ref),
emulate=False,
)
scaling_repr = None
linear_float8 = Float8Linear.from_float(
copy.deepcopy(linear_ref),
emulate=False,
scaling_type_x=scaling_type_x,
scaling_type_w=scaling_type_w,
scaling_type_dL_dY=scaling_type_dL_dY,
)
scaling_repr = linear_float8.scaling_repr()

if fast_accum:
linear_float8.forward_config = ScaledMMConfig(False, True, False)
Expand All @@ -172,19 +152,10 @@ def main(
input_tensor = torch.randn(M, K, device=device, dtype=dtype, requires_grad=True)
ref_forw_backward = lambda: linear_ref(input_tensor).sum().backward()

if linear_type_enum == LinearType.DELAYED:

def float8_forw_backward():
if linear_requires_sync(
linear_type_enum, scaling_type_x, scaling_type_w, scaling_type_dL_dY
):
sync_float8_amax_and_scale_history(linear_float8)
linear_float8(input_tensor).sum().backward()

else:

def float8_forw_backward():
linear_float8(input_tensor).sum().backward()
def float8_forw_backward():
if linear_requires_sync(scaling_type_x, scaling_type_w, scaling_type_dL_dY):
sync_float8_amax_and_scale_history(linear_float8)
linear_float8(input_tensor).sum().backward()

def n_times(n, fn, *args, **kwargs):
def wrapper(*args, **kwargs):
Expand Down Expand Up @@ -224,7 +195,6 @@ def wrapper(*args, **kwargs):
dtype,
compile,
use_fast_accum=fast_accum,
linear_type=linear_type,
scaling_repr=scaling_repr,
)
print(experiment)
Expand All @@ -237,7 +207,6 @@ def wrapper(*args, **kwargs):
"M",
"K",
"N",
"linear_type",
"scaling_repr",
"ref_dtype",
"compiled",
Expand All @@ -257,7 +226,6 @@ def wrapper(*args, **kwargs):
experiment.shape[0],
experiment.shape[1],
experiment.shape[2],
experiment.linear_type,
experiment.scaling_repr,
experiment.dtype,
experiment.compiled,
Expand Down Expand Up @@ -287,7 +255,6 @@ def wrapper(*args, **kwargs):
[
"name",
"shape",
"linear_type",
"scaling_repr",
"compiled",
"use_fast_accum",
Expand All @@ -311,7 +278,6 @@ def invoke_main() -> None:
parser.add_argument("-n", "--n_limit", type=int, required=False)
parser.add_argument("--fast_accum_filter", type=bool, required=False)
parser.add_argument("--shape_name_filter", type=str, required=False)
parser.add_argument("--linear_type_filter", type=str, required=False)
parser.add_argument("--scaling_type_x", type=str, required=False)
parser.add_argument("--scaling_type_w", type=str, required=False)
parser.add_argument("--scaling_type_dL_dY", type=str, required=False)
Expand All @@ -330,7 +296,6 @@ def invoke_main() -> None:
args.n_limit,
args.fast_accum_filter,
args.shape_name_filter,
args.linear_type_filter,
**kwargs,
)

Expand Down
10 changes: 8 additions & 2 deletions benchmarks/bench_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch.multiprocessing as mp
import torch.nn as nn
import torch.utils.benchmark as benchmark
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
from float8_experimental.float8_linear_utils import (
swap_linear_with_float8_linear,
sync_float8_amax_and_scale_history,
Expand Down Expand Up @@ -65,7 +65,13 @@ def get_model(K, N, is_fp8, base_dtype=torch.float32):
modules.append(nn.ReLU())
m = nn.Sequential(*modules)
if is_fp8:
swap_linear_with_float8_linear(m, Float8Linear, emulate=False)
swap_linear_with_float8_linear(
m,
emulate=False,
scaling_type_x=TensorScalingType.DELAYED,
scaling_type_w=TensorScalingType.DELAYED,
scaling_type_dL_dY=TensorScalingType.DELAYED,
)
return m


Expand Down
49 changes: 23 additions & 26 deletions benchmarks/profile_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,9 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
from float8_experimental.float8_linear_utils import (
linear_requires_sync,
LinearType,
swap_linear_with_float8_linear,
sync_float8_amax_and_scale_history,
)
Expand Down Expand Up @@ -206,19 +204,25 @@ def profile_function(
def main(
profile_path_prefix: Path,
compile: bool = True,
linear_type: str = "dynamic",
scaling_type_x: str = "delayed",
scaling_type_w: str = "delayed",
scaling_type_dL_dY: str = "delayed",
scaling_type_x: str = "dynamic",
scaling_type_w: str = "dynamic",
scaling_type_dL_dY: str = "dynamic",
model_type: str = "linear",
dtype_filter: str = "both",
):
assert model_type in ("linear", "ln_linear", "norm_ffn_norm"), "unsupported"
assert dtype_filter in ("both", "float8", "bfloat16")

print(f"Compile is set to | {compile}")
print(f"Using Linear type: | {linear_type}")
print(f"model_type is set to | {model_type}")
scaling_type_x = TensorScalingType(scaling_type_x)
scaling_type_w = TensorScalingType(scaling_type_w)
scaling_type_dL_dY = TensorScalingType(scaling_type_dL_dY)
scaling_repr = "_".join(
[s.short_str() for s in (scaling_type_x, scaling_type_w, scaling_type_dL_dY)]
)

print(f"Compile is set to | {compile}")
print(f"model_type is set to | {model_type}")
print(f"scaling_repr is set to | {scaling_repr}")

device = "cuda"
ref_dtype = torch.bfloat16
Expand Down Expand Up @@ -249,21 +253,14 @@ def main(

m_ref = m_ref.to(device).to(ref_dtype)

linear_type = LinearType[linear_type.upper()]
linear_cls = (
Float8Linear if linear_type is LinearType.DELAYED else Float8DynamicLinear
)
extra_kwargs = {}
scaling_type_x = TensorScalingType(scaling_type_x)
scaling_type_w = TensorScalingType(scaling_type_w)
scaling_type_dL_dY = TensorScalingType(scaling_type_dL_dY)
if linear_type is LinearType.DELAYED:
extra_kwargs["scaling_type_x"] = scaling_type_x
extra_kwargs["scaling_type_w"] = scaling_type_w
extra_kwargs["scaling_type_dL_dY"] = scaling_type_dL_dY
extra_kwargs = {
"scaling_type_x": scaling_type_x,
"scaling_type_w": scaling_type_w,
"scaling_type_dL_dY": scaling_type_dL_dY,
}

m_float8 = copy.deepcopy(m_ref)
swap_linear_with_float8_linear(m_float8, linear_cls, **extra_kwargs)
swap_linear_with_float8_linear(m_float8, **extra_kwargs)

def ref_forw_backward(x):
out = m_ref(x)
Expand All @@ -281,9 +278,7 @@ def float8_forw_backward_wrapper(x):
# inspection of the fw+bw torch.compile without the scale
# syncing code
# TODO(future): make this better
if linear_requires_sync(
linear_type, scaling_type_x, scaling_type_w, scaling_type_dL_dY
):
if linear_requires_sync(scaling_type_x, scaling_type_w, scaling_type_dL_dY):
with record_function("scale_amax_and_scales"):
sync_amax_history(m_float8)
out = float8_forw(x)
Expand Down Expand Up @@ -345,7 +340,9 @@ def float8_forw_backward_wrapper(x):
if dtype_filter != "bfloat16":
# Profile Float8 Model
print("profiling float8")
float8_suffix = f"_{model_type}_float8_compile_{compile}_{linear_type}.json"
float8_suffix = (
f"_{model_type}_float8_compile_{compile}_{scaling_repr}.json"
)
float8_path = profile_path_prefix + float8_suffix
profile_config = ProfileConfig(
float8_path,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,64 +53,6 @@ def backward(ctx, gradY):
return fp8_tensor, None


class Float8DynamicLinear(torch.nn.Linear):
"""
A wrapper around a `torch.nn.Linear` module which does fp8 compute. By on the fly
conversion to fp8 of the input and weight tensors.
"""

def __init__(self, **super_kwargs):
super().__init__(**super_kwargs)

def forward(self, input: torch.Tensor) -> torch.Tensor:
x_fp8 = cast_to_float8_e4m3_dynamic(input, self.forward_config)
if isinstance(self.weight, Float8Tensor): # cast by FSDP
w_fp8 = self.weight
else:
w_fp8 = cast_to_float8_e4m3_dynamic(self.weight, self.forward_config)
y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias)
y = cast_to_float8_e5m2_dynamic_bw(y, self.backward_config)
return y

@classmethod
def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
"""
Create an nn.Linear with fp8 compute from a regular nn.Linear
Args:
mod (torch.nn.Linear): nn.Linear to convert
emulate (bool): whether to emulate fp8 matmul logic in float32
"""
with torch.device("meta"):
super_kwargs = {
"in_features": mod.in_features,
"out_features": mod.out_features,
"bias": False,
}
new_mod = cls(**super_kwargs)

new_mod.forward_config = ScaledMMConfig(
emulate=emulate,
use_fast_accum=not bool(emulate),
fp8_output=False,
pad_inner_dim=config.pad_inner_dim,
)
new_mod.backward_config = ScaledMMConfig(
emulate=emulate,
use_fast_accum=False,
fp8_output=False,
pad_inner_dim=config.pad_inner_dim,
)
if config.enable_fsdp_fp8_all_gather:
new_mod.weight = nn.Parameter(
WeightWithDynamicFloat8CastTensor(mod.weight, new_mod.forward_config)
)
else:
new_mod.weight = mod.weight
new_mod.bias = mod.bias
return new_mod


def cast_to_float8_e4m3_dynamic(
inpt_tensor: torch.Tensor, mm_config: ScaledMMConfig, reduce_amax: bool = False
) -> Float8Tensor:
Expand Down
6 changes: 3 additions & 3 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import torch

from float8_experimental.float8_dynamic_linear import (
from float8_experimental.float8_dynamic_utils import (
cast_to_float8_e4m3_dynamic,
cast_to_float8_e5m2_dynamic_bw,
WeightWithDynamicFloat8CastTensor,
Expand Down Expand Up @@ -402,8 +402,8 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:

def scaling_repr(self):
# add scaling settings without using too many characters
# example: "x:del,w:del,dldy:dyn"
return f"x:{self.scaling_type_x.short_str()},w:{self.scaling_type_w.short_str()},dldy:{self.scaling_type_dL_dY.short_str()}"
# example: "x_del_w_del_dldy_dyn"
return f"x_{self.scaling_type_x.short_str()}_w_{self.scaling_type_w.short_str()}_dldy_{self.scaling_type_dL_dY.short_str()}"

def extra_repr(self):
s = f'{super().extra_repr()}, scaling="{self.scaling_repr()}"'
Expand Down
Loading

0 comments on commit 8e9623a

Please sign in to comment.