Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA: use arch list for compatibility check #11775

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions ggml/src/ggml-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,6 @@ GGML_TABLE_BEGIN(uint8_t, ksigns_iq2xs, 128)
240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
GGML_TABLE_END()

//#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A // lowest compute capability for integer intrinsics
GGML_TABLE_BEGIN(uint64_t, ksigns64, 128)
0x0000000000000000, 0xff000000000000ff, 0xff0000000000ff00, 0x000000000000ffff,
0xff00000000ff0000, 0x0000000000ff00ff, 0x0000000000ffff00, 0xff00000000ffffff,
Expand Down Expand Up @@ -508,7 +507,6 @@ GGML_TABLE_BEGIN(uint64_t, ksigns64, 128)
0x00ffffffff000000, 0xffffffffff0000ff, 0xffffffffff00ff00, 0x00ffffffff00ffff,
0xffffffffffff0000, 0x00ffffffffff00ff, 0x00ffffffffffff00, 0xffffffffffffffff,
GGML_TABLE_END()
//#endif


GGML_TABLE_BEGIN(uint64_t, iq2xxs_grid, 256)
Expand Down
65 changes: 60 additions & 5 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,47 @@
#define GGML_CUDA_CC_QY1 210
#define GGML_CUDA_CC_QY2 220

#ifdef __CUDA_ARCH_LIST__
constexpr bool ggml_cuda_has_arch_impl(int) {
return false;
}

template<class ... Archs>
constexpr bool ggml_cuda_has_arch_impl(const int arch, const int first, Archs... rest) {
return arch == first || ggml_cuda_has_arch_impl(arch, rest...);
}

constexpr bool ggml_cuda_has_arch(const int arch) {
return ggml_cuda_has_arch_impl(arch, __CUDA_ARCH_LIST__);
}

constexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur) {
if (cur == 0) {
GGML_ABORT("ggml was not compiled with any CUDA arch <= %d", arch);
}
return cur;
}

template<class ... Archs>
constexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur, const int first, Archs... rest) {
if (first <= arch && first > cur) {
return ggml_cuda_highest_compiled_arch_impl(arch, first, rest...);
} else {
return ggml_cuda_highest_compiled_arch_impl(arch, cur, rest...);
}
}

constexpr int ggml_cuda_highest_compiled_arch(const int arch) {
return ggml_cuda_highest_compiled_arch_impl(arch, 0, __CUDA_ARCH_LIST__);
}
#else
static int ggml_cuda_highest_compiled_arch(const int arch) {
return arch;
}
#endif // __CUDA_ARCH_LIST__

// ---------------------------------------------------------------------------------------------------------

#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses

#if defined(_MSC_VER)
Expand Down Expand Up @@ -162,18 +203,32 @@ typedef float2 dfloat2;
#define FLASH_ATTN_AVAILABLE
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)

static constexpr bool fast_fp16_available(const int cc) {
static bool fp16_available(const int cc) {
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL;
}

static bool fast_fp16_available(const int cc) {
return fp16_available(cc) && cc != 610;
}

// To be used for feature selection of external libraries, e.g. cuBLAS.
static bool fast_fp16_hardware_available(const int cc) {
return cc >= GGML_CUDA_CC_PASCAL && cc != 610;
}

// Any FP16 tensor cores are available.
static constexpr bool fp16_mma_available(const int cc) {
// Any FP16 tensor core instructions are available for ggml code.
static bool fp16_mma_available(const int cc) {
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA;
}

// To be used for feature selection of external libraries, e.g. cuBLAS.
static bool fp16_mma_hardware_available(const int cc) {
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA;
}

// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
static constexpr bool new_mma_available(const int cc) {
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_TURING;
static bool new_mma_available(const int cc) {
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
}

static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-cuda/convert.cu
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
case GGML_TYPE_Q5_1:
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
case GGML_TYPE_Q8_0:
if (ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= GGML_CUDA_CC_PASCAL) {
if (fp16_available(ggml_cuda_info().devices[ggml_cuda_get_device()].cc)) {
return dequantize_block_q8_0_f16_cuda;
}
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
Expand Down
12 changes: 6 additions & 6 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1867,14 +1867,14 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor

const int cc = ggml_cuda_info().devices[id].cc;
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc);
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_available(cc);
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
}
} else {
const int cc = ggml_cuda_info().devices[ctx.device].cc;
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc);
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_available(cc);
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
}

