Skip to content

Commit

Permalink
bf16 f16 sigmoid:affine binary ternary unary
Browse files Browse the repository at this point in the history
  • Loading branch information
haricot committed Jan 15, 2025
1 parent 1f5a344 commit 2b888da
Show file tree
Hide file tree
Showing 7 changed files with 192 additions and 32 deletions.
1 change: 1 addition & 0 deletions candle-kernels/src/affine.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ AFFINE_OP(__nv_bfloat16, affine_bf16)
#endif

#if __CUDA_ARCH__ >= 530
AFFINE_OP(__nv_bfloat16, affine_bf16)
AFFINE_OP(__half, affine_f16)
#endif

Expand Down
7 changes: 1 addition & 6 deletions candle-kernels/src/binary.cu
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include "binary_op_macros.cuh"
#include<stdint.h>

#if __CUDA_ARCH__ >= 800
#if __CUDA_ARCH__ >= 530
BINARY_OP(__nv_bfloat16, badd_bf16, x + y)
BINARY_OP(__nv_bfloat16, bdiv_bf16, x / y)
BINARY_OP(__nv_bfloat16, bmul_bf16, x * y)
Expand All @@ -14,12 +14,7 @@ BINARY_OP_OUT(__nv_bfloat16, uint8_t, lt_bf16, x < y)
BINARY_OP_OUT(__nv_bfloat16, uint8_t, le_bf16, x <= y)
BINARY_OP_OUT(__nv_bfloat16, uint8_t, gt_bf16, x > y)
BINARY_OP_OUT(__nv_bfloat16, uint8_t, ge_bf16, x >= y)
#endif

#if __CUDA_ARCH__ >= 530
#include "cuda_bf16.h"
BINARY_OP(__nv_bfloat16, bmul_bf16, x * y)
BINARY_OP(__nv_bfloat16, badd_bf16, x + y)
BINARY_OP(__half, badd_f16, x + y)
BINARY_OP(__half, bdiv_f16, x / y)
BINARY_OP(__half, bmul_f16, x * y)
Expand Down
1 change: 1 addition & 0 deletions candle-kernels/src/cuda_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ __device__ __forceinline__ uint8_t ming(uint8_t a, uint8_t b) { return min(a, b)
__device__ __forceinline__ uint8_t maxg(uint8_t a, uint8_t b) { return max(a, b); }
#if __CUDA_ARCH__ >= 530
#include "cuda_bf16.h"
__device__ __forceinline__ __nv_bfloat16 ming(__nv_bfloat16 a, __nv_bfloat16 b) { return __hmin_nan(a, b); }
__device__ __forceinline__ __nv_bfloat16 maxg(__nv_bfloat16 a, __nv_bfloat16 b) { return __hmax_nan(a, b); }
__device__ __forceinline__ __half powg(__half a, __half b) { return __float2half(powf(__half2float(a), __half2float(b))); }
__device__ __forceinline__ bool isnang(__half a) { return __hisnan(a); }
Expand Down
4 changes: 4 additions & 0 deletions candle-kernels/src/reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,10 @@ FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argm
ROPE_OP(__nv_bfloat16, rope_bf16, rope_i_bf16, rope_thd_bf16)
SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16)
RMSNORM_OP(__nv_bfloat16, rmsnorm_bf16)

//LAYERNORM_OP(__nv_bfloat16, layernorm_bf16)
FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16)

SOFTMAX_OP(__half, float, softmax_f16)
RMSNORM_OP(__half, rmsnorm_f16)
LAYERNORM_OP(__half, layernorm_f16)
Expand Down
4 changes: 1 addition & 3 deletions candle-kernels/src/ternary.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,11 @@ extern "C" __global__ void FN_NAME( \
} \
} \

#if __CUDA_ARCH__ >= 800
#if __CUDA_ARCH__ >= 530
WHERE_OP(__nv_bfloat16, int64_t, where_i64_bf16)
WHERE_OP(__nv_bfloat16, uint32_t, where_u32_bf16)
WHERE_OP(__nv_bfloat16, uint8_t, where_u8_bf16)
#endif

