Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
add an option to pad inner_dims
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Jun 23, 2024
1 parent ef603c5 commit 53921b2
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 12 deletions.
6 changes: 6 additions & 0 deletions float8_experimental/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,9 @@
# If True, use 'fnuz' float8 types for calculations.
# Currently, ROCm only supports fnuz variants.
use_fnuz_dtype = False

# If True, then prior to performing the fp8 scaled mamtmul we will pad the
# inner dimension of a (dim 1) and b (dim 2) with 0s. This is needed for matmuls
# _scaled_mm since it has the strong constraint that for M,N,K N, K must be a multiple of 16.
# This can cause a memory spike however so we keep this off by default.
pad_inner_dim = False
8 changes: 6 additions & 2 deletions float8_experimental/float8_dynamic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,12 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
"bias": False,
}
new_mod = cls(**super_kwargs)
new_mod.forward_config = ScaledMMConfig(emulate, not bool(emulate))
new_mod.backward_config = ScaledMMConfig(emulate, False)
new_mod.forward_config = ScaledMMConfig(
emulate, not bool(emulate), False, config.pad_inner_dim
)
new_mod.backward_config = ScaledMMConfig(
emulate, False, False, config.pad_inner_dim
)
if config.enable_fsdp_fp8_all_gather:
new_mod.weight = nn.Parameter(
WeightWithDynamicFloat8CastTensor(mod.weight, new_mod.forward_config)
Expand Down
8 changes: 6 additions & 2 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,10 @@ def from_float(cls, mod, emulate: bool = False):
new_mod.create_buffers()
# Defines the behavior of the matmul in the forward and backward
# Forward we use fast_accum, backwards we do not
new_mod.forward_config = ScaledMMConfig(emulate, True if not emulate else False)
new_mod.backward_config = ScaledMMConfig(emulate, False)
new_mod.forward_config = ScaledMMConfig(
emulate, True if not emulate else False, False, config.pad_inner_dim
)
new_mod.backward_config = ScaledMMConfig(
emulate, False, False, config.pad_inner_dim
)
return new_mod
7 changes: 6 additions & 1 deletion float8_experimental/float8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
merge_mm_configs,
ScaledMMConfig,
)
from float8_experimental.float8_utils import is_row_major
from float8_experimental.float8_utils import is_row_major, pad_tensor_for_matmul

from torch.utils._pytree import tree_map

aten = torch.ops.aten
Expand Down Expand Up @@ -121,6 +122,10 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor):
a_scale = a._scale
b_data = b._data

if a._mm_config.pad_inner_dim:
a_data = pad_tensor_for_matmul(a_data, dims=1)
b_data = pad_tensor_for_matmul(b_data, dims=0)

if not is_row_major(a_data.stride()):
a_data = a_data.contiguous()
if is_row_major(b_data.stride()):
Expand Down
6 changes: 4 additions & 2 deletions float8_experimental/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@
# emulate: whether to emulate the matmuls in fp32
# use_fast_accum: whether to use the fast-accumulation option for scaled_mm
# fp8_output: whether to output the result of the scaled_mm in fp8
# pad_inner_dim: whether to pad the inner dimension of a and b with 0s. This is needed for matmuls not aligned to 16.
ScaledMMConfig = namedtuple(
"ScaledMMConfig",
["emulate", "use_fast_accum", "fp8_output"],
defaults=[False, False, False],
["emulate", "use_fast_accum", "fp8_output", "pad_inner_dim"],
defaults=[False, False, False, False],
)


Expand All @@ -48,6 +49,7 @@ def merge_mm_configs(
emulate=a_mm_config.emulate,
use_fast_accum=a_mm_config.use_fast_accum and b_mm_config.use_fast_accum,
fp8_output=a_mm_config.fp8_output and b_mm_config.fp8_output,
pad_inner_dim=a_mm_config.pad_inner_dim and b_mm_config.pad_inner_dim,
)


Expand Down
15 changes: 10 additions & 5 deletions float8_experimental/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from typing import Literal, Tuple
from typing import Iterable, Literal, Tuple, Union

import float8_experimental.config as config

Expand Down Expand Up @@ -197,7 +197,9 @@ def get_min_alignment(size: int, alignment_value: int):
return (1 + (size // alignment_value)) * alignment_value


def pad_tensor_for_matmul(tensor: torch.Tensor, both: bool = False) -> torch.Tensor:
def pad_tensor_for_matmul(
tensor: torch.Tensor, dims: Union[int, Iterable[int]]
) -> torch.Tensor:
"""
Pads a 2D tensor with zeros to ensure that its dimensions are multiples of 16, which is required for H100s.
Expand All @@ -211,9 +213,12 @@ def pad_tensor_for_matmul(tensor: torch.Tensor, both: bool = False) -> torch.Ten
assert tensor.dim() == 2
dim1, dim2 = tensor.shape

# Calculate aligned dimensions
dim2_aligned = get_min_alignment(dim2, 16)
dim1_aligned = get_min_alignment(dim1, 16) if both else dim1
if isinstance(dims, int):
dims = (dims,)

# Calculate aligned dimensions based on the specified dims
dim1_aligned = get_min_alignment(dim1, 16) if 0 in dims else dim1
dim2_aligned = get_min_alignment(dim2, 16) if 1 in dims else dim2

# Check if padding is needed for either dimension
if dim1 == dim1_aligned and dim2 == dim2_aligned:
Expand Down

0 comments on commit 53921b2

Please sign in to comment.