Skip to content

Commit 1175edb

Browse files
emmatypingbluss
authored andcommitted
Add complex GEMM and simple test
1 parent 1c685ef commit 1175edb

File tree

3 files changed

+48
-0
lines changed

3 files changed

+48
-0
lines changed

src/linalg/impl_linalg.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ use cblas_sys as blas_sys;
2929
#[cfg(feature = "blas")]
3030
use cblas_sys::{CblasNoTrans, CblasRowMajor, CblasTrans, CBLAS_LAYOUT};
3131

32+
#[cfg(feature = "blas")]
33+
use num_complex::{Complex32 as c32, Complex64 as c64};
34+
3235
/// len of vector before we use blas
3336
#[cfg(feature = "blas")]
3437
const DOT_BLAS_CUTOFF: usize = 32;
@@ -374,6 +377,7 @@ fn mat_mul_impl<A>(
374377
) where
375378
A: LinalgScalar,
376379
{
380+
377381
// size cutoff for using BLAS
378382
let cut = GEMM_BLAS_CUTOFF;
379383
let ((mut m, a), (_, mut n)) = (lhs.dim(), rhs.dim());
@@ -455,6 +459,8 @@ fn mat_mul_impl<A>(
455459
}
456460
gemm!(f32, cblas_sgemm);
457461
gemm!(f64, cblas_dgemm);
462+
gemm!(c32, cblas_cgemm);
463+
gemm!(c64, cblas_zgemm);
458464
}
459465
mat_mul_general(alpha, lhs, rhs, beta, c)
460466
}

xtest-blas/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ test = false
1111
approx = "0.4"
1212
defmac = "0.2"
1313
num-traits = "0.2"
14+
num-complex = { version = "0.4", default-features = false }
1415

1516
[dependencies]
1617
ndarray = { path = "../", features = ["approx", "blas"] }

xtest-blas/tests/oper.rs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ extern crate defmac;
33
extern crate ndarray;
44
extern crate num_traits;
55
extern crate blas_src;
6+
extern crate num_complex;
67

78
use ndarray::prelude::*;
89

@@ -12,6 +13,8 @@ use ndarray::{Data, Ix, LinalgScalar};
1213

1314
use approx::assert_relative_eq;
1415
use defmac::defmac;
16+
use num_complex::Complex32;
17+
use num_complex::Complex64;
1518

1619
#[test]
1720
fn mat_vec_product_1d() {
@@ -52,6 +55,20 @@ fn range_mat64(m: Ix, n: Ix) -> Array2<f64> {
5255
.unwrap()
5356
}
5457

58+
fn range_mat_complex(m: Ix, n: Ix) -> Array2<Complex32> {
59+
Array::linspace(0., (m * n) as f32 - 1., m * n)
60+
.into_shape((m, n))
61+
.unwrap()
62+
.map_mut(|&mut f| Complex32::new(f, 0.))
63+
}
64+
65+
fn range_mat_complex64(m: Ix, n: Ix) -> Array2<Complex64> {
66+
Array::linspace(0., (m * n) as f64 - 1., m * n)
67+
.into_shape((m, n))
68+
.unwrap()
69+
.map_mut(|&mut f| Complex64::new(f, 0.))
70+
}
71+
5572
fn range1_mat64(m: Ix) -> Array1<f64> {
5673
Array::linspace(0., m as f64 - 1., m)
5774
}
@@ -250,6 +267,30 @@ fn gemm_64_1_f() {
250267
assert_relative_eq!(y, answer, epsilon = 1e-12, max_relative = 1e-7);
251268
}
252269

270+
#[test]
271+
fn gemm_c64_1_f() {
272+
let a = range_mat_complex64(64, 64).reversed_axes();
273+
let (m, n) = a.dim();
274+
// m x n times n x 1 == m x 1
275+
let x = range_mat_complex64(n, 1);
276+
let mut y = range_mat_complex64(m, 1);
277+
let answer = reference_mat_mul(&a, &x) + &y;
278+
general_mat_mul(Complex64::new(1.0, 0.), &a, &x, Complex64::new(1.0, 0.), &mut y);
279+
assert_relative_eq!(y.mapv(|i| i.norm_sqr()), answer.mapv(|i| i.norm_sqr()), epsilon = 1e-12, max_relative = 1e-7);
280+
}
281+
282+
#[test]
283+
fn gemm_c32_1_f() {
284+
let a = range_mat_complex(64, 64).reversed_axes();
285+
let (m, n) = a.dim();
286+
// m x n times n x 1 == m x 1
287+
let x = range_mat_complex(n, 1);
288+
let mut y = range_mat_complex(m, 1);
289+
let answer = reference_mat_mul(&a, &x) + &y;
290+
general_mat_mul(Complex32::new(1.0, 0.), &a, &x, Complex32::new(1.0, 0.), &mut y);
291+
assert_relative_eq!(y.mapv(|i| i.norm_sqr()), answer.mapv(|i| i.norm_sqr()), epsilon = 1e-12, max_relative = 1e-7);
292+
}
293+
253294
#[test]
254295
fn gen_mat_vec_mul() {
255296
let alpha = -2.3;

0 commit comments

Comments
 (0)