From 2b888da8e7d85ccf5cca7c6a4c6cf5820c37b829 Mon Sep 17 00:00:00 2001 From: Nicolas <344493+haricot@users.noreply.github.com> Date: Wed, 15 Jan 2025 16:03:04 +0100 Subject: [PATCH] bf16 f16 sigmoid:affine binary ternary unary --- candle-kernels/src/affine.cu | 1 + candle-kernels/src/binary.cu | 7 +- candle-kernels/src/cuda_utils.cuh | 1 + candle-kernels/src/reduce.cu | 4 + candle-kernels/src/ternary.cu | 4 +- candle-kernels/src/unary.cu | 182 ++++++++++++++++++++++++++---- candle-nn/tests/ops.rs | 25 +++- 7 files changed, 192 insertions(+), 32 deletions(-) diff --git a/candle-kernels/src/affine.cu b/candle-kernels/src/affine.cu index 540d0819f5..6cd6f0c258 100644 --- a/candle-kernels/src/affine.cu +++ b/candle-kernels/src/affine.cu @@ -33,6 +33,7 @@ AFFINE_OP(__nv_bfloat16, affine_bf16) #endif #if __CUDA_ARCH__ >= 530 +AFFINE_OP(__nv_bfloat16, affine_bf16) AFFINE_OP(__half, affine_f16) #endif diff --git a/candle-kernels/src/binary.cu b/candle-kernels/src/binary.cu index 7a0e74ef92..d9278200f7 100644 --- a/candle-kernels/src/binary.cu +++ b/candle-kernels/src/binary.cu @@ -1,7 +1,7 @@ #include "binary_op_macros.cuh" #include -#if __CUDA_ARCH__ >= 800 +#if __CUDA_ARCH__ >= 530 BINARY_OP(__nv_bfloat16, badd_bf16, x + y) BINARY_OP(__nv_bfloat16, bdiv_bf16, x / y) BINARY_OP(__nv_bfloat16, bmul_bf16, x * y) @@ -14,12 +14,7 @@ BINARY_OP_OUT(__nv_bfloat16, uint8_t, lt_bf16, x < y) BINARY_OP_OUT(__nv_bfloat16, uint8_t, le_bf16, x <= y) BINARY_OP_OUT(__nv_bfloat16, uint8_t, gt_bf16, x > y) BINARY_OP_OUT(__nv_bfloat16, uint8_t, ge_bf16, x >= y) -#endif -#if __CUDA_ARCH__ >= 530 -#include "cuda_bf16.h" -BINARY_OP(__nv_bfloat16, bmul_bf16, x * y) -BINARY_OP(__nv_bfloat16, badd_bf16, x + y) BINARY_OP(__half, badd_f16, x + y) BINARY_OP(__half, bdiv_f16, x / y) BINARY_OP(__half, bmul_f16, x * y) diff --git a/candle-kernels/src/cuda_utils.cuh b/candle-kernels/src/cuda_utils.cuh index 4ae23317b7..f893c1e727 100644 --- a/candle-kernels/src/cuda_utils.cuh +++ b/candle-kernels/src/cuda_utils.cuh @@ -160,6 +160,7 @@ __device__ __forceinline__ uint8_t ming(uint8_t a, uint8_t b) { return min(a, b) __device__ __forceinline__ uint8_t maxg(uint8_t a, uint8_t b) { return max(a, b); } #if __CUDA_ARCH__ >= 530 #include "cuda_bf16.h" +__device__ __forceinline__ __nv_bfloat16 ming(__nv_bfloat16 a, __nv_bfloat16 b) { return __hmin_nan(a, b); } __device__ __forceinline__ __nv_bfloat16 maxg(__nv_bfloat16 a, __nv_bfloat16 b) { return __hmax_nan(a, b); } __device__ __forceinline__ __half powg(__half a, __half b) { return __float2half(powf(__half2float(a), __half2float(b))); } __device__ __forceinline__ bool isnang(__half a) { return __hisnan(a); } diff --git a/candle-kernels/src/reduce.cu b/candle-kernels/src/reduce.cu index 9cfc6aed71..3f8b39941d 100644 --- a/candle-kernels/src/reduce.cu +++ b/candle-kernels/src/reduce.cu @@ -585,6 +585,10 @@ FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argm ROPE_OP(__nv_bfloat16, rope_bf16, rope_i_bf16, rope_thd_bf16) SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16) RMSNORM_OP(__nv_bfloat16, rmsnorm_bf16) + +//LAYERNORM_OP(__nv_bfloat16, layernorm_bf16) +FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16) + SOFTMAX_OP(__half, float, softmax_f16) RMSNORM_OP(__half, rmsnorm_f16) LAYERNORM_OP(__half, layernorm_f16) diff --git a/candle-kernels/src/ternary.cu b/candle-kernels/src/ternary.cu index aaa8a881fb..049d384bac 100644 --- a/candle-kernels/src/ternary.cu +++ b/candle-kernels/src/ternary.cu @@ -32,13 +32,11 @@ extern "C" __global__ void FN_NAME( \ } \ } \ -#if __CUDA_ARCH__ >= 800 +#if __CUDA_ARCH__ >= 530 WHERE_OP(__nv_bfloat16, int64_t, where_i64_bf16) WHERE_OP(__nv_bfloat16, uint32_t, where_u32_bf16) WHERE_OP(__nv_bfloat16, uint8_t, where_u8_bf16) -#endif -#if __CUDA_ARCH__ >= 530 WHERE_OP(__half, int64_t, where_i64_f16) WHERE_OP(__half, uint32_t, where_u32_f16) WHERE_OP(__half, uint8_t, where_u8_f16) diff --git a/candle-kernels/src/unary.cu b/candle-kernels/src/unary.cu index fdf7310752..2784e4f725 100644 --- a/candle-kernels/src/unary.cu +++ b/candle-kernels/src/unary.cu @@ -61,8 +61,68 @@ __device__ __forceinline__ T silu_fwd(T x) { } template -__device__ __forceinline__ T sigmoid_fwd(T x) { - return recipg(static_cast(1) + expg(-x)); +__device__ __forceinline__ T sigmoid_fwd(T x); + +__device__ __forceinline__ __nv_bfloat16 exp_bf16(__nv_bfloat16 x) { + // Convert to double for maximum mantissa precision + double x_double = static_cast(__bfloat162float(x)); + + // Compute exp in double precision to preserve mantissa bits + double exp_result = exp(x_double); + + // Careful conversion back to preserve significant bits + return __float2bfloat16(static_cast(exp_result)); +} + +__device__ __forceinline__ __half exp_halft(__half x) { + // Convert to double for maximum mantissa precision + double x_double = static_cast(__half2float(x)); + + // Compute exp in double precision to preserve mantissa bits + double exp_result = exp(x_double); + + // Careful conversion back to half + return __float2half(static_cast(exp_result)); +} + +template<> +__device__ __forceinline__ __nv_bfloat16 sigmoid_fwd<__nv_bfloat16>(__nv_bfloat16 x) { +#if __CUDA_ARCH__ >= 800 + return x / (static_cast(1) + expg(-x)); +#elif __CUDA_ARCH__ >= 530 + __nv_bfloat16 exp_neg_x = exp_bf16(__nv_bfloat16(-x)); + __nv_bfloat16 one = __float2bfloat16(1.0f); + return recipg(one + exp_neg_x); +#else + // Fallback using float computation + float x_float = __bfloat162float(x); + float result = 1.0f / (1.0f + expf(-x_float)); + return __float2bfloat16(result); +#endif +} + +template<> +__device__ __forceinline__ __half sigmoid_fwd<__half>(__half x) { +#if __CUDA_ARCH__ >= 530 + __half exp_neg_x = exp_halft(__hneg(x)); + __half one = __float2half(1.0f); + return recipg(one + exp_neg_x); +#else + // Fallback using float computation + float x_float = __half2float(x); + float result = 1.0f / (1.0f + expf(-x_float)); + return __float2half(result); +#endif +} + +template<> +__device__ __forceinline__ float sigmoid_fwd(float x) { + return 1.0f / (1.0f + expf(-x)); +} + +template<> +__device__ __forceinline__ double sigmoid_fwd(double x) { + return 1.0 / (1.0 + exp(-x)); } #define UNARY_OP1(TYPENAME, FN_NAME, FUNC) \ @@ -98,11 +158,102 @@ __device__ T sign_(T t) { #if __CUDA_ARCH__ >= 800 +UNARY_OP(__nv_bfloat16, usilu_bf16, silu_fwd(x)) +UNARY_OP(__nv_bfloat16, ugelu_bf16, gelu_fwd(x)) +UNARY_OP(__nv_bfloat16, ugelu_erf_bf16, gelu_erf_fwd(x)) +UNARY_OP(__nv_bfloat16, urelu_bf16, relu_fwd(x)) +UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param)) +UNARY_OP(__nv_bfloat16, usigmoid_bf16, sigmoid_fwd(x)) +UNARY_OP(__nv_bfloat16, ulog_bf16, logg(x)) + + +#elif __CUDA_ARCH__ >= 530 +template +__device__ __forceinline__ T silu_fwd_fallback(T x) { + float x_float = __bfloat162float(x); + float exp_neg_x = expf(-x_float); + + float sigmoid = 1.0f / (1.0f + exp_neg_x); + float result = x_float * sigmoid; +} + +__device__ __nv_bfloat16 gelu_fwd_fallback(__nv_bfloat16 x) { + __nv_bfloat16 half = __float2bfloat16(0.5f); + __nv_bfloat16 one = __float2bfloat16(1.0f); + __nv_bfloat16 tanh_val = tanhg(x); + return half * x * (one + tanh_val); +} + +__device__ __nv_bfloat16 gelu_erf_fwd_fallback(__nv_bfloat16 x) { + // Convert to float for computation on older architectures + float x_float = __bfloat162float(x); + float result = x_float * normcdfg(x_float); + return __float2bfloat16(result); +} + +__device__ __forceinline__ __nv_bfloat16 sigmoid_fwd_fallback(__nv_bfloat16 x) { + // Fallback using float32 computation + float x_float = __bfloat162float(x); + float result = 1.0f / (1.0f + expf(-x_float)); + return __float2bfloat16(result); +} + +__device__ __forceinline__ __nv_bfloat16 elu_fwd_fallback(__nv_bfloat16 x, __nv_bfloat16 alpha) { + // Fallback implementation using float32 + float x_float = __bfloat162float(x); + float alpha_float = __bfloat162float(alpha); + + if (x_float > 0.0f) { + return x; + } + + float result = alpha_float * (expf(x_float) - 1.0f); + return __float2bfloat16(result); +} + +__device__ __forceinline__ __nv_bfloat16 logg_fallback(__nv_bfloat16 x) { + float x_float = __bfloat162float(x); + float result = logf(fabsf(x_float)); + return __float2bfloat16(result); +} + +__device__ __forceinline__ __nv_bfloat16 negate_bfloat16(__nv_bfloat16 x) { + union { + __nv_bfloat16 bf16_val; + unsigned short bits; + } u; + u.bf16_val = x; + // Flip the sign bit + u.bits ^= 0x8000; + return u.bf16_val; +} + +__device__ __forceinline__ __nv_bfloat16 abs_bfloat16(__nv_bfloat16 x) { + union { + __nv_bfloat16 bf16_val; + unsigned short bits; + } u; + u.bf16_val = x; + // Clear the sign bit (most significant bit) + u.bits &= 0x7FFF; + return u.bf16_val; +} + +__device__ __forceinline__ __nv_bfloat16 sqrt_bfloat16_fallback(__nv_bfloat16 x) { + float x_float = __bfloat162float(x); + + // Take absolute value before computing square root + float abs_x = fabsf(x_float); + float result = sqrtf(abs_x); + + return __float2bfloat16(result); +} + UNARY_OP(__nv_bfloat16, ucopy_bf16, x) UNARY_OP(__nv_bfloat16, uneg_bf16, -x) UNARY_OP(__nv_bfloat16, urecip_bf16, recipg(x)) UNARY_OP(__nv_bfloat16, uexp_bf16, expg(x)) -UNARY_OP(__nv_bfloat16, ulog_bf16, logg(x)) +UNARY_OP(__nv_bfloat16, ulog_bf16, logg_fallback(x)) /// UNARY_OP(__nv_bfloat16, usin_bf16, sing(x)) UNARY_OP(__nv_bfloat16, ucos_bf16, cosg(x)) UNARY_OP(__nv_bfloat16, utanh_bf16, tanhg(x)) @@ -113,29 +264,16 @@ UNARY_OP(__nv_bfloat16, uround_bf16, roundg(x)) UNARY_OP(__nv_bfloat16, unormcdf_bf16, normcdfg(x)) UNARY_OP(__nv_bfloat16, uabs_bf16, absg(x)) UNARY_OP(__nv_bfloat16, usqr_bf16, x*x) -UNARY_OP(__nv_bfloat16, usqrt_bf16, sqrtg(x)) -UNARY_OP(__nv_bfloat16, ugelu_bf16, gelu_fwd(x)) -UNARY_OP(__nv_bfloat16, ugelu_erf_bf16, gelu_erf_fwd(x)) -UNARY_OP(__nv_bfloat16, urelu_bf16, relu_fwd(x)) -UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param)) -UNARY_OP(__nv_bfloat16, usilu_bf16, silu_fwd(x)) +UNARY_OP(__nv_bfloat16, usqrt_bf16, sqrt_bfloat16_fallback(x)) +UNARY_OP(__nv_bfloat16, ugelu_bf16, gelu_fwd_fallback(x)) +UNARY_OP(__nv_bfloat16, ugelu_erf_bf16, gelu_erf_fwd_fallback(x)) +UNARY_OP(__nv_bfloat16, urelu_bf16, __float2bfloat16(relu_fwd(__bfloat162float(x)))) +UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd_fallback(x, param)) +UNARY_OP(__nv_bfloat16, usilu_bf16, silu_fwd_fallback(x)) UNARY_OP1(__nv_bfloat16, upowf_bf16, powg(x, param)) UNARY_OP(__nv_bfloat16, usign_bf16, sign_(x)) UNARY_OP(__nv_bfloat16, usigmoid_bf16, sigmoid_fwd(x)) -#endif -#if __CUDA_ARCH__ >= 530 -#include "cuda_bf16.h" -template -__device__ __forceinline__ T silu_fwd_fallback(T x) { - const T one = T(1.0f); - const T neg_x = -x; - const T exp_neg_x = expg(neg_x); - return x / (one + exp_neg_x); -} - -UNARY_OP(__nv_bfloat16, ucopy_bf16, x) -UNARY_OP(__nv_bfloat16, usilu_bf16, silu_fwd_fallback(x)) UNARY_OP(__half, ucopy_f16, x) UNARY_OP(__half, uneg_f16, -x) UNARY_OP(__half, urecip_f16, recipg(x)) diff --git a/candle-nn/tests/ops.rs b/candle-nn/tests/ops.rs index 3a8a0bb915..3a98c8527e 100644 --- a/candle-nn/tests/ops.rs +++ b/candle-nn/tests/ops.rs @@ -4,7 +4,7 @@ extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; -use candle::{test_device, test_utils::to_vec3_round, Device, Result, Tensor}; +use candle::{test_device, test_utils::to_vec3_round, Device, DType, Result, Tensor}; fn softmax(device: &Device) -> Result<()> { let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]]; @@ -249,6 +249,27 @@ fn sigmoid(device: &Device) -> Result<()> { Ok(()) } +fn sigmoid_f16(device: &Device) -> Result<()> { + let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]]; + let tensor = Tensor::new(data, device)?.to_dtype(DType::F16)?; + let s1 = candle_nn::ops::sigmoid(&tensor)?; + let s2 = (1. / (1. + tensor.neg()?.exp()?)?)?; + let diff = (s1 - s2)?.abs()?.sum_all()?.to_vec0::()?; + assert_eq!(diff, half::f16::from_f32(0.)); + Ok(()) +} + +fn sigmoid_bf16(device: &Device) -> Result<()> { + let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]]; + let tensor = Tensor::new(data, device)?.to_dtype(DType::BF16)?; + let s1 = candle_nn::ops::sigmoid(&tensor)?; + let s2 = (1. / (1. + tensor.neg()?.exp()?)?)?; + let diff = (s1 - s2)?.abs()?.sum_all()?.to_vec0::()?; + assert_eq!(diff, half::bf16::from_f32(0.)); + Ok(()) +} + + test_device!(ropei, ropei_cpu, ropei_gpu, ropei_metal); test_device!(rope, rope_cpu, rope_gpu, rope_metal); test_device!(rope_thd, rope_thd_cpu, rope_thd_gpu, rope_thd_metal); @@ -258,3 +279,5 @@ test_device!(rms_norml, rms_norml_cpu, rms_norml_gpu, rms_norml_metal); test_device!(layer_norm, ln_cpu, ln_gpu, ln_metal); test_device!(layer_norml, lnl_cpu, lnl_gpu, lnl_metal); test_device!(sigmoid, sigmoid_cpu, sigmoid_gpu, sigmoid_metal); +test_device!(sigmoid_f16, sigmoid_b16_cpu, sigmoid_b16_gpu, sigmoid_b16_metal); +test_device!(sigmoid_bf16, sigmoid_bf16_cpu, sigmoid_bf16_gpu, sigmoid_bf16_metal);