Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
Merge branch 'pytorch-labs:main' into fsdp2
Browse files Browse the repository at this point in the history
  • Loading branch information
weifengpy authored Jul 21, 2024
2 parents 969f91f + c58fb5d commit f475c40
Show file tree
Hide file tree
Showing 11 changed files with 430 additions and 185 deletions.
9 changes: 7 additions & 2 deletions float8_experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,16 @@
# LICENSE file in the root directory of this source tree.
# Lets define a few top level things here
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig
from float8_experimental.float8_tensor import (
Float8Tensor,
GemmInputRole,
LinearMMConfig,
ScaledMMConfig,
)

# Needed to load Float8Tensor with weights_only = True
from torch.serialization import add_safe_globals

add_safe_globals([Float8Tensor, ScaledMMConfig])
add_safe_globals([Float8Tensor, ScaledMMConfig, GemmInputRole, LinearMMConfig])

__all__ = ["Float8Tensor", "Float8Linear"]
30 changes: 22 additions & 8 deletions float8_experimental/float8_dynamic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

from float8_experimental.float8_tensor import (
Float8Tensor,
ScaledMMConfig,
GemmInputRole,
LinearMMConfig,
tensor_already_casted_to_fp8,
to_fp8_no_autograd,
)
Expand All @@ -26,9 +27,9 @@ class NoopFwToFloat8E5M2Bw(torch.autograd.Function):
def forward(
ctx,
tensor,
mm_config: ScaledMMConfig,
linear_mm_config: LinearMMConfig,
):
ctx.mm_config = mm_config
ctx.linear_mm_config = linear_mm_config
return tensor

@staticmethod
Expand All @@ -37,21 +38,34 @@ def backward(ctx, gradY):
return gradY, None
gradY_scale = tensor_to_scale(gradY, e5m2_dtype)
fp8_tensor = to_fp8_no_autograd(
gradY, gradY_scale, e5m2_dtype, mm_config=ctx.mm_config
gradY,
gradY_scale,
e5m2_dtype,
linear_mm_config=ctx.linear_mm_config,
gemm_input_role=GemmInputRole.DL_DY,
)
return fp8_tensor, None


def cast_to_float8_e4m3_dynamic(
inpt_tensor: torch.Tensor, mm_config: ScaledMMConfig, reduce_amax: bool = False
inpt_tensor: torch.Tensor,
linear_mm_config: LinearMMConfig,
reduce_amax: bool = False,
gemm_input_role: GemmInputRole = GemmInputRole.X,
) -> Float8Tensor:
if tensor_already_casted_to_fp8(inpt_tensor):
return inpt_tensor
scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax)
return Float8Tensor.to_float8(inpt_tensor, scale, e4m3_dtype, mm_config=mm_config)
return Float8Tensor.to_float8(
inpt_tensor,
scale,
e4m3_dtype,
linear_mm_config=linear_mm_config,
gemm_input_role=gemm_input_role,
)


def cast_to_float8_e5m2_dynamic_bw(
gradY: torch.Tensor, mm_config: ScaledMMConfig
gradY: torch.Tensor, linear_mm_config: LinearMMConfig
) -> torch.Tensor:
return NoopFwToFloat8E5M2Bw.apply(gradY, mm_config)
return NoopFwToFloat8E5M2Bw.apply(gradY, linear_mm_config)
50 changes: 33 additions & 17 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

from float8_experimental.float8_tensor import (
Float8Tensor,
GemmInputRole,
LinearMMConfig,
ScaledMMConfig,
to_fp8_no_autograd,
)
Expand Down Expand Up @@ -85,12 +87,12 @@ def forward(
fp8_scale_dL_dY,
scale_fn_name,
is_amax_initialized,
mm_config: ScaledMMConfig,
linear_mm_config: LinearMMConfig,
):
ctx.save_for_backward(fp8_amax_dL_dY, fp8_amax_history_dL_dY, fp8_scale_dL_dY)
ctx.scale_fn_name = scale_fn_name
ctx.is_amax_initialized = is_amax_initialized
ctx.mm_config = mm_config
ctx.linear_mm_config = linear_mm_config
return tensor

