Skip to content

Commit

Permalink
blas: Update layout logic for gemm
Browse files Browse the repository at this point in the history
We compute A B -> C with matrices A, B, C

With the blas (cblas) interface it supports matrices that adhere to
certain criteria. They should be contiguous on one dimension (stride=1).

We glance a little at how numpy does this to try to catch all cases.

In short, we accept A, B contiguous on either axis (row or column
major). We use the case where C is (weakly) row major, but if it is
column major we transpose A, B, C => A^t, B^t, C^t so that we are back
to the C row major case.

(Weakly = contiguous with stride=1 on that inner dimension, but stride
for the other dimension can be larger; to differentiate from strictly
whole array contiguous.)

Minor change to the gemv function, no functional change, only updating
due to the refactoring of blas layout functions.

Fixes rust-ndarray#1278
  • Loading branch information
bluss committed Aug 8, 2024
1 parent 2ca801c commit 27e347c
Show file tree
Hide file tree
Showing 4 changed files with 278 additions and 139 deletions.
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ rawpointer = { version = "0.2" }
defmac = "0.2"
quickcheck = { workspace = true }
approx = { workspace = true, default-features = true }
itertools = { version = "0.13.0", default-features = false, features = ["use_std"] }
itertools = { workspace = true }

[features]
default = ["std"]
Expand All @@ -73,6 +73,7 @@ matrixmultiply-threading = ["matrixmultiply/threading"]

portable-atomic-critical-section = ["portable-atomic/critical-section"]


[target.'cfg(not(target_has_atomic = "ptr"))'.dependencies]
portable-atomic = { version = "1.6.0" }
portable-atomic-util = { version = "0.2.0", features = [ "alloc" ] }
Expand Down Expand Up @@ -103,6 +104,7 @@ approx = { version = "0.5", default-features = false }
quickcheck = { version = "1.0", default-features = false }
rand = { version = "0.8.0", features = ["small_rng"] }
rand_distr = { version = "0.4.0" }
itertools = { version = "0.13.0", default-features = false, features = ["use_std"] }

[profile.bench]
debug = true
Expand Down
2 changes: 2 additions & 0 deletions crates/blas-tests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ doctest = false

[dependencies]
ndarray = { workspace = true, features = ["approx", "blas"] }
ndarray-gen = { workspace = true }

blas-src = { version = "0.10", optional = true }
openblas-src = { version = "0.10", optional = true }
Expand All @@ -23,6 +24,7 @@ defmac = "0.2"
approx = { workspace = true }
num-traits = { workspace = true }
num-complex = { workspace = true }
itertools = { workspace = true }

[features]
# Just for making an example and to help testing, , multiple different possible
Expand Down
53 changes: 40 additions & 13 deletions crates/blas-tests/tests/oper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@ use ndarray::prelude::*;

use ndarray::linalg::general_mat_mul;
use ndarray::linalg::general_mat_vec_mul;
use ndarray::Order;
use ndarray::{Data, Ix, LinalgScalar};
use ndarray_gen::array_builder::ArrayBuilder;

use approx::assert_relative_eq;
use defmac::defmac;
use itertools::iproduct;
use num_complex::Complex32;
use num_complex::Complex64;

Expand Down Expand Up @@ -243,32 +246,56 @@ fn gen_mat_mul()
let sizes = vec![
(4, 4, 4),
(8, 8, 8),
(17, 15, 16),
(10, 10, 10),
(8, 8, 1),
(1, 10, 10),
(10, 1, 10),
(10, 10, 1),
(1, 10, 1),
(10, 1, 1),
(1, 1, 10),
(4, 17, 3),
(17, 3, 22),
(19, 18, 2),
(16, 17, 15),
(15, 16, 17),
(67, 63, 62),
];
// test different strides
for &s1 in &[1, 2, -1, -2] {
for &s2 in &[1, 2, -1, -2] {
for &(m, k, n) in &sizes {
let a = range_mat64(m, k);
let b = range_mat64(k, n);
let mut c = range_mat64(m, n);
let strides = &[1, 2, -1, -2];
let cf_order = [Order::C, Order::F];

// test different strides and memory orders
for (&s1, &s2) in iproduct!(strides, strides) {
for &(m, k, n) in &sizes {
for (ord1, ord2, ord3) in iproduct!(cf_order, cf_order, cf_order) {
println!("Case s1={}, s2={}, orders={:?}, {:?}, {:?}", s1, s2, ord1, ord2, ord3);
let a = ArrayBuilder::new((m, k)).memory_order(ord1).build();
let b = ArrayBuilder::new((k, n)).memory_order(ord2).build();
let mut c = ArrayBuilder::new((m, n)).memory_order(ord3).build();

let mut answer = c.clone();

{
let a = a.slice(s![..;s1, ..;s2]);
let b = b.slice(s![..;s2, ..;s2]);
let mut cv = c.slice_mut(s![..;s1, ..;s2]);
let av;
let bv;
let mut cv;

if s1 != 1 || s2 != 1 {
av = a.slice(s![..;s1, ..;s2]);
bv = b.slice(s![..;s2, ..;s2]);
cv = c.slice_mut(s![..;s1, ..;s2]);
} else {
// different stride cases for slicing versus not sliced (for axes of
// len=1); so test not sliced here.
av = a.view();
bv = b.view();
cv = c.view_mut();
}

let answer_part = alpha * reference_mat_mul(&a, &b) + beta * &cv;
let answer_part = alpha * reference_mat_mul(&av, &bv) + beta * &cv;
answer.slice_mut(s![..;s1, ..;s2]).assign(&answer_part);

general_mat_mul(alpha, &a, &b, beta, &mut cv);
general_mat_mul(alpha, &av, &bv, beta, &mut cv);
}
assert_relative_eq!(c, answer, epsilon = 1e-12, max_relative = 1e-7);
}
Expand Down
Loading

0 comments on commit 27e347c

Please sign in to comment.