#if __CUDA_ARCH__ >= 530
WHERE_OP(__half, int64_t, where_i64_f16)
WHERE_OP(__half, uint32_t, where_u32_f16)
WHERE_OP(__half, uint8_t, where_u8_f16)
Expand Down
182 changes: 160 additions & 22 deletions candle-kernels/src/unary.cu
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,68 @@ __device__ __forceinline__ T silu_fwd(T x) {
}

template<typename T>
__device__ __forceinline__ T sigmoid_fwd(T x) {
return recipg(static_cast<T>(1) + expg(-x));
__device__ __forceinline__ T sigmoid_fwd(T x);

__device__ __forceinline__ __nv_bfloat16 exp_bf16(__nv_bfloat16 x) {
// Convert to double for maximum mantissa precision
double x_double = static_cast<double>(__bfloat162float(x));

// Compute exp in double precision to preserve mantissa bits
double exp_result = exp(x_double);

// Careful conversion back to preserve significant bits
return __float2bfloat16(static_cast<float>(exp_result));
}

__device__ __forceinline__ __half exp_halft(__half x) {
// Convert to double for maximum mantissa precision
double x_double = static_cast<double>(__half2float(x));

// Compute exp in double precision to preserve mantissa bits
double exp_result = exp(x_double);

// Careful conversion back to half
return __float2half(static_cast<float>(exp_result));
}

template<>
__device__ __forceinline__ __nv_bfloat16 sigmoid_fwd<__nv_bfloat16>(__nv_bfloat16 x) {
#if __CUDA_ARCH__ >= 800
return x / (static_cast<T>(1) + expg(-x));
#elif __CUDA_ARCH__ >= 530
__nv_bfloat16 exp_neg_x = exp_bf16(__nv_bfloat16(-x));
__nv_bfloat16 one = __float2bfloat16(1.0f);
return recipg(one + exp_neg_x);
#else
// Fallback using float computation
float x_float = __bfloat162float(x);
float result = 1.0f / (1.0f + expf(-x_float));
return __float2bfloat16(result);
#endif
}

template<>
__device__ __forceinline__ __half sigmoid_fwd<__half>(__half x) {
#if __CUDA_ARCH__ >= 530
__half exp_neg_x = exp_halft(__hneg(x));
__half one = __float2half(1.0f);
return recipg(one + exp_neg_x);
#else
// Fallback using float computation
float x_float = __half2float(x);
float result = 1.0f / (1.0f + expf(-x_float));
return __float2half(result);
#endif
}

template<>
__device__ __forceinline__ float sigmoid_fwd<float>(float x) {
return 1.0f / (1.0f + expf(-x));
}

template<>
__device__ __forceinline__ double sigmoid_fwd<double>(double x) {
return 1.0 / (1.0 + exp(-x));
}

#define UNARY_OP1(TYPENAME, FN_NAME, FUNC) \
Expand Down Expand Up @@ -98,11 +158,102 @@ __device__ T sign_(T t) {


#if __CUDA_ARCH__ >= 800
UNARY_OP(__nv_bfloat16, usilu_bf16, silu_fwd(x))
UNARY_OP(__nv_bfloat16, ugelu_bf16, gelu_fwd(x))
UNARY_OP(__nv_bfloat16, ugelu_erf_bf16, gelu_erf_fwd(x))
UNARY_OP(__nv_bfloat16, urelu_bf16, relu_fwd(x))
UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param))
UNARY_OP(__nv_bfloat16, usigmoid_bf16, sigmoid_fwd(x))
UNARY_OP(__nv_bfloat16, ulog_bf16, logg(x))


#elif __CUDA_ARCH__ >= 530
template <typename T>
__device__ __forceinline__ T silu_fwd_fallback(T x) {
float x_float = __bfloat162float(x);
float exp_neg_x = expf(-x_float);

float sigmoid = 1.0f / (1.0f + exp_neg_x);
float result = x_float * sigmoid;
}

