From 08f18af4e5c24812b0a4e459ae203c042aeed203 Mon Sep 17 00:00:00 2001 From: laurent Date: Thu, 26 Sep 2024 22:55:21 +0200 Subject: [PATCH] Yet another cuda qmm padding fix. --- candle-core/src/quantized/cuda.rs | 80 +++++++++++++++++++++---------- 1 file changed, 55 insertions(+), 25 deletions(-) diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs index 8e4884b28d..b0df49978d 100644 --- a/candle-core/src/quantized/cuda.rs +++ b/candle-core/src/quantized/cuda.rs @@ -6,9 +6,15 @@ use half::f16; use cudarc::driver::{CudaSlice, CudaView, DeviceSlice}; +#[derive(Clone, Debug)] +struct PaddedCudaSlice { + inner: CudaSlice, + len: usize, +} + #[derive(Clone, Debug)] pub struct QCudaStorage { - data: CudaSlice, + data: PaddedCudaSlice, dtype: GgmlDType, device: CudaDevice, } @@ -61,7 +67,7 @@ fn quantize_q8_1( } fn dequantize_f32( - data: &CudaSlice, + data: &PaddedCudaSlice, dtype: GgmlDType, elem_count: usize, dev: &CudaDevice, @@ -104,21 +110,21 @@ fn dequantize_f32( }; if is_k { - let params = (data, &dst); + let params = (&data.inner, &dst); unsafe { func.launch(cfg, params) }.w()?; } else { let nb32 = match dtype { GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count, _ => elem_count / 32, }; - let params = (data, &dst, nb32 as i32); + let params = (&data.inner, &dst, nb32 as i32); unsafe { func.launch(cfg, params) }.w()?; } Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone())) } fn dequantize_f16( - data: &CudaSlice, + data: &PaddedCudaSlice, dtype: GgmlDType, elem_count: usize, dev: &CudaDevice, @@ -161,21 +167,21 @@ fn dequantize_f16( }; if is_k { - let params = (data, &dst); + let params = (&data.inner, &dst); unsafe { func.launch(cfg, params) }.w()?; } else { let nb32 = match dtype { GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count, _ => elem_count / 32, }; - let params = (data, &dst, nb32 as i32); + let params = (&data.inner, &dst, nb32 as i32); unsafe { func.launch(cfg, params) }.w()?; } Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone())) } fn dequantize_mul_mat_vec( - data: &CudaSlice, + data: &PaddedCudaSlice, y: &CudaView, dtype: GgmlDType, ncols: usize, @@ -184,7 +190,7 @@ fn dequantize_mul_mat_vec( ) -> Result { use cudarc::driver::LaunchAsync; - let data_elems = data.len() / dtype.type_size() * dtype.block_size(); + let data_elems = data.len / dtype.type_size() * dtype.block_size(); if data_elems < ncols * nrows { crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems) } @@ -213,13 +219,13 @@ fn dequantize_mul_mat_vec( shared_mem_bytes: 0, }; - let params = (data, y, &dst, ncols as i32, nrows as i32); + let params = (&data.inner, y, &dst, ncols as i32, nrows as i32); unsafe { func.launch(cfg, params) }.w()?; Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone())) } fn mul_mat_vec_via_q8_1( - data: &CudaSlice, + data: &PaddedCudaSlice, y: &CudaView, dtype: GgmlDType, ncols: usize, @@ -229,7 +235,7 @@ fn mul_mat_vec_via_q8_1( ) -> Result { use cudarc::driver::LaunchAsync; - let data_elems = data.len() / dtype.type_size() * dtype.block_size(); + let data_elems = data.len / dtype.type_size() * dtype.block_size(); if data_elems < ncols * nrows { crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems) } @@ -276,7 +282,7 @@ fn mul_mat_vec_via_q8_1( }; let params = ( - data, + &data.inner, &y_q8_1, &dst, /* ncols_x */ ncols as i32, @@ -290,7 +296,7 @@ fn mul_mat_vec_via_q8_1( #[allow(clippy::too_many_arguments)] fn mul_mat_via_q8_1( - data: &CudaSlice, + data: &PaddedCudaSlice, y: &CudaView, dtype: GgmlDType, x_rows: usize, @@ -301,7 +307,7 @@ fn mul_mat_via_q8_1( ) -> Result { use cudarc::driver::LaunchAsync; - let data_elems = data.len() / dtype.type_size() * dtype.block_size(); + let data_elems = data.len / dtype.type_size() * dtype.block_size(); if data_elems < x_rows * x_cols { crate::bail!("unexpected lhs size {}, {x_rows} {x_cols}", data_elems) } @@ -345,7 +351,7 @@ fn mul_mat_via_q8_1( }; let params = ( - /* vx */ data, + /* vx */ &data.inner, /* vy */ &y_q8_1, /* dst */ &dst, /* ncols_x */ x_cols as i32, @@ -361,9 +367,14 @@ fn mul_mat_via_q8_1( impl QCudaStorage { pub fn zeros(device: &CudaDevice, el_count: usize, dtype: GgmlDType) -> Result { let size_in_bytes = ceil_div(el_count, dtype.block_size()) * dtype.type_size(); - let data = device.alloc_zeros::(size_in_bytes).w()?; + let padded_size_in_bytes = + ceil_div(el_count + MATRIX_ROW_PADDING, dtype.block_size()) * dtype.type_size(); + let inner = device.alloc_zeros::(padded_size_in_bytes).w()?; Ok(QCudaStorage { - data, + data: PaddedCudaSlice { + inner, + len: size_in_bytes, + }, device: device.clone(), dtype, }) @@ -403,7 +414,10 @@ impl QCudaStorage { } // Run the dequantization on cpu. - let buffer = self.device.dtoh_sync_copy(&self.data).w()?; + let buffer = self + .device + .dtoh_sync_copy(&self.data.inner.slice(..self.data.len)) + .w()?; let mut out = vec![0.0; elem_count]; let block_len = elem_count / self.dtype.block_size(); match self.dtype { @@ -444,13 +458,21 @@ impl QCudaStorage { let mut qcpu_storage = crate::Device::Cpu.qzeros(src_len, self.dtype)?; qcpu_storage.quantize(&src)?; let data = qcpu_storage.data()?; - let data = self.device.htod_sync_copy(data.as_ref()).w()?; - self.data = data; + let padded_len = + data.len() + MATRIX_ROW_PADDING * self.dtype.type_size() / self.dtype.block_size(); + let mut inner = unsafe { self.device.alloc::(padded_len).w()? }; + self.device + .htod_sync_copy_into(data.as_ref(), &mut inner.slice_mut(..data.len())) + .w()?; + self.data = PaddedCudaSlice { + inner, + len: data.len(), + }; Ok(()) } pub fn storage_size_in_bytes(&self) -> usize { - self.data.len() + self.data.len } pub fn fwd( @@ -573,11 +595,19 @@ pub fn load_quantized( let data = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const u8, core::mem::size_of_val(data)) }; - let data = device.htod_sync_copy(data).w()?; + let dtype = T::DTYPE; + let padded_len = data.len() + MATRIX_ROW_PADDING * dtype.type_size() / dtype.block_size(); + let mut inner = unsafe { device.alloc::(padded_len).w()? }; + device + .htod_sync_copy_into(data, &mut inner.slice_mut(..data.len())) + .w()?; Ok(QStorage::Cuda(QCudaStorage { - data, + data: PaddedCudaSlice { + inner, + len: data.len(), + }, device: device.clone(), - dtype: T::DTYPE, + dtype, })) }