From 1d350d6611bcd0434abe80f697aff9b540555f1f Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 7 Jan 2025 23:18:56 +0800 Subject: [PATCH 1/5] add w4a4 --- torchao/csrc/cuda/int4_cutlass.cu | 231 ++++++++++++++++++++++++++++++ torchao/ops.py | 53 +++++++ 2 files changed, 284 insertions(+) create mode 100644 torchao/csrc/cuda/int4_cutlass.cu diff --git a/torchao/csrc/cuda/int4_cutlass.cu b/torchao/csrc/cuda/int4_cutlass.cu new file mode 100644 index 0000000000..452abcceaa --- /dev/null +++ b/torchao/csrc/cuda/int4_cutlass.cu @@ -0,0 +1,231 @@ +#include +#include + +// copied from s8s4_linear_cutlass.cu +#if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) && \ + defined(CUDA_VERSION) && (CUDA_VERSION >= 11080) +#define BUILD_INT4_MM_CUTLASS +#endif + +#if defined(BUILD_INT4_MM_CUTLASS) +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/epilogue/threadblock/fusion/visitors.hpp" +#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" + +#define CUTLASS_STATUS_CHECK(status) \ + { \ + TORCH_CHECK(status == cutlass::Status::kSuccess, \ + __func__, " : Got CUTLASS error: ", \ + cutlassGetStatusString(status)); \ + } +#endif + +namespace torchao { + +#if defined(BUILD_INT4_MM_CUTLASS) +// define common params +using ElementA = cutlass::int4b_t; +using ElementB = cutlass::int4b_t; +using ElementAccumulator = int32_t; +using OpClass = cutlass::arch::OpClassTensorOp; +using ArchTag = cutlass::arch::Sm80; + +// how many elements to load at a time -> load 128-bit = 32 x 4-bit +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; +#endif + +// we will do input checks in python. A and B are stored as int8 +torch::Tensor int4_mm_cutlass(torch::Tensor A, torch::Tensor B) { +#if defined(BUILD_INT4_MM_CUTLASS) + int M = A.size(0); + int K = A.size(1) * 2; + int N = B.size(1); + torch::Tensor C = torch::empty({M, N}, A.options().dtype(torch::kInt32)); + + // some configs for int4 mma + // https://github.com/NVIDIA/cutlass/blob/v3.5.1/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sm80.cu + // using default config. this can be tuned. + using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 128>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + // static int const kStages = 3; + using ElementC = int32_t; + using Gemm = cutlass::gemm::device::Gemm< + ElementA, cutlass::layout::RowMajor, // A matrix + ElementB, cutlass::layout::ColumnMajor, // B matrix + ElementC, cutlass::layout::RowMajor, // C matrix + ElementAccumulator, OpClass, ArchTag, + ThreadblockShape, WarpShape, InstructionShape + >; + Gemm::Arguments args { + {M, N, K}, + {reinterpret_cast(A.data_ptr()), K}, + {reinterpret_cast(B.data_ptr()), K}, + {C.data_ptr(), N}, + {C.data_ptr(), N}, + {1, 0} // epilogue + }; + Gemm gemm_op; + CUTLASS_STATUS_CHECK(gemm_op(args)); + return C; +#else + TORCH_CHECK_NOT_IMPLEMENTED(false, __func__); + return at::Tensor{}; +#endif +} + +template< + typename ElementC, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + int numStages> +void scaled_int4_mm_cutlass_dispatch(torch::Tensor A, torch::Tensor B, torch::Tensor row_scale, torch::Tensor col_scale, torch::Tensor C) { + // problem shape + int M = A.size(0); + int K = A.size(1) * 2; + int N = B.size(1); + + constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // 8 for BF16/FP16 + using ElementEpilogue = float; + constexpr int numEpilogueStages = 1; + + // build epilogue visitor tree + using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout< + ThreadblockShape, WarpShape, ElementC, AlignmentC, numEpilogueStages + >; + + using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; + constexpr auto RoundMode = cutlass::FloatRoundStyle::round_to_nearest; + using Multiply = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, ElementEpilogue, ElementEpilogue, RoundMode + >; + + // (1, N) + using ColScale = cutlass::epilogue::threadblock::VisitorRowBroadcast< + OutputTileThreadMap, ElementC, + cute::Stride // MNL + >; + using EVTCompute0 = cutlass::epilogue::threadblock::Sm80EVT; + + // (M, 1) + using RowScale = cutlass::epilogue::threadblock::VisitorColBroadcast< + OutputTileThreadMap, ElementC, + cute::Stride // MNL + >; + using EVTCompute1 = cutlass::epilogue::threadblock::Sm80EVT; + + using Output = cutlass::epilogue::threadblock::VisitorAuxStore< + OutputTileThreadMap, ElementC, RoundMode, + cute::Stride // MNL + >; + using EVTOutput = cutlass::epilogue::threadblock::Sm80EVT; + + using EVTKernel = typename cutlass::gemm::kernel::DefaultGemmWithVisitor< + ElementA, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, AlignmentA, + ElementB, cutlass::layout::ColumnMajor, cutlass::ComplexTransform::kNone, AlignmentB, + ElementC, cutlass::layout::RowMajor, AlignmentC, + ElementAccumulator, ElementEpilogue, OpClass, ArchTag, + ThreadblockShape, WarpShape, InstructionShape, + EVTOutput, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, + numStages, + cutlass::arch::OpMultiplyAddSaturate, // OpMultiplyAdd does not work + numEpilogueStages + >::GemmKernel; + using DeviceGemm = cutlass::gemm::device::GemmUniversalAdapter; + + // col_scale, row_scale, and C must have the same dtype + const ElementA *A_ptr = reinterpret_cast(A.data_ptr()); + const ElementB *B_ptr = reinterpret_cast(B.data_ptr()); + const ElementC *col_scale_ptr = reinterpret_cast(col_scale.data_ptr()); + const ElementC *row_scale_ptr = reinterpret_cast(row_scale.data_ptr()); + ElementC *C_ptr = reinterpret_cast(C.data_ptr()); + + typename EVTOutput::Arguments callback_args{ + { + { + {}, // Accum + {col_scale_ptr, ElementC(0), {cute::_0{}, cute::_1{}, int32_t(N)}}, // ColScale + {} // Multiply + }, // EVTCompute0 + {row_scale_ptr, ElementC(0), {cute::_1{}, cute::_0{}, int32_t(M)}}, // RowScale + {} // Multiply + }, // EVTCompute1 + {C_ptr, {int64_t{N}, cute::_1{}, int64_t{M*N}}} // EVTOutput + }; + + typename DeviceGemm::Arguments args( + cutlass::gemm::GemmUniversalMode::kGemm, + cutlass::gemm::GemmCoord{M, N, K}, + 1, // batch_split + callback_args, + A_ptr, B_ptr, nullptr, nullptr, // unsued C_ptr and D_ptr + M * K, N * K, 0, 0, // batch_stride A, B, C, D + K, K, 0, 0 // stride A, B, C, D + ); + + DeviceGemm gemm_op; + auto stream = at::cuda::getCurrentCUDAStream(); + CUTLASS_STATUS_CHECK(gemm_op.can_implement(args)); + CUTLASS_STATUS_CHECK(gemm_op(args, nullptr, stream)); +} + +// we will do input checks in python. A and B are stored as int8 +// this function is based on the following cutlass example +// https://github.com/NVIDIA/cutlass/blob/main/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu +// also with the help of emitted code from cutlass Python +torch::Tensor scaled_int4_mm_cutlass(torch::Tensor A, torch::Tensor B, torch::Tensor row_scale, torch::Tensor col_scale) { +#if defined(BUILD_INT4_MM_CUTLASS) + int M = A.size(0); + int N = B.size(1); + torch::Tensor C = torch::empty({M, N}, row_scale.options()); + + // some configs for int4 mma + // https://github.com/NVIDIA/cutlass/blob/v3.5.1/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sm80.cu + // using default config. this can be tuned. + using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 128>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + constexpr int numStages = 3; + + AT_DISPATCH_SWITCH( + row_scale.scalar_type(), + "scaled_int4_mm_cutlass", + AT_DISPATCH_CASE( + torch::ScalarType::Half, + [&]() { + using ElementC = cutlass::half_t; + scaled_int4_mm_cutlass_dispatch< + ElementC, ThreadblockShape, WarpShape, InstructionShape, numStages>( + A, B, row_scale, col_scale, C); + } + ) + AT_DISPATCH_CASE( + torch::ScalarType::BFloat16, + [&]() { + using ElementC = cutlass::bfloat16_t; + scaled_int4_mm_cutlass_dispatch< + ElementC, ThreadblockShape, WarpShape, InstructionShape, numStages>( + A, B, row_scale, col_scale, C); + } + ) + ); + + return C; +#else + TORCH_CHECK_NOT_IMPLEMENTED(false, __func__); + return at::Tensor{}; +#endif +} + +TORCH_LIBRARY_IMPL(torchao, CUDA, m) { + m.impl("torchao::int4_mm_cutlass", &int4_mm_cutlass); + m.impl("torchao::scaled_int4_mm_cutlass", &scaled_int4_mm_cutlass); +} + +} // namespace torchao diff --git a/torchao/ops.py b/torchao/ops.py index f4b55c4951..840dbc0e97 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -22,6 +22,10 @@ lib.define( "s8s4_linear_cutlass(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor" ) +lib.define("int4_mm_cutlass(Tensor A, Tensor B) -> Tensor") +lib.define( + "scaled_int4_mm_cutlass(Tensor A, Tensor B, Tensor row_scale, Tensor col_scale) -> Tensor" +) def register_custom_op(name): @@ -615,3 +619,52 @@ def _( dtype=input_scale.dtype, device=input.device, ) + + +def int4_mm_cutlass(A: Tensor, B: Tensor) -> Tensor: + """ + CUTLASS-based W4A4 matmul. + Args: + A: first INT4 tensor, packed in INT8 dtype, row-major layout. + B: second INT4 tensor, packed in INT8 dtype, column-major layout. + Returns: + output: result tensor, in row-major layout. + """ + assert A.dtype == B.dtype == torch.int8 + assert A.ndim == B.ndim == 2 + assert A.shape[1] == B.shape[0] + assert A.is_contiguous() and B.T.is_contiguous() + return torch.ops.torchao.int4_mm_cutlass.default(A, B) + + +@register_custom_op("torchao::int4_mm_cutlass") +def _(A: Tensor, B: Tensor) -> Tensor: + return A.new_empty(A.shape[0], B.shape[1], dtype=torch.int32) + + +def scaled_int4_mm_cutlass( + A: Tensor, B: Tensor, row_scale: Tensor, col_scale: Tensor +) -> Tensor: + """ + CUTLASS-based W4A4 scaled-matmul. + Args: + A: first INT4 tensor, packed in INT8 dtype, row-major layout. + B: second INT4 tensor, packed in INT8 dtype, column-major layout. + row_scale: scaling for each output row. + col_scale: scaling for each output column. + Returns: + output: result tensor, in row-major layout. + """ + assert A.dtype == B.dtype == torch.int8 + assert A.ndim == B.ndim == 2 + assert A.shape[1] == B.shape[0] + assert A.is_contiguous() and B.T.is_contiguous() + assert row_scale.ndim == col_scale.ndim == 1 + assert row_scale.dtype == col_scale.dtype + assert row_scale.dtype in (torch.float16, torch.bfloat16) + return torch.ops.torchao.scaled_int4_mm_cutlass.default(A, B, row_scale, col_scale) + + +@register_custom_op("torchao::scaled_int4_mm_cutlass") +def _(A: Tensor, B: Tensor, row_scale: Tensor, col_scale: Tensor) -> Tensor: + return row_scale.new_empty(A.shape[0], B.shape[1]) From 7e277df04ab75ceeda1ec0efc71116c91e9c6083 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 7 Jan 2025 23:41:15 +0800 Subject: [PATCH 2/5] add test --- test/test_ops.py | 58 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/test/test_ops.py b/test/test_ops.py index c5821eed44..a7e05843ec 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -536,5 +536,63 @@ def test_marlin_qqq(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_fact ) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize( + "M,N,K", + [(1, 256, 512), (18, 512, 256), (17, 256, 512)] +) +def test_int4_mm_cutlass(M, N, K): + A = torch.randint(-128, 127, size=(M, K // 2), dtype=torch.int8, device="cuda") + B = torch.randint(-128, 127, size=(N, K // 2), dtype=torch.int8, device="cuda") + actual = torchao.ops.int4_mm_cutlass(A, B.T) + + # NOTE: A >> 4 will perform sign-bit extension + unpacked_A = torch.stack([A >> 4, A << 4 >> 4], dim=1).reshape(M, K) + unpacked_B = torch.stack([B >> 4, B << 4 >> 4], dim=1).reshape(N, K) + expected = (unpacked_A.float() @ unpacked_B.float().T).to(torch.int32) + + torch.testing.assert_close(actual, expected) + + # Performs opcheck + test_utils = ["test_schema", "test_autograd_registration", "test_faketensor"] + opcheck( + torch.ops.torchao.int4_mm_cutlass, + (A, B), + test_utils=test_utils, + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize( + "M,N,K", + [(1, 256, 512), (18, 512, 256), (17, 256, 512)] +) +@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) +def test_scaled_int4_mm_cutlass(M, N, K, dtype): + A = torch.randint(-128, 127, size=(M, K // 2), dtype=torch.int8, device="cuda") + B = torch.randint(-128, 127, size=(N, K // 2), dtype=torch.int8, device="cuda") + row_scale = torch.randn(M, dtype=dtype, device="cuda") + col_scale = torch.randn(N, dtype=dtype, device="cuda") + actual = torchao.ops.scaled_int4_mm_cutlass(A, B.T, row_scale, col_scale) + + # NOTE: A >> 4 will perform sign-bit extension + unpacked_A = torch.stack([A >> 4, A << 4 >> 4], dim=1).reshape(M, K) + unpacked_B = torch.stack([B >> 4, B << 4 >> 4], dim=1).reshape(N, K) + + expected = unpacked_A.float() @ unpacked_B.float().T + expected = expected * row_scale.view(-1, 1) * col_scale.view(1, -1) + expected = expected.to(dtype) + + torch.testing.assert_close(actual, expected) + + # Performs opcheck + test_utils = ["test_schema", "test_autograd_registration", "test_faketensor"] + opcheck( + torch.ops.torchao.scaled_int4_mm_cutlass, + (A, B, row_scale, col_scale), + test_utils=test_utils, + ) + + if __name__ == "__main__": run_tests() From a44df9e421d45826a6003a554c7a0718a1889b3e Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 8 Jan 2025 17:40:52 +0700 Subject: [PATCH 3/5] hook up to AQT --- torchao/dtypes/affine_quantized_tensor_ops.py | 6 ++ .../uintx/cutlass_int4_packed_layout.py | 36 ++++++++++++ torchao/quantization/__init__.py | 2 + torchao/quantization/quant_api.py | 56 +++++++++++++++++++ 4 files changed, 100 insertions(+) diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index 76df949852..f844a6a89a 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -21,6 +21,8 @@ _linear_int8_act_int8_weight_block_sparse_impl, ) from torchao.dtypes.uintx.cutlass_int4_packed_layout import ( + _linear_int4_act_int4_weight_cutlass_check, + _linear_int4_act_int4_weight_cutlass_impl, _linear_int8_act_int4_weight_cutlass_check, _linear_int8_act_int4_weight_cutlass_impl, ) @@ -151,6 +153,10 @@ def _register_aqt_quantized_linear_dispatches(): _linear_int8_act_int4_weight_cutlass_check, _linear_int8_act_int4_weight_cutlass_impl, ), + ( + _linear_int4_act_int4_weight_cutlass_check, + _linear_int4_act_int4_weight_cutlass_impl, + ), ]: register_aqt_quantized_linear_dispatch(dispatch_condition, impl) diff --git a/torchao/dtypes/uintx/cutlass_int4_packed_layout.py b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py index a6412ec88c..018c2e2ad5 100644 --- a/torchao/dtypes/uintx/cutlass_int4_packed_layout.py +++ b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py @@ -154,3 +154,39 @@ def _linear_int8_act_int4_weight_cutlass_impl(input_tensor, weight_tensor, bias) out = s8s4_linear_cutlass(input, input_scale, weight, weight_scale, bias) return out + + +def _linear_int4_act_int4_weight_cutlass_check(input_tensor, weight_tensor, bias): + return ( + isinstance(input_tensor, AffineQuantizedTensor) + and _aqt_is_int4(input_tensor) + and input_tensor.dtype in (torch.float16, torch.bfloat16) + and len(input_tensor.shape) >= 2 + and input_tensor.tensor_impl.scale.dtype == input_tensor.dtype + and len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) - 1 + and isinstance(weight_tensor, AffineQuantizedTensor) + and _aqt_is_int4(weight_tensor) + and weight_tensor.dtype == input_tensor.dtype + and len(weight_tensor.shape) == 2 + and weight_tensor.tensor_impl.scale.dtype == weight_tensor.dtype + and len(weight_tensor.tensor_impl.scale.shape) == 1 + ) + + +def _linear_int4_act_int4_weight_cutlass_impl(input_tensor, weight_tensor, bias): + from torchao.ops import scaled_int4_mm_cutlass + + weight = weight_tensor.tensor_impl.int_data + weight_scale = weight_tensor.tensor_impl.scale + input = input_tensor.tensor_impl.int_data + input_scale = input_tensor.tensor_impl.scale + + batch_dims = input_tensor.shape[:-2] + input = input.view(-1, input.shape[-1]) + input_scale = input_scale.view(-1) + out = scaled_int4_mm_cutlass(input, weight.T, input_scale, weight_scale) + if bias is not None: + out = out + bias + out = out.view(*batch_dims, out.shape[-1]) + + return out diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index d0d29cf4be..49e75822b2 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -51,6 +51,7 @@ fpx_weight_only, gemlite_uintx_weight_only, int4_weight_only, + int4_dynamic_activation_int4_weight, int8_dynamic_activation_int4_weight, int8_dynamic_activation_int8_semi_sparse_weight, int8_dynamic_activation_int8_weight, @@ -102,6 +103,7 @@ "ALL_AUTOQUANT_CLASS_LIST", # top level API - manual "quantize_", + "int4_dynamic_activation_int4_weight", "int8_dynamic_activation_int4_weight", "int8_dynamic_activation_int8_weight", "int8_dynamic_activation_int8_semi_sparse_weight", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index b2eff196fd..801911b01e 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -654,6 +654,62 @@ def int8_dynamic_activation_int4_weight( ) +def int4_dynamic_activation_int4_weight( + layout=CutlassInt4PackedLayout(), + mapping_type=MappingType.SYMMETRIC, + act_mapping_type=MappingType.SYMMETRIC, +): + """Applies int4 dynamic per token symmetric activation quantization and int4 per row weight symmetric quantization to linear + + Args: + `layout`: layout type for quantized weight tensor, only supports `MarlinQQQLayout()` and `CutlassInt4PackedLayout()` for now + `mapping_type`: quantization type for weight, controls the weight quantization is symmetric or asymmetric + `act_mapping_type`: quantization type for activation, controls the activation quantization is symmetric or asymmetric + """ + + if not isinstance(layout, CutlassInt4PackedLayout): + raise NotImplementedError( + f"Only CutlassInt4PackedLayout layout is supported. Received {layout}." + ) + if mapping_type != MappingType.SYMMETRIC: + raise NotImplementedError("Only mapping_type=SYMMETRIC is supported.") + if act_mapping_type != MappingType.SYMMETRIC: + raise NotImplementedError("Only act_mapping_type=SYMMETRIC is supported.") + + def _int4_symm_per_token_quant_cutlass(x): + return to_affine_quantized_intx( + x, + mapping_type=act_mapping_type, + block_size=_get_per_token_block_size(x), + target_dtype=torch.int8, + quant_min=-8, + quant_max=7, + eps=1e-5, + _layout=layout, + ) + + def apply_int4_dynamic_activation_int4_weight_quant(weight): + weight = to_affine_quantized_intx( + weight, + mapping_type=mapping_type, + block_size=(1, weight.shape[1]), + target_dtype=torch.int8, + quant_min=-8, + quant_max=7, + eps=torch.finfo(torch.float32).eps, + _layout=layout, + ) + weight = to_linear_activation_quantized( + weight, + _int4_symm_per_token_quant_cutlass, + ) + return weight + + return _get_linear_subclass_inserter( + apply_int4_dynamic_activation_int4_weight_quant + ) + + def gemlite_uintx_weight_only( group_size: Optional[int] = 64, bit_width: int = 4, From de167f062ccf4d4b5f3f7385558587d4769cb469 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 8 Jan 2025 21:21:59 +0800 Subject: [PATCH 4/5] fix quant api test --- test/dtypes/test_affine_quantized.py | 2 + test/test_ops.py | 10 +-- .../uintx/cutlass_int4_packed_layout.py | 2 +- torchao/quantization/__init__.py | 2 +- torchao/quantization/quant_api.py | 84 +++++++++++-------- 5 files changed, 53 insertions(+), 47 deletions(-) diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index f08ba7aa72..eb79b05332 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -11,6 +11,7 @@ from torchao.dtypes import CutlassInt4PackedLayout, Int4CPULayout, SemiSparseLayout from torchao.quantization import ( float8_weight_only, + int4_dynamic_activation_int4_weight, int4_weight_only, int8_dynamic_activation_int4_weight, int8_dynamic_activation_int8_weight, @@ -57,6 +58,7 @@ def get_quantization_functions( layout=CutlassInt4PackedLayout(), ) ) + base_functions.append(int4_dynamic_activation_int4_weight()) if do_sparse: base_functions.append( diff --git a/test/test_ops.py b/test/test_ops.py index 8d66bc516c..38797d8cab 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -608,10 +608,7 @@ def test_marlin_qqq(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_fact @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.parametrize( - "M,N,K", - [(1, 256, 512), (18, 512, 256), (17, 256, 512)] -) +@pytest.mark.parametrize("M,N,K", [(1, 256, 512), (18, 512, 256), (17, 256, 512)]) def test_int4_mm_cutlass(M, N, K): A = torch.randint(-128, 127, size=(M, K // 2), dtype=torch.int8, device="cuda") B = torch.randint(-128, 127, size=(N, K // 2), dtype=torch.int8, device="cuda") @@ -634,10 +631,7 @@ def test_int4_mm_cutlass(M, N, K): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.parametrize( - "M,N,K", - [(1, 256, 512), (18, 512, 256), (17, 256, 512)] -) +@pytest.mark.parametrize("M,N,K", [(1, 256, 512), (18, 512, 256), (17, 256, 512)]) @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) def test_scaled_int4_mm_cutlass(M, N, K, dtype): A = torch.randint(-128, 127, size=(M, K // 2), dtype=torch.int8, device="cuda") diff --git a/torchao/dtypes/uintx/cutlass_int4_packed_layout.py b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py index 018c2e2ad5..d7374c8d50 100644 --- a/torchao/dtypes/uintx/cutlass_int4_packed_layout.py +++ b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py @@ -181,7 +181,7 @@ def _linear_int4_act_int4_weight_cutlass_impl(input_tensor, weight_tensor, bias) input = input_tensor.tensor_impl.int_data input_scale = input_tensor.tensor_impl.scale - batch_dims = input_tensor.shape[:-2] + batch_dims = input_tensor.shape[:-1] input = input.view(-1, input.shape[-1]) input_scale = input_scale.view(-1) out = scaled_int4_mm_cutlass(input, weight.T, input_scale, weight_scale) diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 49e75822b2..aa4a51d497 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -50,8 +50,8 @@ float8_weight_only, fpx_weight_only, gemlite_uintx_weight_only, - int4_weight_only, int4_dynamic_activation_int4_weight, + int4_weight_only, int8_dynamic_activation_int4_weight, int8_dynamic_activation_int8_semi_sparse_weight, int8_dynamic_activation_int8_weight, diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 801911b01e..b209f28043 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -654,19 +654,12 @@ def int8_dynamic_activation_int4_weight( ) -def int4_dynamic_activation_int4_weight( +def apply_int4_dynamic_activation_int4_weight_quant( + weight: torch.Tensor, layout=CutlassInt4PackedLayout(), mapping_type=MappingType.SYMMETRIC, act_mapping_type=MappingType.SYMMETRIC, ): - """Applies int4 dynamic per token symmetric activation quantization and int4 per row weight symmetric quantization to linear - - Args: - `layout`: layout type for quantized weight tensor, only supports `MarlinQQQLayout()` and `CutlassInt4PackedLayout()` for now - `mapping_type`: quantization type for weight, controls the weight quantization is symmetric or asymmetric - `act_mapping_type`: quantization type for activation, controls the activation quantization is symmetric or asymmetric - """ - if not isinstance(layout, CutlassInt4PackedLayout): raise NotImplementedError( f"Only CutlassInt4PackedLayout layout is supported. Received {layout}." @@ -676,37 +669,40 @@ def int4_dynamic_activation_int4_weight( if act_mapping_type != MappingType.SYMMETRIC: raise NotImplementedError("Only act_mapping_type=SYMMETRIC is supported.") - def _int4_symm_per_token_quant_cutlass(x): - return to_affine_quantized_intx( - x, - mapping_type=act_mapping_type, - block_size=_get_per_token_block_size(x), - target_dtype=torch.int8, - quant_min=-8, - quant_max=7, - eps=1e-5, - _layout=layout, - ) + weight = to_affine_quantized_intx( + weight, + mapping_type=mapping_type, + block_size=(1, weight.shape[1]), + target_dtype=torch.int8, + quant_min=-8, + quant_max=7, + eps=torch.finfo(torch.float32).eps, + _layout=layout, + ) + weight = to_linear_activation_quantized( + weight, + _int4_symm_per_token_quant_cutlass, + ) + return weight - def apply_int4_dynamic_activation_int4_weight_quant(weight): - weight = to_affine_quantized_intx( - weight, - mapping_type=mapping_type, - block_size=(1, weight.shape[1]), - target_dtype=torch.int8, - quant_min=-8, - quant_max=7, - eps=torch.finfo(torch.float32).eps, - _layout=layout, - ) - weight = to_linear_activation_quantized( - weight, - _int4_symm_per_token_quant_cutlass, - ) - return weight +def int4_dynamic_activation_int4_weight( + layout=CutlassInt4PackedLayout(), + mapping_type=MappingType.SYMMETRIC, + act_mapping_type=MappingType.SYMMETRIC, +): + """Applies int4 dynamic per token symmetric activation quantization and int4 per row weight symmetric quantization to linear + + Args: + `layout`: layout type for quantized weight tensor, only supports `MarlinQQQLayout()` and `CutlassInt4PackedLayout()` for now + `mapping_type`: quantization type for weight, controls the weight quantization is symmetric or asymmetric + `act_mapping_type`: quantization type for activation, controls the activation quantization is symmetric or asymmetric + """ return _get_linear_subclass_inserter( - apply_int4_dynamic_activation_int4_weight_quant + apply_int4_dynamic_activation_int4_weight_quant, + layout=layout, + mapping_type=mapping_type, + act_mapping_type=act_mapping_type, ) @@ -909,6 +905,19 @@ def _int8_symm_per_token_reduced_range_quant_cutlass( ) +def _int4_symm_per_token_quant_cutlass(x: torch.Tensor) -> torch.Tensor: + return to_affine_quantized_intx( + x, + mapping_type=MappingType.SYMMETRIC, + block_size=_get_per_token_block_size(x), + target_dtype=torch.int8, + quant_min=-8, + quant_max=7, + eps=1e-5, + _layout=CutlassInt4PackedLayout(), + ) + + def int8_dynamic_activation_int8_weight( layout=PlainLayout(), act_mapping_type=MappingType.SYMMETRIC, @@ -1348,6 +1357,7 @@ def apply_quant_llm(weight: torch.Tensor) -> torch.Tensor: _int8_asymm_per_token_quant, _int8_symm_per_token_reduced_range_quant, _int8_symm_per_token_reduced_range_quant_cutlass, + _int4_symm_per_token_quant_cutlass, _input_activation_quant_func_fp8, ] ) From fe1f0eb4af306a31175df8f6464473e829d6bb14 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 8 Jan 2025 22:17:11 +0800 Subject: [PATCH 5/5] fix test --- test/test_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 38797d8cab..d99bc6b055 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -641,8 +641,8 @@ def test_scaled_int4_mm_cutlass(M, N, K, dtype): actual = torchao.ops.scaled_int4_mm_cutlass(A, B.T, row_scale, col_scale) # NOTE: A >> 4 will perform sign-bit extension - unpacked_A = torch.stack([A >> 4, A << 4 >> 4], dim=1).reshape(M, K) - unpacked_B = torch.stack([B >> 4, B << 4 >> 4], dim=1).reshape(N, K) + unpacked_A = torch.stack([A >> 4, A << 4 >> 4], dim=2).reshape(M, K) + unpacked_B = torch.stack([B >> 4, B << 4 >> 4], dim=2).reshape(N, K) expected = unpacked_A.float() @ unpacked_B.float().T expected = expected * row_scale.view(-1, 1) * col_scale.view(1, -1)