__device__ __nv_bfloat16 gelu_fwd_fallback(__nv_bfloat16 x) {
__nv_bfloat16 half = __float2bfloat16(0.5f);
__nv_bfloat16 one = __float2bfloat16(1.0f);
__nv_bfloat16 tanh_val = tanhg(x);
return half * x * (one + tanh_val);
}

__device__ __nv_bfloat16 gelu_erf_fwd_fallback(__nv_bfloat16 x) {
// Convert to float for computation on older architectures
float x_float = __bfloat162float(x);
float result = x_float * normcdfg(x_float);
return __float2bfloat16(result);
}

__device__ __forceinline__ __nv_bfloat16 sigmoid_fwd_fallback(__nv_bfloat16 x) {
// Fallback using float32 computation
float x_float = __bfloat162float(x);
float result = 1.0f / (1.0f + expf(-x_float));
return __float2bfloat16(result);
}

__device__ __forceinline__ __nv_bfloat16 elu_fwd_fallback(__nv_bfloat16 x, __nv_bfloat16 alpha) {
// Fallback implementation using float32
float x_float = __bfloat162float(x);
float alpha_float = __bfloat162float(alpha);

if (x_float > 0.0f) {
return x;
}

float result = alpha_float * (expf(x_float) - 1.0f);
return __float2bfloat16(result);
}

__device__ __forceinline__ __nv_bfloat16 logg_fallback(__nv_bfloat16 x) {
float x_float = __bfloat162float(x);
float result = logf(fabsf(x_float));
return __float2bfloat16(result);
}

__device__ __forceinline__ __nv_bfloat16 negate_bfloat16(__nv_bfloat16 x) {
union {
__nv_bfloat16 bf16_val;
unsigned short bits;
} u;
u.bf16_val = x;
// Flip the sign bit
u.bits ^= 0x8000;
return u.bf16_val;
}

__device__ __forceinline__ __nv_bfloat16 abs_bfloat16(__nv_bfloat16 x) {
union {
__nv_bfloat16 bf16_val;
unsigned short bits;
} u;
u.bf16_val = x;
// Clear the sign bit (most significant bit)
u.bits &= 0x7FFF;
return u.bf16_val;
}

__device__ __forceinline__ __nv_bfloat16 sqrt_bfloat16_fallback(__nv_bfloat16 x) {
float x_float = __bfloat162float(x);

// Take absolute value before computing square root
float abs_x = fabsf(x_float);
float result = sqrtf(abs_x);

return __float2bfloat16(result);
}

UNARY_OP(__nv_bfloat16, ucopy_bf16, x)
UNARY_OP(__nv_bfloat16, uneg_bf16, -x)
UNARY_OP(__nv_bfloat16, urecip_bf16, recipg(x))
UNARY_OP(__nv_bfloat16, uexp_bf16, expg(x))
UNARY_OP(__nv_bfloat16, ulog_bf16, logg(x))
UNARY_OP(__nv_bfloat16, ulog_bf16, logg_fallback(x)) ///
UNARY_OP(__nv_bfloat16, usin_bf16, sing(x))
UNARY_OP(__nv_bfloat16, ucos_bf16, cosg(x))
UNARY_OP(__nv_bfloat16, utanh_bf16, tanhg(x))
Expand All @@ -113,29 +264,16 @@ UNARY_OP(__nv_bfloat16, uround_bf16, roundg(x))
UNARY_OP(__nv_bfloat16, unormcdf_bf16, normcdfg(x))
UNARY_OP(__nv_bfloat16, uabs_bf16, absg(x))
UNARY_OP(__nv_bfloat16, usqr_bf16, x*x)
UNARY_OP(__nv_bfloat16, usqrt_bf16, sqrtg(x))
UNARY_OP(__nv_bfloat16, ugelu_bf16, gelu_fwd(x))
UNARY_OP(__nv_bfloat16, ugelu_erf_bf16, gelu_erf_fwd(x))
UNARY_OP(__nv_bfloat16, urelu_bf16, relu_fwd(x))
UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param))
UNARY_OP(__nv_bfloat16, usilu_bf16, silu_fwd(x))
UNARY_OP(__nv_bfloat16, usqrt_bf16, sqrt_bfloat16_fallback(x))
UNARY_OP(__nv_bfloat16, ugelu_bf16, gelu_fwd_fallback(x))
UNARY_OP(__nv_bfloat16, ugelu_erf_bf16, gelu_erf_fwd_fallback(x))
UNARY_OP(__nv_bfloat16, urelu_bf16, __float2bfloat16(relu_fwd(__bfloat162float(x))))
UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd_fallback(x, param))
UNARY_OP(__nv_bfloat16, usilu_bf16, silu_fwd_fallback(x))
UNARY_OP1(__nv_bfloat16, upowf_bf16, powg(x, param))
UNARY_OP(__nv_bfloat16, usign_bf16, sign_(x))
UNARY_OP(__nv_bfloat16, usigmoid_bf16, sigmoid_fwd(x))
#endif

