Skip to content

Commit

Permalink
Matmul (no batch, no strided, f32, f32 only) sort of done.
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Nov 1, 2023
1 parent 492d164 commit 1980094
Show file tree
Hide file tree
Showing 9 changed files with 205 additions and 96 deletions.
4 changes: 2 additions & 2 deletions candle-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@ mod device;
pub mod display;
mod dtype;
mod dummy_cuda_backend;
#[cfg(feature = "metal")]
pub mod metal_backend;
pub mod error;
mod indexer;
pub mod layout;
#[cfg(feature = "metal")]
pub mod metal_backend;
#[cfg(feature = "accelerate")]
mod metal_backend;
#[cfg(feature = "mkl")]
Expand Down
178 changes: 119 additions & 59 deletions candle-core/src/metal_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@ use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose2D};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape};
pub use candle_metal;
use metal;
use core::mem;
use half::{f16, bf16};
use half::{bf16, f16};
use metal;
use metal::mps::matrix::{MatrixMultiplication, Matrix, MatrixDescriptor};
use metal::mps::{Float32, MPSDataType};
use metal::MTLResourceOptions;
use crate::bail;

/// Metal related errors
#[derive(thiserror::Error, Debug)]
Expand All @@ -17,6 +21,7 @@ pub enum MetalError {
#[derive(Clone)]
pub struct MetalDevice {
device: metal::Device,
command_queue: metal::CommandQueue,
}

impl std::fmt::Debug for MetalDevice {
Expand Down Expand Up @@ -47,8 +52,7 @@ impl MetalDevice {
pub struct MetalStorage {
buffer: metal::Buffer,
device: MetalDevice,
dtype: DType

dtype: DType,
}

impl BackendStorage for MetalStorage {
Expand Down Expand Up @@ -192,12 +196,77 @@ impl BackendStorage for MetalStorage {
rhs_l: &Layout,
) -> Result<Self> {
let elem_count = b * m * n;
let dev = &self.device;
match (self.dtype, rhs.dtype){
match (self.dtype, rhs.dtype) {
(DType::F32, DType::F32) => {
todo!("MATMUL {b} {m} {n} {k}");
if b != 1 {
bail!("Didn't implemented strided matmul yet");
}
if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() {
bail!("Didn't implemented non contiguous matmul yet");
}
let out_buffer = self.device.new_buffer(
(elem_count * mem::size_of::<f32>()) as u64,
MTLResourceOptions::empty(),
);
let m : u64 = m.try_into().expect("usize should fit u64");
let n: u64 = n.try_into().expect("usize should fit u64");
let k: u64 = k.try_into().expect("usize should fit u64");
// Create descriptors
let left_descriptor =
MatrixDescriptor::init_single(m, k, k * Float32::SIZE, Float32::TYPE_ID);
let right_descriptor =
MatrixDescriptor::init_single(k, n, n * Float32::SIZE, Float32::TYPE_ID);
let result_descriptor =
MatrixDescriptor::init_single(m, n, n * Float32::SIZE, Float32::TYPE_ID);

// Create matrix objects
let left_matrix =
Matrix::init_with_buffer_descriptor(&self.buffer, &left_descriptor)
.expect("Failed to create left matrix");
let right_matrix =
Matrix::init_with_buffer_descriptor(&rhs.buffer, &right_descriptor)
.expect("Failed to create left matrix");

let result_matrix =
Matrix::init_with_buffer_descriptor(&out_buffer, &result_descriptor)
.expect("Failed to create left matrix");

let transpose_left = false;
let transpose_right = false;
let alpha = 1.0;
let beta = 0.0;


// Create kernel
let matrix_multiplication = MatrixMultiplication::init(
&self.device,
transpose_left,
transpose_right,
m,
n,
k,
alpha,
beta,
)
.expect("Failed to create matrix multiplication kernel");

let buffer = self.device.command_queue.new_command_buffer();
// Encode kernel to command buffer
matrix_multiplication.encode_to_command_buffer(
&buffer,
&left_matrix,
&right_matrix,
&result_matrix,
);
buffer.commit();
Ok(Self{
buffer: out_buffer,
device: self.device.clone(),
dtype: self.dtype(),
})

}
_ => todo!("Unimplemented matmul for this pair")
_ => todo!("Unimplemented matmul for this pair"),
}
}

Expand All @@ -211,7 +280,8 @@ impl BackendDevice for MetalDevice {

fn new(ordinal: usize) -> Result<Self> {
let device = metal::Device::all().swap_remove(ordinal);
Ok(Self{device })
let command_queue = device.new_command_queue();
Ok(Self { device, command_queue })
}

fn set_seed(&self, _seed: u64) -> Result<()> {
Expand All @@ -237,57 +307,47 @@ impl BackendDevice for MetalDevice {
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
let option = metal::MTLResourceOptions::CPUCacheModeDefaultCache;
let buffer = match storage {
CpuStorage::U8(storage) => {
self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<u8>()) as u64,
option
)
}
CpuStorage::U32(storage) => {
self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<u32>()) as u64,
option
)
}
CpuStorage::I64(storage) => {
self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<i64>()) as u64,
option
)
}
CpuStorage::BF16(storage) => {
self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<bf16>()) as u64,
option
)
}
CpuStorage::F16(storage) => {
self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<f16>()) as u64,
option
)
}
CpuStorage::F32(storage) => {
self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<f32>()) as u64,
option
)
}
CpuStorage::F64(storage) => {
self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<f64>()) as u64,
option
)
}
CpuStorage::U8(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<u8>()) as u64,
option,
),
CpuStorage::U32(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<u32>()) as u64,
option,
),
CpuStorage::I64(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<i64>()) as u64,
option,
),
CpuStorage::BF16(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<bf16>()) as u64,
option,
),
CpuStorage::F16(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<f16>()) as u64,
option,
),
CpuStorage::F32(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<f32>()) as u64,
option,
),
CpuStorage::F64(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<f64>()) as u64,
option,
),
};
Ok(Self::Storage{buffer, device: self.clone(), dtype: storage.dtype()})
Ok(Self::Storage {
buffer,
device: self.clone(),
dtype: storage.dtype(),
})
}

fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
Expand Down
8 changes: 6 additions & 2 deletions candle-core/src/op.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#![allow(clippy::redundant_closure_call)]
use crate::{CpuStorage, CudaStorage, MetalStorage, Layout, Result, Shape, Tensor};
use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor};
use half::{bf16, f16};
use num_traits::float::Float;

Expand Down Expand Up @@ -176,7 +176,11 @@ pub trait CustomOp1 {

/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(&self, _storage: &MetalStorage, _layout: &Layout) -> Result<(MetalStorage, Shape)> {
fn metal_fwd(
&self,
_storage: &MetalStorage,
_layout: &Layout,
) -> Result<(MetalStorage, Shape)> {
Err(crate::Error::Metal(
format!("no cuda implementation for {}", self.name()).into(),
))
Expand Down
47 changes: 35 additions & 12 deletions candle-core/src/quantized/ggml_file.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Support for the GGML file format.
use super::{k_quants, GgmlDType};
use crate::{Result, Device};
use crate::{Device, Result};
use byteorder::{LittleEndian, ReadBytesExt};
use std::collections::HashMap;

Expand Down Expand Up @@ -148,16 +148,36 @@ pub fn qtensor_from_ggml(
match ggml_dtype {
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims, device),
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims, device),
GgmlDType::Q4_0 => from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims, device),
GgmlDType::Q4_1 => from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims, device),
GgmlDType::Q5_0 => from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims, device),
GgmlDType::Q5_1 => from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims, device),
GgmlDType::Q8_0 => from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims, device),
GgmlDType::Q2K => from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims, device),
GgmlDType::Q3K => from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims, device),
GgmlDType::Q4K => from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims, device),
GgmlDType::Q5K => from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims, device),
GgmlDType::Q6K => from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims, device),
GgmlDType::Q4_0 => {
from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q4_1 => {
from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q5_0 => {
from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q5_1 => {
from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q8_0 => {
from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q2K => {
from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q3K => {
from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q4K => {
from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q5K => {
from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q6K => {
from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims, device)
}
_ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"),
}
}
Expand Down Expand Up @@ -204,7 +224,10 @@ pub struct Content {
}

impl Content {
pub fn read<R: std::io::Seek + std::io::Read>(reader: &mut R, device: &Device) -> Result<Content> {
pub fn read<R: std::io::Seek + std::io::Read>(
reader: &mut R,
device: &Device,
) -> Result<Content> {
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
let last_position = reader.seek(std::io::SeekFrom::End(0))?;
reader.seek(std::io::SeekFrom::Start(0))?;
Expand Down
9 changes: 7 additions & 2 deletions candle-core/src/quantized/gguf_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md
use super::{GgmlDType, QTensor};
use crate::{Result, Device};
use crate::{Device, Result};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use std::collections::HashMap;

Expand Down Expand Up @@ -70,7 +70,12 @@ impl TensorInfo {
let mut raw_data = vec![0u8; size_in_bytes];
reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?;
reader.read_exact(&mut raw_data)?;
super::ggml_file::qtensor_from_ggml(self.ggml_dtype, &raw_data, self.shape.dims().to_vec(), device)
super::ggml_file::qtensor_from_ggml(
self.ggml_dtype,
&raw_data,
self.shape.dims().to_vec(),
device,
)
}
}

Expand Down
4 changes: 2 additions & 2 deletions candle-core/src/quantized/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ impl QTensor {
Ok(Self {
data: Box::new(data),
shape,
device: device.clone()
device: device.clone(),
})
}

Expand All @@ -201,7 +201,7 @@ impl QTensor {
Ok(Self {
data: Box::new(data),
shape: shape.clone(),
device: device.clone()
device: device.clone(),
})
}

Expand Down
6 changes: 4 additions & 2 deletions candle-core/src/storage.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::backend::BackendStorage;
use crate::op::{self, CmpOp, CustomOp1, CustomOp2, CustomOp3, ReduceOp};
use crate::{CpuStorage, CudaStorage, MetalStorage, DType, Device, Error, Layout, Result, Shape};
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};

// We do not want to implement Clone on Storage as cloning may fail because of
// out of memory. Instead try_clone should be used.
Expand Down Expand Up @@ -659,7 +659,9 @@ impl Storage {
match (self, dst) {
(Self::Cpu(src), Self::Cpu(dst)) => src.copy_strided_src(dst, dst_offset, src_l),
(Self::Cuda(src), Self::Cuda(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?),
(Self::Metal(src), Self::Metal(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?),
(Self::Metal(src), Self::Metal(dst)) => {
Ok(src.copy_strided_src(dst, dst_offset, src_l)?)
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
Expand Down
Loading

0 comments on commit 1980094

Please sign in to comment.