diff --git a/benchmarks/benchmark_rowwise_scaled_linear_cutlass.py b/benchmarks/benchmark_rowwise_scaled_linear_cutlass.py new file mode 100644 index 0000000000..d547dd4127 --- /dev/null +++ b/benchmarks/benchmark_rowwise_scaled_linear_cutlass.py @@ -0,0 +1,68 @@ +import pandas as pd +import torch +from tqdm import tqdm + +from torchao.ops import ( + rowwise_scaled_linear_cutlass_s4s4, + rowwise_scaled_linear_cutlass_s8s4, +) +from torchao.utils import benchmark_torch_function_in_microseconds + + +def get_problem(m: int, n: int, k: int, A_nbits: int, B_nbits: int): + assert A_nbits in (4, 8) and B_nbits in (4, 8) + + dev = torch.device("cuda") + A = torch.randint(-128, 127, (m, k * A_nbits // 8), dtype=torch.int8, device=dev) + A_scale = torch.randn((m,), dtype=torch.half, device=dev) + B = torch.randint( + -128, 127, size=(n, k * B_nbits // 8), dtype=torch.int8, device=dev + ) + B_scale = torch.randn((n,), dtype=torch.half, device=dev) + C = None + + return A, A_scale, B, B_scale, C + + +def benchmark(m: int, k: int, n: int): + dev = torch.device("cuda") + A_ref = torch.randn((m, k), dtype=torch.half, device=dev) + B_ref = torch.randn((n, k), dtype=torch.half, device=dev) + fp16_time = benchmark_torch_function_in_microseconds( + torch.nn.functional.linear, A_ref, B_ref + ) + + A, A_scale, B, B_scale, C = get_problem(m, n, k, 8, 4) + rowwise_scaled_linear_cutlass_s8s4_time = benchmark_torch_function_in_microseconds( + rowwise_scaled_linear_cutlass_s8s4, A, A_scale, B, B_scale, C + ) + + A, A_scale, B, B_scale, C = get_problem(m, n, k, 4, 4) + rowwise_scaled_linear_cutlass_s4s4_time = benchmark_torch_function_in_microseconds( + rowwise_scaled_linear_cutlass_s4s4, A, A_scale, B, B_scale, C + ) + + return { + "m": m, + "k": k, + "n": n, + "fp16_latency (ms)": fp16_time, + "rowwise_scaled_linear_cutlass_s8s4 latency (ms)": rowwise_scaled_linear_cutlass_s8s4_time, + "speedup_s8s4 (d/s)": fp16_time / rowwise_scaled_linear_cutlass_s8s4_time, + "rowwise_scaled_linear_cutlass_s4s4 latency (ms)": rowwise_scaled_linear_cutlass_s4s4_time, + "speedup_s4s4 (d/s)": fp16_time / rowwise_scaled_linear_cutlass_s4s4_time, + } + + +if __name__ == "__main__": + k_vals = (8192, 8192, 8192, 28672) + n_vals = (8192, 10240, 57344, 8192) + + results = [] + for m in tqdm([1 << i for i in range(10)]): + for n, k in zip(n_vals, k_vals): + results.append(benchmark(m, k, n)) + + df = pd.DataFrame(results) + df.to_csv("rowwise_scaled_linear_cutlass_time_results.csv", index=False) + print(df.to_markdown(index=False)) diff --git a/benchmarks/benchmark_s8s4_cutlass.py b/benchmarks/benchmark_s8s4_cutlass.py deleted file mode 100644 index fbf07ebb35..0000000000 --- a/benchmarks/benchmark_s8s4_cutlass.py +++ /dev/null @@ -1,52 +0,0 @@ -import pandas as pd -import torch -from tqdm import tqdm - -from torchao.ops import s8s4_linear_cutlass -from torchao.utils import benchmark_torch_function_in_microseconds - - -def get_problem(m, n, k): - dev = torch.device("cuda") - A_ref = torch.randn((m, k), dtype=torch.half, device=dev) - B_ref = torch.randn((k, n), dtype=torch.half, device=dev) - - A = torch.randint(-128, 127, (m, k), dtype=torch.int8, device=dev) - A_scale = torch.randn((m,), dtype=torch.half, device=dev) - B = torch.randint(-128, 127, size=(n, k // 2), dtype=torch.int8, device=dev) - B_scale = torch.randn((n,), dtype=torch.half, device=dev) - C = None - - return A_ref, B_ref, A, A_scale, B, B_scale, C - - -def benchmark(m: int, k: int, n: int): - A_ref, B_ref, A, A_scale, B, B_scale, C = get_problem(m, n, k) - - fp16_time = benchmark_torch_function_in_microseconds(torch.matmul, A_ref, B_ref) - s8s4_linear_cutlass_time = benchmark_torch_function_in_microseconds( - s8s4_linear_cutlass, A, A_scale, B, B_scale, C - ) - - return { - "m": m, - "k": k, - "n": n, - "fp16_latency (ms)": fp16_time, - "s8s4_linear_cutlass latency (ms)": s8s4_linear_cutlass_time, - "speedup (d/s)": fp16_time / s8s4_linear_cutlass_time, - } - - -if __name__ == "__main__": - k_vals = (8192, 8192, 8192, 28672) - n_vals = (8192, 10240, 57344, 8192) - - results = [] - for m in tqdm([1 << i for i in range(10)]): - for n, k in zip(n_vals, k_vals): - results.append(benchmark(m, k, n)) - - df = pd.DataFrame(results) - df.to_csv("s8s4_linear_cutlass_time_results.csv", index=False) - print(df.to_markdown(index=False)) diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 8be0652e9a..52b25dab82 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -11,6 +11,7 @@ from torchao.dtypes import CutlassInt4PackedLayout, Int4CPULayout, SemiSparseLayout from torchao.quantization import ( float8_weight_only, + int4_dynamic_activation_int4_weight, int4_weight_only, int8_dynamic_activation_int4_weight, int8_dynamic_activation_int8_weight, @@ -61,6 +62,7 @@ def get_quantization_functions( layout=CutlassInt4PackedLayout(), ) ) + base_functions.append(int4_dynamic_activation_int4_weight()) if do_sparse: base_functions.append( diff --git a/test/test_rowwise_scaled_linear_cutlass.py b/test/test_rowwise_scaled_linear_cutlass.py new file mode 100644 index 0000000000..d6203ab9a4 --- /dev/null +++ b/test/test_rowwise_scaled_linear_cutlass.py @@ -0,0 +1,104 @@ +import itertools + +import pytest +import torch + +from torchao.ops import ( + rowwise_scaled_linear_cutlass_s4s4, + rowwise_scaled_linear_cutlass_s8s4, +) +from torchao.quantization.utils import group_quantize_tensor_symmetric + +ROWWISE_SCALED_LINEAR_CUTLASS_DTYPE = [torch.float16, torch.bfloat16] +ROWWISE_SCALED_LINEAR_CUTLASS_BATCH_SIZE = [1, 4, 8, 16, 32, 64] +ROWWISE_SCALED_LINEAR_CUTLASS_SIZE_MNK = [ + (2, 512, 128), + (3, 2048, 2048), + (4, 3584, 640), + (13, 8704, 8576), + (26, 18944, 1664), + (67, 6656, 1408), +] +ROWWISE_SCALED_LINEAR_CUTLASS_USE_BIAS = [False, True] +ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS = list( + itertools.product( + ROWWISE_SCALED_LINEAR_CUTLASS_DTYPE, + ROWWISE_SCALED_LINEAR_CUTLASS_BATCH_SIZE, + ROWWISE_SCALED_LINEAR_CUTLASS_SIZE_MNK, + ROWWISE_SCALED_LINEAR_CUTLASS_USE_BIAS, + ) +) + + +def run_test_for_op(op, xq_bits, wq_bits, dtype, batch_size, size_mnk, use_bias): + assert xq_bits in [4, 8] + assert wq_bits in [4, 8] + + size_m, size_n, size_k = size_mnk + + x = torch.randn((batch_size, size_m, size_k), dtype=dtype, device="cuda") + w = torch.rand((size_n, size_k), dtype=dtype, device="cuda") + bias = torch.rand((size_n,), dtype=dtype, device="cuda") if use_bias else None + + x_2d = x.view(-1, x.shape[-1]) + xq_2d_s8, xq_2d_scales, xq_2d_zeros = group_quantize_tensor_symmetric( + x_2d, xq_bits, size_k, dtype + ) + assert torch.all(xq_2d_zeros == 0) + xq_s8 = xq_2d_s8.reshape(x.shape) + if xq_bits == 4: + xq = (xq_s8[..., 1::2] << 4) | (xq_s8[..., 0::2] & 0xF) + else: + xq = xq_s8 + xq_scales = xq_2d_scales.reshape(x.shape[:-1]) + + wq_s8, wq_scales, wq_zeros = group_quantize_tensor_symmetric( + w, wq_bits, size_n, dtype + ) + assert torch.all(wq_zeros == 0) + if wq_bits == 4: + wq = (wq_s8[:, 1::2] << 4) | (wq_s8[:, 0::2] & 0xF) + else: + wq = wq_s8 + + # If torch.nn.functional.linear(x, w, bias) used as reference, the + # error would be too big. The calculation below is approximately + # what rowwise_scaled_linear_cutlass kernel is doing (except that + # matrix multiplication is over integers there). + size_m_2d = x_2d.shape[0] + output_ref = ( + (xq_2d_s8.float() @ wq_s8.float().T) + * xq_2d_scales.view(size_m_2d, 1) + * wq_scales.view(1, size_n) + ) + if bias is not None: + output_ref += bias + output_ref = output_ref.to(dtype).reshape(x.shape[:-1] + (size_n,)) + + fn_inputs = (xq, xq_scales, wq, wq_scales, bias) + try: + output = op(*fn_inputs) + except NotImplementedError: + pytest.xfail("operator not implemented") + + torch.testing.assert_close(output, output_ref) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize( + "dtype, batch_size, size_mnk, use_bias", ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS +) +def test_rowwise_scaled_linear_cutlass_s4s4(dtype, batch_size, size_mnk, use_bias): + run_test_for_op( + rowwise_scaled_linear_cutlass_s4s4, 4, 4, dtype, batch_size, size_mnk, use_bias + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize( + "dtype, batch_size, size_mnk, use_bias", ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS +) +def test_rowwise_scaled_linear_cutlass_s8s4(dtype, batch_size, size_mnk, use_bias): + run_test_for_op( + rowwise_scaled_linear_cutlass_s8s4, 8, 4, dtype, batch_size, size_mnk, use_bias + ) diff --git a/test/test_s8s4_linear_cutlass.py b/test/test_s8s4_linear_cutlass.py deleted file mode 100644 index 6510adaea3..0000000000 --- a/test/test_s8s4_linear_cutlass.py +++ /dev/null @@ -1,77 +0,0 @@ -import itertools - -import pytest -import torch - -from torchao.ops import s8s4_linear_cutlass -from torchao.quantization.utils import group_quantize_tensor_symmetric -from torchao.utils import compute_max_diff - -S8S4_LINEAR_CUTLASS_DTYPE = [torch.float16, torch.bfloat16] -S8S4_LINEAR_CUTLASS_BATCH_SIZE = [1, 4, 8, 16, 32, 64] -S8S4_LINEAR_CUTLASS_SIZE_MNK = [ - (2, 512, 128), - (3, 2048, 2048), - (4, 3584, 640), - (13, 8704, 8576), - (26, 18944, 1664), - (67, 6656, 1408), -] -S8S4_LINEAR_CUTLASS_USE_BIAS = [False, True] -S8S4_LINEAR_CUTLASS_TEST_PARAMS = list( - itertools.product( - S8S4_LINEAR_CUTLASS_DTYPE, - S8S4_LINEAR_CUTLASS_BATCH_SIZE, - S8S4_LINEAR_CUTLASS_SIZE_MNK, - S8S4_LINEAR_CUTLASS_USE_BIAS, - ) -) - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.parametrize( - "dtype, batch_size, size_mnk, use_bias", S8S4_LINEAR_CUTLASS_TEST_PARAMS -) -def test_s8s4_linear_cutlass(dtype, batch_size, size_mnk, use_bias): - size_m, size_n, size_k = size_mnk - - input = torch.randn((batch_size, size_m, size_k), dtype=dtype, device="cuda") - weight = torch.rand((size_n, size_k), dtype=dtype, device="cuda") - bias = torch.rand((size_n,), dtype=dtype, device="cuda") if use_bias else None - - input_2d = input.view(-1, input.shape[-1]) - input_2d_s8, input_2d_scales, input_2d_zeros = group_quantize_tensor_symmetric( - input_2d, 8, size_k, dtype - ) - assert torch.all(input_2d_zeros == 0) - input_s8 = input_2d_s8.reshape(input.shape) - input_scales = input_2d_scales.reshape(input.shape[:-1]) - - weight_s8, weight_scales, weight_zeros = group_quantize_tensor_symmetric( - weight, 4, size_n, dtype - ) - assert torch.all(weight_zeros == 0) - weight_s4 = ((weight_s8[:, 1::2] & 0xF) << 4) | (weight_s8[:, 0::2] & 0xF) - - # If torch.nn.functional.linear(input, weight, bias) used as - # reference, the error would be too big. The calculation below is - # approximately what s8s4_linear_cutlass kernel is doing (except - # that matrrix multiplication is over integers there)). - size_m_2d = input_2d.shape[0] - output_ref = ( - (input_2d_s8.to(dtype) @ weight_s8.to(dtype).T) - * input_2d_scales.view(size_m_2d, 1) - * weight_scales.view(1, size_n) - ) - if bias is not None: - output_ref += bias - output_ref = output_ref.reshape(input.shape[:-1] + (size_n,)) - - fn_inputs = (input_s8, input_scales, weight_s4, weight_scales, bias) - try: - output = s8s4_linear_cutlass(*fn_inputs) - except NotImplementedError: - pytest.xfail("s8s4_linear_cutlass() op not implemented") - - max_diff = compute_max_diff(output, output_ref) - assert max_diff < 5e-3 diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/README.md b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/README.md new file mode 100644 index 0000000000..7c36f7c7ed --- /dev/null +++ b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/README.md @@ -0,0 +1,52 @@ +This directory is intended to contain implementations for all of the +CUTLASS-based row-wise scaled linear operators, for non-sparse inputs +of both same and mixed data types. + +The implementation is through single kernel per SM generation, that +should reside in `rowwise_scaled_linear_kernel_cutlass.cuh` file. At +the moment, only SM8.x architectures are supported, through +`rowwise_scaled_linear_kernel_cutlass_sm8x` kernel, but the SM9.x, and +eventually higher, can and will be supported too. + +The rest of source files, besides +`rowwise_scaled_linear_kernel_cutlass.cuh` file, contain just the +corresponding template instantiation and PyTorch operator declaration +for given operator. + +In order to support new combination of data types, copy one of +existing `.cu` files, for example +`rowwise_scaled_linear_kernel_cutlass_s8s4.cu`, rename the new file, +as well as operator to be defined inside, to reflect data types to be +supported, and also change `using ElementA` and `using ElementB` +directives accordingly. + +In the `.cuh` file, looking from the bottom up, the changes needed as +follows: + +1. Optionally, in the `rowwise_scaled_linear_cutlass_check_inputs` +template, changes may be needed at the places where the last dimension +of first operand is checked - but this check will have to be updated +only for inputs of mixed data types, where wider data type is not +exactly two times wider than the other data type. +2. In the `select_config` template, a section should be added to +choose optimal configuration(s) for your kernel. The configuration +selection is critical for performance of any CUTLASS-based kernel, so +this is where the most time should and will be spent when making +changes. +3. Optionally, in the `rowwise_scaled_linear_kernel_cutlass_sm8x` +template, `using Operator` directive may need to be adjusted; namely, +for some combination of operands, `OpMultiplyAdd` may have to be used. + +After making these changes, the test file +`tests/test_rowwise_scaled_linear_cutlass.py` should be changed too - +add a test for the new operator alike to existing tests. + +To restrict build times, the implementation in `.cuh` file has some +restrictions at the moment, for example: scale tensors could be only +of `float16` or `bfloat16` data types, the output is produces to be of +the same data type as first input scale tensor, scale tensors are not +optional while bias is optional, etc. If any of these restrictions +should be removed, or if any alike changes are needed, or if support +for other architectures is needed, or if you need any kind of help in +extending this code to support other data type combinations - get in +touch with the developers. diff --git a/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass.cuh similarity index 79% rename from torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu rename to torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass.cuh index 411343f0da..ab7cda07f6 100644 --- a/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu +++ b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass.cuh @@ -1,4 +1,4 @@ -#include +#pragma once #include #include @@ -7,15 +7,16 @@ #if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) && \ defined(CUDA_VERSION) && (CUDA_VERSION >= 11080) -#define BUILD_S8S4_LINEAR_CUTLASS +#define BUILD_ROWWISE_SCALED_LINEAR_CUTLASS #endif -#if defined(BUILD_S8S4_LINEAR_CUTLASS) +#if defined(BUILD_ROWWISE_SCALED_LINEAR_CUTLASS) #include #include #include #include #include +#include #define CUTLASS_STATUS_CHECK(status) \ { \ @@ -27,41 +28,52 @@ namespace torchao { -#if defined(BUILD_S8S4_LINEAR_CUTLASS) +#if defined(BUILD_ROWWISE_SCALED_LINEAR_CUTLASS) template< typename ThreadblockShape, typename WarpShape, typename InstructionShape, + typename ThreadblockSwizzle, int NumStages, typename ElementA, typename ElementB, - typename ElementAccumulator, - typename Operator, - typename ElementAScale, - typename ElementBScale, + typename ElementOutput, typename ElementC, typename UseTensorC, - typename ElementOutput> -void s8s4_linear_kernel_cutlass_sm8x( + typename ElementAScale, + typename ElementBScale> +void rowwise_scaled_linear_kernel_cutlass_sm8x( const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, const at::Tensor& tensor_c, at::Tensor& tensor_d) { + static_assert((cutlass::sizeof_bits::value >= 8 || + 8 % cutlass::sizeof_bits::value == 0) && + (cutlass::sizeof_bits::value >= 8 || + 8 % cutlass::sizeof_bits::value == 0)); + using SmArch = cutlass::arch::Sm80; using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; using LayoutOutput = cutlass::layout::RowMajor; - using ElementEpilogue = float; + // TODO: use FP32 if either ElementA/B is FP + using ElementAccumulator = int32_t; + using Operator = + std::conditional_t::value, + cutlass::arch::OpMultiplyAddSaturate, + cutlass::arch::OpMultiplyAddMixedInputUpcast>; - using ThreadblockSwizzle = - cutlass::gemm::threadblock::ThreadblockSwizzleStreamK; + using ElementEpilogue = float; constexpr auto NumEVTEpilogueStages = 1; const int m = tensor_a.size(0); const int n = tensor_b.size(0); - const int k = tensor_a.size(1); + int k = tensor_a.size(1); + if constexpr (cutlass::sizeof_bits::value < 8) { + k *= 8 / cutlass::sizeof_bits::value; + } constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; constexpr int AlignmentAScale = @@ -196,7 +208,7 @@ void s8s4_linear_kernel_cutlass_sm8x( NumEVTEpilogueStages >::GemmKernel; - using Gemm = cutlass::gemm::device::GemmUniversalBase; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; cutlass::gemm::GemmCoord problem_size(m, n, k); constexpr auto SplitKFactor = 1; @@ -242,7 +254,6 @@ void s8s4_linear_kernel_cutlass_sm8x( }, // EVTApplySum output_arguments // Output }; // EVTOutput - constexpr auto AvailSms = -1; typename Gemm::Arguments arguments( cutlass::gemm::GemmUniversalMode::kGemm, @@ -260,8 +271,8 @@ void s8s4_linear_kernel_cutlass_sm8x( problem_size.k(), // stride A problem_size.k(), // stride B 0, // stride C (unused) - 0, // stride D (unused) - AvailSms); + 0 // stride D (unused) + ); Gemm gemm_op; @@ -293,14 +304,31 @@ template static void select_config( const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, - const at::Tensor& tensor_c, at::Tensor& tensor_d) { + const at::Tensor& tensor_c, at::Tensor& tensor_d) { const auto dprops = at::cuda::getCurrentDeviceProperties(); const auto is_sm8x = dprops->major == 8; if (is_sm8x) { - if constexpr (std::is_same::value && + if constexpr (std::is_same::value && + std::is_same::value) { + // TODO: add some tuning + using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 128>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + using ThreadblockSwizzle = + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>; + constexpr auto NumStages = 3; + rowwise_scaled_linear_kernel_cutlass_sm8x< + ThreadblockShape, WarpShape, InstructionShape, ThreadblockSwizzle, + NumStages, ElementA, ElementB, Types...>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + return; + } else if constexpr (std::is_same::value && std::is_same::value) { using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + using ThreadblockSwizzle = + cutlass::gemm::threadblock::ThreadblockSwizzleStreamK; // A minimal heuristic to improve performance for small number // of inputs cases. @@ -308,27 +336,27 @@ static void select_config( using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 128>; using WarpShape = cutlass::gemm::GemmShape<16, 32, 128>; constexpr auto NumStages = 6; - s8s4_linear_kernel_cutlass_sm8x< - ThreadblockShape, WarpShape, InstructionShape, NumStages, ElementA, - ElementB, Types...>( + rowwise_scaled_linear_kernel_cutlass_sm8x< + ThreadblockShape, WarpShape, InstructionShape, ThreadblockSwizzle, + NumStages, ElementA, ElementB, Types...>( tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); } else if (tensor_a.size(0) <= 32) { using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 128>; using WarpShape = cutlass::gemm::GemmShape<32, 32, 128>; constexpr auto NumStages = 5; - s8s4_linear_kernel_cutlass_sm8x< - ThreadblockShape, WarpShape, InstructionShape, NumStages, ElementA, - ElementB, Types...>( + rowwise_scaled_linear_kernel_cutlass_sm8x< + ThreadblockShape, WarpShape, InstructionShape, ThreadblockSwizzle, + NumStages, ElementA, ElementB, Types...>( tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); } else { using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 128>; using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; constexpr auto NumStages = 4; - s8s4_linear_kernel_cutlass_sm8x< - ThreadblockShape, WarpShape, InstructionShape, NumStages, ElementA, - ElementB, Types...>( + rowwise_scaled_linear_kernel_cutlass_sm8x< + ThreadblockShape, WarpShape, InstructionShape, ThreadblockSwizzle, + NumStages, ElementA, ElementB, Types...>( tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); } @@ -341,36 +369,11 @@ static void select_config( dprops->minor, " for given operands"); } -template -static void -dispatch_on_tensor_a_and_tensor_b( - const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, - const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, - const at::Tensor& tensor_c, at::Tensor& tensor_d) { - if (tensor_a.scalar_type() == at::ScalarType::Char) { - if (tensor_b.scalar_type() == at::ScalarType::Char) { - if (tensor_a.size(1) == 2 * tensor_b.size(1)) { - using ElementA = int8_t; - using ElementB = cutlass::int4b_t; - using ElementAccumulator = int32_t; - using Operator = cutlass::arch::OpMultiplyAddMixedInputUpcast; - select_config< - ElementA, ElementB, ElementAccumulator, Operator, Types...>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); - } - return; - } - } - - TORCH_CHECK(false, - __func__, " : Operator not supported for combination of data ", - "types ", tensor_a.scalar_type(), " for first operand and ", - tensor_b.scalar_type(), " for second operand"); -} - - -template +template< + typename ElementA, + typename ElementB, + typename ElementOutput, + typename... Types> static void dispatch_on_tensor_c( const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, @@ -379,8 +382,8 @@ dispatch_on_tensor_c( if (tensor_c.numel() == 0) { using ElementC = ElementOutput; using UseTensorC = std::false_type; - dispatch_on_tensor_a_and_tensor_b< - ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>( + select_config< + ElementA, ElementB, ElementOutput, ElementC, UseTensorC, Types...>( tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); return; @@ -389,15 +392,15 @@ dispatch_on_tensor_c( using UseTensorC = std::true_type; if (tensor_c.scalar_type() == at::ScalarType::Half) { using ElementC = cutlass::half_t; - dispatch_on_tensor_a_and_tensor_b< - ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>( + select_config< + ElementA, ElementB, ElementOutput, ElementC, UseTensorC, Types...>( tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); return; } else if (tensor_c.scalar_type() == at::ScalarType::BFloat16) { using ElementC = cutlass::bfloat16_t; - dispatch_on_tensor_a_and_tensor_b< - ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>( + select_config< + ElementA, ElementB, ElementOutput, ElementC, UseTensorC, Types...>( tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); return; @@ -408,6 +411,7 @@ dispatch_on_tensor_c( tensor_c.scalar_type(), " for addend"); } +template static void dispatch_on_tensor_a_scale_and_tensor_b_scale( const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, @@ -423,7 +427,8 @@ dispatch_on_tensor_a_scale_and_tensor_b_scale( using ElementAScale = cutlass::half_t; using ElementBScale = cutlass::half_t; using ElementOutput = cutlass::half_t; - dispatch_on_tensor_c( + dispatch_on_tensor_c( tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); return; } else if (tensor_a_scale.scalar_type() == at::ScalarType::BFloat16 && @@ -431,7 +436,8 @@ dispatch_on_tensor_a_scale_and_tensor_b_scale( using ElementAScale = cutlass::bfloat16_t; using ElementBScale = cutlass::bfloat16_t; using ElementOutput = cutlass::bfloat16_t; - dispatch_on_tensor_c( + dispatch_on_tensor_c( tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); return; } @@ -443,8 +449,9 @@ dispatch_on_tensor_a_scale_and_tensor_b_scale( " for second operand scale"); } +template void -check_inputs( +rowwise_scaled_linear_cutlass_check_inputs( const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq, const at::Tensor& w_scale, const at::Tensor& bias) { // Validate layouts of arguments. @@ -483,9 +490,10 @@ check_inputs( // Validate sizes of arguments. const auto xq_sizes = xq.sizes().vec(); - TORCH_CHECK(xq_sizes.back() == 2 * wq.size(1), - __func__, " : Expected xq argument to have ", 2 * wq.size(1), - " columns, but got ", xq_sizes.back()); + TORCH_CHECK(xq_sizes.back() == wq.size(1) || + xq_sizes.back() == 2 * wq.size(1), + __func__, " : Expected xq argument to have ", wq.size(1), " or ", + 2 * wq.size(1), " columns, but got ", xq_sizes.back()); const auto x_scale_sizes = x_scale.sizes().vec(); for (auto i = 0; i < x_scale_sizes.size(); ++i) TORCH_CHECK(x_scale_sizes[i] == xq_sizes[i], @@ -525,57 +533,48 @@ check_inputs( } #endif -// Perform linear operation, using corresponding CUTLASS mixed -// data-types GEMM kernel, to given arguments: -// result = (xq * x_scale) @ (wq * w_scale).T + bias -// Notes: The "x_scale" tensor is expected to be a vector, of size -// equal to number of rows of "xq" tensor. The "w_scale" tensor is -// expected to be a vector, of size equal to number of rows of "wq" -// tensor. The "bias" tensor is expected to be a vector, of size equal -// to number of rows of "wq" tensor. +// Perform linear operation, using corresponding CUTLASS datatypes +// GEMM kernel, to given arguments - result produced is: +// (tensor_a * tensor_a_scale) @ (tensor_b * tensor_b_scale).T + tensor_c +// +// Notes: The "tensor_a" and "tensor_b" are expected to be 2D tensors. +// The "tensor_a_scale" tensor is expected to be a vector, of size +// equal to number of rows of "tensor_a" tensor. The "tensor_b_scale" +// tensor is expected to be a vector, of size equal to number of rows +// of "tensor_b" tensor. The "tensor_c" tensor is expected to be a +// vector, of size equal to number of rows of "tensor_b" tensor. +template at::Tensor -s8s4_linear_cutlass( +rowwise_scaled_linear_cutlass( const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq, const at::Tensor& w_scale, const at::Tensor& bias) { -#if defined(BUILD_S8S4_LINEAR_CUTLASS) +#if defined(BUILD_ROWWISE_SCALED_LINEAR_CUTLASS) // Check inputs. - check_inputs(xq, x_scale, wq, w_scale, bias); + rowwise_scaled_linear_cutlass_check_inputs( + xq, x_scale, wq, w_scale, bias); // Squash the input tensors as appropriate. const auto xq_sizes = xq.sizes().vec(); const auto xq_2d = xq.reshape({-1, xq_sizes.back()}); - const auto x_scale_sizes = x_scale.sizes().vec(); const auto x_scale_1d = x_scale.reshape({-1}); const auto w_scale_1d = w_scale.reshape({-1}); - // Introduce alias names for arguments, according to the CUTLASS - // naming conventions. - const auto& tensor_a = xq_2d; - const auto& tensor_a_scale = x_scale_1d; - const auto& tensor_b = wq; - const auto& tensor_b_scale = w_scale_1d; - const auto& tensor_c = bias; - - // Create output tensor. - at::Tensor tensor_d = - tensor_a_scale.new_empty({tensor_a.size(0), tensor_b.size(0)}); + // Create result tensor. + at::Tensor result = + x_scale.new_empty({xq_2d.size(0), wq.size(0)}); // Dispatch to appropriate kernel template. - dispatch_on_tensor_a_scale_and_tensor_b_scale( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); + dispatch_on_tensor_a_scale_and_tensor_b_scale( + xq_2d, x_scale_1d, wq, w_scale_1d, bias, result); - // Reshape and return output tensor. - auto tensor_d_sizes = xq_sizes; - tensor_d_sizes.back() = wq.size(0); - return tensor_d.reshape(tensor_d_sizes); + // Reshape and return result tensor. + auto result_sizes = xq_sizes; + result_sizes.back() = wq.size(0); + return result.reshape(result_sizes); #else TORCH_CHECK_NOT_IMPLEMENTED(false, __func__); return at::Tensor{}; #endif } -TORCH_LIBRARY_IMPL(torchao, CUDA, m) { - m.impl("torchao::s8s4_linear_cutlass", &s8s4_linear_cutlass); -} - } // namespace torchao diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s4s4.cu b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s4s4.cu new file mode 100644 index 0000000000..9a64b2bdfb --- /dev/null +++ b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s4s4.cu @@ -0,0 +1,28 @@ +#include + +#include "rowwise_scaled_linear_cutlass.cuh" + +namespace torchao { + +at::Tensor +rowwise_scaled_linear_cutlass_s4s4( + const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq, + const at::Tensor& w_scale, const at::Tensor& bias) { + // Validate input datatypes. + TORCH_CHECK(xq.dtype() == at::kChar && wq.dtype() == at::kChar, + __func__, " : The input datatypes combination ", xq.dtype(), + " for xq and ", wq.dtype(), " for wq is not supported"); + + // Dispatch to appropriate kernel template. + using ElementA = cutlass::int4b_t; + using ElementB = cutlass::int4b_t; + return rowwise_scaled_linear_cutlass( + xq, x_scale, wq, w_scale, bias); +} + +TORCH_LIBRARY_IMPL(torchao, CUDA, m) { + m.impl("torchao::rowwise_scaled_linear_cutlass_s4s4", + &rowwise_scaled_linear_cutlass_s4s4); +} + +} // namespace torchao diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s8s4.cu b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s8s4.cu new file mode 100644 index 0000000000..752c557e79 --- /dev/null +++ b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s8s4.cu @@ -0,0 +1,28 @@ +#include + +#include "rowwise_scaled_linear_cutlass.cuh" + +namespace torchao { + +at::Tensor +rowwise_scaled_linear_cutlass_s8s4( + const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq, + const at::Tensor& w_scale, const at::Tensor& bias) { + // Validate input datatypes. + TORCH_CHECK(xq.dtype() == at::kChar && wq.dtype() == at::kChar, + __func__, " : The input datatypes combination ", xq.dtype(), + " for xq and ", wq.dtype(), " for wq is not supported"); + + // Dispatch to appropriate kernel template. + using ElementA = int8_t; + using ElementB = cutlass::int4b_t; + return rowwise_scaled_linear_cutlass( + xq, x_scale, wq, w_scale, bias); +} + +TORCH_LIBRARY_IMPL(torchao, CUDA, m) { + m.impl("torchao::rowwise_scaled_linear_cutlass_s8s4", + &rowwise_scaled_linear_cutlass_s8s4); +} + +} // namespace torchao diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index ef8691699e..54f4a72811 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -21,6 +21,8 @@ _linear_int8_act_int8_weight_block_sparse_impl, ) from torchao.dtypes.uintx.cutlass_int4_packed_layout import ( + _linear_int4_act_int4_weight_cutlass_check, + _linear_int4_act_int4_weight_cutlass_impl, _linear_int8_act_int4_weight_cutlass_check, _linear_int8_act_int4_weight_cutlass_impl, ) @@ -155,6 +157,10 @@ def _register_aqt_quantized_linear_dispatches(): _linear_int8_act_int4_weight_cutlass_check, _linear_int8_act_int4_weight_cutlass_impl, ), + ( + _linear_int4_act_int4_weight_cutlass_check, + _linear_int4_act_int4_weight_cutlass_impl, + ), ( _linear_fp_act_uint4_weight_cpu_check, _linear_fp_act_uint4_weight_cpu_impl, diff --git a/torchao/dtypes/uintx/cutlass_int4_packed_layout.py b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py index a6412ec88c..037ae1f3ad 100644 --- a/torchao/dtypes/uintx/cutlass_int4_packed_layout.py +++ b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py @@ -144,13 +144,47 @@ def _linear_int8_act_int4_weight_cutlass_check(input_tensor, weight_tensor, bias def _linear_int8_act_int4_weight_cutlass_impl(input_tensor, weight_tensor, bias): - from torchao.ops import s8s4_linear_cutlass + from torchao.ops import rowwise_scaled_linear_cutlass_s8s4 weight = weight_tensor.tensor_impl.int_data weight_scale = weight_tensor.tensor_impl.scale input = input_tensor.tensor_impl.int_data input_scale = input_tensor.tensor_impl.scale - out = s8s4_linear_cutlass(input, input_scale, weight, weight_scale, bias) + out = rowwise_scaled_linear_cutlass_s8s4( + input, input_scale, weight, weight_scale, bias + ) + + return out + + +def _linear_int4_act_int4_weight_cutlass_check(input_tensor, weight_tensor, bias): + return ( + isinstance(input_tensor, AffineQuantizedTensor) + and _aqt_is_int4(input_tensor) + and input_tensor.dtype in (torch.float16, torch.bfloat16) + and len(input_tensor.shape) >= 2 + and input_tensor.tensor_impl.scale.dtype == input_tensor.dtype + and len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) - 1 + and isinstance(weight_tensor, AffineQuantizedTensor) + and _aqt_is_int4(weight_tensor) + and weight_tensor.dtype == input_tensor.dtype + and len(weight_tensor.shape) == 2 + and weight_tensor.tensor_impl.scale.dtype == weight_tensor.dtype + and len(weight_tensor.tensor_impl.scale.shape) == 1 + ) + + +def _linear_int4_act_int4_weight_cutlass_impl(input_tensor, weight_tensor, bias): + from torchao.ops import rowwise_scaled_linear_cutlass_s4s4 + + weight = weight_tensor.tensor_impl.int_data + weight_scale = weight_tensor.tensor_impl.scale + input = input_tensor.tensor_impl.int_data + input_scale = input_tensor.tensor_impl.scale + + out = rowwise_scaled_linear_cutlass_s4s4( + input, input_scale, weight, weight_scale, bias + ) return out diff --git a/torchao/ops.py b/torchao/ops.py index f4b55c4951..8b573876f2 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -20,7 +20,10 @@ "marlin_qqq_gemm(Tensor x, Tensor weight_marlin, Tensor s_tok, Tensor s_ch, Tensor s_group, Tensor workspace, int size_m, int size_n, int size_k) -> Tensor" ) lib.define( - "s8s4_linear_cutlass(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor" + "rowwise_scaled_linear_cutlass_s4s4(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor" +) +lib.define( + "rowwise_scaled_linear_cutlass_s8s4(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor" ) @@ -514,7 +517,7 @@ def _( return torch.empty((size_m, size_n), dtype=torch.float16, device=x.device) -def s8s4_linear_cutlass( +def rowwise_scaled_linear_cutlass_s8s4( input: Tensor, input_scale: Tensor, weight: Tensor, @@ -522,23 +525,23 @@ def s8s4_linear_cutlass( bias: Tensor, ) -> Tensor: """ - CUTLASS-based W4A8 linear operator. + CUTLASS-based row-wise scaled W4A8 linear operator. Args: - input: input tensor, quantized to 8-bit integer values. + input: quantized input tensor, in row-major layout. input_scale: scale factors for input tensor, has to be tensor of the same shape as the input tensor, minus the last dimension. - weight: weight matrix, quantized to 4-bit integer values, in row-major layout. + weight: quantized weight matrix, in row-major layout. weight_scale: scale factors for weight tensor, one value per row of weight matrix (thus also tensor of the same shape as the weight tensor, minus the last dimension). bias: a vector of size equal to number of rows of weight tensor, or None. Returns: output: result tensor, in row-major layout. """ - return torch.ops.torchao.s8s4_linear_cutlass.default( + return torch.ops.torchao.rowwise_scaled_linear_cutlass_s8s4.default( input, input_scale, weight, weight_scale, bias ) -@register_custom_op("torchao::s8s4_linear_cutlass") +@register_custom_op("torchao::rowwise_scaled_linear_cutlass_s8s4") def _( input: Tensor, input_scale: Tensor, @@ -546,72 +549,46 @@ def _( weight_scale: Tensor, bias: Tensor, ) -> Tensor: - # Validate dtypes. - torch._check( - input.dtype == torch.int8, - lambda: f"input dtype {input.dtype} instead of {torch.int8}", - ) - torch._check( - input_scale.dtype in (torch.float16, torch.bfloat16), - lambda: f"input_scale dtype {input_scale.dtype} instead of {torch.float16} or {torch.bfloat16}", - ) - torch._check( - weight.dtype == torch.int8, - lambda: f"weight dtype {weight.dtype} instead of {torch.int8}", - ) - torch._check( - weight_scale.dtype == input_scale.dtype, - lambda: f"weight_scale dtype {weight_scale.dtype} instead of {input_scale.dtype}", - ) - if bias is not None: - torch._check( - bias.dtype == input_scale.dtype, - lambda: f"bias dtype {weight_scale.dtype} instead of {input_scale.dtype}", - ) - - # Validate dims. - torch._check(input.dim() >= 2, lambda: f"input is {input.dim()}D instead of >=2D") - torch._check( - input_scale.dim() == input.dim() - 1, - lambda: f"input_scale is {input_scale.dim()}D instead of {input.dim() - 1}D", - ) - torch._check(weight.dim() == 2, lambda: f"weight is {weight.dim()}D instead of 2D") - torch._check( - weight_scale.dim() == 1 or weight_scale.dim() == 2, - lambda: f"weight_scale is {weight_scale.dim()}D instead of 1D or 2D", - ) - if bias is not None: - torch._check(bias.dim() == 1, lambda: f"bias is {bias.dim()}D instead of 1D") - - # Validate shapes. - torch._check( - input.shape[-1] == 2 * weight.shape[-1], - lambda: "input and weight shapes do not match for matrix product", - ) - for i in range(input_scale.dim()): - torch._check( - input_scale.shape[i] == input.shape[i], - lambda: f"input_scale and input shapes do not match at position {i}", - ) - torch._check( - weight_scale.numel() == weight.shape[0], - lambda: f"weight_scale has {weight_scale.numel()} elements instead of {weight.shape[0]}", - ) - if bias is not None: - torch._check( - bias.numel() == weight.shape[0], - lambda: f"bias has {bias.numel()} elements instead of {weight.shape[0]}", - ) - - # Validate strides (input, input_scales and weight_scales will be - # reshape()-d by the operator, so no need to check strides for - # them). - torch._check(weight.stride(-1) == 1, lambda: "weight is not in row-major layout") - if bias is not None: - torch._check(bias.is_contiguous(), lambda: "bias is not contiguous") + # No checks here, as detailed checks are performed by the + # operator itself. return torch.empty( (*input.shape[:-1], weight.shape[0]), dtype=input_scale.dtype, device=input.device, ) + + +def rowwise_scaled_linear_cutlass_s4s4( + input: Tensor, + input_scale: Tensor, + weight: Tensor, + weight_scale: Tensor, + bias: Tensor, +) -> Tensor: + """ + CUTLASS-based row-wise scaled W4A4 linear operator. + Args: + input: quantized input tensor, in row-major layout. + input_scale: scale factors for input tensor, has to be tensor of the same shape as the input tensor, minus the last dimension. + weight: quantized weight matrix, in row-major layout. + weight_scale: scale factors for weight tensor, one value per row of weight matrix (thus also tensor of the same shape as the weight tensor, minus the last dimension). + bias: a vector of size equal to number of rows of weight tensor, or None. + Returns: + output: result tensor, in row-major layout. + """ + + return torch.ops.torchao.rowwise_scaled_linear_cutlass_s4s4.default( + input, input_scale, weight, weight_scale, bias + ) + + +@register_custom_op("torchao::rowwise_scaled_linear_cutlass_s4s4") +def _( + input: Tensor, + input_scale: Tensor, + weight: Tensor, + weight_scale: Tensor, + bias: Tensor, +) -> Tensor: + return input_scale.new_empty(*input.shape[:-1], weight.shape[0]) diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index d0d29cf4be..aa4a51d497 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -50,6 +50,7 @@ float8_weight_only, fpx_weight_only, gemlite_uintx_weight_only, + int4_dynamic_activation_int4_weight, int4_weight_only, int8_dynamic_activation_int4_weight, int8_dynamic_activation_int8_semi_sparse_weight, @@ -102,6 +103,7 @@ "ALL_AUTOQUANT_CLASS_LIST", # top level API - manual "quantize_", + "int4_dynamic_activation_int4_weight", "int8_dynamic_activation_int4_weight", "int8_dynamic_activation_int8_weight", "int8_dynamic_activation_int8_semi_sparse_weight", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 3a73b97ad1..95908a79eb 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -654,6 +654,58 @@ def int8_dynamic_activation_int4_weight( ) +def apply_int4_dynamic_activation_int4_weight_quant( + weight: torch.Tensor, + layout=CutlassInt4PackedLayout(), + mapping_type=MappingType.SYMMETRIC, + act_mapping_type=MappingType.SYMMETRIC, +): + if not isinstance(layout, CutlassInt4PackedLayout): + raise NotImplementedError( + f"Only CutlassInt4PackedLayout layout is supported. Received {layout}." + ) + if mapping_type != MappingType.SYMMETRIC: + raise NotImplementedError("Only mapping_type=SYMMETRIC is supported.") + if act_mapping_type != MappingType.SYMMETRIC: + raise NotImplementedError("Only act_mapping_type=SYMMETRIC is supported.") + + weight = to_affine_quantized_intx( + weight, + mapping_type=mapping_type, + block_size=(1, weight.shape[1]), + target_dtype=torch.int8, + quant_min=-8, + quant_max=7, + eps=torch.finfo(torch.float32).eps, + _layout=layout, + ) + weight = to_linear_activation_quantized( + weight, + _int4_symm_per_token_quant_cutlass, + ) + return weight + + +def int4_dynamic_activation_int4_weight( + layout=CutlassInt4PackedLayout(), + mapping_type=MappingType.SYMMETRIC, + act_mapping_type=MappingType.SYMMETRIC, +): + """Applies int4 dynamic per token symmetric activation quantization and int4 per row weight symmetric quantization to linear + + Args: + `layout`: layout type for quantized weight tensor, only supports `MarlinQQQLayout()` and `CutlassInt4PackedLayout()` for now + `mapping_type`: quantization type for weight, controls the weight quantization is symmetric or asymmetric + `act_mapping_type`: quantization type for activation, controls the activation quantization is symmetric or asymmetric + """ + return _get_linear_subclass_inserter( + apply_int4_dynamic_activation_int4_weight_quant, + layout=layout, + mapping_type=mapping_type, + act_mapping_type=act_mapping_type, + ) + + def gemlite_uintx_weight_only( group_size: Optional[int] = 64, bit_width: int = 4, @@ -855,6 +907,19 @@ def _int8_symm_per_token_reduced_range_quant_cutlass( ) +def _int4_symm_per_token_quant_cutlass(x: torch.Tensor) -> torch.Tensor: + return to_affine_quantized_intx( + x, + mapping_type=MappingType.SYMMETRIC, + block_size=_get_per_token_block_size(x), + target_dtype=torch.int8, + quant_min=-8, + quant_max=7, + eps=1e-5, + _layout=CutlassInt4PackedLayout(), + ) + + def int8_dynamic_activation_int8_weight( layout=PlainLayout(), act_mapping_type=MappingType.SYMMETRIC, @@ -1294,6 +1359,7 @@ def apply_quant_llm(weight: torch.Tensor) -> torch.Tensor: _int8_asymm_per_token_quant, _int8_symm_per_token_reduced_range_quant, _int8_symm_per_token_reduced_range_quant_cutlass, + _int4_symm_per_token_quant_cutlass, _input_activation_quant_func_fp8, ] )