diff --git a/candle-kernels/src/cast.cu b/candle-kernels/src/cast.cu index 010f40979..ab60dfefd 100644 --- a/candle-kernels/src/cast.cu +++ b/candle-kernels/src/cast.cu @@ -1,5 +1,6 @@ #include "cuda_utils.cuh" #include +#include "dummy_bf16.cuh" template __device__ void cast_( @@ -96,6 +97,28 @@ __device__ void cast_through( } } +template +__device__ void cast_bf16_dummy( + const size_t numel, + const size_t num_dims, + const size_t *info, + const uint16_t *inp, + T *out +) { + const size_t *dims = info; + const size_t *strides = info + num_dims; + if (info == nullptr || is_contiguous(num_dims, dims, strides)) { + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + out[i] = static_cast(bf16_to_f32(inp[i])); + } + } + else { + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + unsigned strided_i = get_strided_index(i, num_dims, dims, strides); + out[i] = static_cast(bf16_to_f32(inp[strided_i])); + } + } +} #define CAST_OP(SRC_TYPENAME, DST_TYPENAME, FN_NAME) \ extern "C" __global__ void FN_NAME( \ @@ -143,6 +166,17 @@ extern "C" __global__ void FN_NAME( \ cast_through(numel, num_dims, info, inp, out); \ } \ +#define CAST_BF16_FALLBACK_OP(DST_TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const size_t numel, \ + const size_t num_dims, \ + const size_t *info, \ + const uint16_t *inp, \ + DST_TYPENAME *out \ +) { \ + cast_bf16_dummy(numel, num_dims, info, inp, out); \ +} \ + #if __CUDA_ARCH__ >= 800 #include "cuda_fp8.h" #include "cuda_bf16.h" @@ -197,6 +231,14 @@ CAST_OP(float, __half, cast_f32_f16) CAST_OP(double, __half, cast_f64_f16) CAST_OP(int32_t, __half, cast_i32_f16 ) CAST_THROUGH_OP(__half, int32_t, float, cast_f16_i32) + +#if __CUDA_ARCH__ < 800 +CAST_BF16_FALLBACK_OP(uint32_t, cast_bf16_u32) +CAST_BF16_FALLBACK_OP(float, cast_bf16_f32) +CAST_BF16_FALLBACK_OP(double, cast_bf16_f64) +CAST_BF16_FALLBACK_OP(__half, cast_bf16_f16) +CAST_BF16_FALLBACK_OP(int32_t, cast_bf16_i32) +#endif #endif CAST_OP(uint32_t, uint32_t, cast_u32_u32) diff --git a/candle-kernels/src/dummy_bf16.cuh b/candle-kernels/src/dummy_bf16.cuh new file mode 100644 index 000000000..ea0f8b3e5 --- /dev/null +++ b/candle-kernels/src/dummy_bf16.cuh @@ -0,0 +1,56 @@ +#include + +__device__ __forceinline__ float bf16_to_f32(const uint16_t i) +{ + // If NaN, keep current mantissa but also set most significant mantissa bit + if ((i & 0x7FFFu) > 0x7F80u) { + // NaN path + uint32_t tmp = ((static_cast(i) | 0x0040u) << 16); + union { + uint32_t as_int; + float as_float; + } u; + u.as_int = tmp; + return u.as_float; + // Alternatively: + // return __int_as_float(((static_cast(i) | 0x0040u) << 16)); + } else { + // Normal path + uint32_t tmp = (static_cast(i) << 16); + union { + uint32_t as_int; + float as_float; + } u; + u.as_int = tmp; + return u.as_float; + // Alternatively: + // return __int_as_float(static_cast(i) << 16); + } +} + +// Convert FP32 (float) to BF16 (unsigned short) +__device__ __forceinline__ uint16_t f32_to_bf16(const float value) +{ + // Reinterpret float bits as uint32_t + union { + float as_float; + uint32_t as_int; + } u; + u.as_float = value; + uint32_t x = u.as_int; + + // Check for NaN + if ((x & 0x7FFF'FFFFu) > 0x7F80'0000u) { + // Keep high part of current mantissa but also set most significant mantissa bit + return static_cast((x >> 16) | 0x0040u); + } + + // Round and shift + constexpr uint32_t round_bit = 0x0000'8000u; // bit 15 + if (((x & round_bit) != 0) && ((x & (3 * round_bit - 1)) != 0)) { + // Round half to even (or to odd) depends on your preference + return static_cast((x >> 16) + 1); + } else { + return static_cast(x >> 16); + } +}