Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

[wip] add scaling granularity #338

Open
wants to merge 2 commits into
base: gh/vkuzo/47/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions float8_experimental/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@ def short_str(self):
return "dyn"


class ScalingGranularity(enum.Enum):
"""
Defines the granularity of scaling strategies for casting to float8
"""

# A single scaling factor for the entire tensor
TENSORWISE = "tensorwise"
# Scaling factors computed along one axis of the tensor, reducing it to
# size 1.
AXISWISE = "axiswise"


@dataclass(frozen=True)
class CastConfig:
"""
Expand Down
7 changes: 6 additions & 1 deletion float8_experimental/float8_dynamic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional, Tuple, Union

import torch

from float8_experimental.config import ScalingGranularity
from float8_experimental.float8_tensor import (
Float8Tensor,
GemmInputRole,
Expand Down Expand Up @@ -52,10 +55,12 @@ def cast_to_float8_e4m3_dynamic(
linear_mm_config: LinearMMConfig,
reduce_amax: bool = False,
gemm_input_role: GemmInputRole = GemmInputRole.INPUT,
granularity: ScalingGranularity = ScalingGranularity.TENSORWISE,
dim: Optional[Union[int, Tuple[int]]] = None,
) -> Float8Tensor:
if tensor_already_casted_to_fp8(inpt_tensor):
return inpt_tensor
scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax)
scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax, granularity, dim)
return Float8Tensor.to_float8(
inpt_tensor,
scale,
Expand Down
6 changes: 0 additions & 6 deletions float8_experimental/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,12 +290,6 @@ def __new__(
linear_mm_config: Optional[LinearMMConfig],
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT,
):
assert (
scale.numel() == 1
), "Scale should contain a single value, but got: {} elements".format(
scale.numel()
)

self = torch.Tensor._make_wrapper_subclass(
cls,
data.size(),
Expand Down
25 changes: 20 additions & 5 deletions float8_experimental/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from typing import Iterable, Literal, Tuple, Union
from typing import Iterable, Literal, Optional, Tuple, Union

import float8_experimental.config as config
from float8_experimental.config import ScalingGranularity

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -100,8 +101,18 @@ def amax_history_to_scale_stack(


@torch.no_grad()
def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False) -> torch.Tensor:
amax = torch.max(torch.abs(x))
def tensor_to_amax(
x: torch.Tensor,
reduce_amax: bool = False,
granularity: ScalingGranularity = ScalingGranularity.AXISWISE,
dim: Optional[Union[int, Tuple[int]]] = None,
) -> torch.Tensor:
if granularity is ScalingGranularity.TENSORWISE:
amax = torch.max(torch.abs(x))
else:
assert granularity is ScalingGranularity.AXISWISE, "unsupported"
assert dim is not None, "unsupported"
amax = torch.amax(torch.abs(x), dim=dim, keepdim=True)

# If the user asked for distributed reduction, do it.
# If the user did not ask for it, assume that it will
Expand All @@ -114,9 +125,13 @@ def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False) -> torch.Tensor:

@torch.no_grad()
def tensor_to_scale(
x: torch.Tensor, float8_dtype: torch.dtype, reduce_amax: bool = False
x: torch.Tensor,
float8_dtype: torch.dtype,
reduce_amax: bool = False,
granularity: ScalingGranularity = ScalingGranularity.AXISWISE,
dim: Optional[Union[int, Tuple[int]]] = None,
) -> torch.Tensor:
amax = tensor_to_amax(x, reduce_amax=reduce_amax)
amax = tensor_to_amax(x, reduce_amax=reduce_amax, granularity=granularity, dim=dim)
return amax_to_scale(amax, float8_dtype, x.dtype)


Expand Down
11 changes: 11 additions & 0 deletions test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,17 @@ def test_weights_only_load(self):
buffer.seek(0)
_ = torch.load(buffer, weights_only=True)

def test_axiswise_dynamic_cast(self):
a = torch.randn(16, 32, dtype=torch.bfloat16)
linear_mm_config = LinearMMConfig()
a_fp8 = cast_to_float8_e4m3_dynamic(
a,
linear_mm_config,
granularity=ScalingGranularity.AXISWISE,
dim=0,
)
print(a_fp8)


class TestFloat8Linear:
def _test_linear_impl(
Expand Down
Loading