Skip to content

Commit

Permalink
blas: Update layout logic for gemm
Browse files Browse the repository at this point in the history
We compute A B -> C with matrices A, B, C

With the blas (cblas) interface it supports matrices that adhere to
certain criteria. They should be contiguous on one dimension (stride=1).

We glance a little at how numpy does this to try to catch all cases.

In short, we accept A, B contiguous on either axis (row or column
major). We use the case where C is (weakly) row major, but if it is
column major we transpose A, B, C => A^t, B^t, C^t so that we are back
to the C row major case.

(Weakly = contiguous with stride=1 on that inner dimension, but stride
for the other dimension can be larger; to differentiate from strictly
whole array contiguous.)

Minor change to the gemv function, no functional change, only updating
due to the refactoring of blas layout functions.

Fixes #1278
  • Loading branch information
bluss committed Aug 7, 2024
1 parent d4fdccf commit 6127d62
Show file tree
Hide file tree
Showing 4 changed files with 233 additions and 132 deletions.
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ rawpointer = { version = "0.2" }
defmac = "0.2"
quickcheck = { workspace = true }
approx = { workspace = true, default-features = true }
itertools = { version = "0.13.0", default-features = false, features = ["use_std"] }
itertools = { workspace = true }

[features]
default = ["std"]
Expand All @@ -73,6 +73,7 @@ matrixmultiply-threading = ["matrixmultiply/threading"]

portable-atomic-critical-section = ["portable-atomic/critical-section"]


[target.'cfg(not(target_has_atomic = "ptr"))'.dependencies]
portable-atomic = { version = "1.6.0" }
portable-atomic-util = { version = "0.2.0", features = [ "alloc" ] }
Expand Down Expand Up @@ -103,6 +104,7 @@ approx = { version = "0.5", default-features = false }
quickcheck = { version = "1.0", default-features = false }
rand = { version = "0.8.0", features = ["small_rng"] }
rand_distr = { version = "0.4.0" }
itertools = { version = "0.13.0", default-features = false, features = ["use_std"] }

[profile.bench]
debug = true
Expand Down
2 changes: 2 additions & 0 deletions crates/blas-tests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ test = false

[dependencies]
ndarray = { workspace = true, features = ["approx", "blas"] }
ndarray-gen = { workspace = true }

blas-src = { version = "0.10", optional = true }
openblas-src = { version = "0.10", optional = true }
Expand All @@ -21,6 +22,7 @@ defmac = "0.2"
approx = { workspace = true }
num-traits = { workspace = true }
num-complex = { workspace = true }
itertools = { workspace = true }

[features]
# Just for making an example and to help testing, , multiple different possible
Expand Down
31 changes: 23 additions & 8 deletions crates/blas-tests/tests/oper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@ use ndarray::prelude::*;

use ndarray::linalg::general_mat_mul;
use ndarray::linalg::general_mat_vec_mul;
use ndarray::Order;
use ndarray::{Data, Ix, LinalgScalar};
use ndarray_gen::array_builder::ArrayBuilder;

use approx::assert_relative_eq;
use defmac::defmac;
use itertools::iproduct;
use ndarray_gen::array_builder::ElementGenerator;
use num_complex::Complex32;
use num_complex::Complex64;

Expand Down Expand Up @@ -243,21 +247,32 @@ fn gen_mat_mul()
let sizes = vec![
(4, 4, 4),
(8, 8, 8),
(17, 15, 16),
(10, 10, 10),
(3, 4, 3),
(3, 4, 5),
(5, 4, 3),
(4, 17, 3),
(17, 3, 22),
(19, 18, 2),
(16, 17, 15),
(15, 16, 17),
(67, 63, 62),
];
// test different strides
for &s1 in &[1, 2, -1, -2] {
for &s2 in &[1, 2, -1, -2] {
for &(m, k, n) in &sizes {
let a = range_mat64(m, k);
let b = range_mat64(k, n);
let mut c = range_mat64(m, n);
let strides = &[1, 2, -1, -2];
let cf_order = [Order::C, Order::F];

// test different strides and memory orders
for (&s1, &s2) in iproduct!(strides, strides) {
for &(m, k, n) in &sizes {
for (ord1, ord2, ord3) in iproduct!(cf_order, cf_order, cf_order) {
println!("Case s1={}, s2={}, orders={:?}, {:?}, {:?}", s1, s2, ord1, ord2, ord3);
let a = ArrayBuilder::new((m, k)).memory_order(ord1).build();
let b = ArrayBuilder::new((k, n)).memory_order(ord2).build();
let mut c = ArrayBuilder::new((m, n))
.memory_order(ord3)
.generator(ElementGenerator::Zero)
.build();

let mut answer = c.clone();

{
Expand Down
Loading

0 comments on commit 6127d62

Please sign in to comment.