Skip to content

Commit

Permalink
Update kernels for metal bf16 (#19)
Browse files Browse the repository at this point in the history
* Update kernels for metal bf16

* Fix typo

* Check if have bfloat
  • Loading branch information
EricLBuehler committed Nov 26, 2024
1 parent b482af4 commit ac95466
Show file tree
Hide file tree
Showing 3 changed files with 206 additions and 3 deletions.
5 changes: 5 additions & 0 deletions candle-core/src/quantized/metal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ impl QMetalStorage {
let vec: Vec<half::f16> = read_to_vec(&buffer, block_len);
half::f16::to_float(&vec, &mut out)?;
}
GgmlDType::BF16 => {
let vec: Vec<half::bf16> = read_to_vec(&buffer, block_len);
half::bf16::to_float(&vec, &mut out)?;
}
GgmlDType::Q4_0 => {
let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?;
Expand Down Expand Up @@ -225,6 +229,7 @@ impl From<GgmlDType> for candle_metal_kernels::GgmlDType {
GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K,
GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16,
GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32,
GgmlDType::BF16 => candle_metal_kernels::GgmlDType::F16,
}
}
}
4 changes: 3 additions & 1 deletion candle-metal-kernels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2164,6 +2164,7 @@ pub enum GgmlDType {
Q8K,
F16,
F32,
BF16,
}

#[allow(clippy::too_many_arguments)]
Expand Down Expand Up @@ -2241,7 +2242,7 @@ pub fn call_quantized_matmul_mv_t(
let align = 2;
(nth0, nth1, align)
}
GgmlDType::F16 | GgmlDType::Q8K => {
GgmlDType::F16 | GgmlDType::BF16 | GgmlDType::Q8K => {
// Original implem uses rows
let nth0 = 32;
let nth1 = 1;
Expand Down Expand Up @@ -2279,6 +2280,7 @@ pub fn call_quantized_matmul_mv_t(
GgmlDType::Q6K => "kernel_mul_mv_q6_K_f32",
GgmlDType::Q8K => "kernel_mul_mv_q8_K_f32",
GgmlDType::F16 => "kernel_mul_mv_f16_f32",
GgmlDType::BF16 => "kernel_mul_mv_bf16_f32",
GgmlDType::F32 => "kernel_mul_mv_f32_f32",
};

Expand Down
200 changes: 198 additions & 2 deletions candle-metal-kernels/src/quantized.metal
Original file line number Diff line number Diff line change
Expand Up @@ -1495,8 +1495,203 @@ kernel void kernel_mul_mv_f16_f32(
kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
}

#if defined(__HAVE_BFLOAT__)
void kernel_mul_mv_bf16_f32_1row_impl(
device const char * src0,
device const char * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {

const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y;
const int64_t im = tgpig.z;

const uint i12 = im%ne12;
const uint i13 = im/ne12;

const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;

device const bfloat* x = (device const bfloat*) (src0 + offset0);
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);

float sumf = 0;
if (ne00 < 128) {
for (int i = tiisg; i < ne00; i += 32) {
sumf += (float) x[i] * (float) y[i];
}
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
} else {
device const bfloat4* x4 = (device const bfloat4*) x;
device const float4 * y4 = (device const float4 *) y;
for (int i = tiisg; i < ne00/4; i += 32) {
for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k];
}
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
}
}

[[host_name("kernel_mul_mv_bf16_f32_1row")]]
kernel void kernel_mul_mv_bf16_f32_1row(
device const char * src0,
device const char * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
kernel_mul_mv_bf16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
}
#endif

#define N_BF16_F32 4

#if defined(__HAVE_BFLOAT__)
void kernel_mul_mv_bf16_f32_impl(
device const char * src0,
device const char * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {

const int64_t r0 = tgpig.x;
const int64_t rb = tgpig.y*N_BF16_F32;
const int64_t im = tgpig.z;

const uint i12 = im%ne12;
const uint i13 = im/ne12;

const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;

device const bfloat * x = (device const bfloat *) (src0 + offset0);

if (ne00 < 128) {
for (int row = 0; row < N_BF16_F32; ++row) {
int r1 = rb + row;
if (r1 >= ne11) {
break;
}

device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);

float sumf = 0;
for (int i = tiisg; i < ne00; i += 32) {
sumf += (float) x[i] * (float) y[i];
}

float all_sum = simd_sum(sumf);
if (tiisg == 0) {
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
}
} else {
device const bfloat4 * x4 = (device const bfloat4 *)x;
for (int row = 0; row < N_BF16_F32; ++row) {
int r1 = rb + row;
if (r1 >= ne11) {
break;
}

device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
device const float4 * y4 = (device const float4 *) y;

float sumf = 0;
for (int i = tiisg; i < ne00/4; i += 32) {
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
}

float all_sum = simd_sum(sumf);
if (tiisg == 0) {
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
}
}
}

[[host_name("kernel_mul_mv_bf16_f32")]]
kernel void kernel_mul_mv_bf16_f32(
device const char * src0,
device const char * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
kernel_mul_mv_bf16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
}
#endif

#if defined(__HAVE_BFLOAT__)
// Assumes row size (ne00) is a multiple of 4
kernel void kernel_mul_mv_f16_f32_l4(
kernel void kernel_mul_mv_bf16_f32_l4(
device const char * src0,
device const char * src1,
device float * dst,
Expand Down Expand Up @@ -1528,7 +1723,7 @@ kernel void kernel_mul_mv_f16_f32_l4(

const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;

device const half4 * x4 = (device const half4 *) (src0 + offset0);
device const bfloat4 * x4 = (device const bfloat4 *) (src0 + offset0);

for (int r1 = 0; r1 < nrows; ++r1) {
device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
Expand All @@ -1544,6 +1739,7 @@ kernel void kernel_mul_mv_f16_f32_l4(
}
}
}
#endif

kernel void kernel_alibi_f32(
device const float * src0,
Expand Down

0 comments on commit ac95466

Please sign in to comment.