Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vkuzo committed Jan 24, 2025
1 parent 2505639 commit 7bcb6c4
Showing 1 changed file with 54 additions and 6 deletions.
60 changes: 54 additions & 6 deletions torchao/prototype/mx_formats/mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
* Zeros: N/A
"""

from enum import Enum, auto
from typing import Dict, Union

import torch
Expand Down Expand Up @@ -53,11 +54,38 @@
unpack_uint4,
)

# TODO(before land): read from somewhere else?
SBITS, EBITS_F32, MBITS_F32 = 1, 8, 23
EBITS_F4_E2M1, MBITS_F4_E2M1 = 2, 1
EBITS_F6_E2M3, MBITS_F6_E2M3 = 2, 3
EBITS_F6_E3M2, MBITS_F6_E3M2 = 3, 2
EBITS_F8_E4M3, MBITS_F8_E4M3 = 4, 3
EBITS_F8_E5M2, MBITS_F8_E5M2 = 5, 2


class ScaleCalculationMode(Enum):
"""
Enum representing the different methods for calculating MX block scaling.
There are three methods available:
FLOOR: This method is recommended by the OCP MX Spec 1.0 and uses X = 2^floor(log2(max_abs(v))-max_exp).
It result in overflow issues for large values and bad for gradient quantization.
CEIL: This method avoids overflow issues, but small values may shift to 0 due to a large scaling factor.
It uses X = 2^ceil(log2(max_abs(v))-max_exp).
EVEN: This method is a trade-off between Option 1 and Option 2. It uses X = 2^(floor(log2(rounding(max_abs(v)))-max_exp)).
It provides better accuracy for MX4 training compared to FLOOR and CEIL.
By default, we use the EVEN method for better accuracy.
"""

FLOOR = auto()
CEIL = auto()
EVEN = auto()


def to_mx(
data_hp: torch.Tensor,
elem_dtype: Union[torch.dtype, str],
block_size: int,
scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR,
):
"""
Takes a high precision tensor and converts to MX scale and raw data, in
Expand Down Expand Up @@ -88,25 +116,45 @@ def to_mx(
# where the values are zero.
eps = F32_MIN_NORMAL * (max_abs == 0).type(max_abs.dtype)

# Find largest power of 2 less than or equal to max_abs.
largest_p2_lt_max_abs = torch.floor(torch.log2(max_abs + eps))

# Set X to be the largest power-of-two less than or equal to
# max_abs(v), divided by the largest power of two representable
# in the element data type
# in the element data type, and get the mbits at the same time
if elem_dtype == torch.float8_e4m3fn:
target_max_pow2 = F8E4M3_MAX_POW2
mbits = MBITS_F8_E4M3
elif elem_dtype == torch.float8_e5m2:
target_max_pow2 = F8E5M2_MAX_POW2
mbits = MBITS_F8_E5M2
elif elem_dtype == DTYPE_FP6_E2M3:
target_max_pow2 = F6_E2M3_MAX_POW2
mbits = MBITS_F6_E2M3
elif elem_dtype == DTYPE_FP6_E3M2:
target_max_pow2 = F6_E3M2_MAX_POW2
mbits = MBITS_F6_E3M2
elif elem_dtype == DTYPE_FP4:
target_max_pow2 = F4_E2M1_MAX_POW2
mbits = MBITS_F4_E2M1
else:
raise AssertionError("unsupported")
scale_e8m0_unbiased = largest_p2_lt_max_abs - target_max_pow2
raise AssertionError("unsupported element dtype")

# rounding before calculating the largest power of 2
# X = 2^(floor(log2(rounding(max_abs(v)))-max_exp))
if scaling_mode == ScaleCalculationMode.EVEN:
nan_mask = torch.isnan(max_abs)
max_abs = max_abs.to(torch.float32).view(torch.int32)
val_to_add = 1 << (MBITS_F32 - mbits - 1)
mask = ((1 << (EBITS_F32 + SBITS)) - 1) << MBITS_F32
max_abs = (max_abs + val_to_add) & mask
max_abs = max_abs.view(torch.float32)
max_abs[nan_mask] = torch.tensor(float("nan"), device=max_abs.device)

# Calculate the scale for different modes
if scaling_mode in (ScaleCalculationMode.FLOOR, ScaleCalculationMode.EVEN):
scale_e8m0_unbiased = torch.floor(torch.log2(max_abs + eps)) - target_max_pow2
elif scaling_mode == ScaleCalculationMode.CEIL:
scale_e8m0_unbiased = torch.ceil(torch.log2(max_abs + eps)) - target_max_pow2
else:
raise AssertionError("unsupported scaling calculation mode")

# Clamp to exponents that can be represented in e8m0
scale_e8m0_unbiased = torch.clamp(
Expand Down

0 comments on commit 7bcb6c4

Please sign in to comment.