Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
[1/x] clean up casting functions
Browse files Browse the repository at this point in the history
Summary:

This is a start of a cleanup of private casting functions in preparation for
rowwise scaling.  In this PR:
1. create `float8_scaling_utils.py` to unify functions which take a
   high precision tensor and return a float8 tensor, taking care of
   scaling
2. delete `Float8Tensor.to_float8` and move callsites to
   `ToFloat8ConstrFunc`, since the two functions do the same thing

The end result is a slightly cleaner state, future PRs will do more
cleanups.

Test Plan:

```
./test/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 30f7160c6751aa8bf82dfa35e9d8449d20a4028d
Pull Request resolved: #339
  • Loading branch information
vkuzo committed Jul 26, 2024
1 parent 4cc99da commit 96162b3
Show file tree
Hide file tree
Showing 11 changed files with 310 additions and 274 deletions.
29 changes: 25 additions & 4 deletions benchmarks/bench_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
import fire

import torch
from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig
from float8_experimental.float8_tensor import (
GemmInputRole,
LinearMMConfig,
ScaledMMConfig,
ToFloat8ConstrFunc,
)
from float8_experimental.float8_utils import pad_tensor_for_matmul
from tabulate import tabulate
from torch._inductor.utils import do_bench_using_profiling
Expand Down Expand Up @@ -50,9 +55,25 @@ def do_fp8_matmul(A, B, fp8_dtype, out_dtype):
b_config = ScaledMMConfig(
emulate=False, use_fast_accum=True, fp8_output=True, pad_inner_dim=True
)

a_fp8 = Float8Tensor.to_float8(A, scale_a, fp8_dtype, mm_config=a_config)
b_fp8 = Float8Tensor.to_float8(B, scale_b, fp8_dtype, mm_config=b_config)
a_config = LinearMMConfig(a_config, a_config, a_config)
b_config = LinearMMConfig(b_config, b_config, b_config)

a_fp8 = ToFloat8ConstrFunc.apply(
A,
scale_a,
fp8_dtype,
None, # amax_buffer
a_config,
GemmInputRole.INPUT,
)
b_fp8 = ToFloat8ConstrFunc.apply(
B,
scale_b,
fp8_dtype,
None, # amax_buffer
b_config,
GemmInputRole.WEIGHT,
)

return a_fp8 @ b_fp8

Expand Down
71 changes: 0 additions & 71 deletions float8_experimental/float8_dynamic_utils.py

This file was deleted.

112 changes: 10 additions & 102 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,61 +16,29 @@

from float8_experimental.config import Float8LinearConfig, ScalingType

from float8_experimental.float8_dynamic_utils import (
from float8_experimental.float8_scaling_utils import (
_maybe_initialize_amaxes_scales_for_float8_cast,
cast_to_float8_delayed,
cast_to_float8_e4m3_dynamic,
cast_to_float8_e5m2_dynamic_bw,
NoopFwToFloat8E5M2BwDelayed,
NoopFwToFloat8E5M2BwDynamic,
)

from float8_experimental.float8_tensor import (
Float8Tensor,
GemmInputRole,
LinearMMConfig,
ScaledMMConfig,
to_fp8_no_autograd,
)

from float8_experimental.float8_utils import (
amax_history_to_scale,
e4m3_dtype,
e5m2_dtype,
tensor_to_amax,
)
from float8_experimental.float8_utils import e4m3_dtype, e5m2_dtype, tensor_to_amax

from float8_experimental.fsdp_utils import (
WeightWithDelayedFloat8CastTensor,
WeightWithDynamicFloat8CastTensor,
)


def _maybe_initialize_amaxes_scales_for_float8_cast(
x,
cur_amax,
amax_history,
scale,
scale_fn_name,
float8_dtype,
is_initialized,
reduce_amax,
):
"""
If x is about to be cast to `float8` and the amax buffers are not initialized,
initializes them inplace.
"""
if is_initialized:
return
with torch.no_grad():
# Note: we need to enable distributed reduction here in order
# to match numerics between single GPU and multi GPU code for
# activations and gradients
new_amax = tensor_to_amax(x, reduce_amax=reduce_amax)
cur_amax.fill_(new_amax)
amax_history[0] = new_amax
new_scale = amax_history_to_scale(
amax_history, float8_dtype, x.dtype, scale_fn_name
)
scale.copy_(new_scale)


# this code was resurrected from https://github.com/pytorch-labs/float8_experimental/pull/128/files
@torch._dynamo.allow_in_graph
class manual_float8_matmul(torch.autograd.Function):
Expand Down Expand Up @@ -127,66 +95,6 @@ def backward(ctx, grad_output_fp8):
return grad_input, grad_weight.t()


@torch._dynamo.allow_in_graph
class NoopFwToFloat8E5M2Bw(torch.autograd.Function):
"""
Forward: no-op
Backward: convert to float8_e5m2, initialize if needed
"""

@staticmethod
def forward(
ctx,
tensor,
fp8_amax_grad_output,
fp8_amax_history_grad_output,
fp8_scale_grad_output,
scale_fn_name,
is_amax_initialized,
linear_mm_config: LinearMMConfig,
):
ctx.save_for_backward(
fp8_amax_grad_output, fp8_amax_history_grad_output, fp8_scale_grad_output
)
ctx.scale_fn_name = scale_fn_name
ctx.is_amax_initialized = is_amax_initialized
ctx.linear_mm_config = linear_mm_config
return tensor

@staticmethod
def backward(ctx, go):
(
fp8_amax_grad_output,
fp8_amax_history_grad_output,
fp8_scale_grad_output,
) = ctx.saved_tensors
scale_fn_name = ctx.scale_fn_name
is_amax_initialized = ctx.is_amax_initialized

_maybe_initialize_amaxes_scales_for_float8_cast(
go,
fp8_amax_grad_output,
fp8_amax_history_grad_output,
fp8_scale_grad_output,
scale_fn_name,
e5m2_dtype,
is_amax_initialized,
reduce_amax=True,
)

fp8_amax_grad_output.fill_(tensor_to_amax(go))

res = to_fp8_no_autograd(
go,
fp8_scale_grad_output,
e5m2_dtype,
linear_mm_config=ctx.linear_mm_config,
gemm_input_role=GemmInputRole.GRAD_OUTPUT,
)
empty_grads = None, None, None, None, None, None
return res, *empty_grads


class Float8Linear(torch.nn.Linear):
"""
Note: this is **not** a public API and is only intended to be used
Expand Down Expand Up @@ -352,7 +260,7 @@ def cast_input_to_float8(
is_amax_initialized,
reduce_amax=True,
)
input_fp8 = Float8Tensor.to_float8(
input_fp8 = cast_to_float8_delayed(
input,
self.fp8_scale_input,
e4m3_dtype,
Expand Down Expand Up @@ -384,7 +292,7 @@ def cast_weight_to_float8(
reduce_amax=False,
)

weight_fp8 = Float8Tensor.to_float8(
weight_fp8 = cast_to_float8_delayed(
weight,
self.fp8_scale_weight,
e4m3_dtype,
Expand All @@ -407,7 +315,7 @@ def cast_weight_to_float8(
def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
if self.scaling_type_grad_output is ScalingType.DELAYED:
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
output = NoopFwToFloat8E5M2Bw.apply(
output = NoopFwToFloat8E5M2BwDelayed.apply(
output,
self.fp8_amax_grad_output,
self.fp8_amax_history_grad_output,
Expand All @@ -418,7 +326,7 @@ def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
)
else:
assert self.scaling_type_grad_output is ScalingType.DYNAMIC
output = cast_to_float8_e5m2_dynamic_bw(output, self.linear_mm_config)
output = NoopFwToFloat8E5M2BwDynamic.apply(output, self.linear_mm_config)
return output

def float8_pre_forward(self, input):
Expand Down
Loading

0 comments on commit 96162b3

Please sign in to comment.