Skip to content

Commit

Permalink
Dispatch to the 2pass kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Jan 16, 2025
1 parent 8cf2565 commit 8d26067
Showing 1 changed file with 74 additions and 21 deletions.
95 changes: 74 additions & 21 deletions candle-nn/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1074,27 +1074,80 @@ impl candle::CustomOp3 for Sdpa {

let command_buffer = q.device().command_buffer()?;
if supports_sdpa_vector {
command_buffer.set_label("vector_attention");
candle_metal_kernels::call_sdpa_vector(
q.device().device(),
&command_buffer,
q.device().kernels(),
q_l.start_offset(),
q_l.dims(),
q.buffer(),
k_l.start_offset(),
k_l.dims(),
k_l.stride(),
k.buffer(),
v_l.start_offset(),
v_l.stride(),
v.buffer(),
&output,
self.scale,
self.softcapping,
itype,
)
.map_err(candle::Error::wrap)?;
// Route to the 2 pass fused attention if the k seqlen is large.
// https://github.com/ml-explore/mlx/pull/1597
const TWO_PASS_K_THRESHOLD: usize = 1024;
if k_l.dim(2)? >= TWO_PASS_K_THRESHOLD {
let mut intermediate_shape = [
&out_dims[0..out_dims.len() - 2],
&[candle_metal_kernels::SDPA_2PASS_BLOCKS],
&[out_dims[out_dims.len() - 1]],
]
.concat();
let intermediate = device.new_buffer(
intermediate_shape.iter().product::<usize>(),
DType::F32,
"sdpa_2pass_intermediate",
)?;
let _ = intermediate_shape.pop().unwrap();
let sums = device.new_buffer(
intermediate_shape.iter().product::<usize>(),
DType::F32,
"sdpa_2pass_sums",
)?;
let maxs = device.new_buffer(
intermediate_shape.iter().product::<usize>(),
DType::F32,
"sdpa_2pass_maxs",
)?;

command_buffer.set_label("vector_attention");
candle_metal_kernels::call_sdpa_vector_2pass(
q.device().device(),
&command_buffer,
q.device().kernels(),
q_l.start_offset(),
q_l.dims(),
q.buffer(),
k_l.start_offset(),
k_l.dims(),
k_l.stride(),
k.buffer(),
v_l.start_offset(),
v_l.stride(),
v.buffer(),
&output,
&intermediate,
&sums,
&maxs,
self.scale,
self.softcapping,
itype,
)
.map_err(candle::Error::wrap)?;
} else {
command_buffer.set_label("vector_attention");
candle_metal_kernels::call_sdpa_vector(
q.device().device(),
&command_buffer,
q.device().kernels(),
q_l.start_offset(),
q_l.dims(),
q.buffer(),
k_l.start_offset(),
k_l.dims(),
k_l.stride(),
k.buffer(),
v_l.start_offset(),
v_l.stride(),
v.buffer(),
&output,
self.scale,
self.softcapping,
itype,
)
.map_err(candle::Error::wrap)?;
}
} else if supports_sdpa_full {
if q_l.dim(2)? != k_l.dim(2)? {
candle::bail!(
Expand Down

0 comments on commit 8d26067

Please sign in to comment.