From 4ccd157e0323c29980498c3f5d5815fe54a3c547 Mon Sep 17 00:00:00 2001 From: Ulrik Sverdrup Date: Wed, 7 Aug 2024 11:18:44 +0200 Subject: [PATCH] blas: Fix to skip array with too short stride If we have a matrix of dimension say 5 x 5, BLAS requires the leading stride to be >= 5. Smaller cases are possible for read-only array views in ndarray(broadcasting and custom strides). In this case we mark the array as not BLAS compatible --- src/linalg/impl_linalg.rs | 39 ++++++++++++++++++++++++++++++++++----- 1 file changed, 34 insertions(+), 5 deletions(-) diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index e7813455d..778fcaabd 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -863,6 +863,7 @@ where #[cfg(feature = "blas")] #[derive(Copy, Clone)] +#[cfg_attr(test, derive(PartialEq, Eq, Debug))] enum MemoryOrder { C, @@ -887,24 +888,34 @@ fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: MemoryOrder) -> bool let (m, n) = dim.into_pattern(); let s0 = stride[0] as isize; let s1 = stride[1] as isize; - let (inner_stride, outer_dim) = match order { - MemoryOrder::C => (s1, n), - MemoryOrder::F => (s0, m), + let (inner_stride, outer_stride, inner_dim, outer_dim) = match order { + MemoryOrder::C => (s1, s0, m, n), + MemoryOrder::F => (s0, s1, n, m), }; + if !(inner_stride == 1 || outer_dim == 1) { return false; } + if s0 < 1 || s1 < 1 { return false; } + if (s0 > blas_index::MAX as isize || s0 < blas_index::MIN as isize) || (s1 > blas_index::MAX as isize || s1 < blas_index::MIN as isize) { return false; } + + // leading stride must >= the dimension (no broadcasting/aliasing) + if inner_dim > 1 && (outer_stride as usize) < outer_dim { + return false; + } + if m > blas_index::MAX as usize || n > blas_index::MAX as usize { return false; } + true } @@ -1068,8 +1079,26 @@ mod blas_tests } #[test] - fn test() + fn blas_too_short_stride() { - //WIP test that stride is larger than other dimension + // leading stride must be longer than the other dimension + // Example, in a 5 x 5 matrix, the leading stride must be >= 5 for BLAS. + + const N: usize = 5; + const MAXSTRIDE: usize = N + 2; + let mut data = [0; MAXSTRIDE * N]; + let mut iter = 0..data.len(); + data.fill_with(|| iter.next().unwrap()); + + for stride in 1..=MAXSTRIDE { + let m = ArrayView::from_shape((N, N).strides((stride, 1)), &data).unwrap(); + eprintln!("{:?}", m); + + if stride < N { + assert_eq!(get_blas_compatible_layout(&m), None); + } else { + assert_eq!(get_blas_compatible_layout(&m), Some(MemoryOrder::C)); + } + } } }