Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into api_ref_dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva committed Jan 24, 2025
2 parents 3a0dd4e + fb335e0 commit 9eea875
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 124 deletions.
37 changes: 0 additions & 37 deletions docs/source/sg_execution_times.rst

This file was deleted.

6 changes: 1 addition & 5 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@
to_affine_quantized_floatx,
to_affine_quantized_floatx_static,
# experimental, will be merged into floatx in the future
to_affine_quantized_fpx,
to_affine_quantized_intx,
to_affine_quantized_intx_static,
)
from .floatx import (
Float8Layout,
FloatxTensor,
FloatxTensorCoreLayout,
to_affine_quantized_fpx,
)
from .nf4tensor import NF4Tensor, to_nf4
from .uintx import (
Expand Down Expand Up @@ -54,6 +52,4 @@
"MarlinQQQLayout",
"Int4CPULayout",
"CutlassInt4PackedLayout",
"FloatxTensor",
"FloatxTensorCoreLayout",
]
87 changes: 66 additions & 21 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
MappingType,
ZeroPointDomain,
choose_qparams_affine,
choose_qparams_affine_floatx,
choose_qparams_and_quantize_affine_hqq,
dequantize_affine,
dequantize_affine_floatx,
quantize_affine,
quantize_affine_floatx,
)
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
Expand All @@ -33,6 +36,7 @@
"to_affine_quantized_floatx",
"to_affine_quantized_intx_static",
"to_affine_quantized_floatx_static",
"to_affine_quantized_fpx",
]


Expand Down Expand Up @@ -118,28 +122,40 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
if output_dtype is None:
output_dtype = self.dtype

data, scale, zero_point = self.tensor_impl.get_plain()
dq = dequantize_affine(
data,
self.block_size,
scale,
zero_point,
data.dtype,
self.quant_min,
self.quant_max,
self.zero_point_domain,
output_dtype=output_dtype,
)
from torchao.dtypes.uintx import TensorCoreTiledLayout
from torchao.dtypes.floatx import FloatxTensorCoreLayout

if isinstance(self._layout, TensorCoreTiledLayout):
# need to return to original shape if tensor was padded
# in preprocessing
# TODO: we could add an API for this if there are more use cases
# (e.g. dequant_post_process) in TensorImpl or Layout
for dim, dim_size in enumerate(self.shape):
dq = dq.narrow(dim, 0, dim_size)
return dq
if isinstance(self._layout, FloatxTensorCoreLayout):
int_data, scale = self.tensor_impl.get_plain()
return dequantize_affine_floatx(
int_data,
scale,
self._layout.ebits,
self._layout.mbits,
output_dtype=output_dtype,
)
else:
data, scale, zero_point = self.tensor_impl.get_plain()
dq = dequantize_affine(
data,
self.block_size,
scale,
zero_point,
data.dtype,
self.quant_min,
self.quant_max,
self.zero_point_domain,
output_dtype=output_dtype,
)
from torchao.dtypes.uintx import TensorCoreTiledLayout

if isinstance(self._layout, TensorCoreTiledLayout):
# need to return to original shape if tensor was padded
# in preprocessing
# TODO: we could add an API for this if there are more use cases
# (e.g. dequant_post_process) in TensorImpl or Layout
for dim, dim_size in enumerate(self.shape):
dq = dq.narrow(dim, 0, dim_size)
return dq

