diff --git a/README.md b/README.md index 663094e60b..b28ff522ba 100644 --- a/README.md +++ b/README.md @@ -128,7 +128,7 @@ The best example we have combining the composability of lower bit dtype with com We've added support for authoring and releasing [custom ops](./torchao/csrc/) that do not graph break with `torch.compile()` so if you love writing kernels but hate packaging them so they work all operating systems and cuda versions, we'd love to accept contributions for your custom ops. We have a few examples you can follow -1. [fp6](torchao/prototype/quant_llm/) for 2x faster inference over fp16 with an easy to use API `quantize_(model, fp6_llm_weight_only())` +1. [fp6](torchao/dtypes/floatx) for 2x faster inference over fp16 with an easy to use API `quantize_(model, fpx_weight_only(3, 2))` 2. [2:4 Sparse Marlin GEMM](https://github.com/pytorch/ao/pull/733) 2x speedups for FP16xINT4 kernels even at batch sizes up to 256 3. [int4 tinygemm unpacker](https://github.com/pytorch/ao/pull/415) which makes it easier to switch quantized backends for inference diff --git a/benchmarks/benchmark_fp6.py b/benchmarks/benchmark_fp6.py index 6c7d46a945..e9f9d21398 100644 --- a/benchmarks/benchmark_fp6.py +++ b/benchmarks/benchmark_fp6.py @@ -1,15 +1,15 @@ import torch import pandas as pd import torch.nn.functional as F -from torchao.dtypes import to_affine_quantized_fpx -from torchao.dtypes.fpx import FpxTensorCoreAQTLayout, FpxTensorCoreLayoutType +from torchao.dtypes import to_affine_quantized_floatx +from torchao.dtypes.floatx import FloatxTensorCoreAQTLayout, FloatxTensorCoreLayoutType from torchao.utils import benchmark_torch_function_in_microseconds from tqdm import tqdm def benchmark(m: int, k: int, n: int): float_data = torch.randn(n, k, dtype=torch.half, device="cuda") - fp6_weight = to_affine_quantized_fpx(float_data, FpxTensorCoreLayoutType(3, 2)) + fp6_weight = to_affine_quantized_floatx(float_data, FloatxTensorCoreLayoutType(3, 2)) fp16_weight = fp6_weight.dequantize(torch.half) fp16_act = torch.randn(m, k, dtype=torch.half, device="cuda") diff --git a/test/dtypes/test_fpx.py b/test/dtypes/test_floatx.py similarity index 53% rename from test/dtypes/test_fpx.py rename to test/dtypes/test_floatx.py index 130bdadf38..b4776f95e1 100644 --- a/test/dtypes/test_fpx.py +++ b/test/dtypes/test_floatx.py @@ -8,14 +8,14 @@ parametrize, run_tests, ) -from torchao.dtypes.fpx import ( - FpxTensorCoreAQTLayout, - FpxTensorCoreLayoutType, - to_scaled_tc_fpx, - from_scaled_tc_fpx, +from torchao.dtypes.floatx import ( + FloatxTensorCoreAQTLayout, + FloatxTensorCoreLayoutType, + to_scaled_tc_floatx, + from_scaled_tc_floatx, ) -from torchao.dtypes.fpx.fpx import _pack_tc_fpx, _pack_tc_fp6 -from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32 +from torchao.dtypes.floatx.floatx import _pack_tc_floatx, _pack_tc_fp6 +from torchao.prototype.custom_fp_utils import _f32_to_floatx_unpacked, _floatx_unpacked_to_f32 from torchao.quantization import ( quantize_, fpx_weight_only, @@ -25,71 +25,71 @@ _DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) -_FPx_DTYPES = [(3, 2), (2, 2)] +_Floatx_DTYPES = [(3, 2), (2, 2)] -class TestFpxTensorCoreAQTLayout(TestCase): +class TestFloatxTensorCoreAQTLayout(TestCase): @parametrize("device", _DEVICES) def test_pack_tc_fp6_correctness(self, device): x = torch.randint(256, size=(256, 64), dtype=torch.uint8, device=device) - expected = _pack_tc_fpx(x, 6) + expected = _pack_tc_floatx(x, 6) actual = _pack_tc_fp6(x) torch.testing.assert_close(actual, expected) - @parametrize("ebits,mbits", _FPx_DTYPES) + @parametrize("ebits,mbits", _Floatx_DTYPES) @parametrize("device", _DEVICES) - def test_to_scaled_tc_fpx_compile(self, ebits, mbits, device): + def test_to_scaled_tc_floatx_compile(self, ebits, mbits, device): x = torch.randn(256, 64, device=device) - expected = to_scaled_tc_fpx(x, ebits, mbits) - actual = torch.compile(to_scaled_tc_fpx, fullgraph=True)(x, ebits, mbits) + expected = to_scaled_tc_floatx(x, ebits, mbits) + actual = torch.compile(to_scaled_tc_floatx, fullgraph=True)(x, ebits, mbits) torch.testing.assert_close(actual, expected) - @parametrize("ebits,mbits", _FPx_DTYPES) + @parametrize("ebits,mbits", _Floatx_DTYPES) @parametrize("device", _DEVICES) - def test_from_tc_fpx_correctness(self, ebits, mbits, device): + def test_from_tc_floatx_correctness(self, ebits, mbits, device): x = torch.randn(256, 64, device=device) * 100 - # quantize and dequantize so that the values are exactly representable in FPx - x = _fpx_unpacked_to_f32(_f32_to_fpx_unpacked(x, ebits, mbits), ebits, mbits) + # quantize and dequantize so that the values are exactly representable in Floatx + x = _floatx_unpacked_to_f32(_f32_to_floatx_unpacked(x, ebits, mbits), ebits, mbits) - tc_fpx, scale = to_scaled_tc_fpx(x, ebits, mbits) - actual = from_scaled_tc_fpx(tc_fpx, ebits, mbits, scale=scale) + tc_floatx, scale = to_scaled_tc_floatx(x, ebits, mbits) + actual = from_scaled_tc_floatx(tc_floatx, ebits, mbits, scale=scale) torch.testing.assert_close(actual, x) - @parametrize("ebits,mbits", _FPx_DTYPES) + @parametrize("ebits,mbits", _Floatx_DTYPES) @parametrize("device", _DEVICES) - def test_from_scaled_tc_fpx_compile(self, ebits, mbits, device): + def test_from_scaled_tc_floatx_compile(self, ebits, mbits, device): M, N = 256, 64 nbits = 1 + ebits + mbits x = torch.randint(256, size=(M, N // 8 * nbits), dtype=torch.uint8, device=device) scale = torch.randn(M, device=device) - expected = from_scaled_tc_fpx(x, ebits, mbits, scale) - actual = torch.compile(from_scaled_tc_fpx, fullgraph=True)(x, ebits, mbits, scale) + expected = from_scaled_tc_floatx(x, ebits, mbits, scale) + actual = torch.compile(from_scaled_tc_floatx, fullgraph=True)(x, ebits, mbits, scale) torch.testing.assert_close(actual, expected) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - @parametrize("ebits,mbits", _FPx_DTYPES) + @parametrize("ebits,mbits", _Floatx_DTYPES) def test_to_copy_device(self, ebits, mbits): from torchao.quantization.quant_primitives import ( - choose_qparams_affine_fpx, - quantize_affine_fpx, + choose_qparams_affine_floatx, + quantize_affine_floatx, ) x = torch.randn(256, 64) - scale = choose_qparams_affine_fpx(x, ebits, mbits) - x = quantize_affine_fpx(x, scale, ebits, mbits) - layout_type = FpxTensorCoreLayoutType(ebits, mbits) - fpx_layout_tensor = FpxTensorCoreAQTLayout.from_plain(x, scale, None, layout_type).cuda() - assert fpx_layout_tensor.device.type == "cuda" - fpx_layout_tensor = fpx_layout_tensor.cpu() - assert fpx_layout_tensor.device.type == "cpu" + scale = choose_qparams_affine_floatx(x, ebits, mbits) + x = quantize_affine_floatx(x, scale, ebits, mbits) + layout_type = FloatxTensorCoreLayoutType(ebits, mbits) + floatx_layout_tensor = FloatxTensorCoreAQTLayout.from_plain(x, scale, None, layout_type).cuda() + assert floatx_layout_tensor.device.type == "cuda" + floatx_layout_tensor = floatx_layout_tensor.cpu() + assert floatx_layout_tensor.device.type == "cpu" @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="quantization only works with torch.compile for 2.5+") - @parametrize("ebits,mbits", _FPx_DTYPES) + @parametrize("ebits,mbits", _Floatx_DTYPES) @parametrize("bias", [False, True]) def test_fpx_weight_only(self, ebits, mbits, bias): N, OC, IC = 4, 256, 64 @@ -106,7 +106,7 @@ def test_fpx_weight_only(self, ebits, mbits, bias): torch.testing.assert_close(actual, expected) -instantiate_parametrized_tests(TestFpxTensorCoreAQTLayout) +instantiate_parametrized_tests(TestFloatxTensorCoreAQTLayout) if __name__ == "__main__": diff --git a/test/dtypes/test_uintx.py b/test/dtypes/test_uintx.py index 97cdc6f7d3..2d689a0c09 100644 --- a/test/dtypes/test_uintx.py +++ b/test/dtypes/test_uintx.py @@ -4,7 +4,7 @@ import torch -from torchao.dtypes.uintx.Uintx import to_uintx +from torchao.dtypes.uintx.uintx import to_uintx from torchao.quantization.quant_api import quantize_, uintx_weight_only from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_3, diff --git a/test/test_ops.py b/test/test_ops.py index e62766756c..31000eafc2 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -11,7 +11,7 @@ ) from torch.testing._internal.optests import opcheck from torchao.utils import is_fbcode, TORCH_VERSION_AT_LEAST_2_5, compute_max_diff -from torchao.dtypes.fpx import from_scaled_tc_fpx +from torchao.dtypes.floatx import from_scaled_tc_floatx from torchao.sparsity.marlin import marlin_24_workspace, pack_to_marlin_24, inject_24 import pytest @@ -33,13 +33,13 @@ class TestOps(TestCase): - def _create_fpx_inputs(self, ebits: int, mbits: int, BS: int, OC: int, IC: int, device): + def _create_floatx_inputs(self, ebits: int, mbits: int, BS: int, OC: int, IC: int, device): # Randomly initialize each byte nbits = 1 + ebits + mbits - fpx_weight = torch.randint(256, (OC, IC // 8 * nbits), dtype=torch.uint8) + floatx_weight = torch.randint(256, (OC, IC // 8 * nbits), dtype=torch.uint8) scale = torch.rand(OC).half() + 0.5 fp16_act = torch.rand(BS, IC).half() + 0.5 - return fpx_weight.to(device), scale.to(device), fp16_act.to(device) + return floatx_weight.to(device), scale.to(device), fp16_act.to(device) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @parametrize("ebits,mbits", [(3, 2), (2, 2)]) @@ -48,28 +48,28 @@ def test_quant_llm_linear(self, ebits, mbits): OC = 256 IC = 256 splitK = 1 - fpx_weight, scale, fp16_act = self._create_fpx_inputs(ebits, mbits, BS, OC, IC, "cuda") + floatx_weight, scale, fp16_act = self._create_floatx_inputs(ebits, mbits, BS, OC, IC, "cuda") # smoke test - torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, fpx_weight, scale, splitK) + torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, floatx_weight, scale, splitK) # comprehensive testing test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] - opcheck(torch.ops.torchao.quant_llm_linear, (ebits, mbits, fp16_act, fpx_weight, scale, splitK), test_utils=test_utils) + opcheck(torch.ops.torchao.quant_llm_linear, (ebits, mbits, fp16_act, floatx_weight, scale, splitK), test_utils=test_utils) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @parametrize("BS,OC,IC,splitK", [(1, 2048, 4096, 5), (2, 8192, 8192, 6)]) @parametrize("ebits,mbits", [(3, 2), (2, 2)]) def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK): # adapted from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/tests/python/kernel_test_fpx.py - fpx_weight, scale, fp16_act = self._create_fpx_inputs(ebits, mbits, BS, OC, IC, "cuda") + floatx_weight, scale, fp16_act = self._create_floatx_inputs(ebits, mbits, BS, OC, IC, "cuda") - results_fpx = torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, fpx_weight, scale, splitK) + results_floatx = torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, floatx_weight, scale, splitK) - fp16_weight = from_scaled_tc_fpx(fpx_weight, ebits, mbits, scale).half() + fp16_weight = from_scaled_tc_floatx(floatx_weight, ebits, mbits, scale).half() results_fp16 = fp16_act @ fp16_weight.T - error = (results_fpx - results_fp16).abs().mean() + error = (results_floatx - results_fp16).abs().mean() gt = results_fp16.abs().mean() relative_error = error / gt assert relative_error < 1e-3 @@ -319,7 +319,7 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] MARLIN_TEST_PARAMS = list(itertools.product( - MARLIN_24_BATCH_SIZE, MARLIN_24_K_CHUNKS, MARLIN_24_N_CHUNKS, + MARLIN_24_BATCH_SIZE, MARLIN_24_K_CHUNKS, MARLIN_24_N_CHUNKS, MARLIN_24_SUPPORTED_NUM_BITS, MARLIN_24_SUPPORTED_GROUP_SIZES, MNK_FACTORS )) @@ -405,7 +405,7 @@ def test_marlin_24(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_facto workspace_24 = marlin_24_workspace(size_n) fn_inputs = ( - input_2d, marlin_24_q_w_comp, meta, marlin_24_scale, workspace_24, + input_2d, marlin_24_q_w_comp, meta, marlin_24_scale, workspace_24, num_bits, a_input_in, marlin_24_scale.shape[1], a_input_out, ) output = torchao.ops.marlin_24_gemm(*fn_inputs) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 418e75d039..142a49a368 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -12,9 +12,9 @@ int_scaled_matmul, choose_qparams_and_quantize_affine_hqq, FP8_TYPES, - choose_qparams_affine_fpx, - quantize_affine_fpx, - dequantize_affine_fpx, + choose_qparams_affine_floatx, + quantize_affine_floatx, + dequantize_affine_floatx, ) from torchao.quantization.utils import ( pack_tinygemm_scales_and_zeros, @@ -199,10 +199,10 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor if output_dtype is None: output_dtype = self.dtype - from torchao.dtypes.fpx import FpxTensorCoreLayoutType - if isinstance(self.layout_type, FpxTensorCoreLayoutType): + from torchao.dtypes.floatx import FloatxTensorCoreLayoutType + if isinstance(self.layout_type, FloatxTensorCoreLayoutType): int_data, scale = self.layout_tensor.get_plain() - return dequantize_affine_fpx(int_data, scale, self.layout_type.ebits, self.layout_type.mbits, output_dtype=output_dtype) + return dequantize_affine_floatx(int_data, scale, self.layout_type.ebits, self.layout_type.mbits, output_dtype=output_dtype) else: data, scale, zero_point = self.layout_tensor.get_plain() return dequantize_affine( @@ -389,8 +389,8 @@ def from_hp_to_fpx( input_float: torch.Tensor, layout_type: LayoutType, ): - from torchao.dtypes.fpx import FpxTensorCoreLayoutType - assert isinstance(layout_type, FpxTensorCoreLayoutType), f"Only FpxTensorCoreLayoutType is supported for fpx, got {layout_type}" + from torchao.dtypes.floatx import FloatxTensorCoreLayoutType + assert isinstance(layout_type, FloatxTensorCoreLayoutType), f"Only FloatxTensorCoreLayoutType is supported for floatx, got {layout_type}" original_shape = input_float.shape input_float = layout_type.pre_process(input_float) # per axis quantization, where axis = 1 @@ -399,12 +399,12 @@ def from_hp_to_fpx( ebits, mbits = layout_type.ebits, layout_type.mbits # Note: these ops are hardcoded to have per axis quantization (axis=1) right now - scale = choose_qparams_affine_fpx(input_float, ebits, mbits) - fpx_unpacked = quantize_affine_fpx(input_float, scale, ebits, mbits) - fpx_packed = layout_type.post_process(fpx_unpacked) + scale = choose_qparams_affine_floatx(input_float, ebits, mbits) + floatx_unpacked = quantize_affine_floatx(input_float, scale, ebits, mbits) + floatx_packed = layout_type.post_process(floatx_unpacked) layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type)) - layout_tensor = layout_tensor_ctr(fpx_packed, scale, None, layout_type) + layout_tensor = layout_tensor_ctr(floatx_packed, scale, None, layout_type) return cls( layout_tensor, block_size, @@ -502,7 +502,7 @@ class MarlinSparseLayoutType(LayoutType): def pre_process(self, input: torch.Tensor) -> torch.Tensor: """Preprocess the input tensor to be in the correct format for the Marlin sparse kernel. - 1º: the input tensor is transposed since the linear layer keeps the weights in a transposed format - - 2º: tensor is injected with 2:4 sparsity + - 2º: tensor is injected with 2:4 sparsity - 3º: transposes it again because the quantization process will compute the scales for dim=-1 Args: @@ -673,8 +673,8 @@ def from_plain( @register_layout_cls(MarlinSparseLayoutType) class MarlinSparseAQTLayout(AQTLayout): """ - Layout storage class for sparse_marlin_24 layout for affine quantized tensor. - + Layout storage class for sparse_marlin_24 layout for affine quantized tensor. + Can be used with 4 bits and 8 bits quantization. Original marlin documentation and information: @@ -760,9 +760,9 @@ def __tensor_unflatten__( def get_plain(self): from torchao.sparsity.marlin import unpack_from_marlin_24 # avoid circular import int_data_expanded, scales_expanded = unpack_from_marlin_24( - self.int_data, - self.scale, - self.meta, + self.int_data, + self.scale, + self.meta, self.original_shape, self.group_size, self.num_bits, @@ -794,7 +794,7 @@ def from_plain( if q_w_24.dtype != torch.int32: raise ValueError("Only `torch.int32` weights are supported.") - + in_features, out_features = q_w_24.shape if in_features % 128 != 0 or out_features != 256 == 0: raise ValueError( @@ -824,11 +824,11 @@ def from_plain( marlin_24_q_w_comp, marlin_24_s, meta = pack_to_marlin_24(q_w_24, scale_t, num_bits, group_size) return cls( - marlin_24_q_w_comp, marlin_24_s, zero_point, + marlin_24_q_w_comp, marlin_24_s, zero_point, meta, layout_type, q_w_24.shape, group_size, num_bits ) - + def get_layout_type(self) -> LayoutType: return self.layout_type @@ -956,7 +956,7 @@ def __repr__(self): f"scale={scale},\n" f"transposed={self.transposed}, " f"layout_type={layout_type})") - + @register_layout_cls(TensorCoreTiledLayoutType) class TensorCoreTiledAQTLayout(AQTLayout): @@ -1308,16 +1308,16 @@ def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): y += bias.to(m.dtype) return y -def _linear_f16_act_fpx_weight_check(input_tensor, weight_tensor, bias): - from torchao.dtypes.fpx import FpxTensorCoreLayoutType +def _linear_f16_act_floatx_weight_check(input_tensor, weight_tensor, bias): + from torchao.dtypes.floatx import FloatxTensorCoreLayoutType return ( # input is native float32 tensor not is_traceable_wrapper_subclass(input_tensor) and input_tensor.is_floating_point() and input_tensor.dtype == torch.float16 and - # weight is fpx Tensor + # weight is floatx Tensor isinstance(weight_tensor, AffineQuantizedTensor) and - isinstance(weight_tensor.layout_type, FpxTensorCoreLayoutType) and + isinstance(weight_tensor.layout_type, FloatxTensorCoreLayoutType) and ( # weight is using fp6 quantization (weight_tensor.layout_type.ebits == 3 and @@ -1332,8 +1332,8 @@ def _linear_f16_act_fpx_weight_check(input_tensor, weight_tensor, bias): ) ) -def _linear_f16_act_fpx_weight_impl(input_tensor, weight_tensor, bias): - from torchao.dtypes.fpx import _SPLIT_K_MAP +def _linear_f16_act_floatx_weight_impl(input_tensor, weight_tensor, bias): + from torchao.dtypes.floatx import _SPLIT_K_MAP from torchao.ops import quant_llm_linear act = input_tensor @@ -1350,7 +1350,7 @@ def _linear_f16_act_fpx_weight_impl(input_tensor, weight_tensor, bias): weight.layout_type.ebits, weight.layout_type.mbits, act_reshaped, - weight.layout_tensor.packed_fpx_data, + weight.layout_tensor.packed_floatx_data, weight.layout_tensor.scale, splitK=splitK, ) @@ -1378,10 +1378,10 @@ def check_aqt(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool: def preprocess_scale(input_scale: torch.Tensor, input_shape: Tuple[int]): """ Ensures input tensor is correctly formated for _scaled_mm """ input_scale = input_scale.unsqueeze(-1) - + if input_scale.dim() > 2: input_scale = input_scale.reshape(-1, input_scale.shape[-1]) - + return input_scale def _linear_fp_act_fp8_weight_impl( @@ -1457,7 +1457,7 @@ def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, b workspace_24 = marlin_24_workspace(original_shape[1]) out = marlin_24_gemm( - input_2d, sparse_w_int4, meta, scale, + input_2d, sparse_w_int4, meta, scale, workspace_24, num_bits, size_m, size_n, size_k ) @@ -1476,7 +1476,7 @@ def _register_aqt_quantized_linear_dispatches(): (_linear_fp_act_fp8_weight_check, _linear_fp_act_fp8_weight_impl), (_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl), (_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl), - (_linear_f16_act_fpx_weight_check, _linear_f16_act_fpx_weight_impl), + (_linear_f16_act_floatx_weight_check, _linear_f16_act_floatx_weight_impl), (_linear_fp_act_int4_weight_sparse_marlin_check, _linear_fp_act_int4_weight_sparse_marlin_impl), ]: register_aqt_quantized_linear_dispatch(dispatch_condition, impl) diff --git a/torchao/dtypes/fpx/README.md b/torchao/dtypes/floatx/README.md similarity index 81% rename from torchao/dtypes/fpx/README.md rename to torchao/dtypes/floatx/README.md index 1de60b0cc1..f4cbf51a03 100644 --- a/torchao/dtypes/fpx/README.md +++ b/torchao/dtypes/floatx/README.md @@ -1,6 +1,6 @@ # Quant-LLM -This is a FP16 x FPx mixed matmul kernel optimized for io bound workloads per [FP6-LLM](https://arxiv.org/abs/2401.14112). The actual CUDA kernel is located under [csrc/cuda/fp6_llm/](../../csrc/cuda/fp6_llm/). This module provides helper functions to quantize FP32/FP16/BF16 weights to FPx and integration with torchao API. +This is a FP16 x Floatx mixed matmul kernel optimized for io bound workloads per [FP6-LLM](https://arxiv.org/abs/2401.14112). The actual CUDA kernel is located under [csrc/cuda/fp6_llm/](../../csrc/cuda/fp6_llm/). This module provides helper functions to quantize FP32/FP16/BF16 weights to Floatx and integration with torchao API. ## Usage @@ -13,7 +13,7 @@ from torchao.quantization import ( model = ... model.half() # not necessary, but recommeneded to maintain accuracy -# for generic FPx EyMz where x = 1 + y + z +# for generic Floatx EyMz where x = 1 + y + z # fp6 with ebits = 3 and mbits = 2 quantize_(model, fpx_weight_only(3, 2)) @@ -25,7 +25,7 @@ It's also possible to pre-process the weight and call the kernel directly. ```python import torch -from torchao.dtypes.fpx import to_scaled_tc_fpx +from torchao.dtypes.floatx import to_scaled_tc_floatx from torchao.ops import quant_llm_linear fp32_weight = torch.randn(1024, 512).cuda() @@ -33,7 +33,7 @@ ebits, mbits = 3, 2 # pre-process the weight. this will quantize the weight to FP6 and pack it in a special # layout for tensor cores. refer to paper for more details. -fp6_weight, scales = to_scaled_tc_fpx(fp32_weight, ebits, mbits) +fp6_weight, scales = to_scaled_tc_floatx(fp32_weight, ebits, mbits) fp16_act = torch.randn(1, 512).cuda().half() outputs = quant_llm_linear(ebits, mbits, fp16_act, fp6_weight, scales) # shape (1, 1024) @@ -48,7 +48,7 @@ outputs = quant_llm_linear(ebits, mbits, fp16_act, fp6_weight, scales) # shape Benchmarks are run on a machine with a single 4070Ti SUPER GPU using the scripts in [_models/llama](../../_models/llama). tokens/s is measured using [generate.py](../../_models/llama/generate.py) which generates text in a latency optimized way (batchsize=1). wikitext perplexity is measured using [eval.py](../../_models/llama/eval.py) which uses [lm_eval](https://github.com/EleutherAI/lm-evaluation-harness). The model used is [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf). -FPx quantization is run with `--precision float16`. The rest uses the default precision of `bfloat16`. +Floatx quantization is run with `--precision float16`. The rest uses the default precision of `bfloat16`. Quantization | wikitext perplexity | tokens/s --------------------|---------------------|---------- diff --git a/torchao/dtypes/floatx/__init__.py b/torchao/dtypes/floatx/__init__.py new file mode 100644 index 0000000000..0eb1e70529 --- /dev/null +++ b/torchao/dtypes/floatx/__init__.py @@ -0,0 +1 @@ +from .floatx import FloatxTensorCoreLayoutType, FloatxTensorCoreAQTLayout, to_scaled_tc_floatx, from_scaled_tc_floatx, _SPLIT_K_MAP diff --git a/torchao/dtypes/fpx/fpx.py b/torchao/dtypes/floatx/floatx.py similarity index 76% rename from torchao/dtypes/fpx/fpx.py rename to torchao/dtypes/floatx/floatx.py index 6afa22f560..dcbfd5f69c 100644 --- a/torchao/dtypes/fpx/fpx.py +++ b/torchao/dtypes/floatx/floatx.py @@ -4,7 +4,7 @@ import torch from torch import Tensor from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32, _n_ones +from torchao.prototype.custom_fp_utils import _f32_to_floatx_unpacked, _floatx_unpacked_to_f32, _n_ones from torchao.dtypes.utils import ( LayoutType, ) @@ -62,7 +62,7 @@ def _bit_interleave(x: Tensor, n_bits: int, undo: bool = False) -> Tensor: # this is a literal adaptation of FP6-LLM ahead-of-time bit-level pre-packing # https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/utils/weight_prepacking.h -def _pack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor: +def _pack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor: assert tensor.ndim == 2, tensor.dtype == torch.uint8 M, N = tensor.shape assert (M % 64 == 0) and (N % 64 == 0) @@ -91,7 +91,7 @@ def _pack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor: return torch.cat(fragments, dim=0).view(M, -1) -# more optimized version of _pack_tc_fpx() for FP6 by merging ops +# more optimized version of _pack_tc_floatx() for FP6 by merging ops def _pack_tc_fp6(tensor: Tensor) -> Tensor: assert tensor.ndim == 2, tensor.dtype == torch.uint8 M, N = tensor.shape @@ -112,13 +112,13 @@ def _pack_tc_fp6(tensor: Tensor) -> Tensor: # currently only optimize for TC-FP6 packing -def pack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor: +def pack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor: if nbits == 6: return _pack_tc_fp6(tensor) - return _pack_tc_fpx(tensor, nbits) + return _pack_tc_floatx(tensor, nbits) -def to_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int) -> Tuple[Tensor, Tensor]: +def to_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int) -> Tuple[Tensor, Tensor]: # _n_ones() is not compatible with torch.compile() due to << operator # https://github.com/pytorch/pytorch/issues/119152 # exp_bias = _n_ones(ebits - 1) @@ -130,13 +130,13 @@ def to_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int) -> Tuple[Tensor, Te tensor = tensor.float() scale = tensor.abs().amax(1).clamp(min=1e-12) / max_normal - tensor_fpx = _f32_to_fpx_unpacked(tensor / scale.view(-1, 1), ebits, mbits) - tensor_tc_fpx = pack_tc_fpx(tensor_fpx, 1 + ebits + mbits) - return tensor_tc_fpx, scale.half() + tensor_floatx = _f32_to_floatx_unpacked(tensor / scale.view(-1, 1), ebits, mbits) + tensor_tc_floatx = pack_tc_floatx(tensor_floatx, 1 + ebits + mbits) + return tensor_tc_floatx, scale.half() -# inverse of _pack_tc_fpx() -def _unpack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor: +# inverse of _pack_tc_floatx() +def _unpack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor: assert tensor.ndim == 2 and tensor.dtype == torch.uint8 M = tensor.shape[0] size = tensor.numel() @@ -144,7 +144,7 @@ def _unpack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor: offset = 0 used_bits = 0 - tensor_fpx = None + tensor_floatx = None for y in [1, 2, 4]: if nbits & y: @@ -159,20 +159,20 @@ def _unpack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor: tensor_ybit = tensor_ybit << (nbits - used_bits - y) used_bits += y - if tensor_fpx is None: - tensor_fpx = tensor_ybit + if tensor_floatx is None: + tensor_floatx = tensor_ybit else: - tensor_fpx |= tensor_ybit + tensor_floatx |= tensor_ybit # undo Pass 1 - tensor_fpx = tensor_fpx.view(32, -1, 2).permute(1, 0, 2) - tensor_fpx = tensor_fpx.reshape(M // 64, -1, 4, 2, 2, 8, 8) - tensor_fpx = tensor_fpx.permute(0, 2, 4, 5, 1, 3, 6) - tensor_fpx = tensor_fpx.reshape(M, -1) - return tensor_fpx + tensor_floatx = tensor_floatx.view(32, -1, 2).permute(1, 0, 2) + tensor_floatx = tensor_floatx.reshape(M // 64, -1, 4, 2, 2, 8, 8) + tensor_floatx = tensor_floatx.permute(0, 2, 4, 5, 1, 3, 6) + tensor_floatx = tensor_floatx.reshape(M, -1) + return tensor_floatx -# more optimized version of _unpack_tc_fpx() for FP6 by merging ops +# more optimized version of _unpack_tc_floatx() for FP6 by merging ops # inverse of _unpack_tc_fp6() def _unpack_tc_fp6(tensor: Tensor) -> Tensor: assert tensor.ndim == 2 and tensor.dtype == torch.uint8 @@ -199,15 +199,15 @@ def _unpack_tc_fp6(tensor: Tensor) -> Tensor: return tensor_fp6 -def unpack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor: +def unpack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor: if nbits == 6: return _unpack_tc_fp6(tensor) - return _unpack_tc_fpx(tensor, nbits) + return _unpack_tc_floatx(tensor, nbits) -def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Tensor: - fpx_unpacked = unpack_tc_fpx(tensor, 1 + ebits + mbits) - tensor = _fpx_unpacked_to_f32(fpx_unpacked, ebits, mbits) +def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Tensor: + floatx_unpacked = unpack_tc_floatx(tensor, 1 + ebits + mbits) + tensor = _floatx_unpacked_to_f32(floatx_unpacked, ebits, mbits) if scale is not None: tensor = tensor * scale.float().view(-1, 1) return tensor @@ -353,17 +353,17 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te # quantization api integrations @dataclass(frozen=True) -class FpxTensorCoreLayoutType(LayoutType): - """Layout type for FpxTensorCoreAQTLayout +class FloatxTensorCoreLayoutType(LayoutType): + """Layout type for FloatxTensorCoreAQTLayout """ ebits: int mbits: int -@register_layout_cls(FpxTensorCoreLayoutType) -class FpxTensorCoreAQTLayout(AQTLayout): - """FpxTensorCoreAQTLayout represents a Tensor with dtype fpx(ebits=a, mbits=b), - it has a internal tensor field of "packed_fpx_data", which is packed from the - uint8 unpacked data (the output of `quantize_affine_fpx` operator) +@register_layout_cls(FloatxTensorCoreLayoutType) +class FloatxTensorCoreAQTLayout(AQTLayout): + """FloatxTensorCoreAQTLayout represents a Tensor with dtype floatx(ebits=a, mbits=b), + it has a internal tensor field of "packed_floatx_data", which is packed from the + uint8 unpacked data (the output of `quantize_affine_floatx` operator) The packing is optimized for TensorCore, from the fp6-llm paper: https://arxiv.org/abs/2401.14112 github repo: https://github.com/usyd-fsalab/fp6_llm, now renamed to quant-llm @@ -377,81 +377,81 @@ class FpxTensorCoreAQTLayout(AQTLayout): If original Tensor shape is (M, N), and the data is in nbit, the shape of the packed data will be (M, N // 8 * nbit) - FpxTensorCoreAQTLayout.from_plain takes an unpacked uint8 fpx Tensor of shape (M, N), with format of + FloatxTensorCoreAQTLayout.from_plain takes an unpacked uint8 floatx Tensor of shape (M, N), with format of (zero padding bits + sign bit + exponent bits + mantissa bits), e.g. 00SEEEMM for fp6_e3_m2 - it will then pack the weight and instantiate the FpxTensorCoreAQTLayout tensor - FpxTensorCoreAQTLayout.__init__() takes a packed fpx Tensor of shape (M, N // 8 * nbit) + it will then pack the weight and instantiate the FloatxTensorCoreAQTLayout tensor + FloatxTensorCoreAQTLayout.__init__() takes a packed floatx Tensor of shape (M, N // 8 * nbit) """ def __new__( cls, - packed_fpx_data: torch.Tensor, + packed_floatx_data: torch.Tensor, scale: torch.Tensor, layout_type: LayoutType, ): - assert packed_fpx_data.ndim == 2 - assert packed_fpx_data.dtype == torch.uint8 - shape = (packed_fpx_data.shape[0], packed_fpx_data.shape[1] // (1 + layout_type.ebits + layout_type.mbits) * 8) + assert packed_floatx_data.ndim == 2 + assert packed_floatx_data.dtype == torch.uint8 + shape = (packed_floatx_data.shape[0], packed_floatx_data.shape[1] // (1 + layout_type.ebits + layout_type.mbits) * 8) kwargs = {} - kwargs["device"] = packed_fpx_data.device + kwargs["device"] = packed_floatx_data.device kwargs["layout"] = ( - kwargs.get("layout") if kwargs.get("layout", False) else packed_fpx_data.layout + kwargs.get("layout") if kwargs.get("layout", False) else packed_floatx_data.layout ) - kwargs["dtype"] = packed_fpx_data.dtype + kwargs["dtype"] = packed_floatx_data.dtype kwargs["requires_grad"] = False return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] def __init__( self, - packed_fpx_data: torch.Tensor, + packed_floatx_data: torch.Tensor, scale: torch.Tensor, layout_type: LayoutType, ): - self.packed_fpx_data = packed_fpx_data + self.packed_floatx_data = packed_floatx_data self.scale = scale self.layout_type = layout_type def __tensor_flatten__(self): - return ["packed_fpx_data", "scale"], [self.layout_type] + return ["packed_floatx_data", "scale"], [self.layout_type] @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): - packed_fpx_data, scale = tensor_data_dict["packed_fpx_data"], tensor_data_dict["scale"] + packed_floatx_data, scale = tensor_data_dict["packed_floatx_data"], tensor_data_dict["scale"] layout_type, = tensor_attributes - return cls(packed_fpx_data, scale, layout_type) + return cls(packed_floatx_data, scale, layout_type) def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor]: - unpacked_fpx_data = unpack_tc_fpx(self.packed_fpx_data, 1 + self.layout_type.ebits + self.layout_type.mbits) - return unpacked_fpx_data, self.scale + unpacked_floatx_data = unpack_tc_floatx(self.packed_floatx_data, 1 + self.layout_type.ebits + self.layout_type.mbits) + return unpacked_floatx_data, self.scale @classmethod def from_plain( cls, - unpacked_fpx_data: torch.Tensor, + unpacked_floatx_data: torch.Tensor, scale: torch.Tensor, zero_point: Optional[torch.Tensor], layout_type: LayoutType, ): """ - Format for `unpacked_fpx_data` will be: + Format for `unpacked_floatx_data` will be: zero padding bits | sign bit | exponent bits | mantissa bits For example for fp6_e3_m2, the format will be: `00SEEEMM`, where S is sign bit, E is exponent bit, M is mantissa bit """ - assert isinstance(layout_type, FpxTensorCoreLayoutType) - packed_fpx_data = pack_tc_fpx(unpacked_fpx_data, 1 + layout_type.ebits + layout_type.mbits) - return cls(packed_fpx_data, scale, layout_type) + assert isinstance(layout_type, FloatxTensorCoreLayoutType) + packed_floatx_data = pack_tc_floatx(unpacked_floatx_data, 1 + layout_type.ebits + layout_type.mbits) + return cls(packed_floatx_data, scale, layout_type) def __repr__(self): - unpacked_fpx_data, scale = self.get_plain() + unpacked_floatx_data, scale = self.get_plain() layout_type = self.get_layout_type() - return f"{self.__class__.__name__}(unpacked_fpx_data={unpacked_fpx_data}, scale={scale}, layout_type={layout_type})" + return f"{self.__class__.__name__}(unpacked_floatx_data={unpacked_floatx_data}, scale={scale}, layout_type={layout_type})" def _apply_fn_to_data(self, fn): return self.__class__( - fn(self.packed_fpx_data), + fn(self.packed_floatx_data), fn(self.scale), self.layout_type, ) @@ -460,7 +460,7 @@ def to(self, *args, **kwargs): kwargs = self._get_to_kwargs(*args, **kwargs) device = kwargs.pop("device") return self.__class__( - self.packed_fpx_data.to(device), + self.packed_floatx_data.to(device), self.scale.to(device), self.layout_type, ) @@ -483,7 +483,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) raise NotImplementedError( - f"FpxTensorCoreAQTLayout dispatch: attempting to run {func}, this is not supported" + f"FloatxTensorCoreAQTLayout dispatch: attempting to run {func}, this is not supported" ) __torch_function__ = torch._C._disabled_torch_function_impl diff --git a/torchao/dtypes/fpx/__init__.py b/torchao/dtypes/fpx/__init__.py deleted file mode 100644 index af77685fac..0000000000 --- a/torchao/dtypes/fpx/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .fpx import FpxTensorCoreLayoutType, FpxTensorCoreAQTLayout, to_scaled_tc_fpx, from_scaled_tc_fpx, _SPLIT_K_MAP diff --git a/torchao/dtypes/uintx/Uintx.py b/torchao/dtypes/uintx/uintx.py similarity index 100% rename from torchao/dtypes/uintx/Uintx.py rename to torchao/dtypes/uintx/uintx.py diff --git a/torchao/ops.py b/torchao/ops.py index 5bb8271638..7f7adab864 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -28,7 +28,7 @@ def quant_llm_linear( EXPONENT: number of exponent bits MANTISSA: number of mantissa bits _in_feats: input activations in FP16 - _weights: packed FPx weights + _weights: packed Floatx weights _scales: scale splitK: split K @@ -74,7 +74,7 @@ def unpack_tensor_core_tiled_layout(packed_w: Tensor, inner_k_tiles: int) -> Ten return torch.ops.torchao.unpack_tensor_core_tiled_layout.default( packed_w=packed_w, inner_k_tiles=inner_k_tiles ) - + @register_custom_op(f"torchao::unpack_tensor_core_tiled_layout") def _(packed_w: Tensor, inner_k_tiles: int) -> Tensor: @@ -111,7 +111,7 @@ def dequantize_tensor_core_tiled_layout(packed_w: Tensor, scales_and_zeros: Tens - packed weights were generated with `torch.ops.aten._convert_weight_to_int4pack` with `inner_k_tiles = 2 | 4 | 8`" - packed scales_and_zeros were generated with `torchao.quantization.utils.pack_tinygemm_scales_and_zeros` - qGroupSize is 32 | 64 | 128 | 256 - + Args: packed_w: torch.tensor: 4D tensor with shape `(N / 8) x (K / (inner_k_tiles * 16)) x 32 x inner_k_tiles / 2`, dtype is torch.int32 scales_and_zeros: torch.tensor: 3D tensor with shape `numQGroups x N x 2`, dtype is torch.bfloat16 where numQGroups is K / qGroupSize @@ -125,7 +125,7 @@ def dequantize_tensor_core_tiled_layout(packed_w: Tensor, scales_and_zeros: Tens return torch.ops.torchao.dequantize_tensor_core_tiled_layout.default( packed_w, scales_and_zeros, group_size, inner_k_tiles ) - + @register_custom_op(f"torchao::dequantize_tensor_core_tiled_layout") def _(packed_w: Tensor, scales_and_zeros: Tensor, group_size: int, inner_k_tiles: int) -> Tensor: @@ -157,7 +157,7 @@ def _(packed_w: Tensor, scales_and_zeros: Tensor, group_size: int, inner_k_tiles torch._check(scales_and_zeros.size(0) == K // group_size, lambda: "scales_and_zeros must have K // qGroupSize at dim 0") torch._check(scales_and_zeros.size(1) == N, lambda: "scales_and_zeros must have N at dim 1") torch._check(scales_and_zeros.size(2) == 2, lambda: "scales_and_zeros must have 2 at dim 2") - + return torch.empty((N, K), dtype=torch.bfloat16, device=packed_w.device) diff --git a/torchao/prototype/README.md b/torchao/prototype/README.md index 1024d635c0..e70154e796 100644 --- a/torchao/prototype/README.md +++ b/torchao/prototype/README.md @@ -9,7 +9,7 @@ - `galore` - fused kernels for memory-efficient pre-training / fine-tuning per the [GaLore algorithm](https://arxiv.org/abs/2403.03507) - `galore/kernels` - `triton` kernels that fuse various steps of the `GaLore` algorithm - `galore/docs` - implementation notes and discussion of issues faced in kernel design. -- [`quant_llm`](quant_llm) - FP16 x FPx mixed matmul kernel per [FP6-LLM](https://arxiv.org/abs/2401.14112) +- [`quant_llm`](quant_llm) - FP16 x Floatx mixed matmul kernel per [FP6-LLM](https://arxiv.org/abs/2401.14112) - [`low_bit_optim`](low_bit_optim) - re-implementation of 8-bit optimizers from [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) and 4-bit optimizers from [lpmm](https://github.com/thu-ml/low-bit-optimizers). #### Roadmap diff --git a/torchao/prototype/custom_fp_utils.py b/torchao/prototype/custom_fp_utils.py index 3af11f1710..41a9e399ef 100644 --- a/torchao/prototype/custom_fp_utils.py +++ b/torchao/prototype/custom_fp_utils.py @@ -7,9 +7,9 @@ # This script was initially developed for sub-byte MX dtypes (FP4 E2M1, FP6 E3M2, and FP6 E2M3). # It has been refactored to support any sub-byte FP dtypes. However, some behaviors of MX dtypes remain: # 1. No encodings are reserved for special values (+/-inf, NaN). -# 2. When downcasting from FP32 to FPx, +# 2. When downcasting from FP32 to Floatx, # - Rounding mode is round to nearest, ties to even. -# - Values outside the representable range of FPx after rounding are clamped to the maximum FPx +# - Values outside the representable range of Floatx after rounding are clamped to the maximum Floatx # magnitude (sign is preserved). import torch @@ -24,7 +24,7 @@ def _n_ones(n: int) -> int: F32_EXP_BIAS = _n_ones(EBITS_F32 - 1) -def _f32_to_fpx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor: +def _f32_to_floatx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor: """Convert FP32 numbers to sub-byte floating point numbers with the given number of exponent and mantissa bits. @@ -35,8 +35,8 @@ def _f32_to_fpx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor: fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding Note: there are no special values (NaN, inf) support in this code. Values - outside the representable range of FPx after rounding are clamped to the - maximum FPx magnitude (sign is preserved). + outside the representable range of Floatx after rounding are clamped to the + maximum Floatx magnitude (sign is preserved). Code below is an adaptation of https://fburl.com/code/ciwofcg4 @@ -142,7 +142,7 @@ def _f32_to_fpx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor: # TODO(future): check if LUT for everything is faster than bit shifting, # especially for fp4 (only 2^4=16 unique values). -def _fpx_unpacked_to_f32(x: Tensor, ebits: int, mbits: int) -> Tensor: +def _floatx_unpacked_to_f32(x: Tensor, ebits: int, mbits: int) -> Tensor: """Convert sub-byte floating point numbers with the given number of exponent and mantissa bits to FP32. diff --git a/torchao/prototype/mx_formats/custom_cast.py b/torchao/prototype/mx_formats/custom_cast.py index d346b212cc..ede5164824 100644 --- a/torchao/prototype/mx_formats/custom_cast.py +++ b/torchao/prototype/mx_formats/custom_cast.py @@ -10,7 +10,7 @@ from torch.utils._triton import has_triton from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 -from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32 +from torchao.prototype.custom_fp_utils import _f32_to_floatx_unpacked, _floatx_unpacked_to_f32 # TODO(future): if needed, make the below work on previous PyTorch versions, # just need to hunt down the previous location of `libdevice`. An assert @@ -54,7 +54,7 @@ def f32_to_f4_unpacked(x): Output: torch.Tensor of dtype torch.uint8, with bits 0-3 empty and bits 4-7 in fp4_e2m1 """ - return _f32_to_fpx_unpacked(x, EBITS_F4_E2M1, MBITS_F4_E2M1) + return _f32_to_floatx_unpacked(x, EBITS_F4_E2M1, MBITS_F4_E2M1) def f32_to_f6_e2m3_unpacked(x): @@ -63,7 +63,7 @@ def f32_to_f6_e2m3_unpacked(x): Output: torch.Tensor of dtype torch.uint8, with bits 0-1 empty and bits 2-7 in fp6_e2m3 """ - return _f32_to_fpx_unpacked(x, EBITS_F6_E2M3, MBITS_F6_E2M3) + return _f32_to_floatx_unpacked(x, EBITS_F6_E2M3, MBITS_F6_E2M3) def f32_to_f6_e3m2_unpacked(x): @@ -72,7 +72,7 @@ def f32_to_f6_e3m2_unpacked(x): Output: torch.Tensor of dtype torch.uint8, with bits 0-1 empty and bits 2-7 in fp6_e3m2 """ - return _f32_to_fpx_unpacked(x, EBITS_F6_E3M2, MBITS_F6_E3M2) + return _f32_to_floatx_unpacked(x, EBITS_F6_E3M2, MBITS_F6_E3M2) def f4_unpacked_to_f32(x: torch.Tensor): @@ -81,7 +81,7 @@ def f4_unpacked_to_f32(x: torch.Tensor): containing an fp4_e2m1 encoding Output: torch.Tensor of dtype fp32 with the dequantized value """ - return _fpx_unpacked_to_f32(x, EBITS_F4_E2M1, MBITS_F4_E2M1) + return _floatx_unpacked_to_f32(x, EBITS_F4_E2M1, MBITS_F4_E2M1) def f6_e2m3_unpacked_to_f32(x: torch.Tensor): @@ -90,7 +90,7 @@ def f6_e2m3_unpacked_to_f32(x: torch.Tensor): containing an fp6_e3m2 encoding Output: torch.Tensor of dtype fp32 with the dequantized value """ - return _fpx_unpacked_to_f32(x, EBITS_F6_E2M3, MBITS_F6_E2M3) + return _floatx_unpacked_to_f32(x, EBITS_F6_E2M3, MBITS_F6_E2M3) def f6_e3m2_unpacked_to_f32(x: torch.Tensor): @@ -99,7 +99,7 @@ def f6_e3m2_unpacked_to_f32(x: torch.Tensor): containing an fp6_e3m2 encoding Output: torch.Tensor of dtype fp32 with the dequantized value """ - return _fpx_unpacked_to_f32(x, EBITS_F6_E3M2, MBITS_F6_E3M2) + return _floatx_unpacked_to_f32(x, EBITS_F6_E3M2, MBITS_F6_E3M2) if has_triton(): diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index bab6e3a070..cf5aab2800 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -23,7 +23,7 @@ from typing import Any, Callable, Union, Dict, Optional, Literal, Tuple import types -from torchao.dtypes.uintx.Uintx import UintxLayoutType +from torchao.dtypes.uintx.uintx import UintxLayoutType from torchao.dtypes import ( to_affine_quantized_intx, to_affine_quantized_floatx, @@ -543,7 +543,7 @@ def apply_int4_weight_only_quant(weight): zero_point_domain = ZeroPointDomain.FLOAT # Sparse Marlin only supports symmetric quantization. - # NOTE: If we start having lots of layouts that require different configurations, + # NOTE: If we start having lots of layouts that require different configurations, # we should consider moving this logic somewhere else. if isinstance(layout_type, MarlinSparseLayoutType): mapping_type = MappingType.SYMMETRIC @@ -716,7 +716,7 @@ def float8_dynamic_activation_float8_weight( granularity: The granularity for quantization. Can be either a single granularity (applied to both activations and weights) or a tuple of two granularities (one for activations, one for weights). - If None, defaults to PerTensor for both. Currently both quantizations need to be the same type. And + If None, defaults to PerTensor for both. Currently both quantizations need to be the same type. And only PerTensor and PerRow are supported. mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation. @@ -818,10 +818,10 @@ def fpx_weight_only(ebits: int, mbits: int): """ def apply_quant_llm(weight: torch.Tensor) -> torch.Tensor: - from torchao.dtypes.fpx import FpxTensorCoreLayoutType + from torchao.dtypes.floatx import FloatxTensorCoreLayoutType from torchao.dtypes import to_affine_quantized_fpx - assert weight.dim() == 2, f"fpx only works for 2-d Tensor, got: {weight.dim()}" + assert weight.dim() == 2, f"floatx only works for 2-d Tensor, got: {weight.dim()}" out_dim, in_dim = weight.shape if (in_dim % 64 != 0) or (out_dim % 256 != 0): logger.info( @@ -830,7 +830,7 @@ def apply_quant_llm(weight: torch.Tensor) -> torch.Tensor: "expected in_dim % 64 == 0 and out_dim % 256 == 0") return weight - layout_type = FpxTensorCoreLayoutType(ebits, mbits) + layout_type = FloatxTensorCoreLayoutType(ebits, mbits) return to_affine_quantized_fpx(weight, layout_type) return _get_linear_subclass_inserter(apply_quant_llm) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 3e24ce5cc4..b1561e4cff 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -15,7 +15,7 @@ TORCH_VERSION_AT_LEAST_2_5, ) from torchao.utils import _register_custom_op, _is_float8_type -from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32, _n_ones +from torchao.prototype.custom_fp_utils import _f32_to_floatx_unpacked, _floatx_unpacked_to_f32, _n_ones __all__ = [ @@ -23,11 +23,11 @@ "int_scaled_matmul", "choose_qparams_affine", "choose_qparams_affine_with_min_max", - "choose_qparams_affine_fpx", + "choose_qparams_affine_floatx", "quantize_affine", "dequantize_affine", - "quantize_affine_fpx", - "dequantize_affine_fpx", + "quantize_affine_floatx", + "dequantize_affine_floatx", "fake_quantize_affine", "fake_quantize_affine_cachemask", "choose_qparams_and_quantize_affine_hqq", @@ -946,7 +946,7 @@ def choose_qparams_and_quantize_affine_hqq( return W_q, scale, zero, shape -def choose_qparams_affine_fpx(tensor: torch.Tensor, ebits: int, mbits: int) -> torch.Tensor: +def choose_qparams_affine_floatx(tensor: torch.Tensor, ebits: int, mbits: int) -> torch.Tensor: # _n_ones() is not compatible with torch.compile() due to << operator # https://github.com/pytorch/pytorch/issues/119152 # exp_bias = _n_ones(ebits - 1) @@ -960,16 +960,16 @@ def choose_qparams_affine_fpx(tensor: torch.Tensor, ebits: int, mbits: int) -> t scale = tensor.abs().amax(1).clamp(min=1e-12) / max_normal return scale.half() -def quantize_affine_fpx(tensor: torch.Tensor, scale: torch.Tensor, ebits: int, mbits: int) -> torch.Tensor: +def quantize_affine_floatx(tensor: torch.Tensor, scale: torch.Tensor, ebits: int, mbits: int) -> torch.Tensor: """Quantizes the float32 high precision floating point tensor to low precision floating point number and converts the result to unpacked floating point format with the format of 00SEEEMM (for fp6_e3m2) where S means sign bit, e means exponent bit and m means mantissa bit """ tensor = tensor.float() - tensor_fpx = _f32_to_fpx_unpacked(tensor / scale.view(-1, 1), ebits, mbits) - return tensor_fpx + tensor_floatx = _f32_to_floatx_unpacked(tensor / scale.view(-1, 1), ebits, mbits) + return tensor_floatx -def dequantize_affine_fpx(tensor: torch.Tensor, scale: torch.Tensor, ebits: int, mbits: int, output_dtype: torch.dtype = torch.float32) -> torch.Tensor: - tensor = _fpx_unpacked_to_f32(tensor, ebits, mbits) +def dequantize_affine_floatx(tensor: torch.Tensor, scale: torch.Tensor, ebits: int, mbits: int, output_dtype: torch.dtype = torch.float32) -> torch.Tensor: + tensor = _floatx_unpacked_to_f32(tensor, ebits, mbits) tensor = tensor * scale.float().view(-1, 1) tensor = tensor.to(dtype=output_dtype) return tensor