Skip to content

Commit

Permalink
add separate quantization primitives for float8
Browse files Browse the repository at this point in the history
ghstack-source-id: 51628a9d0c9bcdc03a77b1ddcb5ab002f49f856e
ghstack-comment-id: 2608048970
Pull Request resolved: #1597
  • Loading branch information
danielvegamyhre committed Jan 22, 2025
1 parent 32d9b0b commit 44951b2
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 13 deletions.
20 changes: 7 additions & 13 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,21 +66,13 @@
Int8DynActInt4WeightGPTQQuantizer,
Int8DynActInt4WeightQuantizer,
)
from .granularity import (
PerRow,
PerTensor,
)
from .granularity import PerRow, PerTensor
from .linear_activation_quantized_tensor import (
LinearActivationQuantizedTensor,
to_linear_activation_quantized,
)
from .qat import (
intx_quantization_aware_training,
)
from .quant_primitives import (
MappingType,
ZeroPointDomain,
)
from .qat import intx_quantization_aware_training
from .quant_primitives import MappingType, ZeroPointDomain
from .subclass import (
Int4WeightOnlyQuantizedLinearWeight,
Int8DynamicallyQuantizedLinearWeight,
Expand Down Expand Up @@ -915,10 +907,12 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight
quantization + 2:4 sparsity to linear layers.
"""
warnings.warn("""int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead.
warnings.warn(
"""int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead.
from torchao.dtypes import SemiSparseLayout
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()""")
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()"""
)

return int8_dynamic_activation_int8_weight(layout=SemiSparseLayout())

Expand Down
71 changes: 71 additions & 0 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@

import torch

from torchao.float8.float8_utils import (
ScalingGranularity,
)
from torchao.float8.float8_utils import (
tensor_to_scale as tensor_to_float8_scale,
)
from torchao.prototype.custom_fp_utils import (
_f32_to_floatx_unpacked,
_floatx_unpacked_to_f32,
Expand Down Expand Up @@ -39,6 +45,9 @@
"MappingType",
"ZeroPointDomain",
"TorchAODType",
"choose_qparams_affine_float8",
"quantize_affine_float8",
"dequantize_affine_float8",
]


Expand Down Expand Up @@ -1300,3 +1309,65 @@ def dequantize_affine_floatx(
tensor = tensor * scale.float().view(-1, 1)
tensor = tensor.to(dtype=output_dtype)
return tensor


def choose_qparams_affine_float8(
tensor: torch.Tensor, float8_dtype: torch.dtype
) -> torch.Tensor:
"""
Calculates float8 scaling factor for the given high precision tensor, using tensorwise granularity.
Args:
tensor (torch.Tensor): Input tensor to be quantized.
float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2).
"""
# NOTE: quantization primitives are hardcoded to use axiswise granularity w/ axis=1 right now:
# https://github.com/pytorch/ao/blob/5d1444bdef6df15eb89c4c5716ede1c5f8677798/torchao/dtypes/affine_quantized_tensor.py#L416
scale = tensor_to_float8_scale(
tensor,
float8_dtype,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=1,
)
return scale


def quantize_affine_float8(
tensor: torch.Tensor,
scale: torch.Tensor,
float8_dtype: torch.dtype,
) -> torch.Tensor:
"""
Quantizes the high precision floating point tensor to a float8 tensor, using the given scaling factor.
Args:
tensor (torch.Tensor): Input tensor to be quantized.
scale (torch.Tensor): Scaling factor for the quantization.
float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2).
"""
# Note: when the line below is compiled with `torch.compile`, `tensor` is automatically
# upcasted to `float32` to multiply with the scale
# In order to match numerics between eager and compile, we upcast manually here.
tensor_scaled = tensor.to(torch.float32) * scale
max_value = torch.finfo(float8_dtype).max
tensor_clamped = tensor_scaled.clamp(min=-max_value, max=max_value)
fp8_tensor = tensor_clamped.to(float8_dtype)
return fp8_tensor


def dequantize_affine_float8(
tensor: torch.Tensor,
scale: torch.Tensor,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Dequantizes the float8 tensor to float32 tensor.
Args:
tensor (torch.Tensor): Input float8 tensor to be dequantized.
scale (torch.Tensor): Scaling factor for the dequantization.
output_dtype (torch.dtype): Data type of the output tensor (e.g., torch.float32).
"""
fp8_tensor = tensor.to(torch.float32)
hp_tensor = fp8_tensor / scale
return hp_tensor.to(output_dtype)

0 comments on commit 44951b2

Please sign in to comment.