@staticmethod
Expand All @@ -113,7 +115,11 @@ def backward(ctx, go):
fp8_amax_dL_dY.fill_(tensor_to_amax(go))

res = to_fp8_no_autograd(
go, fp8_scale_dL_dY, e5m2_dtype, mm_config=ctx.mm_config
go,
fp8_scale_dL_dY,
e5m2_dtype,
linear_mm_config=ctx.linear_mm_config,
gemm_input_role=GemmInputRole.DL_DY,
)
empty_grads = None, None, None, None, None, None
return res, *empty_grads
Expand Down Expand Up @@ -192,12 +198,18 @@ def __init__(self, *args, **kwargs):

self.create_buffers()

# Defines the behavior of the matmul in the forward and backward pass
self.forward_config = ScaledMMConfig(
emulate, True if not emulate else False, False, config.pad_inner_dim
)
self.backward_config = ScaledMMConfig(
emulate, False, False, config.pad_inner_dim
# TODO(future): user level configuration of gemms
self.linear_mm_config = LinearMMConfig(
# x
ScaledMMConfig(
emulate, True if not emulate else False, False, config.pad_inner_dim
),
# w
ScaledMMConfig(
emulate, True if not emulate else False, False, config.pad_inner_dim
),
# dL_dY
ScaledMMConfig(emulate, False, False, config.pad_inner_dim),
)

# Note: is_amax_initialized is not a buffer to avoid data dependent
Expand Down Expand Up @@ -308,11 +320,12 @@ def cast_x_to_float8(
self.fp8_scale_x,
e4m3_dtype,
self.fp8_amax_x,
self.forward_config,
linear_mm_config=self.linear_mm_config,
gemm_input_role=GemmInputRole.X,
)
else:
assert self.scaling_type_x is TensorScalingType.DYNAMIC
x_fp8 = cast_to_float8_e4m3_dynamic(x, self.forward_config)
x_fp8 = cast_to_float8_e4m3_dynamic(x, self.linear_mm_config)
return x_fp8

def cast_w_to_float8(
Expand All @@ -339,14 +352,17 @@ def cast_w_to_float8(
self.fp8_scale_w,
e4m3_dtype,
self.fp8_amax_w,
self.forward_config,
linear_mm_config=self.linear_mm_config,
gemm_input_role=GemmInputRole.W,
)
else:
assert self.scaling_type_w is TensorScalingType.DYNAMIC
if isinstance(self.weight, Float8Tensor): # cast by FSDP
w_fp8 = self.weight
else:
w_fp8 = cast_to_float8_e4m3_dynamic(self.weight, self.forward_config)
w_fp8 = cast_to_float8_e4m3_dynamic(
self.weight, self.linear_mm_config, gemm_input_role=GemmInputRole.W
)
return w_fp8

def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor:
Expand All @@ -359,11 +375,11 @@ def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor:
self.fp8_scale_dL_dY,
scale_fn_name,
self.is_amax_initialized,
self.backward_config,
self.linear_mm_config,
)
else:
assert self.scaling_type_dL_dY is TensorScalingType.DYNAMIC
y = cast_to_float8_e5m2_dynamic_bw(y, self.backward_config)
y = cast_to_float8_e5m2_dynamic_bw(y, self.linear_mm_config)
return y

def float8_pre_forward(self, x):
Expand Down Expand Up @@ -457,7 +473,7 @@ def from_float(
new_mod.weight = torch.nn.Parameter(
WeightWithDynamicFloat8CastTensor(
new_mod.weight,
new_mod.forward_config,
new_mod.linear_mm_config,
)
)
else:
Expand All @@ -468,7 +484,7 @@ def from_float(
new_mod.fp8_amax_w,
new_mod.fp8_amax_history_w,
new_mod.fp8_scale_w,
new_mod.forward_config,
new_mod.linear_mm_config,
new_mod.is_amax_initialized,
)
)
Expand Down
Loading

0 comments on commit f475c40

Please sign in to comment.