diff --git a/Cargo.toml b/Cargo.toml index a19ae00a4..ac1960242 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,7 +47,7 @@ rawpointer = { version = "0.2" } defmac = "0.2" quickcheck = { workspace = true } approx = { workspace = true, default-features = true } -itertools = { version = "0.13.0", default-features = false, features = ["use_std"] } +itertools = { workspace = true } [features] default = ["std"] @@ -73,6 +73,7 @@ matrixmultiply-threading = ["matrixmultiply/threading"] portable-atomic-critical-section = ["portable-atomic/critical-section"] + [target.'cfg(not(target_has_atomic = "ptr"))'.dependencies] portable-atomic = { version = "1.6.0" } portable-atomic-util = { version = "0.2.0", features = [ "alloc" ] } @@ -103,6 +104,7 @@ approx = { version = "0.5", default-features = false } quickcheck = { version = "1.0", default-features = false } rand = { version = "0.8.0", features = ["small_rng"] } rand_distr = { version = "0.4.0" } +itertools = { version = "0.13.0", default-features = false, features = ["use_std"] } [profile.bench] debug = true diff --git a/crates/blas-tests/Cargo.toml b/crates/blas-tests/Cargo.toml index 299364172..fb7eba85e 100644 --- a/crates/blas-tests/Cargo.toml +++ b/crates/blas-tests/Cargo.toml @@ -10,6 +10,7 @@ test = false [dependencies] ndarray = { workspace = true, features = ["approx", "blas"] } +ndarray-gen = { workspace = true } blas-src = { version = "0.10", optional = true } openblas-src = { version = "0.10", optional = true } @@ -21,6 +22,7 @@ defmac = "0.2" approx = { workspace = true } num-traits = { workspace = true } num-complex = { workspace = true } +itertools = { workspace = true } [features] # Just for making an example and to help testing, , multiple different possible diff --git a/crates/blas-tests/tests/oper.rs b/crates/blas-tests/tests/oper.rs index 3ed81915e..dfb46c3ee 100644 --- a/crates/blas-tests/tests/oper.rs +++ b/crates/blas-tests/tests/oper.rs @@ -9,10 +9,14 @@ use ndarray::prelude::*; use ndarray::linalg::general_mat_mul; use ndarray::linalg::general_mat_vec_mul; +use ndarray::Order; use ndarray::{Data, Ix, LinalgScalar}; +use ndarray_gen::array_builder::ArrayBuilder; use approx::assert_relative_eq; use defmac::defmac; +use itertools::iproduct; +use ndarray_gen::array_builder::ElementGenerator; use num_complex::Complex32; use num_complex::Complex64; @@ -243,7 +247,10 @@ fn gen_mat_mul() let sizes = vec![ (4, 4, 4), (8, 8, 8), - (17, 15, 16), + (10, 10, 10), + (3, 4, 3), + (3, 4, 5), + (5, 4, 3), (4, 17, 3), (17, 3, 22), (19, 18, 2), @@ -251,13 +258,21 @@ fn gen_mat_mul() (15, 16, 17), (67, 63, 62), ]; - // test different strides - for &s1 in &[1, 2, -1, -2] { - for &s2 in &[1, 2, -1, -2] { - for &(m, k, n) in &sizes { - let a = range_mat64(m, k); - let b = range_mat64(k, n); - let mut c = range_mat64(m, n); + let strides = &[1, 2, -1, -2]; + let cf_order = [Order::C, Order::F]; + + // test different strides and memory orders + for (&s1, &s2) in iproduct!(strides, strides) { + for &(m, k, n) in &sizes { + for (ord1, ord2, ord3) in iproduct!(cf_order, cf_order, cf_order) { + println!("Case s1={}, s2={}, orders={:?}, {:?}, {:?}", s1, s2, ord1, ord2, ord3); + let a = ArrayBuilder::new((m, k)).memory_order(ord1).build(); + let b = ArrayBuilder::new((k, n)).memory_order(ord2).build(); + let mut c = ArrayBuilder::new((m, n)) + .memory_order(ord3) + .generator(ElementGenerator::Zero) + .build(); + let mut answer = c.clone(); { diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index f3bedae71..e1c46da6e 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -25,8 +25,6 @@ use num_complex::{Complex32 as c32, Complex64 as c64}; #[cfg(feature = "blas")] use libc::c_int; #[cfg(feature = "blas")] -use std::cmp; -#[cfg(feature = "blas")] use std::mem::swap; #[cfg(feature = "blas")] @@ -39,7 +37,7 @@ use cblas_sys::{CblasNoTrans, CblasRowMajor, CblasTrans, CBLAS_LAYOUT}; const DOT_BLAS_CUTOFF: usize = 32; /// side of matrix before we use blas #[cfg(feature = "blas")] -const GEMM_BLAS_CUTOFF: usize = 7; +const GEMM_BLAS_CUTOFF: usize = 3; //WIP #[cfg(feature = "blas")] #[allow(non_camel_case_types)] type blas_index = c_int; // blas index type @@ -388,8 +386,9 @@ fn mat_mul_impl( { // size cutoff for using BLAS let cut = GEMM_BLAS_CUTOFF; - let ((mut m, a), (_, mut n)) = (lhs.dim(), rhs.dim()); - if !(m > cut || n > cut || a > cut) + let ((mut m, k), (k2, mut n)) = (lhs.dim(), rhs.dim()); + debug_assert_eq!(k, k2); + if !(m > cut || n > cut || k > cut) || !(same_type::() || same_type::() || same_type::() @@ -397,31 +396,70 @@ fn mat_mul_impl( { return mat_mul_general(alpha, lhs, rhs, beta, c); } - { - // Use `c` for c-order and `f` for an f-order matrix - // We can handle c * c, f * f generally and - // c * f and f * c if the `f` matrix is square. - let mut lhs_ = lhs.view(); - let mut rhs_ = rhs.view(); - let mut c_ = c.view_mut(); - let lhs_s0 = lhs_.strides()[0]; - let rhs_s0 = rhs_.strides()[0]; - let both_f = lhs_s0 == 1 && rhs_s0 == 1; - let mut lhs_trans = CblasNoTrans; - let mut rhs_trans = CblasNoTrans; - if both_f { + + #[allow(clippy::never_loop)] // MSRV Rust 1.64 does not have break from block + 'blas_block: loop { + let mut a = lhs.view(); + let mut b = rhs.view(); + let mut c = c.view_mut(); + + let mut a_layout = get_blas_compatible_layout(&a); + let mut b_layout = get_blas_compatible_layout(&b); + let c_layout = get_blas_compatible_layout(&c); + let c_layout_is_c = matches!(c_layout, Some(MemoryOrder::C)); + let c_layout_is_f = matches!(c_layout, Some(MemoryOrder::F)); + + // Compute A B -> C + // we require for BLAS compatibility that: + // A, B are contiguous (stride=1) in their fastest dimension. + // C is c-contiguous in one dimension (stride=1 in Axis(1)) + // + // If C is f-contiguous, use transpose equivalency + // to translate to the C-contiguous case: + // A^t B^t = C^t => B A = C + + if a_layout.is_some() && b_layout.is_some() && c_layout_is_c { + // normal case + } else if a_layout.is_some() && b_layout.is_some() && c_layout_is_f { + // Transpose equivalency // A^t B^t = C^t => B A = C - let lhs_t = lhs_.reversed_axes(); - lhs_ = rhs_.reversed_axes(); - rhs_ = lhs_t; - c_ = c_.reversed_axes(); + // + // A^t becomes the new B + // B^t becomes the new A + let lhs_t = a.reversed_axes(); + a = b.reversed_axes(); + b = lhs_t; + c = c.reversed_axes(); + // Assign (n, k, m) -> (m, k, n) effectively swap(&mut m, &mut n); - } else if lhs_s0 == 1 && m == a { - lhs_ = lhs_.reversed_axes(); - lhs_trans = CblasTrans; - } else if rhs_s0 == 1 && a == n { - rhs_ = rhs_.reversed_axes(); - rhs_trans = CblasTrans; + + // Continue using the already computed memory layouts + let tmp = a_layout.map(MemoryOrder::opposite); + a_layout = b_layout.map(MemoryOrder::opposite); + b_layout = tmp; + } else { + break 'blas_block; + } + + let a_trans; + let b_trans; + let lda; // Stride of a + let ldb; // Stride of b + + if let Some(MemoryOrder::C) = a_layout { + lda = a.strides()[0]; + a_trans = CblasNoTrans; + } else { + lda = a.strides()[1]; + a_trans = CblasTrans; + } + + if let Some(MemoryOrder::C) = b_layout { + ldb = b.strides()[0]; + b_trans = CblasNoTrans; + } else { + ldb = b.strides()[1]; + b_trans = CblasTrans; } macro_rules! gemm_scalar_cast { @@ -441,44 +479,27 @@ fn mat_mul_impl( macro_rules! gemm { ($ty:tt, $gemm:ident) => { - if blas_row_major_2d::<$ty, _>(&lhs_) - && blas_row_major_2d::<$ty, _>(&rhs_) - && blas_row_major_2d::<$ty, _>(&c_) - { - let (m, k) = match lhs_trans { - CblasNoTrans => lhs_.dim(), - _ => { - let (rows, cols) = lhs_.dim(); - (cols, rows) - } - }; - let n = match rhs_trans { - CblasNoTrans => rhs_.raw_dim()[1], - _ => rhs_.raw_dim()[0], - }; - // adjust strides, these may [1, 1] for column matrices - let lhs_stride = cmp::max(lhs_.strides()[0] as blas_index, k as blas_index); - let rhs_stride = cmp::max(rhs_.strides()[0] as blas_index, n as blas_index); - let c_stride = cmp::max(c_.strides()[0] as blas_index, n as blas_index); + if same_type::() { + let ldc = c.strides()[0] as blas_index; // gemm is C ← αA^Op B^Op + βC // Where Op is notrans/trans/conjtrans unsafe { blas_sys::$gemm( CblasRowMajor, - lhs_trans, - rhs_trans, + a_trans, + b_trans, m as blas_index, // m, rows of Op(a) n as blas_index, // n, cols of Op(b) k as blas_index, // k, cols of Op(a) gemm_scalar_cast!($ty, alpha), // alpha - lhs_.ptr.as_ptr() as *const _, // a - lhs_stride, // lda - rhs_.ptr.as_ptr() as *const _, // b - rhs_stride, // ldb + a.ptr.as_ptr() as *const _, // a + lda as blas_index, // lda + b.ptr.as_ptr() as *const _, // b + ldb as blas_index, // ldb gemm_scalar_cast!($ty, beta), // beta - c_.ptr.as_ptr() as *mut _, // c - c_stride, // ldc + c.ptr.as_ptr() as *mut _, // c + ldc as blas_index, // ldc ); } return; @@ -490,6 +511,7 @@ fn mat_mul_impl( gemm!(c32, cblas_cgemm); gemm!(c64, cblas_zgemm); + break 'blas_block; } mat_mul_general(alpha, lhs, rhs, beta, c) } @@ -693,46 +715,51 @@ unsafe fn general_mat_vec_mul_impl( #[cfg(feature = "blas")] macro_rules! gemv { ($ty:ty, $gemv:ident) => { - if let Some(layout) = blas_layout::<$ty, _>(&a) { - if blas_compat_1d::<$ty, _>(&x) && blas_compat_1d::<$ty, _>(&y) { - // Determine stride between rows or columns. Note that the stride is - // adjusted to at least `k` or `m` to handle the case of a matrix with a - // trivial (length 1) dimension, since the stride for the trivial dimension - // may be arbitrary. - let a_trans = CblasNoTrans; - let a_stride = match layout { - CBLAS_LAYOUT::CblasRowMajor => { - a.strides()[0].max(k as isize) as blas_index - } - CBLAS_LAYOUT::CblasColMajor => { - a.strides()[1].max(m as isize) as blas_index - } - }; - - // Low addr in memory pointers required for x, y - let x_offset = offset_from_low_addr_ptr_to_logical_ptr(&x.dim, &x.strides); - let x_ptr = x.ptr.as_ptr().sub(x_offset); - let y_offset = offset_from_low_addr_ptr_to_logical_ptr(&y.dim, &y.strides); - let y_ptr = y.ptr.as_ptr().sub(y_offset); - - let x_stride = x.strides()[0] as blas_index; - let y_stride = y.strides()[0] as blas_index; - - blas_sys::$gemv( - layout, - a_trans, - m as blas_index, // m, rows of Op(a) - k as blas_index, // n, cols of Op(a) - cast_as(&alpha), // alpha - a.ptr.as_ptr() as *const _, // a - a_stride, // lda - x_ptr as *const _, // x - x_stride, - cast_as(&beta), // beta - y_ptr as *mut _, // y - y_stride, - ); - return; + if same_type::() { + if let Some(layout) = get_blas_compatible_layout(&a) { + if blas_compat_1d::<$ty, _>(&x) && blas_compat_1d::<$ty, _>(&y) { + // Determine stride between rows or columns. Note that the stride is + // adjusted to at least `k` or `m` to handle the case of a matrix with a + // trivial (length 1) dimension, since the stride for the trivial dimension + // may be arbitrary. + let a_trans = CblasNoTrans; + + let (a_stride, cblas_layout) = match layout { + MemoryOrder::C => { + (a.strides()[0].max(k as isize) as blas_index, + CBLAS_LAYOUT::CblasRowMajor) + } + MemoryOrder::F => { + (a.strides()[1].max(m as isize) as blas_index, + CBLAS_LAYOUT::CblasColMajor) + } + }; + + // Low addr in memory pointers required for x, y + let x_offset = offset_from_low_addr_ptr_to_logical_ptr(&x.dim, &x.strides); + let x_ptr = x.ptr.as_ptr().sub(x_offset); + let y_offset = offset_from_low_addr_ptr_to_logical_ptr(&y.dim, &y.strides); + let y_ptr = y.ptr.as_ptr().sub(y_offset); + + let x_stride = x.strides()[0] as blas_index; + let y_stride = y.strides()[0] as blas_index; + + blas_sys::$gemv( + cblas_layout, + a_trans, + m as blas_index, // m, rows of Op(a) + k as blas_index, // n, cols of Op(a) + cast_as(&alpha), // alpha + a.ptr.as_ptr() as *const _, // a + a_stride, // lda + x_ptr as *const _, // x + x_stride, + cast_as(&beta), // beta + y_ptr as *mut _, // y + y_stride, + ); + return; + } } } }; @@ -834,6 +861,7 @@ where } #[cfg(feature = "blas")] +#[derive(Copy, Clone)] enum MemoryOrder { C, @@ -841,29 +869,15 @@ enum MemoryOrder } #[cfg(feature = "blas")] -fn blas_row_major_2d(a: &ArrayBase) -> bool -where - S: Data, - A: 'static, - S::Elem: 'static, -{ - if !same_type::() { - return false; - } - is_blas_2d(&a.dim, &a.strides, MemoryOrder::C) -} - -#[cfg(feature = "blas")] -fn blas_column_major_2d(a: &ArrayBase) -> bool -where - S: Data, - A: 'static, - S::Elem: 'static, +impl MemoryOrder { - if !same_type::() { - return false; + fn opposite(self) -> Self + { + match self { + MemoryOrder::C => MemoryOrder::F, + MemoryOrder::F => MemoryOrder::C, + } } - is_blas_2d(&a.dim, &a.strides, MemoryOrder::F) } #[cfg(feature = "blas")] @@ -893,20 +907,46 @@ fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: MemoryOrder) -> bool true } +/// Get BLAS compatible layout if any (C or F, preferring the former) #[cfg(feature = "blas")] -fn blas_layout(a: &ArrayBase) -> Option +fn get_blas_compatible_layout(a: &ArrayBase) -> Option +where S: Data +{ + if is_blas_2d(&a.dim, &a.strides, MemoryOrder::C) { + Some(MemoryOrder::C) + } else if is_blas_2d(&a.dim, &a.strides, MemoryOrder::F) { + Some(MemoryOrder::F) + } else { + None + } +} + +#[cfg(test)] +#[cfg(feature = "blas")] +fn blas_row_major_2d(a: &ArrayBase) -> bool where S: Data, A: 'static, S::Elem: 'static, { - if blas_row_major_2d::(a) { - Some(CBLAS_LAYOUT::CblasRowMajor) - } else if blas_column_major_2d::(a) { - Some(CBLAS_LAYOUT::CblasColMajor) - } else { - None + if !same_type::() { + return false; + } + is_blas_2d(&a.dim, &a.strides, MemoryOrder::C) +} + +#[cfg(test)] +#[cfg(feature = "blas")] +fn blas_column_major_2d(a: &ArrayBase) -> bool +where + S: Data, + A: 'static, + S::Elem: 'static, +{ + if !same_type::() { + return false; } + is_blas_2d(&a.dim, &a.strides, MemoryOrder::F) } #[cfg(test)] @@ -964,4 +1004,46 @@ mod blas_tests assert!(!blas_row_major_2d::(&m)); assert!(blas_column_major_2d::(&m)); } + + #[test] + fn blas_row_major_2d_skip_rows_ok() + { + let m: Array2 = Array2::zeros((5, 5)); + let mv = m.slice(s![..;2, ..]); + assert!(blas_row_major_2d::(&mv)); + assert!(!blas_column_major_2d::(&mv)); + } + + #[test] + fn blas_row_major_2d_skip_columns_fail() + { + let m: Array2 = Array2::zeros((5, 5)); + let mv = m.slice(s![.., ..;2]); + assert!(!blas_row_major_2d::(&mv)); + assert!(!blas_column_major_2d::(&mv)); + } + + #[test] + fn blas_col_major_2d_skip_columns_ok() + { + let m: Array2 = Array2::zeros((5, 5).f()); + let mv = m.slice(s![.., ..;2]); + assert!(blas_column_major_2d::(&mv)); + assert!(!blas_row_major_2d::(&mv)); + } + + #[test] + fn blas_col_major_2d_skip_rows_fail() + { + let m: Array2 = Array2::zeros((5, 5).f()); + let mv = m.slice(s![..;2, ..]); + assert!(!blas_column_major_2d::(&mv)); + assert!(!blas_row_major_2d::(&mv)); + } + + #[test] + fn test() + { + //WIP test that stride is larger than other dimension + } }