Skip to content

Commit

Permalink
Add CUDA fallback support for bf16 cast for fallback (#62)
Browse files Browse the repository at this point in the history
* Add cast for fallback

* Only cast for now
  • Loading branch information
EricLBuehler authored Jan 8, 2025
1 parent ea0fa95 commit bac2055
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 0 deletions.
42 changes: 42 additions & 0 deletions candle-kernels/src/cast.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "cuda_utils.cuh"
#include<stdint.h>
#include "dummy_bf16.cuh"

template <typename S, typename T>
__device__ void cast_(
Expand Down Expand Up @@ -96,6 +97,28 @@ __device__ void cast_through(
}
}

template <typename T>
__device__ void cast_bf16_dummy(
const size_t numel,
const size_t num_dims,
const size_t *info,
const uint16_t *inp,
T *out
) {
const size_t *dims = info;
const size_t *strides = info + num_dims;
if (info == nullptr || is_contiguous(num_dims, dims, strides)) {
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
out[i] = static_cast<T>(bf16_to_f32(inp[i]));
}
}
else {
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
unsigned strided_i = get_strided_index(i, num_dims, dims, strides);
out[i] = static_cast<T>(bf16_to_f32(inp[strided_i]));
}
}
}

#define CAST_OP(SRC_TYPENAME, DST_TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
Expand Down Expand Up @@ -143,6 +166,17 @@ extern "C" __global__ void FN_NAME( \
cast_through<SRC_TYPENAME, DST_TYPENAME, INT_TYPENAME>(numel, num_dims, info, inp, out); \
} \

#define CAST_BF16_FALLBACK_OP(DST_TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const size_t numel, \
const size_t num_dims, \
const size_t *info, \
const uint16_t *inp, \
DST_TYPENAME *out \
) { \
cast_bf16_dummy<DST_TYPENAME>(numel, num_dims, info, inp, out); \
} \

#if __CUDA_ARCH__ >= 800
#include "cuda_fp8.h"
#include "cuda_bf16.h"
Expand Down Expand Up @@ -197,6 +231,14 @@ CAST_OP(float, __half, cast_f32_f16)
CAST_OP(double, __half, cast_f64_f16)
CAST_OP(int32_t, __half, cast_i32_f16 )
CAST_THROUGH_OP(__half, int32_t, float, cast_f16_i32)

#if __CUDA_ARCH__ < 800
CAST_BF16_FALLBACK_OP(uint32_t, cast_bf16_u32)
CAST_BF16_FALLBACK_OP(float, cast_bf16_f32)
CAST_BF16_FALLBACK_OP(double, cast_bf16_f64)
CAST_BF16_FALLBACK_OP(__half, cast_bf16_f16)
CAST_BF16_FALLBACK_OP(int32_t, cast_bf16_i32)
#endif
#endif

CAST_OP(uint32_t, uint32_t, cast_u32_u32)
Expand Down
56 changes: 56 additions & 0 deletions candle-kernels/src/dummy_bf16.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#include<stdint.h>

__device__ __forceinline__ float bf16_to_f32(const uint16_t i)
{
// If NaN, keep current mantissa but also set most significant mantissa bit
if ((i & 0x7FFFu) > 0x7F80u) {
// NaN path
uint32_t tmp = ((static_cast<uint32_t>(i) | 0x0040u) << 16);
union {
uint32_t as_int;
float as_float;
} u;
u.as_int = tmp;
return u.as_float;
// Alternatively:
// return __int_as_float(((static_cast<uint32_t>(i) | 0x0040u) << 16));
} else {
// Normal path
uint32_t tmp = (static_cast<uint32_t>(i) << 16);
union {
uint32_t as_int;
float as_float;
} u;
u.as_int = tmp;
return u.as_float;
// Alternatively:
// return __int_as_float(static_cast<uint32_t>(i) << 16);
}
}

// Convert FP32 (float) to BF16 (unsigned short)
__device__ __forceinline__ uint16_t f32_to_bf16(const float value)
{
// Reinterpret float bits as uint32_t
union {
float as_float;
uint32_t as_int;
} u;
u.as_float = value;
uint32_t x = u.as_int;

// Check for NaN
if ((x & 0x7FFF'FFFFu) > 0x7F80'0000u) {
// Keep high part of current mantissa but also set most significant mantissa bit
return static_cast<uint16_t>((x >> 16) | 0x0040u);
}

// Round and shift
constexpr uint32_t round_bit = 0x0000'8000u; // bit 15
if (((x & round_bit) != 0) && ((x & (3 * round_bit - 1)) != 0)) {
// Round half to even (or to odd) depends on your preference
return static_cast<uint16_t>((x >> 16) + 1);
} else {
return static_cast<uint16_t>(x >> 16);
}
}

0 comments on commit bac2055

Please sign in to comment.