Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CUTLASS-based W4A4 #1515

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torchao.dtypes import CutlassInt4PackedLayout, Int4CPULayout, SemiSparseLayout
from torchao.quantization import (
float8_weight_only,
int4_dynamic_activation_int4_weight,
int4_weight_only,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_weight,
Expand Down Expand Up @@ -57,6 +58,7 @@ def get_quantization_functions(
layout=CutlassInt4PackedLayout(),
)
)
base_functions.append(int4_dynamic_activation_int4_weight())

if do_sparse:
base_functions.append(
Expand Down
52 changes: 52 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,5 +607,57 @@ def test_marlin_qqq(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_fact
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("M,N,K", [(1, 256, 512), (18, 512, 256), (17, 256, 512)])
def test_int4_mm_cutlass(M, N, K):
A = torch.randint(-128, 127, size=(M, K // 2), dtype=torch.int8, device="cuda")
B = torch.randint(-128, 127, size=(N, K // 2), dtype=torch.int8, device="cuda")
actual = torchao.ops.int4_mm_cutlass(A, B.T)

# NOTE: A >> 4 will perform sign-bit extension
unpacked_A = torch.stack([A >> 4, A << 4 >> 4], dim=1).reshape(M, K)
unpacked_B = torch.stack([B >> 4, B << 4 >> 4], dim=1).reshape(N, K)
expected = (unpacked_A.float() @ unpacked_B.float().T).to(torch.int32)

torch.testing.assert_close(actual, expected)

# Performs opcheck
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor"]
opcheck(
torch.ops.torchao.int4_mm_cutlass,
(A, B),
test_utils=test_utils,
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("M,N,K", [(1, 256, 512), (18, 512, 256), (17, 256, 512)])
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
def test_scaled_int4_mm_cutlass(M, N, K, dtype):
A = torch.randint(-128, 127, size=(M, K // 2), dtype=torch.int8, device="cuda")
B = torch.randint(-128, 127, size=(N, K // 2), dtype=torch.int8, device="cuda")
row_scale = torch.randn(M, dtype=dtype, device="cuda")
col_scale = torch.randn(N, dtype=dtype, device="cuda")
actual = torchao.ops.scaled_int4_mm_cutlass(A, B.T, row_scale, col_scale)

# NOTE: A >> 4 will perform sign-bit extension
unpacked_A = torch.stack([A >> 4, A << 4 >> 4], dim=2).reshape(M, K)
unpacked_B = torch.stack([B >> 4, B << 4 >> 4], dim=2).reshape(N, K)

expected = unpacked_A.float() @ unpacked_B.float().T
expected = expected * row_scale.view(-1, 1) * col_scale.view(1, -1)
expected = expected.to(dtype)

torch.testing.assert_close(actual, expected)

# Performs opcheck
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor"]
opcheck(
torch.ops.torchao.scaled_int4_mm_cutlass,
(A, B, row_scale, col_scale),
test_utils=test_utils,
)


if __name__ == "__main__":
run_tests()
231 changes: 231 additions & 0 deletions torchao/csrc/cuda/int4_cutlass.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>

// copied from s8s4_linear_cutlass.cu
#if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) && \
defined(CUDA_VERSION) && (CUDA_VERSION >= 11080)
#define BUILD_INT4_MM_CUTLASS
#endif

#if defined(BUILD_INT4_MM_CUTLASS)
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"

#define CUTLASS_STATUS_CHECK(status) \
{ \
TORCH_CHECK(status == cutlass::Status::kSuccess, \
__func__, " : Got CUTLASS error: ", \
cutlassGetStatusString(status)); \
}
#endif

namespace torchao {

#if defined(BUILD_INT4_MM_CUTLASS)
// define common params
using ElementA = cutlass::int4b_t;
using ElementB = cutlass::int4b_t;
using ElementAccumulator = int32_t;
using OpClass = cutlass::arch::OpClassTensorOp;
using ArchTag = cutlass::arch::Sm80;

// how many elements to load at a time -> load 128-bit = 32 x 4-bit
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
#endif

// we will do input checks in python. A and B are stored as int8
torch::Tensor int4_mm_cutlass(torch::Tensor A, torch::Tensor B) {
#if defined(BUILD_INT4_MM_CUTLASS)
int M = A.size(0);
int K = A.size(1) * 2;
int N = B.size(1);
torch::Tensor C = torch::empty({M, N}, A.options().dtype(torch::kInt32));

// some configs for int4 mma
// https://github.com/NVIDIA/cutlass/blob/v3.5.1/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sm80.cu
// using default config. this can be tuned.
using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 128>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>;
// static int const kStages = 3;
using ElementC = int32_t;
using Gemm = cutlass::gemm::device::Gemm<
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know if the universal gemm api can be used?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will look into it. I wrote this quite some time ago...

ElementA, cutlass::layout::RowMajor, // A matrix
ElementB, cutlass::layout::ColumnMajor, // B matrix
ElementC, cutlass::layout::RowMajor, // C matrix
ElementAccumulator, OpClass, ArchTag,
ThreadblockShape, WarpShape, InstructionShape
>;
Gemm::Arguments args {
{M, N, K},
{reinterpret_cast<ElementA *>(A.data_ptr<int8_t>()), K},
{reinterpret_cast<ElementB *>(B.data_ptr<int8_t>()), K},
{C.data_ptr<ElementC>(), N},
{C.data_ptr<ElementC>(), N},
{1, 0} // epilogue
};
Gemm gemm_op;
CUTLASS_STATUS_CHECK(gemm_op(args));
return C;
#else
TORCH_CHECK_NOT_IMPLEMENTED(false, __func__);
return at::Tensor{};
#endif
}

template<
typename ElementC,
typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
int numStages>
void scaled_int4_mm_cutlass_dispatch(torch::Tensor A, torch::Tensor B, torch::Tensor row_scale, torch::Tensor col_scale, torch::Tensor C) {
// problem shape
int M = A.size(0);
int K = A.size(1) * 2;
int N = B.size(1);

constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // 8 for BF16/FP16
using ElementEpilogue = float;
constexpr int numEpilogueStages = 1;

// build epilogue visitor tree
using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
ThreadblockShape, WarpShape, ElementC, AlignmentC, numEpilogueStages
>;

using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
constexpr auto RoundMode = cutlass::FloatRoundStyle::round_to_nearest;
using Multiply = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, ElementEpilogue, ElementEpilogue, RoundMode
>;

// (1, N)
using ColScale = cutlass::epilogue::threadblock::VisitorRowBroadcast<
OutputTileThreadMap, ElementC,
cute::Stride<cute::_0, cute::_1, int32_t> // MNL
>;
using EVTCompute0 = cutlass::epilogue::threadblock::Sm80EVT<Multiply, Accum, ColScale>;

// (M, 1)
using RowScale = cutlass::epilogue::threadblock::VisitorColBroadcast<
OutputTileThreadMap, ElementC,
cute::Stride<cute::_1, cute::_0, int32_t> // MNL
>;
using EVTCompute1 = cutlass::epilogue::threadblock::Sm80EVT<Multiply, EVTCompute0, RowScale>;

using Output = cutlass::epilogue::threadblock::VisitorAuxStore<
OutputTileThreadMap, ElementC, RoundMode,
cute::Stride<int64_t, cute::_1, int64_t> // MNL
>;
using EVTOutput = cutlass::epilogue::threadblock::Sm80EVT<Output, EVTCompute1>;

using EVTKernel = typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
ElementA, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, AlignmentA,
ElementB, cutlass::layout::ColumnMajor, cutlass::ComplexTransform::kNone, AlignmentB,
ElementC, cutlass::layout::RowMajor, AlignmentC,
ElementAccumulator, ElementEpilogue, OpClass, ArchTag,
ThreadblockShape, WarpShape, InstructionShape,
EVTOutput,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
numStages,
cutlass::arch::OpMultiplyAddSaturate, // OpMultiplyAdd does not work
numEpilogueStages
>::GemmKernel;
using DeviceGemm = cutlass::gemm::device::GemmUniversalAdapter<EVTKernel>;

// col_scale, row_scale, and C must have the same dtype
const ElementA *A_ptr = reinterpret_cast<ElementA *>(A.data_ptr<int8_t>());
const ElementB *B_ptr = reinterpret_cast<ElementB *>(B.data_ptr<int8_t>());
const ElementC *col_scale_ptr = reinterpret_cast<ElementC *>(col_scale.data_ptr());
const ElementC *row_scale_ptr = reinterpret_cast<ElementC *>(row_scale.data_ptr());
ElementC *C_ptr = reinterpret_cast<ElementC *>(C.data_ptr());

typename EVTOutput::Arguments callback_args{
{
{
{}, // Accum
{col_scale_ptr, ElementC(0), {cute::_0{}, cute::_1{}, int32_t(N)}}, // ColScale
{} // Multiply
}, // EVTCompute0
{row_scale_ptr, ElementC(0), {cute::_1{}, cute::_0{}, int32_t(M)}}, // RowScale
{} // Multiply
}, // EVTCompute1
{C_ptr, {int64_t{N}, cute::_1{}, int64_t{M*N}}} // EVTOutput
};

typename DeviceGemm::Arguments args(
cutlass::gemm::GemmUniversalMode::kGemm,
cutlass::gemm::GemmCoord{M, N, K},
1, // batch_split
callback_args,
A_ptr, B_ptr, nullptr, nullptr, // unsued C_ptr and D_ptr
M * K, N * K, 0, 0, // batch_stride A, B, C, D
K, K, 0, 0 // stride A, B, C, D
);

DeviceGemm gemm_op;
auto stream = at::cuda::getCurrentCUDAStream();
CUTLASS_STATUS_CHECK(gemm_op.can_implement(args));
CUTLASS_STATUS_CHECK(gemm_op(args, nullptr, stream));
}

// we will do input checks in python. A and B are stored as int8
// this function is based on the following cutlass example
// https://github.com/NVIDIA/cutlass/blob/main/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu
// also with the help of emitted code from cutlass Python
torch::Tensor scaled_int4_mm_cutlass(torch::Tensor A, torch::Tensor B, torch::Tensor row_scale, torch::Tensor col_scale) {
#if defined(BUILD_INT4_MM_CUTLASS)
int M = A.size(0);
int N = B.size(1);
torch::Tensor C = torch::empty({M, N}, row_scale.options());

// some configs for int4 mma
// https://github.com/NVIDIA/cutlass/blob/v3.5.1/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sm80.cu
// using default config. this can be tuned.
using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 128>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>;
constexpr int numStages = 3;

AT_DISPATCH_SWITCH(
drisspg marked this conversation as resolved.
Show resolved Hide resolved
row_scale.scalar_type(),
"scaled_int4_mm_cutlass",
AT_DISPATCH_CASE(
torch::ScalarType::Half,
[&]() {
using ElementC = cutlass::half_t;
scaled_int4_mm_cutlass_dispatch<
ElementC, ThreadblockShape, WarpShape, InstructionShape, numStages>(
A, B, row_scale, col_scale, C);
}
)
AT_DISPATCH_CASE(
torch::ScalarType::BFloat16,
[&]() {
using ElementC = cutlass::bfloat16_t;
scaled_int4_mm_cutlass_dispatch<
ElementC, ThreadblockShape, WarpShape, InstructionShape, numStages>(
A, B, row_scale, col_scale, C);
}
)
);

return C;
#else
TORCH_CHECK_NOT_IMPLEMENTED(false, __func__);
return at::Tensor{};
#endif
}

TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
m.impl("torchao::int4_mm_cutlass", &int4_mm_cutlass);
m.impl("torchao::scaled_int4_mm_cutlass", &scaled_int4_mm_cutlass);
}

} // namespace torchao
6 changes: 6 additions & 0 deletions torchao/dtypes/affine_quantized_tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
_linear_int8_act_int8_weight_block_sparse_impl,
)
from torchao.dtypes.uintx.cutlass_int4_packed_layout import (
_linear_int4_act_int4_weight_cutlass_check,
_linear_int4_act_int4_weight_cutlass_impl,
_linear_int8_act_int4_weight_cutlass_check,
_linear_int8_act_int4_weight_cutlass_impl,
)
Expand Down Expand Up @@ -151,6 +153,10 @@ def _register_aqt_quantized_linear_dispatches():
_linear_int8_act_int4_weight_cutlass_check,
_linear_int8_act_int4_weight_cutlass_impl,
),
(
_linear_int4_act_int4_weight_cutlass_check,
_linear_int4_act_int4_weight_cutlass_impl,
),
]:
register_aqt_quantized_linear_dispatch(dispatch_condition, impl)

Expand Down
36 changes: 36 additions & 0 deletions torchao/dtypes/uintx/cutlass_int4_packed_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,39 @@ def _linear_int8_act_int4_weight_cutlass_impl(input_tensor, weight_tensor, bias)
out = s8s4_linear_cutlass(input, input_scale, weight, weight_scale, bias)

return out


def _linear_int4_act_int4_weight_cutlass_check(input_tensor, weight_tensor, bias):
return (
isinstance(input_tensor, AffineQuantizedTensor)
and _aqt_is_int4(input_tensor)
and input_tensor.dtype in (torch.float16, torch.bfloat16)
and len(input_tensor.shape) >= 2
and input_tensor.tensor_impl.scale.dtype == input_tensor.dtype
and len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) - 1
and isinstance(weight_tensor, AffineQuantizedTensor)
and _aqt_is_int4(weight_tensor)
and weight_tensor.dtype == input_tensor.dtype
and len(weight_tensor.shape) == 2
and weight_tensor.tensor_impl.scale.dtype == weight_tensor.dtype
and len(weight_tensor.tensor_impl.scale.shape) == 1
)


def _linear_int4_act_int4_weight_cutlass_impl(input_tensor, weight_tensor, bias):
from torchao.ops import scaled_int4_mm_cutlass

weight = weight_tensor.tensor_impl.int_data
weight_scale = weight_tensor.tensor_impl.scale
input = input_tensor.tensor_impl.int_data
input_scale = input_tensor.tensor_impl.scale

batch_dims = input_tensor.shape[:-1]
input = input.view(-1, input.shape[-1])
input_scale = input_scale.view(-1)
out = scaled_int4_mm_cutlass(input, weight.T, input_scale, weight_scale)
if bias is not None:
out = out + bias
out = out.view(*batch_dims, out.shape[-1])

return out
Loading
Loading