// debug helpers
Expand Down Expand Up @@ -3205,8 +3205,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
return true;
}
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
return cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) &&
op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
}
case GGML_OP_CROSS_ENTROPY_LOSS:
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
Expand Down
9 changes: 5 additions & 4 deletions ggml/src/ggml-cuda/mmq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ void ggml_cuda_op_mul_mat_q(
const int64_t stride00 = ne00 / ggml_blck_size(src0->type);

int id = ggml_cuda_get_device();
const int compute_capability = ggml_cuda_info().devices[id].cc;
const int cc = ggml_cuda_info().devices[id].cc;

// the main device has a larger memory buffer to hold the results from all GPUs
// nrows_dst == nrows of the matrix that the kernel writes into
Expand All @@ -27,7 +27,8 @@ void ggml_cuda_op_mul_mat_q(
// The stream-k decomposition is only faster for recent NVIDIA GPUs.
// Also its fixup needs to allocate a temporary buffer in the memory pool.
// There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
const bool use_stream_k = compute_capability >= GGML_CUDA_CC_VOLTA && compute_capability < GGML_CUDA_CC_OFFSET_AMD && src1_ncols == ne11;
const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA &&
cc < GGML_CUDA_CC_OFFSET_AMD && src1_ncols == ne11;
const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst, use_stream_k};

switch (src0->type) {
Expand Down Expand Up @@ -136,7 +137,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
return true;
}

if (cc < GGML_CUDA_CC_DP4A) {
if (ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_DP4A) {
return false;
}

Expand All @@ -145,7 +146,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
#endif //GGML_CUDA_FORCE_MMQ

if (cc < GGML_CUDA_CC_OFFSET_AMD) {
return cc < GGML_CUDA_CC_VOLTA || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
}

return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc) && !GGML_CUDA_CC_IS_GCN(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
Expand Down
14 changes: 8 additions & 6 deletions ggml/src/ggml-cuda/mmq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,13 @@ struct tile_x_sizes {
int sc;
};

static constexpr int get_mmq_x_max_host(const int cc) {
static int get_mmq_x_max_host(const int cc) {
return new_mma_available(cc) ? 128 :
ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ?
#ifdef GGML_CUDA_FORCE_MMQ
cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ? 128 : 64;
128 : 64;
#else
cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ? MMQ_DP4A_MAX_BATCH_SIZE : 64;
MMQ_DP4A_MAX_BATCH_SIZE : 64;
#endif // GGML_CUDA_FORCE_MMQ
}

Expand Down Expand Up @@ -119,8 +120,9 @@ static constexpr __device__ int get_mmq_x_max_device() {
#endif // NEW_MMA_AVAILABLE
}

static constexpr int get_mmq_y_host(const int cc) {
return cc >= GGML_CUDA_CC_OFFSET_AMD ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) : (cc >= GGML_CUDA_CC_VOLTA ? 128 : 64);
static int get_mmq_y_host(const int cc) {
return cc >= GGML_CUDA_CC_OFFSET_AMD ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) :
(ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ? 128 : 64);
}

static constexpr __device__ int get_mmq_y_device() {
Expand Down Expand Up @@ -2828,7 +2830,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
const int mmq_x_max = get_mmq_x_max_host(cc);
const int mmq_y = get_mmq_y_host(cc);
const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
const bool use_stream_k = cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD;
const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD;

int mmq_x_best = 0;
int nparts_best = INT_MAX;
Expand Down
Loading