def __tensor_flatten__(self):
return ["tensor_impl"], [
Expand Down Expand Up @@ -379,6 +395,33 @@ def from_hp_to_floatx_static(
f"Unsupported dtype {target_dtype} for from_hp_to_floatx_static"
)

@classmethod
def from_hp_to_fpx(
cls,
input_float: torch.Tensor,
_layout: Layout,
):
from torchao.dtypes.floatx import FloatxTensorCoreLayout

assert isinstance(
_layout, FloatxTensorCoreLayout
), f"Only FloatxTensorCoreLayout is supported for floatx, got {_layout}"
original_shape = input_float.shape
input_float = _layout.pre_process(input_float)
# per axis quantization, where axis = 1
block_size = list(input_float.shape)
block_size[1] = 1

ebits, mbits = _layout.ebits, _layout.mbits
# Note: these ops are hardcoded to have per axis quantization (axis=1) right now
scale = choose_qparams_affine_floatx(input_float, ebits, mbits)
floatx_unpacked = quantize_affine_floatx(input_float, scale, ebits, mbits)
floatx_packed = _layout.post_process(floatx_unpacked)

tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
tensor_impl = tensor_impl_ctr(floatx_packed, scale, None, _layout)
return cls(tensor_impl, block_size, original_shape, dtype=input_float.dtype)

@property
def _layout(self) -> Layout:
return self.tensor_impl._layout
Expand Down Expand Up @@ -434,6 +477,8 @@ def _apply_fn_to_data(self, fn):
to_affine_quantized_intx_static = AffineQuantizedTensor.from_hp_to_intx_static
to_affine_quantized_floatx = AffineQuantizedTensor.from_hp_to_floatx
to_affine_quantized_floatx_static = AffineQuantizedTensor.from_hp_to_floatx_static
# experimental will be merged in to floatx
to_affine_quantized_fpx = AffineQuantizedTensor.from_hp_to_fpx

if TORCH_VERSION_AT_LEAST_2_5:
# Allow a model with AffineQuantizedTensor weights to be loaded with `weights_only=True`
Expand Down
4 changes: 0 additions & 4 deletions torchao/dtypes/floatx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from .float8_layout import Float8Layout
from .floatx_tensor_core_layout import (
FloatxTensor,
FloatxTensorCoreLayout,
from_scaled_tc_floatx,
to_affine_quantized_fpx,
to_scaled_tc_floatx,
)

Expand All @@ -12,6 +10,4 @@
"to_scaled_tc_floatx",
"from_scaled_tc_floatx",
"Float8Layout",
"to_affine_quantized_fpx",
"FloatxTensor",
]
57 changes: 0 additions & 57 deletions torchao/dtypes/floatx/floatx_tensor_core_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from torchao.dtypes.affine_quantized_tensor import (
AffineQuantizedTensor,
get_tensor_impl_constructor,
register_layout,
)
from torchao.dtypes.utils import (
Expand All @@ -23,11 +22,6 @@
_floatx_unpacked_to_f32,
_n_ones,
)
from torchao.quantization.quant_primitives import (
choose_qparams_affine_floatx,
dequantize_affine_floatx,
quantize_affine_floatx,
)

aten = torch.ops.aten
_ONES_TABLE = [_n_ones(i) for i in range(8)]
Expand Down Expand Up @@ -464,54 +458,6 @@ class FloatxTensorCoreLayout(Layout):
mbits: int


class FloatxTensor(AffineQuantizedTensor):
"""
Floatx quantized tensor subclass which inherits AffineQuantizedTensor class. It uses floating-point format defined by ebits (exponent bits) and mbits (mantissa bits) and supports float1 - float7 tensor types.
For details about float8 tensor type, please refer to https://github.com/pytorch/ao/blob/main/torchao/dtypes/floatx/float8_layout.py.
To see what happens during choose_qparams_and_quantize_affine_fpx, quantization and dequantization for floatx quantization,
please checkout https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py
and check the two quant primitive ops: choose_qparams_affine_floatx, quantize_affine_floatx and dequantize_affine_floatx.
"""

def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor:
if output_dtype is None:
output_dtype = self.dtype
int_data, scale = self.tensor_impl.get_plain()
return dequantize_affine_floatx(
int_data,
scale,
self._layout.ebits,
self._layout.mbits,
output_dtype=output_dtype,
)

@classmethod
def from_hp_to_floatx(
cls,
input_float: torch.Tensor,
_layout: Layout,
):
assert isinstance(
_layout, FloatxTensorCoreLayout
), f"Only FloatxTensorCoreLayout is supported for floatx, got {_layout}"
original_shape = input_float.shape
input_float = _layout.pre_process(input_float)
# per axis quantization, where axis = 1
block_size = list(input_float.shape)
block_size[1] = 1

ebits, mbits = _layout.ebits, _layout.mbits
# Note: these ops are hardcoded to have per axis quantization (axis=1) right now
scale = choose_qparams_affine_floatx(input_float, ebits, mbits)
floatx_unpacked = quantize_affine_floatx(input_float, scale, ebits, mbits)
floatx_packed = _layout.post_process(floatx_unpacked)

tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
tensor_impl = tensor_impl_ctr(floatx_packed, scale, None, _layout)
return cls(tensor_impl, block_size, original_shape, dtype=input_float.dtype)


@register_layout(FloatxTensorCoreLayout)
class FloatxTensorCoreAQTTensorImpl(AQTTensorImpl):
"""FloatxTensorCoreAQTTensorImpl represents a Tensor with dtype floatx(ebits=a, mbits=b),
Expand Down Expand Up @@ -713,6 +659,3 @@ def _linear_f16_bf16_act_floatx_weight_impl(input_tensor, weight_tensor, bias):
out += bias

return out.view(*act.shape[:-1], out_dim).to(act.dtype)


to_affine_quantized_fpx = FloatxTensor.from_hp_to_floatx

0 comments on commit 9eea875

Please sign in to comment.