diff --git a/Cargo.toml b/Cargo.toml
index a19ae00a4..ac1960242 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -47,7 +47,7 @@ rawpointer = { version = "0.2" }
defmac = "0.2"
quickcheck = { workspace = true }
approx = { workspace = true, default-features = true }
-itertools = { version = "0.13.0", default-features = false, features = ["use_std"] }
+itertools = { workspace = true }
[features]
default = ["std"]
@@ -73,6 +73,7 @@ matrixmultiply-threading = ["matrixmultiply/threading"]
portable-atomic-critical-section = ["portable-atomic/critical-section"]
+
[target.'cfg(not(target_has_atomic = "ptr"))'.dependencies]
portable-atomic = { version = "1.6.0" }
portable-atomic-util = { version = "0.2.0", features = [ "alloc" ] }
@@ -103,6 +104,7 @@ approx = { version = "0.5", default-features = false }
quickcheck = { version = "1.0", default-features = false }
rand = { version = "0.8.0", features = ["small_rng"] }
rand_distr = { version = "0.4.0" }
+itertools = { version = "0.13.0", default-features = false, features = ["use_std"] }
[profile.bench]
debug = true
diff --git a/crates/blas-tests/Cargo.toml b/crates/blas-tests/Cargo.toml
index 299364172..fb7eba85e 100644
--- a/crates/blas-tests/Cargo.toml
+++ b/crates/blas-tests/Cargo.toml
@@ -10,6 +10,7 @@ test = false
[dependencies]
ndarray = { workspace = true, features = ["approx", "blas"] }
+ndarray-gen = { workspace = true }
blas-src = { version = "0.10", optional = true }
openblas-src = { version = "0.10", optional = true }
@@ -21,6 +22,7 @@ defmac = "0.2"
approx = { workspace = true }
num-traits = { workspace = true }
num-complex = { workspace = true }
+itertools = { workspace = true }
[features]
# Just for making an example and to help testing, , multiple different possible
diff --git a/crates/blas-tests/tests/oper.rs b/crates/blas-tests/tests/oper.rs
index 3ed81915e..dfb46c3ee 100644
--- a/crates/blas-tests/tests/oper.rs
+++ b/crates/blas-tests/tests/oper.rs
@@ -9,10 +9,14 @@ use ndarray::prelude::*;
use ndarray::linalg::general_mat_mul;
use ndarray::linalg::general_mat_vec_mul;
+use ndarray::Order;
use ndarray::{Data, Ix, LinalgScalar};
+use ndarray_gen::array_builder::ArrayBuilder;
use approx::assert_relative_eq;
use defmac::defmac;
+use itertools::iproduct;
+use ndarray_gen::array_builder::ElementGenerator;
use num_complex::Complex32;
use num_complex::Complex64;
@@ -243,7 +247,10 @@ fn gen_mat_mul()
let sizes = vec![
(4, 4, 4),
(8, 8, 8),
- (17, 15, 16),
+ (10, 10, 10),
+ (3, 4, 3),
+ (3, 4, 5),
+ (5, 4, 3),
(4, 17, 3),
(17, 3, 22),
(19, 18, 2),
@@ -251,13 +258,21 @@ fn gen_mat_mul()
(15, 16, 17),
(67, 63, 62),
];
- // test different strides
- for &s1 in &[1, 2, -1, -2] {
- for &s2 in &[1, 2, -1, -2] {
- for &(m, k, n) in &sizes {
- let a = range_mat64(m, k);
- let b = range_mat64(k, n);
- let mut c = range_mat64(m, n);
+ let strides = &[1, 2, -1, -2];
+ let cf_order = [Order::C, Order::F];
+
+ // test different strides and memory orders
+ for (&s1, &s2) in iproduct!(strides, strides) {
+ for &(m, k, n) in &sizes {
+ for (ord1, ord2, ord3) in iproduct!(cf_order, cf_order, cf_order) {
+ println!("Case s1={}, s2={}, orders={:?}, {:?}, {:?}", s1, s2, ord1, ord2, ord3);
+ let a = ArrayBuilder::new((m, k)).memory_order(ord1).build();
+ let b = ArrayBuilder::new((k, n)).memory_order(ord2).build();
+ let mut c = ArrayBuilder::new((m, n))
+ .memory_order(ord3)
+ .generator(ElementGenerator::Zero)
+ .build();
+
let mut answer = c.clone();
{
diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs
index f3bedae71..e1c46da6e 100644
--- a/src/linalg/impl_linalg.rs
+++ b/src/linalg/impl_linalg.rs
@@ -25,8 +25,6 @@ use num_complex::{Complex32 as c32, Complex64 as c64};
#[cfg(feature = "blas")]
use libc::c_int;
#[cfg(feature = "blas")]
-use std::cmp;
-#[cfg(feature = "blas")]
use std::mem::swap;
#[cfg(feature = "blas")]
@@ -39,7 +37,7 @@ use cblas_sys::{CblasNoTrans, CblasRowMajor, CblasTrans, CBLAS_LAYOUT};
const DOT_BLAS_CUTOFF: usize = 32;
/// side of matrix before we use blas
#[cfg(feature = "blas")]
-const GEMM_BLAS_CUTOFF: usize = 7;
+const GEMM_BLAS_CUTOFF: usize = 3; //WIP
#[cfg(feature = "blas")]
#[allow(non_camel_case_types)]
type blas_index = c_int; // blas index type
@@ -388,8 +386,9 @@ fn mat_mul_impl(
{
// size cutoff for using BLAS
let cut = GEMM_BLAS_CUTOFF;
- let ((mut m, a), (_, mut n)) = (lhs.dim(), rhs.dim());
- if !(m > cut || n > cut || a > cut)
+ let ((mut m, k), (k2, mut n)) = (lhs.dim(), rhs.dim());
+ debug_assert_eq!(k, k2);
+ if !(m > cut || n > cut || k > cut)
|| !(same_type::()
|| same_type::()
|| same_type::()
@@ -397,31 +396,70 @@ fn mat_mul_impl(
{
return mat_mul_general(alpha, lhs, rhs, beta, c);
}
- {
- // Use `c` for c-order and `f` for an f-order matrix
- // We can handle c * c, f * f generally and
- // c * f and f * c if the `f` matrix is square.
- let mut lhs_ = lhs.view();
- let mut rhs_ = rhs.view();
- let mut c_ = c.view_mut();
- let lhs_s0 = lhs_.strides()[0];
- let rhs_s0 = rhs_.strides()[0];
- let both_f = lhs_s0 == 1 && rhs_s0 == 1;
- let mut lhs_trans = CblasNoTrans;
- let mut rhs_trans = CblasNoTrans;
- if both_f {
+
+ #[allow(clippy::never_loop)] // MSRV Rust 1.64 does not have break from block
+ 'blas_block: loop {
+ let mut a = lhs.view();
+ let mut b = rhs.view();
+ let mut c = c.view_mut();
+
+ let mut a_layout = get_blas_compatible_layout(&a);
+ let mut b_layout = get_blas_compatible_layout(&b);
+ let c_layout = get_blas_compatible_layout(&c);
+ let c_layout_is_c = matches!(c_layout, Some(MemoryOrder::C));
+ let c_layout_is_f = matches!(c_layout, Some(MemoryOrder::F));
+
+ // Compute A B -> C
+ // we require for BLAS compatibility that:
+ // A, B are contiguous (stride=1) in their fastest dimension.
+ // C is c-contiguous in one dimension (stride=1 in Axis(1))
+ //
+ // If C is f-contiguous, use transpose equivalency
+ // to translate to the C-contiguous case:
+ // A^t B^t = C^t => B A = C
+
+ if a_layout.is_some() && b_layout.is_some() && c_layout_is_c {
+ // normal case
+ } else if a_layout.is_some() && b_layout.is_some() && c_layout_is_f {
+ // Transpose equivalency
// A^t B^t = C^t => B A = C
- let lhs_t = lhs_.reversed_axes();
- lhs_ = rhs_.reversed_axes();
- rhs_ = lhs_t;
- c_ = c_.reversed_axes();
+ //
+ // A^t becomes the new B
+ // B^t becomes the new A
+ let lhs_t = a.reversed_axes();
+ a = b.reversed_axes();
+ b = lhs_t;
+ c = c.reversed_axes();
+ // Assign (n, k, m) -> (m, k, n) effectively
swap(&mut m, &mut n);
- } else if lhs_s0 == 1 && m == a {
- lhs_ = lhs_.reversed_axes();
- lhs_trans = CblasTrans;
- } else if rhs_s0 == 1 && a == n {
- rhs_ = rhs_.reversed_axes();
- rhs_trans = CblasTrans;
+
+ // Continue using the already computed memory layouts
+ let tmp = a_layout.map(MemoryOrder::opposite);
+ a_layout = b_layout.map(MemoryOrder::opposite);
+ b_layout = tmp;
+ } else {
+ break 'blas_block;
+ }
+
+ let a_trans;
+ let b_trans;
+ let lda; // Stride of a
+ let ldb; // Stride of b
+
+ if let Some(MemoryOrder::C) = a_layout {
+ lda = a.strides()[0];
+ a_trans = CblasNoTrans;
+ } else {
+ lda = a.strides()[1];
+ a_trans = CblasTrans;
+ }
+
+ if let Some(MemoryOrder::C) = b_layout {
+ ldb = b.strides()[0];
+ b_trans = CblasNoTrans;
+ } else {
+ ldb = b.strides()[1];
+ b_trans = CblasTrans;
}
macro_rules! gemm_scalar_cast {
@@ -441,44 +479,27 @@ fn mat_mul_impl(
macro_rules! gemm {
($ty:tt, $gemm:ident) => {
- if blas_row_major_2d::<$ty, _>(&lhs_)
- && blas_row_major_2d::<$ty, _>(&rhs_)
- && blas_row_major_2d::<$ty, _>(&c_)
- {
- let (m, k) = match lhs_trans {
- CblasNoTrans => lhs_.dim(),
- _ => {
- let (rows, cols) = lhs_.dim();
- (cols, rows)
- }
- };
- let n = match rhs_trans {
- CblasNoTrans => rhs_.raw_dim()[1],
- _ => rhs_.raw_dim()[0],
- };
- // adjust strides, these may [1, 1] for column matrices
- let lhs_stride = cmp::max(lhs_.strides()[0] as blas_index, k as blas_index);
- let rhs_stride = cmp::max(rhs_.strides()[0] as blas_index, n as blas_index);
- let c_stride = cmp::max(c_.strides()[0] as blas_index, n as blas_index);
+ if same_type::() {
+ let ldc = c.strides()[0] as blas_index;
// gemm is C ← αA^Op B^Op + βC
// Where Op is notrans/trans/conjtrans
unsafe {
blas_sys::$gemm(
CblasRowMajor,
- lhs_trans,
- rhs_trans,
+ a_trans,
+ b_trans,
m as blas_index, // m, rows of Op(a)
n as blas_index, // n, cols of Op(b)
k as blas_index, // k, cols of Op(a)
gemm_scalar_cast!($ty, alpha), // alpha
- lhs_.ptr.as_ptr() as *const _, // a
- lhs_stride, // lda
- rhs_.ptr.as_ptr() as *const _, // b
- rhs_stride, // ldb
+ a.ptr.as_ptr() as *const _, // a
+ lda as blas_index, // lda
+ b.ptr.as_ptr() as *const _, // b
+ ldb as blas_index, // ldb
gemm_scalar_cast!($ty, beta), // beta
- c_.ptr.as_ptr() as *mut _, // c
- c_stride, // ldc
+ c.ptr.as_ptr() as *mut _, // c
+ ldc as blas_index, // ldc
);
}
return;
@@ -490,6 +511,7 @@ fn mat_mul_impl(
gemm!(c32, cblas_cgemm);
gemm!(c64, cblas_zgemm);
+ break 'blas_block;
}
mat_mul_general(alpha, lhs, rhs, beta, c)
}
@@ -693,46 +715,51 @@ unsafe fn general_mat_vec_mul_impl(
#[cfg(feature = "blas")]
macro_rules! gemv {
($ty:ty, $gemv:ident) => {
- if let Some(layout) = blas_layout::<$ty, _>(&a) {
- if blas_compat_1d::<$ty, _>(&x) && blas_compat_1d::<$ty, _>(&y) {
- // Determine stride between rows or columns. Note that the stride is
- // adjusted to at least `k` or `m` to handle the case of a matrix with a
- // trivial (length 1) dimension, since the stride for the trivial dimension
- // may be arbitrary.
- let a_trans = CblasNoTrans;
- let a_stride = match layout {
- CBLAS_LAYOUT::CblasRowMajor => {
- a.strides()[0].max(k as isize) as blas_index
- }
- CBLAS_LAYOUT::CblasColMajor => {
- a.strides()[1].max(m as isize) as blas_index
- }
- };
-
- // Low addr in memory pointers required for x, y
- let x_offset = offset_from_low_addr_ptr_to_logical_ptr(&x.dim, &x.strides);
- let x_ptr = x.ptr.as_ptr().sub(x_offset);
- let y_offset = offset_from_low_addr_ptr_to_logical_ptr(&y.dim, &y.strides);
- let y_ptr = y.ptr.as_ptr().sub(y_offset);
-
- let x_stride = x.strides()[0] as blas_index;
- let y_stride = y.strides()[0] as blas_index;
-
- blas_sys::$gemv(
- layout,
- a_trans,
- m as blas_index, // m, rows of Op(a)
- k as blas_index, // n, cols of Op(a)
- cast_as(&alpha), // alpha
- a.ptr.as_ptr() as *const _, // a
- a_stride, // lda
- x_ptr as *const _, // x
- x_stride,
- cast_as(&beta), // beta
- y_ptr as *mut _, // y
- y_stride,
- );
- return;
+ if same_type::() {
+ if let Some(layout) = get_blas_compatible_layout(&a) {
+ if blas_compat_1d::<$ty, _>(&x) && blas_compat_1d::<$ty, _>(&y) {
+ // Determine stride between rows or columns. Note that the stride is
+ // adjusted to at least `k` or `m` to handle the case of a matrix with a
+ // trivial (length 1) dimension, since the stride for the trivial dimension
+ // may be arbitrary.
+ let a_trans = CblasNoTrans;
+
+ let (a_stride, cblas_layout) = match layout {
+ MemoryOrder::C => {
+ (a.strides()[0].max(k as isize) as blas_index,
+ CBLAS_LAYOUT::CblasRowMajor)
+ }
+ MemoryOrder::F => {
+ (a.strides()[1].max(m as isize) as blas_index,
+ CBLAS_LAYOUT::CblasColMajor)
+ }
+ };
+
+ // Low addr in memory pointers required for x, y
+ let x_offset = offset_from_low_addr_ptr_to_logical_ptr(&x.dim, &x.strides);
+ let x_ptr = x.ptr.as_ptr().sub(x_offset);
+ let y_offset = offset_from_low_addr_ptr_to_logical_ptr(&y.dim, &y.strides);
+ let y_ptr = y.ptr.as_ptr().sub(y_offset);
+
+ let x_stride = x.strides()[0] as blas_index;
+ let y_stride = y.strides()[0] as blas_index;
+
+ blas_sys::$gemv(
+ cblas_layout,
+ a_trans,
+ m as blas_index, // m, rows of Op(a)
+ k as blas_index, // n, cols of Op(a)
+ cast_as(&alpha), // alpha
+ a.ptr.as_ptr() as *const _, // a
+ a_stride, // lda
+ x_ptr as *const _, // x
+ x_stride,
+ cast_as(&beta), // beta
+ y_ptr as *mut _, // y
+ y_stride,
+ );
+ return;
+ }
}
}
};
@@ -834,6 +861,7 @@ where
}
#[cfg(feature = "blas")]
+#[derive(Copy, Clone)]
enum MemoryOrder
{
C,
@@ -841,29 +869,15 @@ enum MemoryOrder
}
#[cfg(feature = "blas")]
-fn blas_row_major_2d(a: &ArrayBase) -> bool
-where
- S: Data,
- A: 'static,
- S::Elem: 'static,
-{
- if !same_type::() {
- return false;
- }
- is_blas_2d(&a.dim, &a.strides, MemoryOrder::C)
-}
-
-#[cfg(feature = "blas")]
-fn blas_column_major_2d(a: &ArrayBase) -> bool
-where
- S: Data,
- A: 'static,
- S::Elem: 'static,
+impl MemoryOrder
{
- if !same_type::() {
- return false;
+ fn opposite(self) -> Self
+ {
+ match self {
+ MemoryOrder::C => MemoryOrder::F,
+ MemoryOrder::F => MemoryOrder::C,
+ }
}
- is_blas_2d(&a.dim, &a.strides, MemoryOrder::F)
}
#[cfg(feature = "blas")]
@@ -893,20 +907,46 @@ fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: MemoryOrder) -> bool
true
}
+/// Get BLAS compatible layout if any (C or F, preferring the former)
#[cfg(feature = "blas")]
-fn blas_layout(a: &ArrayBase) -> Option
+fn get_blas_compatible_layout(a: &ArrayBase) -> Option
+where S: Data
+{
+ if is_blas_2d(&a.dim, &a.strides, MemoryOrder::C) {
+ Some(MemoryOrder::C)
+ } else if is_blas_2d(&a.dim, &a.strides, MemoryOrder::F) {
+ Some(MemoryOrder::F)
+ } else {
+ None
+ }
+}
+
+#[cfg(test)]
+#[cfg(feature = "blas")]
+fn blas_row_major_2d(a: &ArrayBase) -> bool
where
S: Data,
A: 'static,
S::Elem: 'static,
{
- if blas_row_major_2d::(a) {
- Some(CBLAS_LAYOUT::CblasRowMajor)
- } else if blas_column_major_2d::(a) {
- Some(CBLAS_LAYOUT::CblasColMajor)
- } else {
- None
+ if !same_type::() {
+ return false;
+ }
+ is_blas_2d(&a.dim, &a.strides, MemoryOrder::C)
+}
+
+#[cfg(test)]
+#[cfg(feature = "blas")]
+fn blas_column_major_2d(a: &ArrayBase) -> bool
+where
+ S: Data,
+ A: 'static,
+ S::Elem: 'static,
+{
+ if !same_type::() {
+ return false;
}
+ is_blas_2d(&a.dim, &a.strides, MemoryOrder::F)
}
#[cfg(test)]
@@ -964,4 +1004,46 @@ mod blas_tests
assert!(!blas_row_major_2d::(&m));
assert!(blas_column_major_2d::(&m));
}
+
+ #[test]
+ fn blas_row_major_2d_skip_rows_ok()
+ {
+ let m: Array2 = Array2::zeros((5, 5));
+ let mv = m.slice(s![..;2, ..]);
+ assert!(blas_row_major_2d::(&mv));
+ assert!(!blas_column_major_2d::(&mv));
+ }
+
+ #[test]
+ fn blas_row_major_2d_skip_columns_fail()
+ {
+ let m: Array2 = Array2::zeros((5, 5));
+ let mv = m.slice(s![.., ..;2]);
+ assert!(!blas_row_major_2d::(&mv));
+ assert!(!blas_column_major_2d::(&mv));
+ }
+
+ #[test]
+ fn blas_col_major_2d_skip_columns_ok()
+ {
+ let m: Array2 = Array2::zeros((5, 5).f());
+ let mv = m.slice(s![.., ..;2]);
+ assert!(blas_column_major_2d::(&mv));
+ assert!(!blas_row_major_2d::(&mv));
+ }
+
+ #[test]
+ fn blas_col_major_2d_skip_rows_fail()
+ {
+ let m: Array2 = Array2::zeros((5, 5).f());
+ let mv = m.slice(s![..;2, ..]);
+ assert!(!blas_column_major_2d::(&mv));
+ assert!(!blas_row_major_2d::(&mv));
+ }
+
+ #[test]
+ fn test()
+ {
+ //WIP test that stride is larger than other dimension
+ }
}