#if __CUDA_ARCH__ >= 530
#include "cuda_bf16.h"
template <typename T>
__device__ __forceinline__ T silu_fwd_fallback(T x) {
const T one = T(1.0f);
const T neg_x = -x;
const T exp_neg_x = expg(neg_x);
return x / (one + exp_neg_x);
}

UNARY_OP(__nv_bfloat16, ucopy_bf16, x)
UNARY_OP(__nv_bfloat16, usilu_bf16, silu_fwd_fallback(x))
UNARY_OP(__half, ucopy_f16, x)
UNARY_OP(__half, uneg_f16, -x)
UNARY_OP(__half, urecip_f16, recipg(x))
Expand Down
25 changes: 24 additions & 1 deletion candle-nn/tests/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;

use candle::{test_device, test_utils::to_vec3_round, Device, Result, Tensor};
use candle::{test_device, test_utils::to_vec3_round, Device, DType, Result, Tensor};

fn softmax(device: &Device) -> Result<()> {
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
Expand Down Expand Up @@ -249,6 +249,27 @@ fn sigmoid(device: &Device) -> Result<()> {
Ok(())
}

fn sigmoid_f16(device: &Device) -> Result<()> {
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
let tensor = Tensor::new(data, device)?.to_dtype(DType::F16)?;
let s1 = candle_nn::ops::sigmoid(&tensor)?;
let s2 = (1. / (1. + tensor.neg()?.exp()?)?)?;
let diff = (s1 - s2)?.abs()?.sum_all()?.to_vec0::<half::f16>()?;
assert_eq!(diff, half::f16::from_f32(0.));
Ok(())
}

fn sigmoid_bf16(device: &Device) -> Result<()> {
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
let tensor = Tensor::new(data, device)?.to_dtype(DType::BF16)?;
let s1 = candle_nn::ops::sigmoid(&tensor)?;
let s2 = (1. / (1. + tensor.neg()?.exp()?)?)?;
let diff = (s1 - s2)?.abs()?.sum_all()?.to_vec0::<half::bf16>()?;
assert_eq!(diff, half::bf16::from_f32(0.));
Ok(())
}


test_device!(ropei, ropei_cpu, ropei_gpu, ropei_metal);
test_device!(rope, rope_cpu, rope_gpu, rope_metal);
test_device!(rope_thd, rope_thd_cpu, rope_thd_gpu, rope_thd_metal);
Expand All @@ -258,3 +279,5 @@ test_device!(rms_norml, rms_norml_cpu, rms_norml_gpu, rms_norml_metal);
test_device!(layer_norm, ln_cpu, ln_gpu, ln_metal);
test_device!(layer_norml, lnl_cpu, lnl_gpu, lnl_metal);
test_device!(sigmoid, sigmoid_cpu, sigmoid_gpu, sigmoid_metal);
test_device!(sigmoid_f16, sigmoid_b16_cpu, sigmoid_b16_gpu, sigmoid_b16_metal);
test_device!(sigmoid_bf16, sigmoid_bf16_cpu, sigmoid_bf16_gpu, sigmoid_bf16_metal);

0 comments on commit 2b888da

Please sign in to comment.