Skip to content

Commit

Permalink
Sync upstream MLX sdpa vector kernels with mask (#2718)
Browse files Browse the repository at this point in the history
* Sync upstream mlx sdpa vector kernels with mask

* Dispatch to the 2pass kernel

* Format
  • Loading branch information
EricLBuehler authored Jan 16, 2025
1 parent 6fd2f63 commit 17cbbe4
Show file tree
Hide file tree
Showing 3 changed files with 486 additions and 49 deletions.
188 changes: 187 additions & 1 deletion candle-metal-kernels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1906,7 +1906,12 @@ pub fn call_sdpa_vector(
alpha
};

let pipeline = kernels.load_pipeline(device, Source::Sdpa, name)?;
let constants = Some(ConstantValues::new(vec![(
20,
Value::Bool(/* sdpa_vector_has_mask */ false),
)]));

let pipeline = kernels.load_pipeline_with_constants(device, Source::Sdpa, name, constants)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
Expand Down Expand Up @@ -1948,6 +1953,187 @@ pub fn call_sdpa_vector(
Ok(())
}

pub const SDPA_2PASS_BLOCKS: usize = 32;

/// SDPA vector 2pass is supported when:
/// - q head dim == 64, 96, 128
/// - no mask
/// - q,k,v are contiguous
#[allow(clippy::too_many_arguments)]
pub fn call_sdpa_vector_2pass(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
q_offset: usize,
q_shape: &[usize],
q_buffer: &Buffer,
k_offset: usize,
k_shape: &[usize],
k_stride: &[usize],
k_buffer: &Buffer,
v_offset: usize,
v_stride: &[usize],
v_buffer: &Buffer,
output: &Buffer,
intermediate: &Buffer,
sums: &Buffer,
maxs: &Buffer,
alpha: f32,
softcapping: f32,
itype: SdpaDType,
) -> Result<(), MetalKernelError> {
let bk = q_shape.last().unwrap();

// First pass
{
let name_pass1 = match (bk, itype) {
(32, SdpaDType::F16) => "sdpa_vector_2pass_1_float16_t_32",
(64, SdpaDType::F16) => "sdpa_vector_2pass_1_float16_t_64",
(96, SdpaDType::F16) => "sdpa_vector_2pass_1_float16_t_96",
(128, SdpaDType::F16) => "sdpa_vector_2pass_1_float16_t_128",
(256, SdpaDType::F16) => "sdpa_vector_2pass_1_float16_t_256",
(32, SdpaDType::BF16) => "sdpa_vector_2pass_1_bfloat16_t_32",
(64, SdpaDType::BF16) => "sdpa_vector_2pass_1_bfloat16_t_64",
(96, SdpaDType::BF16) => "sdpa_vector_2pass_1_bfloat16_t_96",
(128, SdpaDType::BF16) => "sdpa_vector_2pass_1_bfloat16_t_128",
(256, SdpaDType::BF16) => "sdpa_vector_2pass_1_bfloat16_t_256",
(32, SdpaDType::F32) => "sdpa_vector_2pass_1_float_32",
(64, SdpaDType::F32) => "sdpa_vector_2pass_1_float_64",
(96, SdpaDType::F32) => "sdpa_vector_2pass_1_float_96",
(128, SdpaDType::F32) => "sdpa_vector_2pass_1_float_128",
(256, SdpaDType::F32) => "sdpa_vector_2pass_1_float_256",
(other, _) => {
return Err(MetalKernelError::SdpaHeadSizeMismatch {
variation: "vector_2pass_1",
got: *other,
expected: vec![32, 64, 96, 128, 256],
})
}
};

let gqa_factor = (q_shape[1] / k_shape[1]) as i32;
let n = k_shape[2] as i32;
let b = (q_shape[0] * q_shape[1]) as i32;
let kstride = k_stride[1];
let vstride = v_stride[1];

let alpha = if softcapping != 1. {
alpha / softcapping
} else {
alpha
};

let constants = Some(ConstantValues::new(vec![(
20,
Value::Bool(/* sdpa_vector_has_mask */ false),
)]));

let pipeline =
kernels.load_pipeline_with_constants(device, Source::Sdpa, &name_pass1, constants)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);

// q = (bs, qhead, seq, hidden)
// k/v = (bs, kv_head, kv_seq, hidden)

set_params!(
encoder,
(
(q_buffer, q_offset),
(k_buffer, k_offset),
(v_buffer, v_offset),
intermediate,
sums,
maxs,
gqa_factor,
n,
kstride,
vstride,
alpha,
softcapping
)
);

let grid_dims = MTLSize {
width: 1,
height: b as u64,
depth: SDPA_2PASS_BLOCKS as u64,
};
let group_dims = MTLSize {
width: 8 * 32,
height: 1,
depth: 1,
};
encoder.use_resource(q_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(k_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(v_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(intermediate, metal::MTLResourceUsage::Write);
encoder.use_resource(sums, metal::MTLResourceUsage::Write);
encoder.use_resource(maxs, metal::MTLResourceUsage::Write);

encoder.dispatch_thread_groups(grid_dims, group_dims);
}

// Final pass
{
let name_pass2 = match (bk, itype) {
(32, SdpaDType::F16) => "sdpa_vector_2pass_2_float16_t_32",
(64, SdpaDType::F16) => "sdpa_vector_2pass_2_float16_t_64",
(96, SdpaDType::F16) => "sdpa_vector_2pass_2_float16_t_96",
(128, SdpaDType::F16) => "sdpa_vector_2pass_2_float16_t_128",
(256, SdpaDType::F16) => "sdpa_vector_2pass_2_float16_t_256",
(32, SdpaDType::BF16) => "sdpa_vector_2pass_2_bfloat16_t_32",
(64, SdpaDType::BF16) => "sdpa_vector_2pass_2_bfloat16_t_64",
(96, SdpaDType::BF16) => "sdpa_vector_2pass_2_bfloat16_t_96",
(128, SdpaDType::BF16) => "sdpa_vector_2pass_2_bfloat16_t_128",
(256, SdpaDType::BF16) => "sdpa_vector_2pass_2_bfloat16_t_256",
(32, SdpaDType::F32) => "sdpa_vector_2pass_2_float_32",
(64, SdpaDType::F32) => "sdpa_vector_2pass_2_float_64",
(96, SdpaDType::F32) => "sdpa_vector_2pass_2_float_96",
(128, SdpaDType::F32) => "sdpa_vector_2pass_2_float_128",
(256, SdpaDType::F32) => "sdpa_vector_2pass_2_float_256",
(other, _) => {
return Err(MetalKernelError::SdpaHeadSizeMismatch {
variation: "vector_2pass_2",
got: *other,
expected: vec![32, 64, 96, 128, 256],
})
}
};

let b = (q_shape[0] * q_shape[1]) as i32;

let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name_pass2)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);

// q = (bs, qhead, seq, hidden)
// k/v = (bs, kv_head, kv_seq, hidden)

set_params!(encoder, (intermediate, sums, maxs, output));

let grid_dims = MTLSize {
width: 1,
height: b as u64,
depth: 1,
};
let group_dims = MTLSize {
width: 1024,
height: 1,
depth: 1,
};
encoder.use_resource(intermediate, metal::MTLResourceUsage::Write);
encoder.use_resource(sums, metal::MTLResourceUsage::Write);
encoder.use_resource(maxs, metal::MTLResourceUsage::Write);
encoder.use_resource(output, metal::MTLResourceUsage::Write);

encoder.dispatch_thread_groups(grid_dims, group_dims);
}
Ok(())
}

#[allow(clippy::too_many_arguments)]
pub fn call_im2col1d_strided(
device: &Device,
Expand Down
Loading

0 comments on commit 17cbbe4

Please sign in to comment.