Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Yet another cuda qmm padding fix. #2509

Merged
merged 2 commits into from
Sep 30, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 55 additions & 25 deletions candle-core/src/quantized/cuda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,15 @@ use half::f16;

use cudarc::driver::{CudaSlice, CudaView, DeviceSlice};

#[derive(Clone, Debug)]
struct PaddedCudaSlice {
inner: CudaSlice<u8>,
len: usize,
}

#[derive(Clone, Debug)]
pub struct QCudaStorage {
data: CudaSlice<u8>,
data: PaddedCudaSlice,
dtype: GgmlDType,
device: CudaDevice,
}
Expand Down Expand Up @@ -61,7 +67,7 @@ fn quantize_q8_1(
}

fn dequantize_f32(
data: &CudaSlice<u8>,
data: &PaddedCudaSlice,
dtype: GgmlDType,
elem_count: usize,
dev: &CudaDevice,
Expand Down Expand Up @@ -104,21 +110,21 @@ fn dequantize_f32(
};

if is_k {
let params = (data, &dst);
let params = (&data.inner, &dst);
unsafe { func.launch(cfg, params) }.w()?;
} else {
let nb32 = match dtype {
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
_ => elem_count / 32,
};
let params = (data, &dst, nb32 as i32);
let params = (&data.inner, &dst, nb32 as i32);
unsafe { func.launch(cfg, params) }.w()?;
}
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}

fn dequantize_f16(
data: &CudaSlice<u8>,
data: &PaddedCudaSlice,
dtype: GgmlDType,
elem_count: usize,
dev: &CudaDevice,
Expand Down Expand Up @@ -161,21 +167,21 @@ fn dequantize_f16(
};

if is_k {
let params = (data, &dst);
let params = (&data.inner, &dst);
unsafe { func.launch(cfg, params) }.w()?;
} else {
let nb32 = match dtype {
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
_ => elem_count / 32,
};
let params = (data, &dst, nb32 as i32);
let params = (&data.inner, &dst, nb32 as i32);
unsafe { func.launch(cfg, params) }.w()?;
}
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}

fn dequantize_mul_mat_vec(
data: &CudaSlice<u8>,
data: &PaddedCudaSlice,
y: &CudaView<f32>,
dtype: GgmlDType,
ncols: usize,
Expand All @@ -184,7 +190,7 @@ fn dequantize_mul_mat_vec(
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;

let data_elems = data.len() / dtype.type_size() * dtype.block_size();
let data_elems = data.len / dtype.type_size() * dtype.block_size();
if data_elems < ncols * nrows {
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
}
Expand Down Expand Up @@ -213,13 +219,13 @@ fn dequantize_mul_mat_vec(
shared_mem_bytes: 0,
};

let params = (data, y, &dst, ncols as i32, nrows as i32);
let params = (&data.inner, y, &dst, ncols as i32, nrows as i32);
unsafe { func.launch(cfg, params) }.w()?;
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}

fn mul_mat_vec_via_q8_1(
data: &CudaSlice<u8>,
data: &PaddedCudaSlice,
y: &CudaView<f32>,
dtype: GgmlDType,
ncols: usize,
Expand All @@ -229,7 +235,7 @@ fn mul_mat_vec_via_q8_1(
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;

let data_elems = data.len() / dtype.type_size() * dtype.block_size();
let data_elems = data.len / dtype.type_size() * dtype.block_size();
if data_elems < ncols * nrows {
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
}
Expand Down Expand Up @@ -276,7 +282,7 @@ fn mul_mat_vec_via_q8_1(
};

let params = (
data,
&data.inner,
&y_q8_1,
&dst,
/* ncols_x */ ncols as i32,
Expand All @@ -290,7 +296,7 @@ fn mul_mat_vec_via_q8_1(

#[allow(clippy::too_many_arguments)]
fn mul_mat_via_q8_1(
data: &CudaSlice<u8>,
data: &PaddedCudaSlice,
y: &CudaView<f32>,
dtype: GgmlDType,
x_rows: usize,
Expand All @@ -301,7 +307,7 @@ fn mul_mat_via_q8_1(
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;

let data_elems = data.len() / dtype.type_size() * dtype.block_size();
let data_elems = data.len / dtype.type_size() * dtype.block_size();
if data_elems < x_rows * x_cols {
crate::bail!("unexpected lhs size {}, {x_rows} {x_cols}", data_elems)
}
Expand Down Expand Up @@ -345,7 +351,7 @@ fn mul_mat_via_q8_1(
};

let params = (
/* vx */ data,
/* vx */ &data.inner,
/* vy */ &y_q8_1,
/* dst */ &dst,
/* ncols_x */ x_cols as i32,
Expand All @@ -361,9 +367,14 @@ fn mul_mat_via_q8_1(
impl QCudaStorage {
pub fn zeros(device: &CudaDevice, el_count: usize, dtype: GgmlDType) -> Result<Self> {
let size_in_bytes = ceil_div(el_count, dtype.block_size()) * dtype.type_size();
let data = device.alloc_zeros::<u8>(size_in_bytes).w()?;
let padded_size_in_bytes =
ceil_div(el_count + MATRIX_ROW_PADDING, dtype.block_size()) * dtype.type_size();
let inner = device.alloc_zeros::<u8>(padded_size_in_bytes).w()?;
Ok(QCudaStorage {
data,
data: PaddedCudaSlice {
inner,
len: size_in_bytes,
},
device: device.clone(),
dtype,
})
Expand Down Expand Up @@ -403,7 +414,10 @@ impl QCudaStorage {
}
// Run the dequantization on cpu.

let buffer = self.device.dtoh_sync_copy(&self.data).w()?;
let buffer = self
.device
.dtoh_sync_copy(&self.data.inner.slice(..self.data.len))
.w()?;
let mut out = vec![0.0; elem_count];
let block_len = elem_count / self.dtype.block_size();
match self.dtype {
Expand Down Expand Up @@ -444,13 +458,21 @@ impl QCudaStorage {
let mut qcpu_storage = crate::Device::Cpu.qzeros(src_len, self.dtype)?;
qcpu_storage.quantize(&src)?;
let data = qcpu_storage.data()?;
let data = self.device.htod_sync_copy(data.as_ref()).w()?;
self.data = data;
let padded_len =
data.len() + MATRIX_ROW_PADDING * self.dtype.type_size() / self.dtype.block_size();
let mut inner = unsafe { self.device.alloc::<u8>(padded_len).w()? };
self.device
.htod_sync_copy_into(data.as_ref(), &mut inner.slice_mut(..data.len()))
.w()?;
self.data = PaddedCudaSlice {
inner,
len: data.len(),
};
Ok(())
}

pub fn storage_size_in_bytes(&self) -> usize {
self.data.len()
self.data.len
}

pub fn fwd(
Expand Down Expand Up @@ -573,11 +595,19 @@ pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
let data = unsafe {
std::slice::from_raw_parts(data.as_ptr() as *const u8, core::mem::size_of_val(data))
};
let data = device.htod_sync_copy(data).w()?;
let dtype = T::DTYPE;
let padded_len = data.len() + MATRIX_ROW_PADDING * dtype.type_size() / dtype.block_size();
let mut inner = unsafe { device.alloc::<u8>(padded_len).w()? };
device
.htod_sync_copy_into(data, &mut inner.slice_mut(..data.len()))
.w()?;
Ok(QStorage::Cuda(QCudaStorage {
data,
data: PaddedCudaSlice {
inner,
len: data.len(),
},
device: device.clone(),
dtype: T::DTYPE,
dtype,
}))
}

Expand Down
Loading