Skip to content

Commit

Permalink
add cuda fallback bf16 for compute_cap < 8.0
Browse files Browse the repository at this point in the history
  • Loading branch information
haricot committed Jan 7, 2025
1 parent cd63913 commit 1f5a344
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 16 deletions.
3 changes: 3 additions & 0 deletions candle-kernels/src/binary.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ BINARY_OP_OUT(__nv_bfloat16, uint8_t, ge_bf16, x >= y)
#endif

#if __CUDA_ARCH__ >= 530
#include "cuda_bf16.h"
BINARY_OP(__nv_bfloat16, bmul_bf16, x * y)
BINARY_OP(__nv_bfloat16, badd_bf16, x + y)
BINARY_OP(__half, badd_f16, x + y)
BINARY_OP(__half, bdiv_f16, x / y)
BINARY_OP(__half, bmul_f16, x * y)
Expand Down
30 changes: 15 additions & 15 deletions candle-kernels/src/compatibility.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,21 @@ __device__ double atomicAdd(double* address, double val) {
// The 16-bit __half floating-point version of atomicAdd() is only supported by devices of compute capability 7.x and higher.
// Solution adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh#L96-L119
__device__ __half atomicAdd(__half *address, __half val) {
// unsigned int *address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
// unsigned int old = *address_as_ui;
// unsigned int assumed;
// bool unaligned = (size_t) address & 2;
// do {
// assumed = old;
// unsigned int hsum;
// hsum = unaligned ? (old >> 16) : (old & 0xffff);
// hsum = __half_as_ushort(__ushort_as_half(hsum) + val);
// old = atomicCAS(address_as_ui, assumed,
// unaligned ? (old & 0xffff) | (hsum << 16) : (old & 0xffff0000) | hsum
// );

// } while (assumed != old);
// return __ushort_as_half(unaligned ? (old >> 16) : (old & 0xffff));
unsigned int *address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
unsigned int old = *address_as_ui;
unsigned int assumed;
bool unaligned = (size_t) address & 2;
do {
assumed = old;
unsigned int hsum;
hsum = unaligned ? (old >> 16) : (old & 0xffff);
hsum = __half_as_ushort(__ushort_as_half(hsum) + val);
old = atomicCAS(address_as_ui, assumed,
unaligned ? (old & 0xffff) | (hsum << 16) : (old & 0xffff0000) | hsum
);

} while (assumed != old);
return __ushort_as_half(unaligned ? (old >> 16) : (old & 0xffff));
}
#endif

Expand Down
2 changes: 2 additions & 0 deletions candle-kernels/src/cuda_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ __device__ __forceinline__ uint32_t maxg(uint32_t a, uint32_t b) { return max(a,
__device__ __forceinline__ uint8_t ming(uint8_t a, uint8_t b) { return min(a, b); }
__device__ __forceinline__ uint8_t maxg(uint8_t a, uint8_t b) { return max(a, b); }
#if __CUDA_ARCH__ >= 530
#include "cuda_bf16.h"
__device__ __forceinline__ __nv_bfloat16 maxg(__nv_bfloat16 a, __nv_bfloat16 b) { return __hmax_nan(a, b); }
__device__ __forceinline__ __half powg(__half a, __half b) { return __float2half(powf(__half2float(a), __half2float(b))); }
__device__ __forceinline__ bool isnang(__half a) { return __hisnan(a); }
__device__ __forceinline__ __half sqrtg(__half a) { return hsqrt(a); }
Expand Down
2 changes: 2 additions & 0 deletions candle-kernels/src/fill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ COPY2D_OP(uint32_t, copy2d_u32)
COPY2D_OP(int64_t, copy2d_i64)

#if __CUDA_ARCH__ >= 530
#include <cuda_bf16.h>
extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); }
COPY2D_OP(__half, copy2d_f16)
COPY2D_OP(__nv_bfloat16, copy2d_bf16)
#endif

#if __CUDA_ARCH__ >= 800
Expand Down
2 changes: 2 additions & 0 deletions candle-kernels/src/indexing.cu
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ SA_OP(__nv_bfloat16, uint8_t, sa_u8_bf16)
#endif

#if __CUDA_ARCH__ >= 530
#include "cuda_bf16.h"
IS_OP(__nv_bfloat16, uint32_t, is_u32_bf16)
IS_OP(__half, int64_t, is_i64_f16)
IS_OP(__half, uint32_t, is_u32_f16)
IS_OP(__half, uint8_t, is_u8_f16)
Expand Down
4 changes: 4 additions & 0 deletions candle-kernels/src/reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,10 @@ FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argm
#endif

#if __CUDA_ARCH__ >= 530
#include "cuda_bf16.h"
ROPE_OP(__nv_bfloat16, rope_bf16, rope_i_bf16, rope_thd_bf16)
SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16)
RMSNORM_OP(__nv_bfloat16, rmsnorm_bf16)
SOFTMAX_OP(__half, float, softmax_f16)
RMSNORM_OP(__half, rmsnorm_f16)
LAYERNORM_OP(__half, layernorm_f16)
Expand Down
13 changes: 12 additions & 1 deletion candle-kernels/src/unary.cu
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,18 @@ UNARY_OP(__nv_bfloat16, usign_bf16, sign_(x))
UNARY_OP(__nv_bfloat16, usigmoid_bf16, sigmoid_fwd(x))
#endif

#if __CUDA_ARCH__ >= 530
#if __CUDA_ARCH__ >= 530
#include "cuda_bf16.h"
template <typename T>
__device__ __forceinline__ T silu_fwd_fallback(T x) {
const T one = T(1.0f);
const T neg_x = -x;
const T exp_neg_x = expg(neg_x);
return x / (one + exp_neg_x);
}

UNARY_OP(__nv_bfloat16, ucopy_bf16, x)
UNARY_OP(__nv_bfloat16, usilu_bf16, silu_fwd_fallback(x))
UNARY_OP(__half, ucopy_f16, x)
UNARY_OP(__half, uneg_f16, -x)
UNARY_OP(__half, urecip_f16, recipg(x))
Expand Down

0 comments on commit 1f5a344

Please sign in to comment.