Skip to content

Commit

Permalink
Add more QMMV cuda kernels. (huggingface#2077)
Browse files Browse the repository at this point in the history
* Add more QMMV cuda kernels.

* Enable the new kernels.

* Adapt the testing.
  • Loading branch information
LaurentMazare authored Apr 18, 2024
1 parent ce6d08d commit 8de0ce6
Show file tree
Hide file tree
Showing 3 changed files with 349 additions and 15 deletions.
18 changes: 10 additions & 8 deletions candle-core/src/quantized/cuda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,8 @@ fn mul_mat_vec_via_q8_1(
if y.len() != ncols * b_size {
crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len())
}
if b_size == 0 || b_size > 4 {
crate::bail!("only bsize between 1 and 4 are supported, got {b_size}")
if b_size == 0 || b_size > 8 {
crate::bail!("only bsize between 1 and 8 are supported, got {b_size}")
}
// Start by quantizing y
let ncols_padded = pad(ncols, MATRIX_ROW_PADDING);
Expand All @@ -204,14 +204,16 @@ fn mul_mat_vec_via_q8_1(
let kernel_name = format!("{kernel_name}{b_size}");
let func = dev.get_or_load_func(&kernel_name, candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(nrows * b_size).w()? };
let nblocks = if b_size == 1 {
nrows as u32
} else {
(nrows as u32 + 1) / 2
// https://github.com/ggerganov/llama.cpp/blob/facb8b56f8fd3bb10a693bf0943ae9d69d0828ef/ggml-cuda/mmvq.cu#L98
let (nblocks, nwarps) = match b_size {
1 => (nrows as u32, 4),
2..=4 => ((nrows as u32 + 1) / 2, 4),
5..=8 => ((nrows as u32 + 1) / 2, 2),
_ => crate::bail!("unexpected bsize {b_size}"),
};
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (nblocks, 1, 1),
block_dim: (WARP_SIZE as u32, 4, 1),
block_dim: (WARP_SIZE as u32, nwarps, 1),
shared_mem_bytes: 0,
};

Expand Down Expand Up @@ -398,7 +400,7 @@ impl QCudaStorage {
let max_bm = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
1
} else {
4
8
};
let use_vec_kernel = match layout.shape().dims() {
[b, m, _k] => b * m <= max_bm,
Expand Down
22 changes: 15 additions & 7 deletions candle-core/tests/quantized_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,17 +193,25 @@ fn qmm_batch(dev: &Device) -> Result<()> {
let mm3 = rhs.forward(&lhs3)?;
assert_eq!(mm3.shape().dims(), [6, 6]);
let diff3 = (mm3.i(2..4)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
if dev.is_cuda() {
assert!(diff3 < 1e-4)
} else {
assert_eq!(diff3, 0.0)
};
assert_eq!(diff3, 0.0);
let diff3 = (mm3.i(4..)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert_eq!(diff3, 0.0);
let lhs4 = Tensor::cat(&[&lhs3, &lhs3], 0)?;
let mm4 = rhs.forward(&lhs4)?;
assert_eq!(mm4.shape().dims(), [12, 6]);
let diff4 = (mm4.i(..6)? - &mm3)?.abs()?.sum_all()?.to_vec0::<f32>()?;
if dev.is_cuda() {
assert!(diff3 < 1e-4)
// We use a different kernel for sizes from 1 to 8 on cuda which explains
// the difference here.
assert!(0. < diff4 && diff4 < 1e-4)
} else {
assert_eq!(diff3, 0.0)
assert_eq!(diff4, 0.0)
};
let diff4 = (mm4.i(6..)? - &mm4.i(..6)?)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff4, 0.0);
Ok(())
}

Expand Down
Loading

0 comments on commit 8de0ce6

Please sign in to comment.