Skip to content

Commit

Permalink
blas: Fix to skip array with too short stride
Browse files Browse the repository at this point in the history
If we have a matrix of dimension say 5 x 5, BLAS requires the leading
stride to be >= 5. Smaller cases are possible for read-only array views
in ndarray(broadcasting and custom strides).

In this case we mark the array as not BLAS compatible
  • Loading branch information
bluss committed Aug 7, 2024
1 parent 6127d62 commit ffb192a
Showing 1 changed file with 34 additions and 5 deletions.
39 changes: 34 additions & 5 deletions src/linalg/impl_linalg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,7 @@ where

#[cfg(feature = "blas")]
#[derive(Copy, Clone)]
#[cfg_attr(test, derive(PartialEq, Eq, Debug))]
enum MemoryOrder
{
C,
Expand All @@ -886,24 +887,34 @@ fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: MemoryOrder) -> bool
let (m, n) = dim.into_pattern();
let s0 = stride[0] as isize;
let s1 = stride[1] as isize;
let (inner_stride, outer_dim) = match order {
MemoryOrder::C => (s1, n),
MemoryOrder::F => (s0, m),
let (inner_stride, outer_stride, inner_dim, outer_dim) = match order {
MemoryOrder::C => (s1, s0, m, n),
MemoryOrder::F => (s0, s1, n, m),
};

if !(inner_stride == 1 || outer_dim == 1) {
return false;
}

if s0 < 1 || s1 < 1 {
return false;
}

if (s0 > blas_index::MAX as isize || s0 < blas_index::MIN as isize)
|| (s1 > blas_index::MAX as isize || s1 < blas_index::MIN as isize)
{
return false;
}

// leading stride must >= the dimension (no broadcasting/aliasing)
if inner_dim > 1 && (outer_stride as usize) < outer_dim {
return false;
}

if m > blas_index::MAX as usize || n > blas_index::MAX as usize {
return false;
}

true
}

Expand Down Expand Up @@ -1042,8 +1053,26 @@ mod blas_tests
}

#[test]
fn test()
fn blas_too_short_stride()
{
//WIP test that stride is larger than other dimension
// leading stride must be longer than the other dimension
// Example, in a 5 x 5 matrix, the leading stride must be >= 5 for BLAS.

const N: usize = 5;
const MAXSTRIDE: usize = N + 2;
let mut data = [0; MAXSTRIDE * N];
let mut iter = 0..data.len();
data.fill_with(|| iter.next().unwrap());

for stride in 1..=MAXSTRIDE {
let m = ArrayView::from_shape((N, N).strides((stride, 1)), &data).unwrap();
eprintln!("{:?}", m);

if stride < N {
assert_eq!(get_blas_compatible_layout(&m), None);
} else {
assert_eq!(get_blas_compatible_layout(&m), Some(MemoryOrder::C));
}
}
}
}

0 comments on commit ffb192a

Please sign in to comment.